Selaa lähdekoodia

Fix galactica equations

oobabooga 3 vuotta sitten
vanhempi
commit
eeb63b1b8a
1 muutettua tiedostoa jossa 7 lisäystä ja 3 poistoa
  1. 7 3
      server.py

+ 7 - 3
server.py

@@ -60,6 +60,11 @@ def fix_gpt4chan(s):
         s = re.sub("--- [0-9]*\n\n\n---", "---", s)
         s = re.sub("--- [0-9]*\n\n\n---", "---", s)
     return s
     return s
 
 
+def fix_galactica(s):
+    s = s.replace(r'\[', r'$')
+    s = s.replace(r'\]', r'$')
+    return s
+
 def generate_reply(question, temperature, max_length, inference_settings, selected_model):
 def generate_reply(question, temperature, max_length, inference_settings, selected_model):
     global model, tokenizer, model_name, loaded_preset, preset
     global model, tokenizer, model_name, loaded_preset, preset
 
 
@@ -81,12 +86,11 @@ def generate_reply(question, temperature, max_length, inference_settings, select
     output = eval(f"model.generate(input_ids, {preset}).cuda()")
     output = eval(f"model.generate(input_ids, {preset}).cuda()")
     reply = tokenizer.decode(output[0], skip_special_tokens=True)
     reply = tokenizer.decode(output[0], skip_special_tokens=True)
 
 
-    if model_name.startswith('gpt4chan'):
-        reply = fix_gpt4chan(reply)
-
     if model_name.lower().startswith('galactica'):
     if model_name.lower().startswith('galactica'):
+        reply = fix_galactica(reply)
         return reply, reply, 'Only applicable for gpt4chan.'
         return reply, reply, 'Only applicable for gpt4chan.'
     elif model_name.lower().startswith('gpt4chan'):
     elif model_name.lower().startswith('gpt4chan'):
+        reply = fix_gpt4chan(reply)
         return reply, 'Only applicable for galactica models.', generate_html(reply)
         return reply, 'Only applicable for galactica models.', generate_html(reply)
     else:
     else:
         return reply, 'Only applicable for galactica models.', 'Only applicable for gpt4chan.'
         return reply, 'Only applicable for galactica models.', 'Only applicable for gpt4chan.'