Просмотр исходного кода

Truncate prompts to 2048 characters

oobabooga 3 лет назад
Родитель
Сommit
54bf55372b
1 измененных файлов с 7 добавлено и 4 удалено
  1. 7 4
      server.py

+ 7 - 4
server.py

@@ -96,6 +96,7 @@ def load_model(model_name):
         tokenizer = AutoTokenizer.from_pretrained(Path("models/gpt-j-6B/"))
     else:
         tokenizer = AutoTokenizer.from_pretrained(Path(f"models/{model_name}/"))
+    tokenizer.truncation_side = 'left'
 
     print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
     return model, tokenizer
@@ -134,10 +135,10 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
 
     if not args.cpu:
         torch.cuda.empty_cache()
-        input_ids = tokenizer.encode(str(question), return_tensors='pt').cuda()
+        input_ids = tokenizer.encode(str(question), return_tensors='pt', truncation=True, max_length=2048-tokens).cuda()
         cuda = ".cuda()"
     else:
-        input_ids = tokenizer.encode(str(question), return_tensors='pt')
+        input_ids = tokenizer.encode(str(question), return_tensors='pt', truncation=True, max_length=2048-tokens)
         cuda = ""
 
     if eos_token is None:
@@ -231,10 +232,12 @@ elif args.chat or args.cai_chat:
 
         if check:
             reply = generate_reply(question, tokens, inference_settings, selected_model, eos_token='\n')[0]
-            reply = reply[len(question):].split('\n')[0].strip()
+            idx = reply.rfind(question[-500:])
+            reply = reply[idx+min(500, len(question)):].split('\n')[0].strip()
         else:
             reply = generate_reply(question, tokens, inference_settings, selected_model)[0]
-            reply = reply[len(question):]
+            idx = reply.rfind(question[-500:])
+            reply = reply[idx+min(500, len(question)):]
             idx = reply.find(f"\n{name1}:")
             if idx != -1:
                 reply = reply[:idx]