oobabooga преди 2 години
родител
ревизия
bfa81e105e
променени са 1 файла, в които са добавени 4 реда и са изтрити 0 реда
  1. 4 0
      modules/text_generation.py

+ 4 - 0
modules/text_generation.py

@@ -258,6 +258,10 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
                 input_ids = np.reshape(output, (1, output.shape[0]))
                 if shared.soft_prompt:
                     inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
+                    generate_params.update({"inputs_embeds": inputs_embeds})
+                    generate_params.update({"inputs": filler_input_ids})
+                else:
+                    generate_params.update({"inputs": input_ids})
 
             yield formatted_outputs(reply, shared.model_name)