Reformat everything

This commit is contained in:
oobabooga
2023-04-07 00:08:46 -03:00
parent 848c4edfd5
commit 01cacfc14f
29 changed files with 334 additions and 193 deletions

View File

@@ -17,6 +17,7 @@ def random_hash():
letters = string.ascii_lowercase + string.digits letters = string.ascii_lowercase + string.digits
return ''.join(random.choice(letters) for i in range(9)) return ''.join(random.choice(letters) for i in range(9))
async def run(context): async def run(context):
server = "127.0.0.1" server = "127.0.0.1"
params = { params = {
@@ -69,6 +70,7 @@ async def run(context):
prompt = "What I would like to say is the following: " prompt = "What I would like to say is the following: "
async def get_result(): async def get_result():
async for response in run(prompt): async for response in run(prompt):
# Print intermediate steps # Print intermediate steps

View File

@@ -17,6 +17,7 @@ parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpForma
parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.") parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.")
args = parser.parse_args() args = parser.parse_args()
def disable_torch_init(): def disable_torch_init():
""" """
Disable the redundant torch default initialization to accelerate model creation. Disable the redundant torch default initialization to accelerate model creation.
@@ -31,12 +32,14 @@ def disable_torch_init():
torch_layer_norm_init_backup = torch.nn.LayerNorm.reset_parameters torch_layer_norm_init_backup = torch.nn.LayerNorm.reset_parameters
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
def restore_torch_init(): def restore_torch_init():
"""Rollback the change made by disable_torch_init.""" """Rollback the change made by disable_torch_init."""
import torch import torch
setattr(torch.nn.Linear, "reset_parameters", torch_linear_init_backup) setattr(torch.nn.Linear, "reset_parameters", torch_linear_init_backup)
setattr(torch.nn.LayerNorm, "reset_parameters", torch_layer_norm_init_backup) setattr(torch.nn.LayerNorm, "reset_parameters", torch_layer_norm_init_backup)
if __name__ == '__main__': if __name__ == '__main__':
path = Path(args.MODEL) path = Path(args.MODEL)
model_name = path.name model_name = path.name

View File

@@ -29,6 +29,7 @@ parser.add_argument('--clean', action='store_true', help='Does not resume the pr
parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.') parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
args = parser.parse_args() args = parser.parse_args()
def get_file(url, output_folder): def get_file(url, output_folder):
filename = Path(url.rsplit('/', 1)[1]) filename = Path(url.rsplit('/', 1)[1])
output_path = output_folder / filename output_path = output_folder / filename
@@ -54,6 +55,7 @@ def get_file(url, output_folder):
t.update(len(data)) t.update(len(data))
f.write(data) f.write(data)
def sanitize_branch_name(branch_name): def sanitize_branch_name(branch_name):
pattern = re.compile(r"^[a-zA-Z0-9._-]+$") pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
if pattern.match(branch_name): if pattern.match(branch_name):
@@ -61,6 +63,7 @@ def sanitize_branch_name(branch_name):
else: else:
raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.") raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.")
def select_model_from_default_options(): def select_model_from_default_options():
models = { models = {
"OPT 6.7B": ("facebook", "opt-6.7b", "main"), "OPT 6.7B": ("facebook", "opt-6.7b", "main"),
@@ -106,6 +109,7 @@ EleutherAI/pythia-1.4b-deduped
return model, branch return model, branch
def get_download_links_from_huggingface(model, branch): def get_download_links_from_huggingface(model, branch):
base = "https://huggingface.co" base = "https://huggingface.co"
page = f"/api/models/{model}/tree/{branch}?cursor=" page = f"/api/models/{model}/tree/{branch}?cursor="
@@ -172,9 +176,11 @@ def get_download_links_from_huggingface(model, branch):
return links, sha256, is_lora return links, sha256, is_lora
def download_files(file_list, output_folder, num_threads=8): def download_files(file_list, output_folder, num_threads=8):
thread_map(lambda url: get_file(url, output_folder), file_list, max_workers=num_threads, disable=True) thread_map(lambda url: get_file(url, output_folder), file_list, max_workers=num_threads, disable=True)
if __name__ == '__main__': if __name__ == '__main__':
model = args.MODEL model = args.MODEL
branch = args.branch branch = args.branch

View File

@@ -9,6 +9,7 @@ params = {
'port': 5000, 'port': 5000,
} }
class Handler(BaseHTTPRequestHandler): class Handler(BaseHTTPRequestHandler):
def do_GET(self): def do_GET(self):
if self.path == '/api/v1/model': if self.path == '/api/v1/model':
@@ -32,7 +33,7 @@ class Handler(BaseHTTPRequestHandler):
self.end_headers() self.end_headers()
prompt = body['prompt'] prompt = body['prompt']
prompt_lines = [l.strip() for l in prompt.split('\n')] prompt_lines = [k.strip() for k in prompt.split('\n')]
max_context = body.get('max_context_length', 2048) max_context = body.get('max_context_length', 2048)
@@ -95,5 +96,6 @@ def run_server():
print(f'Starting KoboldAI compatible api at http://{server_addr[0]}:{server_addr[1]}/api') print(f'Starting KoboldAI compatible api at http://{server_addr[0]}:{server_addr[1]}/api')
server.serve_forever() server.serve_forever()
def setup(): def setup():
Thread(target=run_server, daemon=True).start() Thread(target=run_server, daemon=True).start()

View File

@@ -5,6 +5,7 @@ params = {
"bias string": " *I am so happy*", "bias string": " *I am so happy*",
} }
def input_modifier(string): def input_modifier(string):
""" """
This function is applied to your text inputs before This function is applied to your text inputs before
@@ -13,6 +14,7 @@ def input_modifier(string):
return string return string
def output_modifier(string): def output_modifier(string):
""" """
This function is applied to the model outputs. This function is applied to the model outputs.
@@ -20,6 +22,7 @@ def output_modifier(string):
return string return string
def bot_prefix_modifier(string): def bot_prefix_modifier(string):
""" """
This function is only applied in chat mode. It modifies This function is only applied in chat mode. It modifies
@@ -27,11 +30,12 @@ def bot_prefix_modifier(string):
behavior. behavior.
""" """
if params['activate'] == True: if params['activate']:
return f'{string} {params["bias string"].strip()} ' return f'{string} {params["bias string"].strip()} '
else: else:
return string return string
def ui(): def ui():
# Gradio elements # Gradio elements
activate = gr.Checkbox(value=params['activate'], label='Activate character bias') activate = gr.Checkbox(value=params['activate'], label='Activate character bias')

View File

@@ -22,6 +22,8 @@ if not shared.args.no_stream:
raise ValueError raise ValueError
# Check if the API is valid and refresh the UI accordingly. # Check if the API is valid and refresh the UI accordingly.
def check_valid_api(): def check_valid_api():
global user, user_info, params global user, user_info, params
@@ -29,7 +31,7 @@ def check_valid_api():
user = ElevenLabsUser(params['api_key']) user = ElevenLabsUser(params['api_key'])
user_info = user._get_subscription_data() user_info = user._get_subscription_data()
print('checking api') print('checking api')
if params['activate'] == False: if not params['activate']:
return gr.update(value='Disconnected') return gr.update(value='Disconnected')
elif user_info is None: elif user_info is None:
print('Incorrect API Key') print('Incorrect API Key')
@@ -39,6 +41,8 @@ def check_valid_api():
return gr.update(value='Connected') return gr.update(value='Connected')
# Once the API is verified, get the available voices and update the dropdown list # Once the API is verified, get the available voices and update the dropdown list
def refresh_voices(): def refresh_voices():
global user, user_info global user, user_info
@@ -51,11 +55,13 @@ def refresh_voices():
else: else:
return return
def remove_surrounded_chars(string): def remove_surrounded_chars(string):
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR # this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string' # 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
return re.sub('\*[^\*]*?(\*|$)', '', string) return re.sub('\*[^\*]*?(\*|$)', '', string)
def input_modifier(string): def input_modifier(string):
""" """
This function is applied to your text inputs before This function is applied to your text inputs before
@@ -64,6 +70,7 @@ def input_modifier(string):
return string return string
def output_modifier(string): def output_modifier(string):
""" """
This function is applied to the model outputs. This function is applied to the model outputs.
@@ -71,9 +78,9 @@ def output_modifier(string):
global params, wav_idx, user, user_info global params, wav_idx, user, user_info
if params['activate'] == False: if not params['activate']:
return string return string
elif user_info == None: elif user_info is None:
return string return string
string = remove_surrounded_chars(string) string = remove_surrounded_chars(string)
@@ -94,6 +101,7 @@ def output_modifier(string):
wav_idx += 1 wav_idx += 1
return string return string
def ui(): def ui():
# Gradio elements # Gradio elements

View File

@@ -7,6 +7,7 @@ params = {
language_codes = {'Afrikaans': 'af', 'Albanian': 'sq', 'Amharic': 'am', 'Arabic': 'ar', 'Armenian': 'hy', 'Azerbaijani': 'az', 'Basque': 'eu', 'Belarusian': 'be', 'Bengali': 'bn', 'Bosnian': 'bs', 'Bulgarian': 'bg', 'Catalan': 'ca', 'Cebuano': 'ceb', 'Chinese (Simplified)': 'zh-CN', 'Chinese (Traditional)': 'zh-TW', 'Corsican': 'co', 'Croatian': 'hr', 'Czech': 'cs', 'Danish': 'da', 'Dutch': 'nl', 'English': 'en', 'Esperanto': 'eo', 'Estonian': 'et', 'Finnish': 'fi', 'French': 'fr', 'Frisian': 'fy', 'Galician': 'gl', 'Georgian': 'ka', 'German': 'de', 'Greek': 'el', 'Gujarati': 'gu', 'Haitian Creole': 'ht', 'Hausa': 'ha', 'Hawaiian': 'haw', 'Hebrew': 'iw', 'Hindi': 'hi', 'Hmong': 'hmn', 'Hungarian': 'hu', 'Icelandic': 'is', 'Igbo': 'ig', 'Indonesian': 'id', 'Irish': 'ga', 'Italian': 'it', 'Japanese': 'ja', 'Javanese': 'jw', 'Kannada': 'kn', 'Kazakh': 'kk', 'Khmer': 'km', 'Korean': 'ko', 'Kurdish': 'ku', 'Kyrgyz': 'ky', 'Lao': 'lo', 'Latin': 'la', 'Latvian': 'lv', 'Lithuanian': 'lt', 'Luxembourgish': 'lb', 'Macedonian': 'mk', 'Malagasy': 'mg', 'Malay': 'ms', 'Malayalam': 'ml', 'Maltese': 'mt', 'Maori': 'mi', 'Marathi': 'mr', 'Mongolian': 'mn', 'Myanmar (Burmese)': 'my', 'Nepali': 'ne', 'Norwegian': 'no', 'Nyanja (Chichewa)': 'ny', 'Pashto': 'ps', 'Persian': 'fa', 'Polish': 'pl', 'Portuguese (Portugal, Brazil)': 'pt', 'Punjabi': 'pa', 'Romanian': 'ro', 'Russian': 'ru', 'Samoan': 'sm', 'Scots Gaelic': 'gd', 'Serbian': 'sr', 'Sesotho': 'st', 'Shona': 'sn', 'Sindhi': 'sd', 'Sinhala (Sinhalese)': 'si', 'Slovak': 'sk', 'Slovenian': 'sl', 'Somali': 'so', 'Spanish': 'es', 'Sundanese': 'su', 'Swahili': 'sw', 'Swedish': 'sv', 'Tagalog (Filipino)': 'tl', 'Tajik': 'tg', 'Tamil': 'ta', 'Telugu': 'te', 'Thai': 'th', 'Turkish': 'tr', 'Ukrainian': 'uk', 'Urdu': 'ur', 'Uzbek': 'uz', 'Vietnamese': 'vi', 'Welsh': 'cy', 'Xhosa': 'xh', 'Yiddish': 'yi', 'Yoruba': 'yo', 'Zulu': 'zu'} language_codes = {'Afrikaans': 'af', 'Albanian': 'sq', 'Amharic': 'am', 'Arabic': 'ar', 'Armenian': 'hy', 'Azerbaijani': 'az', 'Basque': 'eu', 'Belarusian': 'be', 'Bengali': 'bn', 'Bosnian': 'bs', 'Bulgarian': 'bg', 'Catalan': 'ca', 'Cebuano': 'ceb', 'Chinese (Simplified)': 'zh-CN', 'Chinese (Traditional)': 'zh-TW', 'Corsican': 'co', 'Croatian': 'hr', 'Czech': 'cs', 'Danish': 'da', 'Dutch': 'nl', 'English': 'en', 'Esperanto': 'eo', 'Estonian': 'et', 'Finnish': 'fi', 'French': 'fr', 'Frisian': 'fy', 'Galician': 'gl', 'Georgian': 'ka', 'German': 'de', 'Greek': 'el', 'Gujarati': 'gu', 'Haitian Creole': 'ht', 'Hausa': 'ha', 'Hawaiian': 'haw', 'Hebrew': 'iw', 'Hindi': 'hi', 'Hmong': 'hmn', 'Hungarian': 'hu', 'Icelandic': 'is', 'Igbo': 'ig', 'Indonesian': 'id', 'Irish': 'ga', 'Italian': 'it', 'Japanese': 'ja', 'Javanese': 'jw', 'Kannada': 'kn', 'Kazakh': 'kk', 'Khmer': 'km', 'Korean': 'ko', 'Kurdish': 'ku', 'Kyrgyz': 'ky', 'Lao': 'lo', 'Latin': 'la', 'Latvian': 'lv', 'Lithuanian': 'lt', 'Luxembourgish': 'lb', 'Macedonian': 'mk', 'Malagasy': 'mg', 'Malay': 'ms', 'Malayalam': 'ml', 'Maltese': 'mt', 'Maori': 'mi', 'Marathi': 'mr', 'Mongolian': 'mn', 'Myanmar (Burmese)': 'my', 'Nepali': 'ne', 'Norwegian': 'no', 'Nyanja (Chichewa)': 'ny', 'Pashto': 'ps', 'Persian': 'fa', 'Polish': 'pl', 'Portuguese (Portugal, Brazil)': 'pt', 'Punjabi': 'pa', 'Romanian': 'ro', 'Russian': 'ru', 'Samoan': 'sm', 'Scots Gaelic': 'gd', 'Serbian': 'sr', 'Sesotho': 'st', 'Shona': 'sn', 'Sindhi': 'sd', 'Sinhala (Sinhalese)': 'si', 'Slovak': 'sk', 'Slovenian': 'sl', 'Somali': 'so', 'Spanish': 'es', 'Sundanese': 'su', 'Swahili': 'sw', 'Swedish': 'sv', 'Tagalog (Filipino)': 'tl', 'Tajik': 'tg', 'Tamil': 'ta', 'Telugu': 'te', 'Thai': 'th', 'Turkish': 'tr', 'Ukrainian': 'uk', 'Urdu': 'ur', 'Uzbek': 'uz', 'Vietnamese': 'vi', 'Welsh': 'cy', 'Xhosa': 'xh', 'Yiddish': 'yi', 'Yoruba': 'yo', 'Zulu': 'zu'}
def input_modifier(string): def input_modifier(string):
""" """
This function is applied to your text inputs before This function is applied to your text inputs before
@@ -15,6 +16,7 @@ def input_modifier(string):
return GoogleTranslator(source=params['language string'], target='en').translate(string) return GoogleTranslator(source=params['language string'], target='en').translate(string)
def output_modifier(string): def output_modifier(string):
""" """
This function is applied to the model outputs. This function is applied to the model outputs.
@@ -22,6 +24,7 @@ def output_modifier(string):
return GoogleTranslator(source='en', target=params['language string']).translate(string) return GoogleTranslator(source='en', target=params['language string']).translate(string)
def bot_prefix_modifier(string): def bot_prefix_modifier(string):
""" """
This function is only applied in chat mode. It modifies This function is only applied in chat mode. It modifies
@@ -31,6 +34,7 @@ def bot_prefix_modifier(string):
return string return string
def ui(): def ui():
# Finding the language name from the language code to use as the default value # Finding the language name from the language code to use as the default value
language_name = list(language_codes.keys())[list(language_codes.values()).index(params['language string'])] language_name = list(language_codes.keys())[list(language_codes.values()).index(params['language string'])]

View File

@@ -4,12 +4,14 @@ import pandas as pd
df = pd.read_csv("https://raw.githubusercontent.com/devbrones/llama-prompts/main/prompts/prompts.csv") df = pd.read_csv("https://raw.githubusercontent.com/devbrones/llama-prompts/main/prompts/prompts.csv")
def get_prompt_by_name(name): def get_prompt_by_name(name):
if name == 'None': if name == 'None':
return '' return ''
else: else:
return df[df['Prompt name'] == name].iloc[0]['Prompt'].replace('\\n', '\n') return df[df['Prompt name'] == name].iloc[0]['Prompt'].replace('\\n', '\n')
def ui(): def ui():
if not shared.is_chat(): if not shared.is_chat():
choices = ['None'] + list(df['Prompt name']) choices = ['None'] + list(df['Prompt name'])

View File

@@ -30,12 +30,15 @@ streaming_state = shared.args.no_stream # remember if chat streaming was enabled
picture_response = False # specifies if the next model response should appear as a picture picture_response = False # specifies if the next model response should appear as a picture
pic_id = 0 pic_id = 0
def remove_surrounded_chars(string): def remove_surrounded_chars(string):
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR # this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string' # 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
return re.sub('\*[^\*]*?(\*|$)', '', string) return re.sub('\*[^\*]*?(\*|$)', '', string)
# I don't even need input_hijack for this as visible text will be commited to history as the unmodified string # I don't even need input_hijack for this as visible text will be commited to history as the unmodified string
def input_modifier(string): def input_modifier(string):
""" """
This function is applied to your text inputs before This function is applied to your text inputs before
@@ -62,6 +65,8 @@ def input_modifier(string):
return string return string
# Get and save the Stable Diffusion-generated picture # Get and save the Stable Diffusion-generated picture
def get_SD_pictures(description): def get_SD_pictures(description):
global params, pic_id global params, pic_id
@@ -101,6 +106,8 @@ def get_SD_pictures(description):
# TODO: how do I make the UI history ignore the resulting pictures (I don't want HTML to appear in history) # TODO: how do I make the UI history ignore the resulting pictures (I don't want HTML to appear in history)
# and replace it with 'text' for the purposes of logging? # and replace it with 'text' for the purposes of logging?
def output_modifier(string): def output_modifier(string):
""" """
This function is applied to the model outputs. This function is applied to the model outputs.
@@ -130,6 +137,7 @@ def output_modifier(string):
shared.args.no_stream = streaming_state shared.args.no_stream = streaming_state
return image + "\n" + text return image + "\n" + text
def bot_prefix_modifier(string): def bot_prefix_modifier(string):
""" """
This function is only applied in chat mode. It modifies This function is only applied in chat mode. It modifies
@@ -139,10 +147,12 @@ def bot_prefix_modifier(string):
return string return string
def force_pic(): def force_pic():
global picture_response global picture_response
picture_response = True picture_response = True
def ui(): def ui():
# Gradio elements # Gradio elements

View File

@@ -17,11 +17,13 @@ input_hijack = {
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float32).to("cpu") model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float32).to("cpu")
def caption_image(raw_image): def caption_image(raw_image):
inputs = processor(raw_image.convert('RGB'), return_tensors="pt").to("cpu", torch.float32) inputs = processor(raw_image.convert('RGB'), return_tensors="pt").to("cpu", torch.float32)
out = model.generate(**inputs, max_new_tokens=100) out = model.generate(**inputs, max_new_tokens=100)
return processor.decode(out[0], skip_special_tokens=True) return processor.decode(out[0], skip_special_tokens=True)
def generate_chat_picture(picture, name1, name2): def generate_chat_picture(picture, name1, name2):
text = f'*{name1} sends {name2} a picture that contains the following: "{caption_image(picture)}"*' text = f'*{name1} sends {name2} a picture that contains the following: "{caption_image(picture)}"*'
# lower the resolution of sent images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history # lower the resolution of sent images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
@@ -32,6 +34,7 @@ def generate_chat_picture(picture, name1, name2):
visible_text = f'<img src="data:image/jpeg;base64,{img_str}" alt="{text}">' visible_text = f'<img src="data:image/jpeg;base64,{img_str}" alt="{text}">'
return text, visible_text return text, visible_text
def ui(): def ui():
picture_select = gr.Image(label='Send a picture', type='pil') picture_select = gr.Image(label='Send a picture', type='pil')

View File

@@ -1,14 +1,16 @@
import re
import time import time
from pathlib import Path from pathlib import Path
import gradio as gr import gradio as gr
import modules.chat as chat
import modules.shared as shared
import torch import torch
from extensions.silero_tts import tts_preprocessor
from modules import chat, shared
from modules.html_generator import chat_html_wrapper
torch._C._jit_set_profiling_mode(False) torch._C._jit_set_profiling_mode(False)
params = { params = {
'activate': True, 'activate': True,
'speaker': 'en_56', 'speaker': 'en_56',
@@ -37,26 +39,25 @@ table = str.maketrans({
'"': "&quot;", '"': "&quot;",
}) })
def xmlesc(txt): def xmlesc(txt):
return txt.translate(table) return txt.translate(table)
def load_model(): def load_model():
model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id']) model, example_text = torch.hub.load(repo_or_dir='snakers4/silero-models', model='silero_tts', language=params['language'], speaker=params['model_id'])
model.to(params['device']) model.to(params['device'])
return model return model
model = load_model() model = load_model()
def remove_surrounded_chars(string):
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
return re.sub('\*[^\*]*?(\*|$)','',string)
def remove_tts_from_history(name1, name2): def remove_tts_from_history(name1, name2, mode):
for i, entry in enumerate(shared.history['internal']): for i, entry in enumerate(shared.history['internal']):
shared.history['visible'][i] = [shared.history['visible'][i][0], entry[1]] shared.history['visible'][i] = [shared.history['visible'][i][0], entry[1]]
return chat.generate_chat_output(shared.history['visible'], name1, name2, shared.character) return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
def toggle_text_in_history(name1, name2):
def toggle_text_in_history(name1, name2, mode):
for i, entry in enumerate(shared.history['visible']): for i, entry in enumerate(shared.history['visible']):
visible_reply = entry[1] visible_reply = entry[1]
if visible_reply.startswith('<audio'): if visible_reply.startswith('<audio'):
@@ -65,7 +66,8 @@ def toggle_text_in_history(name1, name2):
shared.history['visible'][i] = [shared.history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>\n\n{reply}"] shared.history['visible'][i] = [shared.history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>\n\n{reply}"]
else: else:
shared.history['visible'][i] = [shared.history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>"] shared.history['visible'][i] = [shared.history['visible'][i][0], f"{visible_reply.split('</audio>')[0]}</audio>"]
return chat.generate_chat_output(shared.history['visible'], name1, name2, shared.character) return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
def input_modifier(string): def input_modifier(string):
""" """
@@ -81,6 +83,7 @@ def input_modifier(string):
shared.args.no_stream = True # Disable streaming cause otherwise the audio output will stutter and begin anew every time the message is being updated shared.args.no_stream = True # Disable streaming cause otherwise the audio output will stutter and begin anew every time the message is being updated
return string return string
def output_modifier(string): def output_modifier(string):
""" """
This function is applied to the model outputs. This function is applied to the model outputs.
@@ -98,11 +101,7 @@ def output_modifier(string):
return string return string
original_string = string original_string = string
string = remove_surrounded_chars(string) string = tts_preprocessor.preprocess(string)
string = string.replace('"', '')
string = string.replace('', '')
string = string.replace('\n', ' ')
string = string.strip()
if string == '': if string == '':
string = '*Empty reply, try regenerating*' string = '*Empty reply, try regenerating*'
@@ -121,6 +120,7 @@ def output_modifier(string):
shared.args.no_stream = streaming_state # restore the streaming option to the previous value shared.args.no_stream = streaming_state # restore the streaming option to the previous value
return string return string
def bot_prefix_modifier(string): def bot_prefix_modifier(string):
""" """
This function is only applied in chat mode. It modifies This function is only applied in chat mode. It modifies
@@ -130,33 +130,37 @@ def bot_prefix_modifier(string):
return string return string
def ui(): def ui():
# Gradio elements # Gradio elements
with gr.Accordion("Silero TTS"): with gr.Accordion("Silero TTS"):
with gr.Row(): with gr.Row():
activate = gr.Checkbox(value=params['activate'], label='Activate TTS') activate = gr.Checkbox(value=params['activate'], label='Activate TTS')
autoplay = gr.Checkbox(value=params['autoplay'], label='Play TTS automatically') autoplay = gr.Checkbox(value=params['autoplay'], label='Play TTS automatically')
show_text = gr.Checkbox(value=params['show_text'], label='Show message text under audio player') show_text = gr.Checkbox(value=params['show_text'], label='Show message text under audio player')
voice = gr.Dropdown(value=params['speaker'], choices=voices_by_gender, label='TTS voice') voice = gr.Dropdown(value=params['speaker'], choices=voices_by_gender, label='TTS voice')
with gr.Row(): with gr.Row():
v_pitch = gr.Dropdown(value=params['voice_pitch'], choices=voice_pitches, label='Voice pitch') v_pitch = gr.Dropdown(value=params['voice_pitch'], choices=voice_pitches, label='Voice pitch')
v_speed = gr.Dropdown(value=params['voice_speed'], choices=voice_speeds, label='Voice speed') v_speed = gr.Dropdown(value=params['voice_speed'], choices=voice_speeds, label='Voice speed')
with gr.Row(): with gr.Row():
convert = gr.Button('Permanently replace audios with the message texts') convert = gr.Button('Permanently replace audios with the message texts')
convert_cancel = gr.Button('Cancel', visible=False) convert_cancel = gr.Button('Cancel', visible=False)
convert_confirm = gr.Button('Confirm (cannot be undone)', variant="stop", visible=False) convert_confirm = gr.Button('Confirm (cannot be undone)', variant="stop", visible=False)
# Convert history with confirmation # Convert history with confirmation
convert_arr = [convert_confirm, convert, convert_cancel] convert_arr = [convert_confirm, convert, convert_cancel]
convert.click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr) convert.click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, convert_arr)
convert_confirm.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr) convert_confirm.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
convert_confirm.click(remove_tts_from_history, [shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display']) convert_confirm.click(remove_tts_from_history, [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']], shared.gradio['display'])
convert_confirm.click(lambda: chat.save_history(timestamp=False), [], [], show_progress=False) convert_confirm.click(lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
convert_cancel.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr) convert_cancel.click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, convert_arr)
# Toggle message text in history # Toggle message text in history
show_text.change(lambda x: params.update({"show_text": x}), show_text, None) show_text.change(lambda x: params.update({"show_text": x}), show_text, None)
show_text.change(toggle_text_in_history, [shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display']) show_text.change(toggle_text_in_history, [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']], shared.gradio['display'])
show_text.change(lambda: chat.save_history(timestamp=False), [], [], show_progress=False) show_text.change(lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
# Event functions to update the parameters in the backend # Event functions to update the parameters in the backend

View File

@@ -17,9 +17,11 @@ from quant import make_quant
def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exclude_layers=['lm_head'], kernel_switch_threshold=128): def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exclude_layers=['lm_head'], kernel_switch_threshold=128):
config = AutoConfig.from_pretrained(model)
def noop(*args, **kwargs): def noop(*args, **kwargs):
pass pass
config = AutoConfig.from_pretrained(model)
torch.nn.init.kaiming_uniform_ = noop torch.nn.init.kaiming_uniform_ = noop
torch.nn.init.uniform_ = noop torch.nn.init.uniform_ = noop
torch.nn.init.normal_ = noop torch.nn.init.normal_ = noop
@@ -64,6 +66,7 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
return model return model
def load_quantized(model_name): def load_quantized(model_name):
if not shared.args.model_type: if not shared.args.model_type:
# Try to determine model type from model name # Try to determine model type from model name

