extensions.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. import gradio as gr
  2. import extensions
  3. import modules.shared as shared
  4. state = {}
  5. available_extensions = []
  6. setup_called = False
  7. def load_extensions():
  8. global state
  9. for i, name in enumerate(shared.args.extensions):
  10. if name in available_extensions:
  11. print(f'Loading the extension "{name}"... ', end='')
  12. try:
  13. exec(f"import extensions.{name}.script")
  14. state[name] = [True, i]
  15. print('Ok.')
  16. except:
  17. print('Fail.')
  18. # This iterator returns the extensions in the order specified in the command-line
  19. def iterator():
  20. for name in sorted(state, key=lambda x : state[x][1]):
  21. if state[name][0] == True:
  22. yield eval(f"extensions.{name}.script"), name
  23. # Extension functions that map string -> string
  24. def apply_extensions(text, typ):
  25. for extension, _ in iterator():
  26. if typ == "input" and hasattr(extension, "input_modifier"):
  27. text = extension.input_modifier(text)
  28. elif typ == "output" and hasattr(extension, "output_modifier"):
  29. text = extension.output_modifier(text)
  30. elif typ == "bot_prefix" and hasattr(extension, "bot_prefix_modifier"):
  31. text = extension.bot_prefix_modifier(text)
  32. return text
  33. def create_extensions_block():
  34. # Updating the default values
  35. for extension, name in iterator():
  36. if hasattr(extension, 'params'):
  37. for param in extension.params:
  38. _id = f"{name}-{param}"
  39. if _id in shared.settings:
  40. extension.params[param] = shared.settings[_id]
  41. # Running setup function
  42. if not setup_called:
  43. for extension, name in iterator():
  44. if hasattr(extension, "setup"):
  45. extension.setup()
  46. setup_called = True
  47. # Creating the extension ui elements
  48. if len(state) > 0:
  49. with gr.Box(elem_id="extensions"):
  50. gr.Markdown("Extensions")
  51. for extension, name in iterator():
  52. if hasattr(extension, "ui"):
  53. extension.ui()