Selaa lähdekoodia

Properly separate the original prompt from the reply

oobabooga 2 vuotta sitten
vanhempi
commit
de6a09dc7f
1 muutettua tiedostoa jossa 19 lisäystä ja 11 poistoa
  1. 19 11
      modules/text_generation.py

+ 19 - 11
modules/text_generation.py

@@ -136,6 +136,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
     input_ids = encode(question, max_new_tokens)
     original_input_ids = input_ids
     output = input_ids[0]
+
     cuda = not any((shared.args.cpu, shared.args.deepspeed, shared.args.flexgen))
     eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
     if eos_token is not None:
@@ -146,9 +147,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
         t = encode(stopping_string, 0, add_special_tokens=False)
         stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=t, starting_idx=len(input_ids[0])))
 
-    generate_params = {
-        'use_cache': not shared.args.no_cache,
-    }
+    generate_params = {}
     if not shared.args.flexgen:
         generate_params.update({
             "max_new_tokens": max_new_tokens,
@@ -175,6 +174,8 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
             "temperature": temperature,
             "stop": eos_token_ids[-1],
         })
+    if shared.args.no_cache:
+        generate_params.update({"use_cache": False})
     if shared.args.deepspeed:
         generate_params.update({"synced_gpus": True})
     if shared.soft_prompt:
@@ -194,9 +195,12 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
             if shared.soft_prompt:
                 output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
 
-            reply = decode(output)
             if not (shared.args.chat or shared.args.cai_chat):
-                reply = original_question + apply_extensions(reply[len(question):], "output")
+                new_tokens = len(output) - len(input_ids[0])
+                reply = decode(output[-new_tokens:])
+                reply = original_question + apply_extensions(reply, "output")
+            else:
+                reply = decode(output)
 
             yield formatted_outputs(reply, shared.model_name)
 
@@ -219,10 +223,12 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
                 for output in generator:
                     if shared.soft_prompt:
                         output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
-                    reply = decode(output)
-
                     if not (shared.args.chat or shared.args.cai_chat):
-                        reply = original_question + apply_extensions(reply[len(question):], "output")
+                        new_tokens = len(output) - len(input_ids[0])
+                        reply = decode(output[-new_tokens:])
+                        reply = original_question + apply_extensions(reply, "output")
+                    else:
+                        reply = decode(output)
 
                     if output[-1] in eos_token_ids:
                         break
@@ -238,10 +244,12 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
                     output = shared.model.generate(**generate_params)[0]
                 if shared.soft_prompt:
                     output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
-                reply = decode(output)
-
                 if not (shared.args.chat or shared.args.cai_chat):
-                    reply = original_question + apply_extensions(reply[len(question):], "output")
+                    new_tokens = len(output) - len(original_input_ids[0])
+                    reply = decode(output[-new_tokens:])
+                    reply = original_question + apply_extensions(reply, "output")
+                else:
+                    reply = decode(output)
 
                 if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
                     break