extensions.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. import gradio as gr
  2. import extensions
  3. import modules.shared as shared
  4. state = {}
  5. available_extensions = []
  6. def load_extensions():
  7. global state
  8. for i, name in enumerate(shared.args.extensions):
  9. if name in available_extensions:
  10. print(f'Loading the extension "{name}"... ', end='')
  11. try:
  12. exec(f"import extensions.{name}.script")
  13. state[name] = [True, i]
  14. print('Ok.')
  15. except:
  16. print('Fail.')
  17. # This iterator returns the extensions in the order specified in the command-line
  18. def iterator():
  19. for name in sorted(state, key=lambda x : state[x][1]):
  20. if state[name][0] == True:
  21. yield eval(f"extensions.{name}.script"), name
  22. # Extension functions that map string -> string
  23. def apply_extensions(text, typ):
  24. for extension, _ in iterator():
  25. if typ == "input" and hasattr(extension, "input_modifier"):
  26. text = extension.input_modifier(text)
  27. elif typ == "output" and hasattr(extension, "output_modifier"):
  28. text = extension.output_modifier(text)
  29. elif typ == "bot_prefix" and hasattr(extension, "bot_prefix_modifier"):
  30. text = extension.bot_prefix_modifier(text)
  31. return text
  32. def create_extensions_block():
  33. # Updating the default values
  34. for extension, name in iterator():
  35. if hasattr(extension, 'params'):
  36. for param in extension.params:
  37. _id = f"{name}-{param}"
  38. if _id in shared.settings:
  39. extension.params[param] = shared.settings[_id]
  40. # Creating the extension ui elements
  41. if len(state) > 0:
  42. with gr.Box(elem_id="extensions"):
  43. gr.Markdown("Extensions")
  44. for extension, name in iterator():
  45. if hasattr(extension, "ui"):
  46. extension.ui()