Procházet zdrojové kódy

Consider the softprompt in the maximum prompt length calculation

oobabooga před 3 roky
rodič
revize
d910d435cd
1 změnil soubory, kde provedl 11 přidání a 3 odebrání
  1. 11 3
      server.py

+ 11 - 3
server.py

@@ -247,8 +247,15 @@ def fix_galactica(s):
     s = s.replace(r'$$', r'$')
     return s
 
+def get_max_prompt_length(tokens):
+    global soft_prompt, soft_prompt_tensor
+    max_length = 2048-tokens
+    if soft_prompt:
+        max_length -= soft_prompt_tensor.shape[1]
+    return max_length
+
 def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
-    input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens_to_generate, add_special_tokens=add_special_tokens)
+    input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens)
     if args.cpu:
         return input_ids
     elif args.deepspeed:
@@ -497,7 +504,8 @@ def generate_chat_prompt(text, tokens, name1, name2, context, history_size, impe
     rows = [f"{context.strip()}\n"]
     i = len(history['internal'])-1
     count = 0
-    while i >= 0 and len(encode(''.join(rows), tokens)[0]) < 2048-tokens:
+    max_length = get_max_prompt_length(tokens)
+    while i >= 0 and len(encode(''.join(rows), tokens)[0]) < max_length:
         rows.insert(1, f"{name2}: {history['internal'][i][1].strip()}\n")
         count += 1
         if not (history['internal'][i][0] == '<|BEGIN-VISIBLE-CHAT|>'):
@@ -515,7 +523,7 @@ def generate_chat_prompt(text, tokens, name1, name2, context, history_size, impe
         rows.append(f"{name1}:")
         limit = 2
 
-    while len(rows) > limit and len(encode(''.join(rows), tokens)[0]) >= 2048-tokens:
+    while len(rows) > limit and len(encode(''.join(rows), tokens)[0]) >= max_length:
         rows.pop(1)
         rows.pop(1)