oobabooga 2 лет назад
Родитель
Сommit
af65c12900
4 измененных файлов с 10 добавлено и 9 удалено
  1. 1 1
      modules/callbacks.py
  2. 0 4
      modules/chat.py
  3. 4 0
      modules/text_generation.py
  4. 5 4
      server.py

+ 1 - 1
modules/callbacks.py

@@ -54,7 +54,7 @@ class Iteratorize:
         self.stop_now = False
 
         def _callback(val):
-            if self.stop_now:
+            if self.stop_now or shared.stop_everything:
                 raise ValueError
             self.q.put(val)
 

+ 0 - 4
modules/chat.py

@@ -80,11 +80,7 @@ def extract_message_from_reply(reply, name1, name2, check):
     reply = fix_newlines(reply)
     return reply, next_character_found
 
-def stop_everything_event():
-    shared.stop_everything = True
-
 def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1, regenerate=False):
-    shared.stop_everything = False
     just_started = True
     eos_token = '\n' if check else None
     name1_original = name1

+ 4 - 0
modules/text_generation.py

@@ -99,9 +99,13 @@ def set_manual_seed(seed):
         if torch.cuda.is_available():
             torch.cuda.manual_seed_all(seed)
 
+def stop_everything_event():
+    shared.stop_everything = True
+
 def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=None, stopping_strings=[]):
     clear_torch_cache()
     set_manual_seed(seed)
+    shared.stop_everything = False
     t0 = time.time()
 
     original_question = question

+ 5 - 4
server.py

@@ -16,7 +16,8 @@ import modules.ui as ui
 from modules.html_generator import generate_chat_html
 from modules.LoRA import add_lora_to_model
 from modules.models import load_model, load_soft_prompt
-from modules.text_generation import clear_torch_cache, generate_reply
+from modules.text_generation import (clear_torch_cache, generate_reply,
+                                     stop_everything_event)
 
 # Loading custom settings
 settings_file = None
@@ -366,7 +367,7 @@ def create_interface():
             gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
             gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
             gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
-            shared.gradio['Stop'].click(chat.stop_everything_event, [], [], cancels=gen_events, queue=False)
+            shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
 
             shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream)
             shared.gradio['Replace last reply'].click(chat.replace_last_reply, [shared.gradio['textbox'], shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'], show_progress=shared.args.no_stream)
@@ -437,7 +438,7 @@ def create_interface():
             output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
             gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
             gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
-            shared.gradio['Stop'].click(None, None, None, cancels=gen_events)
+            shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
             shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
 
         else:
@@ -471,7 +472,7 @@ def create_interface():
             gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
             gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
             gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream))
-            shared.gradio['Stop'].click(None, None, None, cancels=gen_events)
+            shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
             shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
 
         with gr.Tab("Interface mode", elem_id="interface-mode"):