oobabooga 3 лет назад
Родитель
Сommit
fd220f827f
1 измененных файлов с 11 добавлено и 4 удалено
  1. 11 4
      server.py

+ 11 - 4
server.py

@@ -9,6 +9,7 @@ import gradio as gr
 import transformers
 from html_generator import *
 from transformers import AutoTokenizer, AutoModelForCausalLM
+import warnings
 
 
 parser = argparse.ArgumentParser()
@@ -20,12 +21,15 @@ parser.add_argument('--auto-devices', action='store_true', help='Automatically s
 parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
 parser.add_argument('--no-listen', action='store_true', help='Make the webui unreachable from your local network.')
 args = parser.parse_args()
+
 loaded_preset = None
 available_models = sorted(set(map(lambda x : str(x.name).replace('.pt', ''), list(Path('models/').glob('*'))+list(Path('torch-dumps/').glob('*')))))
 available_models = [item for item in available_models if not item.endswith('.txt')]
 available_models = sorted(available_models, key=str.lower)
 available_presets = sorted(set(map(lambda x : str(x.name).split('.')[0], list(Path('presets').glob('*.txt')))))
 
+transformers.logging.set_verbosity_error()
+
 def load_model(model_name):
     print(f"Loading {model_name}...")
     t0 = time.time()
@@ -188,10 +192,15 @@ if args.notebook:
 elif args.chat:
     history = []
 
-    def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check):
+    # This gets the new line characters right.
+    def chat_response_cleaner(text):
         text = text.replace('\n', '\n\n')
         text = re.sub(r"\n{3,}", "\n\n", text)
         text = text.strip()
+        return text
+
+    def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check):
+        text = chat_response_cleaner(text)
 
         question = context+'\n\n'
         for i in range(len(history)):
@@ -209,9 +218,7 @@ elif args.chat:
             idx = reply.find(f"\n{name1}:")
             if idx != -1:
                 reply = reply[:idx]
-            reply = reply.replace('\n', '\n\n')
-            reply = re.sub(r"\n{3,}", "\n\n", reply)
-            reply = reply.strip()
+            reply = chat_response_cleaner(response)
 
         history.append((text, reply))
         return history