Sfoglia il codice sorgente

Merge branch 'main' into mcmonkey4eva-add-train-lora-tab

oobabooga 2 anni fa
parent
commit
c2cad30772
6 ha cambiato i file con 25 aggiunte e 22 eliminazioni
  1. 1 0
      .gitignore
  2. 5 6
      css/main.css
  3. 1 1
      modules/callbacks.py
  4. 0 4
      modules/chat.py
  5. 4 2
      modules/text_generation.py
  6. 14 9
      server.py

+ 1 - 0
.gitignore

@@ -19,3 +19,4 @@ repositories
 settings.json
 img_bot*
 img_me*
+prompts/[0-9]*

+ 5 - 6
css/main.css

@@ -37,12 +37,6 @@
   text-decoration: none !important;
 }
 
-svg {
-  display: unset !important;
-  vertical-align: middle !important;
-  margin: 5px;
-}
-
 ol li p, ul li p {
     display: inline-block;
 }
@@ -64,3 +58,8 @@ ol li p, ul li p {
   padding: 15px;
   padding: 15px;
 }
+
+span.math.inline {
+  font-size: 27px;
+  vertical-align: baseline !important;
+}

+ 1 - 1
modules/callbacks.py

@@ -54,7 +54,7 @@ class Iteratorize:
         self.stop_now = False
 
         def _callback(val):
-            if self.stop_now:
+            if self.stop_now or shared.stop_everything:
                 raise ValueError
             self.q.put(val)
 

+ 0 - 4
modules/chat.py

@@ -80,11 +80,7 @@ def extract_message_from_reply(reply, name1, name2, check):
     reply = fix_newlines(reply)
     return reply, next_character_found
 
-def stop_everything_event():
-    shared.stop_everything = True
-
 def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1, regenerate=False):
-    shared.stop_everything = False
     just_started = True
     eos_token = '\n' if check else None
     name1_original = name1

+ 4 - 2
modules/text_generation.py

@@ -99,9 +99,13 @@ def set_manual_seed(seed):
         if torch.cuda.is_available():
             torch.cuda.manual_seed_all(seed)
 
+def stop_everything_event():
+    shared.stop_everything = True
+
 def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=None, stopping_strings=[]):
     clear_torch_cache()
     set_manual_seed(seed)
+    shared.stop_everything = False
     t0 = time.time()
 
     original_question = question
@@ -236,8 +240,6 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
                         break
                     yield formatted_outputs(reply, shared.model_name)
 
-                yield formatted_outputs(reply, shared.model_name)
-
         # Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
         else:
             for i in range(max_new_tokens//8+1):

+ 14 - 9
server.py

@@ -14,7 +14,8 @@ import modules.extensions as extensions_module
 from modules.html_generator import generate_chat_html
 from modules.LoRA import add_lora_to_model
 from modules.models import load_model, load_soft_prompt
-from modules.text_generation import clear_torch_cache, generate_reply
+from modules.text_generation import (clear_torch_cache, generate_reply,
+                                     stop_everything_event)
 
 # Loading custom settings
 settings_file = None
@@ -133,7 +134,7 @@ def save_prompt(text):
     fname = f"{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}.txt"
     with open(Path(f'prompts/{fname}'), 'w', encoding='utf-8') as f:
         f.write(text)
-    return f"Saved prompt to prompts/{fname}"
+    return f"Saved to prompts/{fname}"
 
 def load_prompt(fname):
     if fname in ['None', '']:
@@ -154,7 +155,7 @@ def create_prompt_menus():
                 shared.gradio['save_prompt'] = gr.Button('Save prompt')
                 shared.gradio['status'] = gr.Markdown('Ready')
 
-    shared.gradio['prompt_menu'].change(load_prompt, [shared.gradio['prompt_menu']], [shared.gradio['textbox']], show_progress=True)
+    shared.gradio['prompt_menu'].change(load_prompt, [shared.gradio['prompt_menu']], [shared.gradio['textbox']], show_progress=False)
     shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False)
 
 def create_settings_menus(default_preset):
@@ -364,7 +365,7 @@ def create_interface():
             gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
             gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
             gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
-            shared.gradio['Stop'].click(chat.stop_everything_event, [], [], cancels=gen_events, queue=False)
+            shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
 
             shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream)
             shared.gradio['Replace last reply'].click(chat.replace_last_reply, [shared.gradio['textbox'], shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'], show_progress=shared.args.no_stream)
@@ -415,11 +416,15 @@ def create_interface():
                             shared.gradio['html'] = gr.HTML()
 
                         with gr.Row():
-                            shared.gradio['Generate'] = gr.Button('Generate')
-                            shared.gradio['Stop'] = gr.Button('Stop')
+                            with gr.Column():
+                                with gr.Row():
+                                    shared.gradio['Generate'] = gr.Button('Generate')
+                                    shared.gradio['Stop'] = gr.Button('Stop')
+                            with gr.Column():
+                                pass
 
                     with gr.Column(scale=1):
-                        gr.Markdown("\n")
+                        gr.HTML('<div style="padding-bottom: 13px"></div>')
                         shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
 
                         create_prompt_menus()
@@ -431,7 +436,7 @@ def create_interface():
             output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
             gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
             gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
-            shared.gradio['Stop'].click(None, None, None, cancels=gen_events)
+            shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
             shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
 
         else:
@@ -465,7 +470,7 @@ def create_interface():
             gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
             gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
             gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream))
-            shared.gradio['Stop'].click(None, None, None, cancels=gen_events)
+            shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
             shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
 
         with gr.Tab("Training", elem_id="training-tab"):