oobabooga 2 سال پیش
والد
کامیت
98dcfb8e12
3فایلهای تغییر یافته به همراه15 افزوده شده و 12 حذف شده
  1. 3 3
      modules/chat.py
  2. 2 0
      modules/text_generation.py
  3. 10 9
      server.py

+ 3 - 3
modules/chat.py

@@ -168,7 +168,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False):
                     shared.history['visible'].append(['', ''])
 
             if _continue:
-                sep = list(map(lambda x : ' ' if x[-1] != ' ' else '', last_reply))
+                sep = list(map(lambda x: ' ' if x[-1] != ' ' else '', last_reply))
                 shared.history['internal'][-1] = [text, f'{last_reply[0]}{sep[0]}{reply}']
                 shared.history['visible'][-1] = [visible_text, f'{last_reply[1]}{sep[1]}{visible_reply}']
             else:
@@ -278,7 +278,7 @@ def clear_chat_log(name1, name2, greeting, mode):
     if greeting != '':
         shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
         shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
-    
+
     # Save cleared logs
     save_history(mode)
 
@@ -446,7 +446,7 @@ def load_character(character, name1, name2, mode):
             if greeting != "":
                 shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
                 shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
-            
+
             # Create .json log files since they don't already exist
             save_history(mode)
 

+ 2 - 0
modules/text_generation.py

@@ -69,6 +69,7 @@ def generate_softprompt_input_tensors(input_ids):
     # filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
     return inputs_embeds, filler_input_ids
 
+
 # Removes empty replies from gpt4chan outputs
 def fix_gpt4chan(s):
     for i in range(10):
@@ -77,6 +78,7 @@ def fix_gpt4chan(s):
         s = re.sub("--- [0-9]*\n\n\n---", "---", s)
     return s
 
+
 # Fix the LaTeX equations in galactica
 def fix_galactica(s):
     s = s.replace(r'\[', r'$')

+ 10 - 9
server.py

@@ -184,22 +184,22 @@ def download_model_wrapper(repo_id):
         branch = "main"
         check = False
 
-        yield("Cleaning up the model/branch names")
+        yield ("Cleaning up the model/branch names")
         model, branch = downloader.sanitize_model_and_branch_names(model, branch)
 
-        yield("Getting the download links from Hugging Face")
+        yield ("Getting the download links from Hugging Face")
         links, sha256, is_lora = downloader.get_download_links_from_huggingface(model, branch, text_only=False)
 
-        yield("Getting the output folder")
+        yield ("Getting the output folder")
         output_folder = downloader.get_output_folder(model, branch, is_lora)
 
         if check:
-            yield("Checking previously downloaded files")
+            yield ("Checking previously downloaded files")
             downloader.check_model_files(model, branch, links, sha256, output_folder)
         else:
-            yield(f"Downloading files to {output_folder}")
+            yield (f"Downloading files to {output_folder}")
             downloader.download_model_files(model, branch, links, sha256, output_folder, threads=1)
-            yield("Done!")
+            yield ("Done!")
     except:
         yield traceback.format_exc()
 
@@ -377,11 +377,12 @@ def create_interface():
         extensions_module.load_extensions()
 
     with gr.Blocks(css=ui.css if not shared.is_chat() else ui.css + ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']:
-        shared.input_elements = list_interface_input_elements(chat=True)
-        shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
-
         if shared.is_chat():
+
+            shared.input_elements = list_interface_input_elements(chat=True)
+            shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
             shared.gradio['Chat input'] = gr.State()
+
             with gr.Tab("Text generation", elem_id="main"):
                 shared.gradio['display'] = gr.HTML(value=chat_html_wrapper(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], 'cai-chat'))
                 shared.gradio['textbox'] = gr.Textbox(label='Input')