View File

@@ -13,6 +13,7 @@ def reload_model():
clear_torch_cache() clear_torch_cache()
shared.model, shared.tokenizer = load_model(shared.model_name) shared.model, shared.tokenizer = load_model(shared.model_name)
def add_lora_to_model(lora_name): def add_lora_to_model(lora_name):
# If a LoRA had been previously loaded, or if we want # If a LoRA had been previously loaded, or if we want

View File

@@ -54,6 +54,7 @@ class RWKVModel:
reply += token reply += token
yield reply yield reply
class RWKVTokenizer: class RWKVTokenizer:
def __init__(self): def __init__(self):
pass pass

View File

@@ -28,6 +28,7 @@ def generate_reply_wrapper(string):
for i in generate_reply(params[0], generate_params): for i in generate_reply(params[0], generate_params):
yield i yield i
def create_apis(): def create_apis():
t1 = gr.Textbox(visible=False) t1 = gr.Textbox(visible=False)
t2 = gr.Textbox(visible=False) t2 = gr.Textbox(visible=False)

View File

@@ -30,6 +30,7 @@ class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
return True return True
return False return False
class Stream(transformers.StoppingCriteria): class Stream(transformers.StoppingCriteria):
def __init__(self, callback_func=None): def __init__(self, callback_func=None):
self.callback_func = callback_func self.callback_func = callback_func
@@ -39,6 +40,7 @@ class Stream(transformers.StoppingCriteria):
self.callback_func(input_ids[0]) self.callback_func(input_ids[0])
return False return False
class Iteratorize: class Iteratorize:
""" """
@@ -96,6 +98,7 @@ class Iteratorize:
self.stop_now = True self.stop_now = True
clear_torch_cache() clear_torch_cache()
def clear_torch_cache(): def clear_torch_cache():
gc.collect() gc.collect()
if not shared.args.cpu: if not shared.args.cpu:

