فهرست منبع

This is the correct way of sampling 1 token at a time

oobabooga 3 سال پیش
والد
کامیت
022960a087
1فایلهای تغییر یافته به همراه5 افزوده شده و 4 حذف شده
  1. 5 4
      server.py

+ 5 - 4
server.py

@@ -139,11 +139,11 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
             preset = infile.read()
             preset = infile.read()
         loaded_preset = inference_settings
         loaded_preset = inference_settings
 
 
+    input_ids = encode(question, 1)
+    preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=1')
+    cuda = ".cuda()" if args.cpu else ""
     for i in range(tokens):
     for i in range(tokens):
-        input_ids = encode(question, 1)
-        preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=1')
 
 
-        cuda = ".cuda()" if args.cpu else ""
         if eos_token is None:
         if eos_token is None:
             output = eval(f"model.generate(input_ids, {preset}){cuda}")
             output = eval(f"model.generate(input_ids, {preset}){cuda}")
         else:
         else:
@@ -152,7 +152,6 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
 
 
         reply = tokenizer.decode(output[0], skip_special_tokens=True)
         reply = tokenizer.decode(output[0], skip_special_tokens=True)
         reply = reply.replace(r'<|endoftext|>', '')
         reply = reply.replace(r'<|endoftext|>', '')
-        question = reply
         if model_name.lower().startswith('galactica'):
         if model_name.lower().startswith('galactica'):
             reply = fix_galactica(reply)
             reply = fix_galactica(reply)
             yield reply, reply, generate_basic_html(reply)
             yield reply, reply, generate_basic_html(reply)
@@ -162,6 +161,8 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
         else:
         else:
             yield reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply)
             yield reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply)
 
 
+        input_ids = output
+
 # Choosing the default model
 # Choosing the default model
 if args.model is not None:
 if args.model is not None:
     model_name = args.model
     model_name = args.model