oobabooga 3 лет назад
Родитель
Сommit
eeb63b1b8a
1 измененных файлов с 7 добавлено и 3 удалено
  1. 7 3
      server.py

+ 7 - 3
server.py

@@ -60,6 +60,11 @@ def fix_gpt4chan(s):
         s = re.sub("--- [0-9]*\n\n\n---", "---", s)
     return s
 
+def fix_galactica(s):
+    s = s.replace(r'\[', r'$')
+    s = s.replace(r'\]', r'$')
+    return s
+
 def generate_reply(question, temperature, max_length, inference_settings, selected_model):
     global model, tokenizer, model_name, loaded_preset, preset
 
@@ -81,12 +86,11 @@ def generate_reply(question, temperature, max_length, inference_settings, select
     output = eval(f"model.generate(input_ids, {preset}).cuda()")
     reply = tokenizer.decode(output[0], skip_special_tokens=True)
 
-    if model_name.startswith('gpt4chan'):
-        reply = fix_gpt4chan(reply)
-
     if model_name.lower().startswith('galactica'):
+        reply = fix_galactica(reply)
         return reply, reply, 'Only applicable for gpt4chan.'
     elif model_name.lower().startswith('gpt4chan'):
+        reply = fix_gpt4chan(reply)
         return reply, 'Only applicable for galactica models.', generate_html(reply)
     else:
         return reply, 'Only applicable for galactica models.', 'Only applicable for gpt4chan.'