Просмотр исходного кода

Manual eos_token implementation

oobabooga 3 лет назад
Родитель
Сommit
116299b3ad
1 измененных файлов с 3 добавлено и 7 удалено
  1. 3 7
      server.py

+ 3 - 7
server.py

@@ -142,16 +142,12 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
     input_ids = encode(question, 1)
     preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=1')
     cuda = "" if args.cpu else ".cuda()"
-    if eos_token is not None:
-        n = tokenizer.encode(eos_token, return_tensors='pt')[0][-1]
     for i in range(tokens):
-        if eos_token is None:
-            output = eval(f"model.generate(input_ids, {preset}){cuda}")
-        else:
-            output = eval(f"model.generate(input_ids, eos_token_id={n}, {preset}){cuda}")
-
+        output = eval(f"model.generate(input_ids, {preset}){cuda}")
         reply = tokenizer.decode(output[0], skip_special_tokens=True)
         reply = reply.replace(r'<|endoftext|>', '')
+        if eos_token is not None and reply[-1] == eos_token:
+            break
         if model_name.lower().startswith('galactica'):
             reply = fix_galactica(reply)
             yield reply, reply, generate_basic_html(reply)