oobabooga 3 лет назад
Родитель
Сommit
8f27d33034
1 измененных файлов с 3 добавлено и 2 удалено
  1. 3 2
      server.py

+ 3 - 2
server.py

@@ -142,11 +142,12 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
     input_ids = encode(question, 1)
     preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=1')
     cuda = "" if args.cpu else ".cuda()"
+    if eos_token is not None:
+        n = tokenizer.encode(eos_token, return_tensors='pt')[0][-1]
     for i in range(tokens):
         if eos_token is None:
             output = eval(f"model.generate(input_ids, {preset}){cuda}")
         else:
-            n = tokenizer.encode(eos_token, return_tensors='pt')[0][-1]
             output = eval(f"model.generate(input_ids, eos_token_id={n}, {preset}){cuda}")
 
         reply = tokenizer.decode(output[0], skip_special_tokens=True)
@@ -240,8 +241,8 @@ elif args.chat or args.cai_chat:
         return question
 
     def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check):
-        history.append(['', ''])
         question = generate_chat_prompt(text, tokens, name1, name2, context)
+        history.append(['', ''])
         eos_token = '\n' if check else None
         for i in generate_reply(question, tokens, inference_settings, selected_model, eos_token=eos_token):
             reply = i[0]