Sfoglia il codice sorgente

Reorganize some chat functions

oobabooga 2 anni fa
parent
commit
a453d4e9c4
1 ha cambiato i file con 9 aggiunte e 8 eliminazioni
  1. 9 8
      modules/chat.py

+ 9 - 8
modules/chat.py

@@ -105,14 +105,16 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
     else:
         stopping_strings = [f"\n{name1}:", f"\n{name2}:"]
 
-    eos_token = '\n' if generate_state['stop_at_newline'] else None
+    # Defining some variables
+    cumulative_reply = ''
+    just_started = True
     name1_original = name1
+    visible_text = custom_generate_chat_prompt = None
+    eos_token = '\n' if generate_state['stop_at_newline'] else None
     if 'pygmalion' in shared.model_name.lower():
         name1 = "You"
 
     # Check if any extension wants to hijack this function call
-    visible_text = None
-    custom_generate_chat_prompt = None
     for extension, _ in extensions_module.iterator():
         if hasattr(extension, 'input_hijack') and extension.input_hijack['state']:
             extension.input_hijack['state'] = False
@@ -124,6 +126,7 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
         visible_text = text
     text = apply_extensions(text, "input")
 
+    # Generating the prompt
     kwargs = {'end_of_turn': end_of_turn, 'is_instruct': mode == 'instruct'}
     if custom_generate_chat_prompt is None:
         prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
@@ -135,8 +138,6 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
         yield shared.history['visible'] + [[visible_text, shared.processing_message]]
 
     # Generate
-    cumulative_reply = ''
-    just_started = True
     for i in range(generate_state['chat_generation_attempts']):
         reply = None
         for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings):
@@ -175,6 +176,8 @@ def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_o
     else:
         stopping_strings = [f"\n{name1}:", f"\n{name2}:"]
 
+    # Defining some variables
+    cumulative_reply = ''
     eos_token = '\n' if generate_state['stop_at_newline'] else None
     if 'pygmalion' in shared.model_name.lower():
         name1 = "You"
@@ -184,7 +187,6 @@ def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_o
     # Yield *Is typing...*
     yield shared.processing_message
 
-    cumulative_reply = ''
     for i in range(generate_state['chat_generation_attempts']):
         reply = None
         for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings):
@@ -264,7 +266,7 @@ def redraw_html(name1, name2, mode):
 
 def tokenize_dialogue(dialogue, name1, name2, mode):
     history = []
-
+    messages = []
     dialogue = re.sub('<START>', '', dialogue)
     dialogue = re.sub('<start>', '', dialogue)
     dialogue = re.sub('(\n|^)[Aa]non:', '\\1You:', dialogue)
@@ -273,7 +275,6 @@ def tokenize_dialogue(dialogue, name1, name2, mode):
     if len(idx) == 0:
         return history
 
-    messages = []
     for i in range(len(idx) - 1):
         messages.append(dialogue[idx[i]:idx[i + 1]].strip())
     messages.append(dialogue[idx[-1]:].strip())