extensions.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. import traceback
  2. import gradio as gr
  3. import extensions
  4. import modules.shared as shared
  5. state = {}
  6. available_extensions = []
  7. setup_called = set()
  8. def load_extensions():
  9. global state, setup_called
  10. for i, name in enumerate(shared.args.extensions):
  11. if name in available_extensions:
  12. print(f'Loading the extension "{name}"... ', end='')
  13. try:
  14. exec(f"import extensions.{name}.script")
  15. extension = eval(f"extensions.{name}.script")
  16. if extension not in setup_called and hasattr(extension, "setup"):
  17. setup_called.add(extension)
  18. extension.setup()
  19. state[name] = [True, i]
  20. print('Ok.')
  21. except:
  22. print('Fail.')
  23. traceback.print_exc()
  24. # This iterator returns the extensions in the order specified in the command-line
  25. def iterator():
  26. for name in sorted(state, key=lambda x: state[x][1]):
  27. if state[name][0]:
  28. yield eval(f"extensions.{name}.script"), name
  29. # Extension functions that map string -> string
  30. def apply_extensions(text, typ):
  31. for extension, _ in iterator():
  32. if typ == "input" and hasattr(extension, "input_modifier"):
  33. text = extension.input_modifier(text)
  34. elif typ == "output" and hasattr(extension, "output_modifier"):
  35. text = extension.output_modifier(text)
  36. elif typ == "bot_prefix" and hasattr(extension, "bot_prefix_modifier"):
  37. text = extension.bot_prefix_modifier(text)
  38. return text
  39. def create_extensions_block():
  40. global setup_called
  41. # Updating the default values
  42. for extension, name in iterator():
  43. if hasattr(extension, 'params'):
  44. for param in extension.params:
  45. _id = f"{name}-{param}"
  46. if _id in shared.settings:
  47. extension.params[param] = shared.settings[_id]
  48. should_display_ui = False
  49. for extension, name in iterator():
  50. if hasattr(extension, "ui"):
  51. should_display_ui = True
  52. # Creating the extension ui elements
  53. if should_display_ui:
  54. with gr.Column(elem_id="extensions"):
  55. for extension, name in iterator():
  56. gr.Markdown(f"\n### {name}")
  57. if hasattr(extension, "ui"):
  58. extension.ui()