View File

@@ -23,7 +23,6 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
end_of_turn = kwargs['end_of_turn'] if 'end_of_turn' in kwargs else '' end_of_turn = kwargs['end_of_turn'] if 'end_of_turn' in kwargs else ''
impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
rows = [f"{context.strip()}\n"] rows = [f"{context.strip()}\n"]
# Finding the maximum prompt size # Finding the maximum prompt size
@@ -68,6 +67,7 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
else: else:
return prompt return prompt
def extract_message_from_reply(reply, name1, name2, stop_at_newline): def extract_message_from_reply(reply, name1, name2, stop_at_newline):
next_character_found = False next_character_found = False
@@ -98,6 +98,7 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline):
reply = fix_newlines(reply) reply = fix_newlines(reply)
return reply, next_character_found return reply, next_character_found
def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False): def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False):
if mode == 'instruct': if mode == 'instruct':
stopping_strings = [f"\n{name1}", f"\n{name2}"] stopping_strings = [f"\n{name1}", f"\n{name2}"]
@@ -113,7 +114,7 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
visible_text = None visible_text = None
custom_generate_chat_prompt = None custom_generate_chat_prompt = None
for extension, _ in extensions_module.iterator(): for extension, _ in extensions_module.iterator():
if hasattr(extension, 'input_hijack') and extension.input_hijack['state'] == True: if hasattr(extension, 'input_hijack') and extension.input_hijack['state']:
extension.input_hijack['state'] = False extension.input_hijack['state'] = False
text, visible_text = extension.input_hijack['value'] text, visible_text = extension.input_hijack['value']
if custom_generate_chat_prompt is None and hasattr(extension, 'custom_generate_chat_prompt'): if custom_generate_chat_prompt is None and hasattr(extension, 'custom_generate_chat_prompt'):
@@ -167,6 +168,7 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
yield shared.history['visible'] yield shared.history['visible']
def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn): def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
if mode == 'instruct': if mode == 'instruct':
stopping_strings = [f"\n{name1}", f"\n{name2}"] stopping_strings = [f"\n{name1}", f"\n{name2}"]
@@ -197,10 +199,12 @@ def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_o
yield reply yield reply
def cai_chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn): def cai_chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
for history in chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn): for history in chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
yield chat_html_wrapper(history, name1, name2, mode) yield chat_html_wrapper(history, name1, name2, mode)
def regenerate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn): def regenerate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
if (shared.character != 'None' and len(shared.history['visible']) == 1) or len(shared.history['internal']) == 0: if (shared.character != 'None' and len(shared.history['visible']) == 1) or len(shared.history['internal']) == 0:
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode) yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
@@ -213,6 +217,7 @@ def regenerate_wrapper(text, generate_state, name1, name2, context, mode, end_of
shared.history['visible'][-1] = [last_visible[0], history[-1][1]] shared.history['visible'][-1] = [last_visible[0], history[-1][1]]
yield chat_html_wrapper(shared.history['visible'], name1, name2, mode) yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
def remove_last_message(name1, name2, mode): def remove_last_message(name1, name2, mode):
if len(shared.history['visible']) > 0 and shared.history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>': if len(shared.history['visible']) > 0 and shared.history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>':
last = shared.history['visible'].pop() last = shared.history['visible'].pop()
@@ -222,12 +227,14 @@ def remove_last_message(name1, name2, mode):
return chat_html_wrapper(shared.history['visible'], name1, name2, mode), last[0] return chat_html_wrapper(shared.history['visible'], name1, name2, mode), last[0]
def send_last_reply_to_input(): def send_last_reply_to_input():
if len(shared.history['internal']) > 0: if len(shared.history['internal']) > 0:
return shared.history['internal'][-1][1] return shared.history['internal'][-1][1]
else: else:
return '' return ''
def replace_last_reply(text, name1, name2, mode): def replace_last_reply(text, name1, name2, mode):
if len(shared.history['visible']) > 0: if len(shared.history['visible']) > 0:
shared.history['visible'][-1][1] = text shared.history['visible'][-1][1] = text
@@ -235,9 +242,11 @@ def replace_last_reply(text, name1, name2, mode):
return chat_html_wrapper(shared.history['visible'], name1, name2, mode) return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
def clear_html(): def clear_html():
return chat_html_wrapper([], "", "") return chat_html_wrapper([], "", "")
def clear_chat_log(name1, name2, greeting, mode): def clear_chat_log(name1, name2, greeting, mode):
shared.history['visible'] = [] shared.history['visible'] = []
shared.history['internal'] = [] shared.history['internal'] = []
@@ -248,9 +257,11 @@ def clear_chat_log(name1, name2, greeting, mode):
return chat_html_wrapper(shared.history['visible'], name1, name2, mode) return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
def redraw_html(name1, name2, mode): def redraw_html(name1, name2, mode):
return chat_html_wrapper(shared.history['visible'], name1, name2, mode) return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
def tokenize_dialogue(dialogue, name1, name2, mode): def tokenize_dialogue(dialogue, name1, name2, mode):
history = [] history = []
@@ -288,6 +299,7 @@ def tokenize_dialogue(dialogue, name1, name2, mode):
return history return history
def save_history(timestamp=True): def save_history(timestamp=True):
if timestamp: if timestamp:
fname = f"{shared.character}_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json" fname = f"{shared.character}_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
@@ -299,6 +311,7 @@ def save_history(timestamp=True):
f.write(json.dumps({'data': shared.history['internal'], 'data_visible': shared.history['visible']}, indent=2)) f.write(json.dumps({'data': shared.history['internal'], 'data_visible': shared.history['visible']}, indent=2))
return Path(f'logs/{fname}') return Path(f'logs/{fname}')
def load_history(file, name1, name2): def load_history(file, name1, name2):
file = file.decode('utf-8') file = file.decode('utf-8')
try: try:
@@ -323,10 +336,12 @@ def load_history(file, name1, name2):
shared.history['internal'] = tokenize_dialogue(file, name1, name2) shared.history['internal'] = tokenize_dialogue(file, name1, name2)
shared.history['visible'] = copy.deepcopy(shared.history['internal']) shared.history['visible'] = copy.deepcopy(shared.history['internal'])
def replace_character_names(text, name1, name2): def replace_character_names(text, name1, name2):
text = text.replace('{{user}}', name1).replace('{{char}}', name2) text = text.replace('{{user}}', name1).replace('{{char}}', name2)
return text.replace('<USER>', name1).replace('<BOT>', name2) return text.replace('<USER>', name1).replace('<BOT>', name2)
def build_pygmalion_style_context(data): def build_pygmalion_style_context(data):
context = "" context = ""
if 'char_persona' in data and data['char_persona'] != '': if 'char_persona' in data and data['char_persona'] != '':
@@ -336,6 +351,7 @@ def build_pygmalion_style_context(data):
context = f"{context.strip()}\n<START>\n" context = f"{context.strip()}\n<START>\n"
return context return context
def generate_pfp_cache(character): def generate_pfp_cache(character):
cache_folder = Path("cache") cache_folder = Path("cache")
if not cache_folder.exists(): if not cache_folder.exists():
@@ -348,6 +364,7 @@ def generate_pfp_cache(character):
return img return img
return None return None
def load_character(character, name1, name2, mode): def load_character(character, name1, name2, mode):
shared.character = character shared.character = character
shared.history['internal'] = [] shared.history['internal'] = []
@@ -404,9 +421,11 @@ def load_character(character, name1, name2, mode):
return name1, name2, picture, greeting, context, end_of_turn, chat_html_wrapper(shared.history['visible'], name1, name2, mode, reset_cache=True) return name1, name2, picture, greeting, context, end_of_turn, chat_html_wrapper(shared.history['visible'], name1, name2, mode, reset_cache=True)
def load_default_history(name1, name2): def load_default_history(name1, name2):
load_character("None", name1, name2, "chat") load_character("None", name1, name2, "chat")
def upload_character(json_file, img, tavern=False): def upload_character(json_file, img, tavern=False):
json_file = json_file if type(json_file) == str else json_file.decode('utf-8') json_file = json_file if type(json_file) == str else json_file.decode('utf-8')
data = json.loads(json_file) data = json.loads(json_file)
@@ -425,6 +444,7 @@ def upload_character(json_file, img, tavern=False):
print(f'New character saved to "characters/{outfile_name}.json".') print(f'New character saved to "characters/{outfile_name}.json".')
return outfile_name return outfile_name
def upload_tavern_character(img, name1, name2): def upload_tavern_character(img, name1, name2):
_img = Image.open(io.BytesIO(img)) _img = Image.open(io.BytesIO(img))
_img.getexif() _img.getexif()
@@ -433,12 +453,13 @@ def upload_tavern_character(img, name1, name2):
_json = {"char_name": _json['name'], "char_persona": _json['description'], "char_greeting": _json["first_mes"], "example_dialogue": _json['mes_example'], "world_scenario": _json['scenario']} _json = {"char_name": _json['name'], "char_persona": _json['description'], "char_greeting": _json["first_mes"], "example_dialogue": _json['mes_example'], "world_scenario": _json['scenario']}
return upload_character(json.dumps(_json), img, tavern=True) return upload_character(json.dumps(_json), img, tavern=True)
def upload_your_profile_picture(img, name1, name2, mode): def upload_your_profile_picture(img, name1, name2, mode):
cache_folder = Path("cache") cache_folder = Path("cache")
if not cache_folder.exists(): if not cache_folder.exists():
cache_folder.mkdir() cache_folder.mkdir()
if img == None: if img is None:
if Path("cache/pfp_me.png").exists(): if Path("cache/pfp_me.png").exists():
Path("cache/pfp_me.png").unlink() Path("cache/pfp_me.png").unlink()
else: else:

View File

@@ -9,6 +9,7 @@ state = {}
available_extensions = [] available_extensions = []
setup_called = set() setup_called = set()
def load_extensions(): def load_extensions():
global state global state
for i, name in enumerate(shared.args.extensions): for i, name in enumerate(shared.args.extensions):
@@ -23,12 +24,16 @@ def load_extensions():
traceback.print_exc() traceback.print_exc()
# This iterator returns the extensions in the order specified in the command-line # This iterator returns the extensions in the order specified in the command-line
def iterator(): def iterator():
for name in sorted(state, key=lambda x: state[x][1]): for name in sorted(state, key=lambda x: state[x][1]):
if state[name][0] == True: if state[name][0] == True:
yield eval(f"extensions.{name}.script"), name yield eval(f"extensions.{name}.script"), name
# Extension functions that map string -> string # Extension functions that map string -> string
def apply_extensions(text, typ): def apply_extensions(text, typ):
for extension, _ in iterator(): for extension, _ in iterator():
if typ == "input" and hasattr(extension, "input_modifier"): if typ == "input" and hasattr(extension, "input_modifier"):
@@ -39,6 +44,7 @@ def apply_extensions(text, typ):
text = extension.bot_prefix_modifier(text) text = extension.bot_prefix_modifier(text)
return text return text
def create_extensions_block(): def create_extensions_block():
global setup_called global setup_called

View File

@@ -24,6 +24,7 @@ with open(Path(__file__).resolve().parent / '../css/html_cai_style.css', 'r') as
with open(Path(__file__).resolve().parent / '../css/html_instruct_style.css', 'r') as f: with open(Path(__file__).resolve().parent / '../css/html_instruct_style.css', 'r') as f:
instruct_css = f.read() instruct_css = f.read()
def fix_newlines(string): def fix_newlines(string):
string = string.replace('\n', '\n\n') string = string.replace('\n', '\n\n')
string = re.sub(r"\n{3,}", "\n\n", string) string = re.sub(r"\n{3,}", "\n\n", string)
@@ -31,6 +32,8 @@ def fix_newlines(string):
return string return string
# This could probably be generalized and improved # This could probably be generalized and improved
def convert_to_markdown(string): def convert_to_markdown(string):
string = string.replace('\\begin{code}', '```') string = string.replace('\\begin{code}', '```')
string = string.replace('\\end{code}', '```') string = string.replace('\\end{code}', '```')
@@ -40,11 +43,13 @@ def convert_to_markdown(string):
string = fix_newlines(string) string = fix_newlines(string)
return markdown.markdown(string, extensions=['fenced_code']) return markdown.markdown(string, extensions=['fenced_code'])
def generate_basic_html(string): def generate_basic_html(string):
string = convert_to_markdown(string) string = convert_to_markdown(string)
string = f'<style>{readable_css}</style><div class="container">{string}</div>' string = f'<style>{readable_css}</style><div class="container">{string}</div>'
return string return string
def process_post(post, c): def process_post(post, c):
t = post.split('\n') t = post.split('\n')
number = t[0].split(' ')[1] number = t[0].split(' ')[1]
@@ -59,6 +64,7 @@ def process_post(post, c):
src = f'<span class="name">Anonymous </span> <span class="number">No.{number}</span>\n{src}' src = f'<span class="name">Anonymous </span> <span class="number">No.{number}</span>\n{src}'
return src return src
def generate_4chan_html(f): def generate_4chan_html(f):
posts = [] posts = []
post = '' post = ''
@@ -98,6 +104,7 @@ def generate_4chan_html(f):
return output return output
def make_thumbnail(image): def make_thumbnail(image):
image = image.resize((350, round(image.size[1] / image.size[0] * 350)), Image.Resampling.LANCZOS) image = image.resize((350, round(image.size[1] / image.size[0] * 350)), Image.Resampling.LANCZOS)
if image.size[1] > 470: if image.size[1] > 470:
@@ -105,6 +112,7 @@ def make_thumbnail(image):
return image return image
def get_image_cache(path): def get_image_cache(path):
cache_folder = Path("cache") cache_folder = Path("cache")
if not cache_folder.exists(): if not cache_folder.exists():
@@ -119,6 +127,7 @@ def get_image_cache(path):
return image_cache[path][1] return image_cache[path][1]
def generate_instruct_html(history): def generate_instruct_html(history):
output = f'<style>{instruct_css}</style><div class="chat" id="chat">' output = f'<style>{instruct_css}</style><div class="chat" id="chat">'
for i, _row in enumerate(history[::-1]): for i, _row in enumerate(history[::-1]):
@@ -151,6 +160,7 @@ def generate_instruct_html(history):
return output return output
def generate_cai_chat_html(history, name1, name2, reset_cache=False): def generate_cai_chat_html(history, name1, name2, reset_cache=False):
output = f'<style>{cai_css}</style><div class="chat" id="chat">' output = f'<style>{cai_css}</style><div class="chat" id="chat">'
@@ -200,9 +210,11 @@ def generate_cai_chat_html(history, name1, name2, reset_cache=False):
output += "</div>" output += "</div>"
return output return output
def generate_chat_html(history, name1, name2): def generate_chat_html(history, name1, name2):
return generate_cai_chat_html(history, name1, name2) return generate_cai_chat_html(history, name1, name2)
def chat_html_wrapper(history, name1, name2, mode, reset_cache=False): def chat_html_wrapper(history, name1, name2, mode, reset_cache=False):
if mode == "cai-chat": if mode == "cai-chat":
return generate_cai_chat_html(history, name1, name2, reset_cache) return generate_cai_chat_html(history, name1, name2, reset_cache)

View File

@@ -6,8 +6,6 @@ Documentation:
https://abetlen.github.io/llama-cpp-python/ https://abetlen.github.io/llama-cpp-python/
''' '''
import multiprocessing
from llama_cpp import Llama from llama_cpp import Llama
from modules import shared from modules import shared

View File

@@ -181,6 +181,7 @@ def load_model(model_name):
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.") print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
return model, tokenizer return model, tokenizer
def load_soft_prompt(name): def load_soft_prompt(name):
if name == 'None': if name == 'None':
shared.soft_prompt = False shared.soft_prompt = False

View File

@@ -61,6 +61,7 @@ settings = {
} }
} }
def str2bool(v): def str2bool(v):
if isinstance(v, bool): if isinstance(v, bool):
return v return v
@@ -71,6 +72,7 @@ def str2bool(v):
else: else:
raise argparse.ArgumentTypeError('Boolean value expected.') raise argparse.ArgumentTypeError('Boolean value expected.')
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=54)) parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=54))
# Basic settings # Basic settings
@@ -145,5 +147,6 @@ if args.cai_chat:
print("Warning: --cai-chat is deprecated. Use --chat instead.") print("Warning: --cai-chat is deprecated. Use --chat instead.")
args.chat = True args.chat = True
def is_chat(): def is_chat():
return args.chat return args.chat

