Explorar el Código

The soft prompt length must be considered here too

oobabooga hace 3 años
padre
commit
596732a981
Se han modificado 1 ficheros con 6 adiciones y 0 borrados
  1. 6 0
      server.py

+ 6 - 0
server.py

@@ -505,11 +505,17 @@ def clean_chat_message(text):
     return text
     return text
 
 
 def generate_chat_prompt(text, tokens, name1, name2, context, chat_prompt_size, impersonate=False):
 def generate_chat_prompt(text, tokens, name1, name2, context, chat_prompt_size, impersonate=False):
+    global soft_prompt, soft_prompt_tensor
+
     text = clean_chat_message(text)
     text = clean_chat_message(text)
     rows = [f"{context.strip()}\n"]
     rows = [f"{context.strip()}\n"]
     i = len(history['internal'])-1
     i = len(history['internal'])-1
     count = 0
     count = 0
+
+    if soft_prompt:
+        chat_prompt_size -= soft_prompt_tensor.shape[1]
     max_length = min(get_max_prompt_length(tokens), chat_prompt_size)
     max_length = min(get_max_prompt_length(tokens), chat_prompt_size)
+
     while i >= 0 and len(encode(''.join(rows), tokens)[0]) < max_length:
     while i >= 0 and len(encode(''.join(rows), tokens)[0]) < max_length:
         rows.insert(1, f"{name2}: {history['internal'][i][1].strip()}\n")
         rows.insert(1, f"{name2}: {history['internal'][i][1].strip()}\n")
         count += 1
         count += 1