From 98dcfb8e1277cf6f262822bf9a32d5597cddffe1 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 11 Apr 2023 11:24:55 -0300 Subject: [PATCH] Minor fixes --- modules/chat.py | 6 +++--- modules/text_generation.py | 2 ++ server.py | 19 ++++++++++--------- 3 files changed, 15 insertions(+), 12 deletions(-) diff --git a/modules/chat.py b/modules/chat.py index 7a8e911..02b1847 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -168,7 +168,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False): shared.history['visible'].append(['', '']) if _continue: - sep = list(map(lambda x : ' ' if x[-1] != ' ' else '', last_reply)) + sep = list(map(lambda x: ' ' if x[-1] != ' ' else '', last_reply)) shared.history['internal'][-1] = [text, f'{last_reply[0]}{sep[0]}{reply}'] shared.history['visible'][-1] = [visible_text, f'{last_reply[1]}{sep[1]}{visible_reply}'] else: @@ -278,7 +278,7 @@ def clear_chat_log(name1, name2, greeting, mode): if greeting != '': shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]] shared.history['visible'] += [['', apply_extensions(greeting, "output")]] - + # Save cleared logs save_history(mode) @@ -446,7 +446,7 @@ def load_character(character, name1, name2, mode): if greeting != "": shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]] shared.history['visible'] += [['', apply_extensions(greeting, "output")]] - + # Create .json log files since they don't already exist save_history(mode) diff --git a/modules/text_generation.py b/modules/text_generation.py index c3daafa..e8f283a 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -69,6 +69,7 @@ def generate_softprompt_input_tensors(input_ids): # filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens return inputs_embeds, filler_input_ids + # Removes empty replies from gpt4chan outputs def fix_gpt4chan(s): for i in range(10): @@ -77,6 +78,7 @@ def fix_gpt4chan(s): s = re.sub("--- [0-9]*\n\n\n---", "---", s) return s + # Fix the LaTeX equations in galactica def fix_galactica(s): s = s.replace(r'\[', r'$') diff --git a/server.py b/server.py index 232eba3..d88a865 100644 --- a/server.py +++ b/server.py @@ -184,22 +184,22 @@ def download_model_wrapper(repo_id): branch = "main" check = False - yield("Cleaning up the model/branch names") + yield ("Cleaning up the model/branch names") model, branch = downloader.sanitize_model_and_branch_names(model, branch) - yield("Getting the download links from Hugging Face") + yield ("Getting the download links from Hugging Face") links, sha256, is_lora = downloader.get_download_links_from_huggingface(model, branch, text_only=False) - yield("Getting the output folder") + yield ("Getting the output folder") output_folder = downloader.get_output_folder(model, branch, is_lora) if check: - yield("Checking previously downloaded files") + yield ("Checking previously downloaded files") downloader.check_model_files(model, branch, links, sha256, output_folder) else: - yield(f"Downloading files to {output_folder}") + yield (f"Downloading files to {output_folder}") downloader.download_model_files(model, branch, links, sha256, output_folder, threads=1) - yield("Done!") + yield ("Done!") except: yield traceback.format_exc() @@ -377,11 +377,12 @@ def create_interface(): extensions_module.load_extensions() with gr.Blocks(css=ui.css if not shared.is_chat() else ui.css + ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']: - shared.input_elements = list_interface_input_elements(chat=True) - shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements}) - if shared.is_chat(): + + shared.input_elements = list_interface_input_elements(chat=True) + shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements}) shared.gradio['Chat input'] = gr.State() + with gr.Tab("Text generation", elem_id="main"): shared.gradio['display'] = gr.HTML(value=chat_html_wrapper(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], 'cai-chat')) shared.gradio['textbox'] = gr.Textbox(label='Input')