oobabooga 2 anni fa
parent
commit
ce7feb3641
4 ha cambiato i file con 21 aggiunte e 22 eliminazioni
  1. 7 4
      modules/chat.py
  2. 0 2
      modules/shared.py
  3. 2 1
      modules/text_generation.py
  4. 12 15
      server.py

+ 7 - 4
modules/chat.py

@@ -6,10 +6,13 @@ from pathlib import Path
 
 import modules.shared as shared
 from modules.extensions import apply_extensions
-from modules.html_generator import *
-from modules.prompt import encode
-from modules.prompt import generate_reply
-from modules.prompt import get_max_prompt_length
+from modules.html_generator import generate_chat_html
+from modules.text_generation import encode
+from modules.text_generation import generate_reply
+from modules.text_generation import get_max_prompt_length
+
+if shared.args.picture and (shared.args.cai_chat or shared.args.chat):
+    import modules.bot_picture as bot_picture
 
 history = {'internal': [], 'visible': []}
 character = None

+ 0 - 2
modules/shared.py

@@ -1,7 +1,5 @@
 import argparse
 
-global tokenizer
-
 model = None
 tokenizer = None
 model_name = ""

+ 2 - 1
modules/prompt.py → modules/text_generation.py

@@ -4,7 +4,8 @@ import modules.shared as shared
 import torch
 import transformers
 from modules.extensions import apply_extensions
-from modules.html_generator import *
+from modules.html_generator import generate_4chan_html
+from modules.html_generator import generate_basic_html
 from modules.stopping_criteria import _SentinelTokenStoppingCriteria
 from tqdm import tqdm
 

+ 12 - 15
server.py

@@ -20,12 +20,12 @@ from transformers import AutoTokenizer
 import modules.chat as chat
 import modules.extensions as extensions_module
 import modules.shared as shared
+import modules.ui as ui
 from modules.extensions import extension_state
 from modules.extensions import load_extensions
 from modules.extensions import update_extensions_parameters
-from modules.html_generator import *
-from modules.prompt import generate_reply
-from modules.ui import *
+from modules.html_generator import generate_chat_html
+from modules.text_generation import generate_reply
 
 transformers.logging.set_verbosity_error()
 
@@ -74,9 +74,6 @@ if shared.args.deepspeed:
     ds_config = generate_ds_config(shared.args.bf16, 1 * world_size, shared.args.nvme_offload_dir)
     dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration
 
-if shared.args.picture and (shared.args.cai_chat or shared.args.chat):
-    import modules.bot_picture as bot_picture
-
 def load_model(model_name):
     print(f"Loading {model_name}...")
     t0 = time.time()
@@ -288,11 +285,11 @@ def create_settings_menus():
         with gr.Column():
             with gr.Row():
                 model_menu = gr.Dropdown(choices=available_models, value=shared.model_name, label='Model')
-                create_refresh_button(model_menu, lambda : None, lambda : {"choices": get_available_models()}, "refresh-button")
+                ui.create_refresh_button(model_menu, lambda : None, lambda : {"choices": get_available_models()}, "refresh-button")
         with gr.Column():
             with gr.Row():
                 preset_menu = gr.Dropdown(choices=available_presets, value=settings[f'preset{suffix}'] if not shared.args.flexgen else 'Naive', label='Generation parameters preset')
-                create_refresh_button(preset_menu, lambda : None, lambda : {"choices": get_available_presets()}, "refresh-button")
+                ui.create_refresh_button(preset_menu, lambda : None, lambda : {"choices": get_available_presets()}, "refresh-button")
 
     with gr.Accordion("Custom generation parameters", open=False, elem_id="accordion"):
         with gr.Row():
@@ -320,7 +317,7 @@ def create_settings_menus():
     with gr.Accordion("Soft prompt", open=False, elem_id="accordion"):
         with gr.Row():
             softprompts_menu = gr.Dropdown(choices=available_softprompts, value="None", label='Soft prompt')
-            create_refresh_button(softprompts_menu, lambda : None, lambda : {"choices": get_available_softprompts()}, "refresh-button")
+            ui.create_refresh_button(softprompts_menu, lambda : None, lambda : {"choices": get_available_softprompts()}, "refresh-button")
 
         gr.Markdown('Upload a soft prompt (.zip format):')
         with gr.Row():
@@ -336,8 +333,9 @@ def create_settings_menus():
 available_models = get_available_models()
 available_presets = get_available_presets()
 available_characters = get_available_characters()
-extensions_module.available_extensions = get_available_extensions()
 available_softprompts = get_available_softprompts()
+
+extensions_module.available_extensions = get_available_extensions()
 if shared.args.extensions is not None:
     load_extensions()
 
@@ -359,7 +357,6 @@ else:
         print()
     shared.model_name = available_models[i]
 shared.model, shared.tokenizer = load_model(shared.model_name)
-loaded_preset = None
 
 # UI settings
 if shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')):
@@ -379,7 +376,7 @@ if shared.args.chat or shared.args.cai_chat:
     if Path(f'logs/persistent.json').exists():
         chat.load_history(open(Path(f'logs/persistent.json'), 'rb').read(), settings[f'name1{suffix}'], settings[f'name2{suffix}'])
 
-    with gr.Blocks(css=css+chat_css, analytics_enabled=False) as interface:
+    with gr.Blocks(css=ui.css+ui.chat_css, analytics_enabled=False) as interface:
         if shared.args.cai_chat:
             display = gr.HTML(value=generate_chat_html(chat.history['visible'], settings[f'name1{suffix}'], settings[f'name2{suffix}'], chat.character))
         else:
@@ -406,7 +403,7 @@ if shared.args.chat or shared.args.cai_chat:
             context = gr.Textbox(value=settings[f'context{suffix}'], lines=2, label='Context')
             with gr.Row():
                 character_menu = gr.Dropdown(choices=available_characters, value="None", label='Character')
-                create_refresh_button(character_menu, lambda : None, lambda : {"choices": get_available_characters()}, "refresh-button")
+                ui.create_refresh_button(character_menu, lambda : None, lambda : {"choices": get_available_characters()}, "refresh-button")
 
             with gr.Row():
                 check = gr.Checkbox(value=settings[f'stop_at_newline{suffix}'], label='Stop generating at new line character?')
@@ -489,7 +486,7 @@ if shared.args.chat or shared.args.cai_chat:
             upload_img_me.upload(lambda : chat.history['visible'], [], [display])
 
 elif shared.args.notebook:
-    with gr.Blocks(css=css, analytics_enabled=False) as interface:
+    with gr.Blocks(css=ui.css, analytics_enabled=False) as interface:
         gr.Markdown(description)
         with gr.Tab('Raw'):
             textbox = gr.Textbox(value=default_text, lines=23)
@@ -513,7 +510,7 @@ elif shared.args.notebook:
         buttons["Stop"].click(None, None, None, cancels=gen_events)
 
 else:
-    with gr.Blocks(css=css, analytics_enabled=False) as interface:
+    with gr.Blocks(css=ui.css, analytics_enabled=False) as interface:
         gr.Markdown(description)
         with gr.Row():
             with gr.Column():