Просмотр исходного кода

Better way of finding the generated reply in the output string

oobabooga 3 лет назад
Родитель
Сommit
849e4c7f90
1 измененных файлов с 17 добавлено и 13 удалено
  1. 17 13
      server.py

+ 17 - 13
server.py

@@ -136,14 +136,17 @@ def decode(output_ids):
     return reply
     return reply
 
 
 def formatted_outputs(reply, model_name):
 def formatted_outputs(reply, model_name):
-    if model_name.lower().startswith('galactica'):
-        reply = fix_galactica(reply)
-        return reply, reply, generate_basic_html(reply)
-    elif model_name.lower().startswith('gpt4chan'):
-        reply = fix_gpt4chan(reply)
-        return reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply)
+    if not (args.chat or args.cai_chat):
+        if model_name.lower().startswith('galactica'):
+            reply = fix_galactica(reply)
+            return reply, reply, generate_basic_html(reply)
+        elif model_name.lower().startswith('gpt4chan'):
+            reply = fix_gpt4chan(reply)
+            return reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply)
+        else:
+            return reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply)
     else:
     else:
-        return reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply)
+        return reply
 
 
 def generate_reply(question, tokens, inference_settings, selected_model, eos_token=None):
 def generate_reply(question, tokens, inference_settings, selected_model, eos_token=None):
     global model, tokenizer, model_name, loaded_preset, preset
     global model, tokenizer, model_name, loaded_preset, preset
@@ -245,16 +248,17 @@ if args.chat or args.cai_chat:
         question = generate_chat_prompt(text, tokens, name1, name2, context)
         question = generate_chat_prompt(text, tokens, name1, name2, context)
         history.append(['', ''])
         history.append(['', ''])
         eos_token = '\n' if check else None
         eos_token = '\n' if check else None
-        for i in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token):
-            reply = i[0]
+        for reply in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token):
             next_character_found = False
             next_character_found = False
 
 
+            previous_idx = [m.start() for m in re.finditer(f"\n{name2}:", question)]
+            idx = [m.start() for m in re.finditer(f"(^|\n){name2}:", reply)]
+            idx = idx[len(previous_idx)-1]
+            reply = reply[idx + len(f"\n{name2}:"):]
+
             if check:
             if check:
-                idx = reply.rfind(question[-1024:])
-                reply = reply[idx+min(1024, len(question)):].split('\n')[0].strip()
+                reply = reply.split('\n')[0].strip()
             else:
             else:
-                idx = reply.rfind(question[-1024:])
-                reply = reply[idx+min(1024, len(question)):]
                 idx = reply.find(f"\n{name1}:")
                 idx = reply.find(f"\n{name1}:")
                 if idx != -1:
                 if idx != -1:
                     reply = reply[:idx]
                     reply = reply[:idx]