Minor fixes
This commit is contained in:
@@ -168,7 +168,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False):
|
|||||||
shared.history['visible'].append(['', ''])
|
shared.history['visible'].append(['', ''])
|
||||||
|
|
||||||
if _continue:
|
if _continue:
|
||||||
sep = list(map(lambda x : ' ' if x[-1] != ' ' else '', last_reply))
|
sep = list(map(lambda x: ' ' if x[-1] != ' ' else '', last_reply))
|
||||||
shared.history['internal'][-1] = [text, f'{last_reply[0]}{sep[0]}{reply}']
|
shared.history['internal'][-1] = [text, f'{last_reply[0]}{sep[0]}{reply}']
|
||||||
shared.history['visible'][-1] = [visible_text, f'{last_reply[1]}{sep[1]}{visible_reply}']
|
shared.history['visible'][-1] = [visible_text, f'{last_reply[1]}{sep[1]}{visible_reply}']
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -69,6 +69,7 @@ def generate_softprompt_input_tensors(input_ids):
|
|||||||
# filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
|
# filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
|
||||||
return inputs_embeds, filler_input_ids
|
return inputs_embeds, filler_input_ids
|
||||||
|
|
||||||
|
|
||||||
# Removes empty replies from gpt4chan outputs
|
# Removes empty replies from gpt4chan outputs
|
||||||
def fix_gpt4chan(s):
|
def fix_gpt4chan(s):
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
@@ -77,6 +78,7 @@ def fix_gpt4chan(s):
|
|||||||
s = re.sub("--- [0-9]*\n\n\n---", "---", s)
|
s = re.sub("--- [0-9]*\n\n\n---", "---", s)
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
# Fix the LaTeX equations in galactica
|
# Fix the LaTeX equations in galactica
|
||||||
def fix_galactica(s):
|
def fix_galactica(s):
|
||||||
s = s.replace(r'\[', r'$')
|
s = s.replace(r'\[', r'$')
|
||||||
|
|||||||
17
server.py
17
server.py
@@ -184,22 +184,22 @@ def download_model_wrapper(repo_id):
|
|||||||
branch = "main"
|
branch = "main"
|
||||||
check = False
|
check = False
|
||||||
|
|
||||||
yield("Cleaning up the model/branch names")
|
yield ("Cleaning up the model/branch names")
|
||||||
model, branch = downloader.sanitize_model_and_branch_names(model, branch)
|
model, branch = downloader.sanitize_model_and_branch_names(model, branch)
|
||||||
|
|
||||||
yield("Getting the download links from Hugging Face")
|
yield ("Getting the download links from Hugging Face")
|
||||||
links, sha256, is_lora = downloader.get_download_links_from_huggingface(model, branch, text_only=False)
|
links, sha256, is_lora = downloader.get_download_links_from_huggingface(model, branch, text_only=False)
|
||||||
|
|
||||||
yield("Getting the output folder")
|
yield ("Getting the output folder")
|
||||||
output_folder = downloader.get_output_folder(model, branch, is_lora)
|
output_folder = downloader.get_output_folder(model, branch, is_lora)
|
||||||
|
|
||||||
if check:
|
if check:
|
||||||
yield("Checking previously downloaded files")
|
yield ("Checking previously downloaded files")
|
||||||
downloader.check_model_files(model, branch, links, sha256, output_folder)
|
downloader.check_model_files(model, branch, links, sha256, output_folder)
|
||||||
else:
|
else:
|
||||||
yield(f"Downloading files to {output_folder}")
|
yield (f"Downloading files to {output_folder}")
|
||||||
downloader.download_model_files(model, branch, links, sha256, output_folder, threads=1)
|
downloader.download_model_files(model, branch, links, sha256, output_folder, threads=1)
|
||||||
yield("Done!")
|
yield ("Done!")
|
||||||
except:
|
except:
|
||||||
yield traceback.format_exc()
|
yield traceback.format_exc()
|
||||||
|
|
||||||
@@ -377,11 +377,12 @@ def create_interface():
|
|||||||
extensions_module.load_extensions()
|
extensions_module.load_extensions()
|
||||||
|
|
||||||
with gr.Blocks(css=ui.css if not shared.is_chat() else ui.css + ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']:
|
with gr.Blocks(css=ui.css if not shared.is_chat() else ui.css + ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']:
|
||||||
|
if shared.is_chat():
|
||||||
|
|
||||||
shared.input_elements = list_interface_input_elements(chat=True)
|
shared.input_elements = list_interface_input_elements(chat=True)
|
||||||
shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
|
shared.gradio['interface_state'] = gr.State({k: None for k in shared.input_elements})
|
||||||
|
|
||||||
if shared.is_chat():
|
|
||||||
shared.gradio['Chat input'] = gr.State()
|
shared.gradio['Chat input'] = gr.State()
|
||||||
|
|
||||||
with gr.Tab("Text generation", elem_id="main"):
|
with gr.Tab("Text generation", elem_id="main"):
|
||||||
shared.gradio['display'] = gr.HTML(value=chat_html_wrapper(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], 'cai-chat'))
|
shared.gradio['display'] = gr.HTML(value=chat_html_wrapper(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], 'cai-chat'))
|
||||||
shared.gradio['textbox'] = gr.Textbox(label='Input')
|
shared.gradio['textbox'] = gr.Textbox(label='Input')
|
||||||
|
|||||||
Reference in New Issue
Block a user