فهرست منبع

Stop generating at \n in chat mode

Makes it a lot more efficient.
oobabooga 3 سال پیش
والد
کامیت
f2a548c098
1فایلهای تغییر یافته به همراه9 افزوده شده و 5 حذف شده
  1. 9 5
      server.py

+ 9 - 5
server.py

@@ -69,7 +69,7 @@ def fix_galactica(s):
     s = s.replace(r'$$', r'$')
     return s
 
-def generate_reply(question, temperature, max_length, inference_settings, selected_model):
+def generate_reply(question, temperature, max_length, inference_settings, selected_model, eos_token=None):
     global model, tokenizer, model_name, loaded_preset, preset
 
     if selected_model != model_name:
@@ -86,7 +86,11 @@ def generate_reply(question, temperature, max_length, inference_settings, select
     torch.cuda.empty_cache()
     input_ids = tokenizer.encode(str(question), return_tensors='pt').cuda()
 
-    output = eval(f"model.generate(input_ids, {preset}).cuda()")
+    if eos_token is None:
+        output = eval(f"model.generate(input_ids, {preset}).cuda()")
+    else:
+        n = tokenizer.encode(eos_token, return_tensors='pt')[0][1]
+        output = eval(f"model.generate(input_ids, eos_token_id={n}, {preset}).cuda()")
     reply = tokenizer.decode(output[0], skip_special_tokens=True)
 
     if model_name.lower().startswith('galactica'):
@@ -159,7 +163,7 @@ elif args.chat:
         question += f"{name1}: {text.strip()}\n"
         question += f"{name2}:"
 
-        reply = generate_reply(question, temperature, max_length, inference_settings, selected_model)[0]
+        reply = generate_reply(question, temperature, max_length, inference_settings, selected_model, eos_token='\n')[0]
         reply = reply[len(question):].split('\n')[0].strip()
         history.append((text, reply))
         return history
@@ -175,7 +179,7 @@ elif args.chat:
             with gr.Column():
                 with gr.Row(equal_height=True):
                     with gr.Column():
-                        length_slider = gr.Slider(minimum=1, maximum=2000, step=1, label='max_length', value=100)
+                        length_slider = gr.Slider(minimum=1, maximum=2000, step=1, label='max_length', value=200)
                         preset_menu = gr.Dropdown(choices=available_presets, value="NovelAI-Sphinx Moth", label='Preset')
                     with gr.Column():
                         temp_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Temperature', value=0.7)
@@ -203,7 +207,7 @@ else:
             with gr.Column():
                 textbox = gr.Textbox(value=default_text, lines=15, label='Input')
                 temp_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Temperature', value=0.7)
-                length_slider = gr.Slider(minimum=1, maximum=2000, step=1, label='max_length', value=100)
+                length_slider = gr.Slider(minimum=1, maximum=2000, step=1, label='max_length', value=200)
                 preset_menu = gr.Dropdown(choices=available_presets, value="NovelAI-Sphinx Moth", label='Preset')
                 model_menu = gr.Dropdown(choices=available_models, value=model_name, label='Model')
                 btn = gr.Button("Generate")