View File

@@ -21,6 +21,7 @@ def get_max_prompt_length(tokens):
max_length -= shared.soft_prompt_tensor.shape[1] max_length -= shared.soft_prompt_tensor.shape[1]
return max_length return max_length
def encode(prompt, tokens_to_generate=0, add_special_tokens=True): def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
if any((shared.is_RWKV, shared.is_llamacpp)): if any((shared.is_RWKV, shared.is_llamacpp)):
input_ids = shared.tokenizer.encode(str(prompt)) input_ids = shared.tokenizer.encode(str(prompt))
@@ -44,6 +45,7 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
else: else:
return input_ids.cuda() return input_ids.cuda()
def decode(output_ids): def decode(output_ids):
# Open Assistant relies on special tokens like <|endoftext|> # Open Assistant relies on special tokens like <|endoftext|>
if re.match('.*(oasst|galactica)-*', shared.model_name.lower()): if re.match('.*(oasst|galactica)-*', shared.model_name.lower()):
@@ -53,6 +55,7 @@ def decode(output_ids):
reply = reply.replace(r'<|endoftext|>', '') reply = reply.replace(r'<|endoftext|>', '')
return reply return reply
def generate_softprompt_input_tensors(input_ids): def generate_softprompt_input_tensors(input_ids):
inputs_embeds = shared.model.transformer.wte(input_ids) inputs_embeds = shared.model.transformer.wte(input_ids)
inputs_embeds = torch.cat((shared.soft_prompt_tensor, inputs_embeds), dim=1) inputs_embeds = torch.cat((shared.soft_prompt_tensor, inputs_embeds), dim=1)
@@ -61,6 +64,8 @@ def generate_softprompt_input_tensors(input_ids):
return inputs_embeds, filler_input_ids return inputs_embeds, filler_input_ids
# Removes empty replies from gpt4chan outputs # Removes empty replies from gpt4chan outputs
def fix_gpt4chan(s): def fix_gpt4chan(s):
for i in range(10): for i in range(10):
s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s) s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s)
@@ -69,6 +74,8 @@ def fix_gpt4chan(s):
return s return s
# Fix the LaTeX equations in galactica # Fix the LaTeX equations in galactica
def fix_galactica(s): def fix_galactica(s):
s = s.replace(r'\[', r'$') s = s.replace(r'\[', r'$')
s = s.replace(r'\]', r'$') s = s.replace(r'\]', r'$')
@@ -79,6 +86,7 @@ def fix_galactica(s):
s = re.sub(r"\n{3,}", "\n\n", s) s = re.sub(r"\n{3,}", "\n\n", s)
return s return s
def formatted_outputs(reply, model_name): def formatted_outputs(reply, model_name):
if not shared.is_chat(): if not shared.is_chat():
if 'galactica' in model_name.lower(): if 'galactica' in model_name.lower():
@@ -92,20 +100,24 @@ def formatted_outputs(reply, model_name):
else: else:
return reply return reply
def clear_torch_cache(): def clear_torch_cache():
gc.collect() gc.collect()
if not shared.args.cpu: if not shared.args.cpu:
torch.cuda.empty_cache() torch.cuda.empty_cache()
def set_manual_seed(seed): def set_manual_seed(seed):
if seed != -1: if seed != -1:
torch.manual_seed(seed) torch.manual_seed(seed)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed_all(seed)
def stop_everything_event(): def stop_everything_event():
shared.stop_everything = True shared.stop_everything = True
def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]): def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]):
clear_torch_cache() clear_torch_cache()
set_manual_seed(generate_state['seed']) set_manual_seed(generate_state['seed'])

