Browse Source

Various fixes in chat mode

oobabooga 2 years ago
parent
commit
341e135036
3 changed files with 22 additions and 24 deletions
  1. 1 0
      modules/callbacks.py
  2. 6 10
      modules/chat.py
  3. 15 14
      modules/text_generation.py

+ 1 - 0
modules/callbacks.py

@@ -64,6 +64,7 @@ class Iteratorize:
                 ret = self.mfunc(callback=_callback, **self.kwargs)
                 ret = self.mfunc(callback=_callback, **self.kwargs)
             except ValueError:
             except ValueError:
                 pass
                 pass
+            clear_torch_cache()
             self.q.put(self.sentinel)
             self.q.put(self.sentinel)
             if self.c_callback:
             if self.c_callback:
                 self.c_callback(ret)
                 self.c_callback(ret)

+ 6 - 10
modules/chat.py

@@ -115,18 +115,14 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
         visible_text = visible_text.replace('\n', '<br>')
         visible_text = visible_text.replace('\n', '<br>')
     text = apply_extensions(text, "input")
     text = apply_extensions(text, "input")
 
 
+    if custom_generate_chat_prompt is None:
+        prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size)
+    else:
+        prompt = custom_generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size)
+
     # Generate
     # Generate
     reply = ''
     reply = ''
     for i in range(chat_generation_attempts):
     for i in range(chat_generation_attempts):
-
-        #  The prompt needs to be generated here because, as the reply
-        #  grows, it may become necessary to remove more old messages to
-        #  fit into the 2048 tokens window.
-        if custom_generate_chat_prompt is None:
-            prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size-len(encode(' '+reply)[0]))
-        else:
-            prompt = custom_generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size-len(encode(' '+reply)[0]))
-
         for reply in generate_reply(f"{prompt}{' ' if len(reply) > 0 else ''}{reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name1}:"):
         for reply in generate_reply(f"{prompt}{' ' if len(reply) > 0 else ''}{reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name1}:"):
 
 
             # Extracting the reply
             # Extracting the reply
@@ -160,10 +156,10 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ
     if 'pygmalion' in shared.model_name.lower():
     if 'pygmalion' in shared.model_name.lower():
         name1 = "You"
         name1 = "You"
 
 
+    prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=True)
 
 
     reply = ''
     reply = ''
     for i in range(chat_generation_attempts):
     for i in range(chat_generation_attempts):
-        prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size-len(encode(' '+reply)[0]), impersonate=True)
         for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"):
         for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"):
             reply, next_character_found, substring_found = extract_message_from_reply(prompt, reply, name1, name2, check, impersonate=True)
             reply, next_character_found, substring_found = extract_message_from_reply(prompt, reply, name1, name2, check, impersonate=True)
             if not substring_found:
             if not substring_found:

+ 15 - 14
modules/text_generation.py

@@ -92,21 +92,22 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
     # These models are not part of Hugging Face, so we handle them
     # These models are not part of Hugging Face, so we handle them
     # separately and terminate the function call earlier
     # separately and terminate the function call earlier
     if shared.is_RWKV:
     if shared.is_RWKV:
-        if shared.args.no_stream:
-            reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k)
-            yield formatted_outputs(reply, shared.model_name)
-        else:
-            yield formatted_outputs(question, shared.model_name)
-            # RWKV has proper streaming, which is very nice.
-            # No need to generate 8 tokens at a time.
-            for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k):
+        try:
+            if shared.args.no_stream:
+                reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k)
                 yield formatted_outputs(reply, shared.model_name)
                 yield formatted_outputs(reply, shared.model_name)
-
-        t1 = time.time()
-        output = encode(reply)[0]
-        input_ids = encode(question)
-        print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(input_ids[0])} tokens)")
-        return
+            else:
+                yield formatted_outputs(question, shared.model_name)
+                # RWKV has proper streaming, which is very nice.
+                # No need to generate 8 tokens at a time.
+                for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k):
+                    yield formatted_outputs(reply, shared.model_name)
+        finally:
+            t1 = time.time()
+            output = encode(reply)[0]
+            input_ids = encode(question)
+            print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(input_ids[0])} tokens)")
+            return
 
 
     original_question = question
     original_question = question
     if not (shared.args.chat or shared.args.cai_chat):
     if not (shared.args.chat or shared.args.cai_chat):