View File

@@ -19,9 +19,11 @@ CURRENT_STEPS = 0
MAX_STEPS = 0 MAX_STEPS = 0
CURRENT_GRADIENT_ACCUM = 1 CURRENT_GRADIENT_ACCUM = 1
def get_dataset(path: str, ext: str): def get_dataset(path: str, ext: str):
return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=str.lower) return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=str.lower)
def create_train_interface(): def create_train_interface():
with gr.Tab('Train LoRA', elem_id='lora-train-tab'): with gr.Tab('Train LoRA', elem_id='lora-train-tab'):
lora_name = gr.Textbox(label="Name", info="The name of your new LoRA file") lora_name = gr.Textbox(label="Name", info="The name of your new LoRA file")
@@ -67,10 +69,12 @@ def create_train_interface():
cutoff_len, dataset, eval_dataset, format, raw_text_file, overlap_len, newline_favor_len], [output]) cutoff_len, dataset, eval_dataset, format, raw_text_file, overlap_len, newline_favor_len], [output])
stop_button.click(do_interrupt, [], [], cancels=[], queue=False) stop_button.click(do_interrupt, [], [], cancels=[], queue=False)
def do_interrupt(): def do_interrupt():
global WANT_INTERRUPT global WANT_INTERRUPT
WANT_INTERRUPT = True WANT_INTERRUPT = True
class Callbacks(transformers.TrainerCallback): class Callbacks(transformers.TrainerCallback):
def on_step_begin(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs): def on_step_begin(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
global CURRENT_STEPS, MAX_STEPS global CURRENT_STEPS, MAX_STEPS
@@ -79,6 +83,7 @@ class Callbacks(transformers.TrainerCallback):
if WANT_INTERRUPT: if WANT_INTERRUPT:
control.should_epoch_stop = True control.should_epoch_stop = True
control.should_training_stop = True control.should_training_stop = True
def on_substep_end(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs): def on_substep_end(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
global CURRENT_STEPS global CURRENT_STEPS
CURRENT_STEPS += 1 CURRENT_STEPS += 1
@@ -86,6 +91,7 @@ class Callbacks(transformers.TrainerCallback):
control.should_epoch_stop = True control.should_epoch_stop = True
control.should_training_stop = True control.should_training_stop = True
def clean_path(base_path: str, path: str): def clean_path(base_path: str, path: str):
""""Strips unusual symbols and forcibly builds a path as relative to the intended directory.""" """"Strips unusual symbols and forcibly builds a path as relative to the intended directory."""
# TODO: Probably could do with a security audit to guarantee there's no ways this can be bypassed to target an unwanted path. # TODO: Probably could do with a security audit to guarantee there's no ways this can be bypassed to target an unwanted path.
@@ -95,6 +101,7 @@ def clean_path(base_path: str, path: str):
return path return path
return f'{Path(base_path).absolute()}/{path}' return f'{Path(base_path).absolute()}/{path}'
def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lora_rank: int, lora_alpha: int, lora_dropout: float, def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lora_rank: int, lora_alpha: int, lora_dropout: float,
cutoff_len: int, dataset: str, eval_dataset: str, format: str, raw_text_file: str, overlap_len: int, newline_favor_len: int): cutoff_len: int, dataset: str, eval_dataset: str, format: str, raw_text_file: str, overlap_len: int, newline_favor_len: int):
global WANT_INTERRUPT, CURRENT_STEPS, MAX_STEPS, CURRENT_GRADIENT_ACCUM global WANT_INTERRUPT, CURRENT_STEPS, MAX_STEPS, CURRENT_GRADIENT_ACCUM
@@ -302,10 +309,12 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
print("Training complete!") print("Training complete!")
yield f"Done! LoRA saved to `{lora_name}`" yield f"Done! LoRA saved to `{lora_name}`"
def split_chunks(arr, step): def split_chunks(arr, step):
for i in range(0, len(arr), step): for i in range(0, len(arr), step):
yield arr[i:i + step] yield arr[i:i + step]
def cut_chunk_for_newline(chunk: str, max_length: int): def cut_chunk_for_newline(chunk: str, max_length: int):
if '\n' not in chunk: if '\n' not in chunk:
return chunk return chunk
@@ -319,6 +328,7 @@ def cut_chunk_for_newline(chunk: str, max_length: int):
chunk = chunk[:last_newline] chunk = chunk[:last_newline]
return chunk return chunk
def format_time(seconds: float): def format_time(seconds: float):
if seconds < 120: if seconds < 120:
return f"`{seconds:.0f}` seconds" return f"`{seconds:.0f}` seconds"

View File

@@ -13,6 +13,7 @@ with open(Path(__file__).resolve().parent / '../css/main.js', 'r') as f:
with open(Path(__file__).resolve().parent / '../css/chat.js', 'r') as f: with open(Path(__file__).resolve().parent / '../css/chat.js', 'r') as f:
chat_js = f.read() chat_js = f.read()
class ToolButton(gr.Button, gr.components.FormComponent): class ToolButton(gr.Button, gr.components.FormComponent):
"""Small button with single emoji as text, fits inside gradio forms""" """Small button with single emoji as text, fits inside gradio forms"""
@@ -22,6 +23,7 @@ class ToolButton(gr.Button, gr.components.FormComponent):
def get_block_name(self): def get_block_name(self):
return "button" return "button"
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
def refresh(): def refresh():
refresh_method() refresh_method()

View File

@@ -34,15 +34,18 @@ if settings_file is not None:
for item in new_settings: for item in new_settings:
shared.settings[item] = new_settings[item] shared.settings[item] = new_settings[item]
def get_available_models(): def get_available_models():
if shared.args.flexgen: if shared.args.flexgen:
return sorted([re.sub('-np$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if item.name.endswith('-np')], key=str.lower) return sorted([re.sub('-np$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if item.name.endswith('-np')], key=str.lower)
else: else:
return sorted([re.sub('.pth$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower) return sorted([re.sub('.pth$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
def get_available_presets(): def get_available_presets():
return sorted(set((k.stem for k in Path('presets').glob('*.txt'))), key=str.lower) return sorted(set((k.stem for k in Path('presets').glob('*.txt'))), key=str.lower)
def get_available_prompts(): def get_available_prompts():
prompts = [] prompts = []
prompts += sorted(set((k.stem for k in Path('prompts').glob('[0-9]*.txt'))), key=str.lower, reverse=True) prompts += sorted(set((k.stem for k in Path('prompts').glob('[0-9]*.txt'))), key=str.lower, reverse=True)
@@ -50,10 +53,12 @@ def get_available_prompts():
prompts += ['None'] prompts += ['None']
return prompts return prompts
def get_available_characters(): def get_available_characters():
paths = (x for x in Path('characters').iterdir() if x.suffix in ('.json', '.yaml', '.yml')) paths = (x for x in Path('characters').iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
return ['None'] + sorted(set((k.stem for k in paths if k.stem != "instruction-following")), key=str.lower) return ['None'] + sorted(set((k.stem for k in paths if k.stem != "instruction-following")), key=str.lower)
def get_available_instruction_templates(): def get_available_instruction_templates():
path = "characters/instruction-following" path = "characters/instruction-following"
paths = [] paths = []
@@ -61,19 +66,24 @@ def get_available_instruction_templates():
paths = (x for x in Path(path).iterdir() if x.suffix in ('.json', '.yaml', '.yml')) paths = (x for x in Path(path).iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
return ['None'] + sorted(set((k.stem for k in paths)), key=str.lower) return ['None'] + sorted(set((k.stem for k in paths)), key=str.lower)
def get_available_extensions(): def get_available_extensions():
return sorted(set(map(lambda x: x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower) return sorted(set(map(lambda x: x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower)
def get_available_softprompts(): def get_available_softprompts():
return ['None'] + sorted(set((k.stem for k in Path('softprompts').glob('*.zip'))), key=str.lower) return ['None'] + sorted(set((k.stem for k in Path('softprompts').glob('*.zip'))), key=str.lower)
def get_available_loras(): def get_available_loras():
return ['None'] + sorted([item.name for item in list(Path(shared.args.lora_dir).glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower) return ['None'] + sorted([item.name for item in list(Path(shared.args.lora_dir).glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
def unload_model(): def unload_model():
shared.model = shared.tokenizer = None shared.model = shared.tokenizer = None
clear_torch_cache() clear_torch_cache()
def load_model_wrapper(selected_model): def load_model_wrapper(selected_model):
if selected_model != shared.model_name: if selected_model != shared.model_name:
shared.model_name = selected_model shared.model_name = selected_model
@@ -84,10 +94,12 @@ def load_model_wrapper(selected_model):
return selected_model return selected_model
def load_lora_wrapper(selected_lora): def load_lora_wrapper(selected_lora):
add_lora_to_model(selected_lora) add_lora_to_model(selected_lora)
return selected_lora return selected_lora
def load_preset_values(preset_menu, state, return_dict=False): def load_preset_values(preset_menu, state, return_dict=False):
generate_params = { generate_params = {
'do_sample': True, 'do_sample': True,
@@ -118,6 +130,7 @@ def load_preset_values(preset_menu, state, return_dict=False):
state.update(generate_params) state.update(generate_params)
return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']] return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
def upload_soft_prompt(file): def upload_soft_prompt(file):
with zipfile.ZipFile(io.BytesIO(file)) as zf: with zipfile.ZipFile(io.BytesIO(file)) as zf:
zf.extract('meta.json') zf.extract('meta.json')
@@ -130,12 +143,14 @@ def upload_soft_prompt(file):
return name return name
def save_prompt(text): def save_prompt(text):
fname = f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}.txt" fname = f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}.txt"
with open(Path(f'prompts/{fname}'), 'w', encoding='utf-8') as f: with open(Path(f'prompts/{fname}'), 'w', encoding='utf-8') as f:
f.write(text) f.write(text)
return f"Saved to prompts/{fname}" return f"Saved to prompts/{fname}"
def load_prompt(fname): def load_prompt(fname):
if fname in ['None', '']: if fname in ['None', '']:
return '' return ''
@@ -146,6 +161,7 @@ def load_prompt(fname):
text = text[:-1] text = text[:-1]
return text return text
def create_prompt_menus(): def create_prompt_menus():
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
@@ -161,6 +177,7 @@ def create_prompt_menus():
shared.gradio['prompt_menu'].change(load_prompt, [shared.gradio['prompt_menu']], [shared.gradio['textbox']], show_progress=False) shared.gradio['prompt_menu'].change(load_prompt, [shared.gradio['prompt_menu']], [shared.gradio['textbox']], show_progress=False)
shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False) shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False)
def create_model_menus(): def create_model_menus():
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
@@ -175,6 +192,7 @@ def create_model_menus():
shared.gradio['model_menu'].change(load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_menu'], show_progress=True) shared.gradio['model_menu'].change(load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_menu'], show_progress=True)
shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True) shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True)
def create_settings_menus(default_preset): def create_settings_menus(default_preset):
generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', {}, return_dict=True) generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', {}, return_dict=True)
for k in ['max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size', 'chat_generation_attempts']: for k in ['max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size', 'chat_generation_attempts']:
@@ -209,7 +227,6 @@ def create_settings_menus(default_preset):
with gr.Box(): with gr.Box():
gr.Markdown('Contrastive search') gr.Markdown('Contrastive search')
shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha') shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha')
with gr.Box(): with gr.Box():
gr.Markdown('Beam search (uses a lot of VRAM)') gr.Markdown('Beam search (uses a lot of VRAM)')
with gr.Row(): with gr.Row():
@@ -219,7 +236,6 @@ def create_settings_menus(default_preset):
shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty') shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping') shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
with gr.Accordion('Soft prompt', open=False): with gr.Accordion('Soft prompt', open=False):
with gr.Row(): with gr.Row():
shared.gradio['softprompts_menu'] = gr.Dropdown(choices=available_softprompts, value='None', label='Soft prompt') shared.gradio['softprompts_menu'] = gr.Dropdown(choices=available_softprompts, value='None', label='Soft prompt')
@@ -233,6 +249,7 @@ def create_settings_menus(default_preset):
shared.gradio['softprompts_menu'].change(load_soft_prompt, shared.gradio['softprompts_menu'], shared.gradio['softprompts_menu'], show_progress=True) shared.gradio['softprompts_menu'].change(load_soft_prompt, shared.gradio['softprompts_menu'], shared.gradio['softprompts_menu'], show_progress=True)
shared.gradio['upload_softprompt'].upload(upload_soft_prompt, shared.gradio['upload_softprompt'], shared.gradio['softprompts_menu']) shared.gradio['upload_softprompt'].upload(upload_soft_prompt, shared.gradio['upload_softprompt'], shared.gradio['softprompts_menu'])
def set_interface_arguments(interface_mode, extensions, bool_active): def set_interface_arguments(interface_mode, extensions, bool_active):
modes = ["default", "notebook", "chat", "cai_chat"] modes = ["default", "notebook", "chat", "cai_chat"]
cmd_list = vars(shared.args) cmd_list = vars(shared.args)
@@ -251,6 +268,7 @@ def set_interface_arguments(interface_mode, extensions, bool_active):
shared.need_restart = True shared.need_restart = True
available_models = get_available_models() available_models = get_available_models()
available_presets = get_available_presets() available_presets = get_available_presets()
available_characters = get_available_characters() available_characters = get_available_characters()
@@ -299,8 +317,8 @@ else:
default_text = load_prompt(shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')]) default_text = load_prompt(shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')])
title = 'Text generation web UI' title = 'Text generation web UI'
def create_interface():
def create_interface():
gen_events = [] gen_events = []
if shared.args.extensions is not None and len(shared.args.extensions) > 0: if shared.args.extensions is not None and len(shared.args.extensions) > 0:
extensions_module.load_extensions() extensions_module.load_extensions()
@@ -562,6 +580,7 @@ def create_interface():
else: else:
shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch, auth=auth) shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch, auth=auth)
create_interface() create_interface()
while True: while True: