Ver código fonte

Make the code more like PEP8 for readability (#862)

oobabooga 3 anos atrás
pai
commit
ea6e77df72

+ 4 - 2
api-example-stream.py

@@ -17,6 +17,7 @@ def random_hash():
     letters = string.ascii_lowercase + string.digits
     letters = string.ascii_lowercase + string.digits
     return ''.join(random.choice(letters) for i in range(9))
     return ''.join(random.choice(letters) for i in range(9))
 
 
+
 async def run(context):
 async def run(context):
     server = "127.0.0.1"
     server = "127.0.0.1"
     params = {
     params = {
@@ -41,7 +42,7 @@ async def run(context):
 
 
     async with websockets.connect(f"ws://{server}:7860/queue/join") as websocket:
     async with websockets.connect(f"ws://{server}:7860/queue/join") as websocket:
         while content := json.loads(await websocket.recv()):
         while content := json.loads(await websocket.recv()):
-            #Python3.10 syntax, replace with if elif on older
+            # Python3.10 syntax, replace with if elif on older
             match content["msg"]:
             match content["msg"]:
                 case "send_hash":
                 case "send_hash":
                     await websocket.send(json.dumps({
                     await websocket.send(json.dumps({
@@ -62,13 +63,14 @@ async def run(context):
                     pass
                     pass
                 case "process_generating" | "process_completed":
                 case "process_generating" | "process_completed":
                     yield content["output"]["data"][0]
                     yield content["output"]["data"][0]
-                    # You can search for your desired end indicator and 
+                    # You can search for your desired end indicator and
                     #  stop generation by closing the websocket here
                     #  stop generation by closing the websocket here
                     if (content["msg"] == "process_completed"):
                     if (content["msg"] == "process_completed"):
                         break
                         break
 
 
 prompt = "What I would like to say is the following: "
 prompt = "What I would like to say is the following: "
 
 
+
 async def get_result():
 async def get_result():
     async for response in run(prompt):
     async for response in run(prompt):
         # Print intermediate steps
         # Print intermediate steps

+ 6 - 3
convert-to-flexgen.py

@@ -13,10 +13,11 @@ import torch
 from tqdm import tqdm
 from tqdm import tqdm
 from transformers import AutoModelForCausalLM, AutoTokenizer
 from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
-parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54))
+parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=54))
 parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.")
 parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.")
 args = parser.parse_args()
 args = parser.parse_args()
 
 
+
 def disable_torch_init():
 def disable_torch_init():
     """
     """
     Disable the redundant torch default initialization to accelerate model creation.
     Disable the redundant torch default initialization to accelerate model creation.
@@ -31,20 +32,22 @@ def disable_torch_init():
     torch_layer_norm_init_backup = torch.nn.LayerNorm.reset_parameters
     torch_layer_norm_init_backup = torch.nn.LayerNorm.reset_parameters
     setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
     setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
 
 
+
 def restore_torch_init():
 def restore_torch_init():
     """Rollback the change made by disable_torch_init."""
     """Rollback the change made by disable_torch_init."""
     import torch
     import torch
     setattr(torch.nn.Linear, "reset_parameters", torch_linear_init_backup)
     setattr(torch.nn.Linear, "reset_parameters", torch_linear_init_backup)
     setattr(torch.nn.LayerNorm, "reset_parameters", torch_layer_norm_init_backup)
     setattr(torch.nn.LayerNorm, "reset_parameters", torch_layer_norm_init_backup)
 
 
+
 if __name__ == '__main__':
 if __name__ == '__main__':
     path = Path(args.MODEL)
     path = Path(args.MODEL)
     model_name = path.name
     model_name = path.name
 
 
     print(f"Loading {model_name}...")
     print(f"Loading {model_name}...")
-    #disable_torch_init()
+    # disable_torch_init()
     model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
     model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
-    #restore_torch_init()
+    # restore_torch_init()
 
 
     tokenizer = AutoTokenizer.from_pretrained(path)
     tokenizer = AutoTokenizer.from_pretrained(path)
 
 

+ 1 - 1
convert-to-safetensors.py

@@ -17,7 +17,7 @@ from pathlib import Path
 import torch
 import torch
 from transformers import AutoModelForCausalLM, AutoTokenizer
 from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
-parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54))
+parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=54))
 parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.")
 parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.")
 parser.add_argument('--output', type=str, default=None, help='Path to the output folder (default: models/{model_name}_safetensors).')
 parser.add_argument('--output', type=str, default=None, help='Path to the output folder (default: models/{model_name}_safetensors).')
 parser.add_argument("--max-shard-size", type=str, default="2GB", help="Maximum size of a shard in GB or MB (default: %(default)s).")
 parser.add_argument("--max-shard-size", type=str, default="2GB", help="Maximum size of a shard in GB or MB (default: %(default)s).")

+ 11 - 5
download-model.py

@@ -29,6 +29,7 @@ parser.add_argument('--clean', action='store_true', help='Does not resume the pr
 parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
 parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
 args = parser.parse_args()
 args = parser.parse_args()
 
 
+
 def get_file(url, output_folder):
 def get_file(url, output_folder):
     filename = Path(url.rsplit('/', 1)[1])
     filename = Path(url.rsplit('/', 1)[1])
     output_path = output_folder / filename
     output_path = output_folder / filename
@@ -54,6 +55,7 @@ def get_file(url, output_folder):
                 t.update(len(data))
                 t.update(len(data))
                 f.write(data)
                 f.write(data)
 
 
+
 def sanitize_branch_name(branch_name):
 def sanitize_branch_name(branch_name):
     pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
     pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
     if pattern.match(branch_name):
     if pattern.match(branch_name):
@@ -61,6 +63,7 @@ def sanitize_branch_name(branch_name):
     else:
     else:
         raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.")
         raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.")
 
 
+
 def select_model_from_default_options():
 def select_model_from_default_options():
     models = {
     models = {
         "OPT 6.7B": ("facebook", "opt-6.7b", "main"),
         "OPT 6.7B": ("facebook", "opt-6.7b", "main"),
@@ -78,11 +81,11 @@ def select_model_from_default_options():
     choices = {}
     choices = {}
 
 
     print("Select the model that you want to download:\n")
     print("Select the model that you want to download:\n")
-    for i,name in enumerate(models):
-        char = chr(ord('A')+i)
+    for i, name in enumerate(models):
+        char = chr(ord('A') + i)
         choices[char] = name
         choices[char] = name
         print(f"{char}) {name}")
         print(f"{char}) {name}")
-    char = chr(ord('A')+len(models))
+    char = chr(ord('A') + len(models))
     print(f"{char}) None of the above")
     print(f"{char}) None of the above")
 
 
     print()
     print()
@@ -106,6 +109,7 @@ EleutherAI/pythia-1.4b-deduped
 
 
     return model, branch
     return model, branch
 
 
+
 def get_download_links_from_huggingface(model, branch):
 def get_download_links_from_huggingface(model, branch):
     base = "https://huggingface.co"
     base = "https://huggingface.co"
     page = f"/api/models/{model}/tree/{branch}?cursor="
     page = f"/api/models/{model}/tree/{branch}?cursor="
@@ -166,15 +170,17 @@ def get_download_links_from_huggingface(model, branch):
 
 
     # If both pytorch and safetensors are available, download safetensors only
     # If both pytorch and safetensors are available, download safetensors only
     if (has_pytorch or has_pt) and has_safetensors:
     if (has_pytorch or has_pt) and has_safetensors:
-        for i in range(len(classifications)-1, -1, -1):
+        for i in range(len(classifications) - 1, -1, -1):
             if classifications[i] in ['pytorch', 'pt']:
             if classifications[i] in ['pytorch', 'pt']:
                 links.pop(i)
                 links.pop(i)
 
 
     return links, sha256, is_lora
     return links, sha256, is_lora
 
 
+
 def download_files(file_list, output_folder, num_threads=8):
 def download_files(file_list, output_folder, num_threads=8):
     thread_map(lambda url: get_file(url, output_folder), file_list, max_workers=num_threads, disable=True)
     thread_map(lambda url: get_file(url, output_folder), file_list, max_workers=num_threads, disable=True)
 
 
+
 if __name__ == '__main__':
 if __name__ == '__main__':
     model = args.MODEL
     model = args.MODEL
     branch = args.branch
     branch = args.branch
@@ -224,7 +230,7 @@ if __name__ == '__main__':
                     validated = False
                     validated = False
                 else:
                 else:
                     print(f'Checksum validated: {sha256[i][0]}  {sha256[i][1]}')
                     print(f'Checksum validated: {sha256[i][0]}  {sha256[i][1]}')
-        
+
         if validated:
         if validated:
             print('[+] Validated checksums of all model files!')
             print('[+] Validated checksums of all model files!')
         else:
         else:

+ 15 - 13
extensions/api/script.py

@@ -9,6 +9,7 @@ params = {
     'port': 5000,
     'port': 5000,
 }
 }
 
 
+
 class Handler(BaseHTTPRequestHandler):
 class Handler(BaseHTTPRequestHandler):
     def do_GET(self):
     def do_GET(self):
         if self.path == '/api/v1/model':
         if self.path == '/api/v1/model':
@@ -32,7 +33,7 @@ class Handler(BaseHTTPRequestHandler):
             self.end_headers()
             self.end_headers()
 
 
             prompt = body['prompt']
             prompt = body['prompt']
-            prompt_lines = [l.strip() for l in prompt.split('\n')]
+            prompt_lines = [k.strip() for k in prompt.split('\n')]
 
 
             max_context = body.get('max_context_length', 2048)
             max_context = body.get('max_context_length', 2048)
 
 
@@ -40,18 +41,18 @@ class Handler(BaseHTTPRequestHandler):
                 prompt_lines.pop(0)
                 prompt_lines.pop(0)
 
 
             prompt = '\n'.join(prompt_lines)
             prompt = '\n'.join(prompt_lines)
-            generate_params =  {
-                'max_new_tokens': int(body.get('max_length', 200)), 
+            generate_params = {
+                'max_new_tokens': int(body.get('max_length', 200)),
                 'do_sample': bool(body.get('do_sample', True)),
                 'do_sample': bool(body.get('do_sample', True)),
-                'temperature': float(body.get('temperature', 0.5)), 
-                'top_p': float(body.get('top_p', 1)), 
-                'typical_p': float(body.get('typical', 1)), 
-                'repetition_penalty': float(body.get('rep_pen', 1.1)), 
+                'temperature': float(body.get('temperature', 0.5)),
+                'top_p': float(body.get('top_p', 1)),
+                'typical_p': float(body.get('typical', 1)),
+                'repetition_penalty': float(body.get('rep_pen', 1.1)),
                 'encoder_repetition_penalty': 1,
                 'encoder_repetition_penalty': 1,
-                'top_k': int(body.get('top_k', 0)), 
+                'top_k': int(body.get('top_k', 0)),
                 'min_length': int(body.get('min_length', 0)),
                 'min_length': int(body.get('min_length', 0)),
-                'no_repeat_ngram_size': int(body.get('no_repeat_ngram_size',0)),
-                'num_beams': int(body.get('num_beams',1)),
+                'no_repeat_ngram_size': int(body.get('no_repeat_ngram_size', 0)),
+                'num_beams': int(body.get('num_beams', 1)),
                 'penalty_alpha': float(body.get('penalty_alpha', 0)),
                 'penalty_alpha': float(body.get('penalty_alpha', 0)),
                 'length_penalty': float(body.get('length_penalty', 1)),
                 'length_penalty': float(body.get('length_penalty', 1)),
                 'early_stopping': bool(body.get('early_stopping', False)),
                 'early_stopping': bool(body.get('early_stopping', False)),
@@ -59,7 +60,7 @@ class Handler(BaseHTTPRequestHandler):
             }
             }
 
 
             generator = generate_reply(
             generator = generate_reply(
-                prompt, 
+                prompt,
                 generate_params,
                 generate_params,
                 stopping_strings=body.get('stopping_strings', []),
                 stopping_strings=body.get('stopping_strings', []),
             )
             )
@@ -84,9 +85,9 @@ class Handler(BaseHTTPRequestHandler):
 def run_server():
 def run_server():
     server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', params['port'])
     server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', params['port'])
     server = ThreadingHTTPServer(server_addr, Handler)
     server = ThreadingHTTPServer(server_addr, Handler)
-    if shared.args.share: 
+    if shared.args.share:
         try:
         try:
-            from flask_cloudflared import  _run_cloudflared
+            from flask_cloudflared import _run_cloudflared
             public_url = _run_cloudflared(params['port'], params['port'] + 1)
             public_url = _run_cloudflared(params['port'], params['port'] + 1)
             print(f'Starting KoboldAI compatible api at {public_url}/api')
             print(f'Starting KoboldAI compatible api at {public_url}/api')
         except ImportError:
         except ImportError:
@@ -95,5 +96,6 @@ def run_server():
         print(f'Starting KoboldAI compatible api at http://{server_addr[0]}:{server_addr[1]}/api')
         print(f'Starting KoboldAI compatible api at http://{server_addr[0]}:{server_addr[1]}/api')
     server.serve_forever()
     server.serve_forever()
 
 
+
 def setup():
 def setup():
     Thread(target=run_server, daemon=True).start()
     Thread(target=run_server, daemon=True).start()

+ 6 - 2
extensions/character_bias/script.py

@@ -5,14 +5,16 @@ params = {
     "bias string": " *I am so happy*",
     "bias string": " *I am so happy*",
 }
 }
 
 
+
 def input_modifier(string):
 def input_modifier(string):
     """
     """
     This function is applied to your text inputs before
     This function is applied to your text inputs before
     they are fed into the model.
     they are fed into the model.
-    """ 
+    """
 
 
     return string
     return string
 
 
+
 def output_modifier(string):
 def output_modifier(string):
     """
     """
     This function is applied to the model outputs.
     This function is applied to the model outputs.
@@ -20,6 +22,7 @@ def output_modifier(string):
 
 
     return string
     return string
 
 
+
 def bot_prefix_modifier(string):
 def bot_prefix_modifier(string):
     """
     """
     This function is only applied in chat mode. It modifies
     This function is only applied in chat mode. It modifies
@@ -27,11 +30,12 @@ def bot_prefix_modifier(string):
     behavior.
     behavior.
     """
     """
 
 
-    if params['activate'] == True:
+    if params['activate']:
         return f'{string} {params["bias string"].strip()} '
         return f'{string} {params["bias string"].strip()} '
     else:
     else:
         return string
         return string
 
 
+
 def ui():
 def ui():
     # Gradio elements
     # Gradio elements
     activate = gr.Checkbox(value=params['activate'], label='Activate character bias')
     activate = gr.Checkbox(value=params['activate'], label='Activate character bias')

+ 21 - 13
extensions/elevenlabs_tts/script.py

@@ -20,16 +20,18 @@ user_info = None
 if not shared.args.no_stream:
 if not shared.args.no_stream:
     print("Please add --no-stream. This extension is not meant to be used with streaming.")
     print("Please add --no-stream. This extension is not meant to be used with streaming.")
     raise ValueError
     raise ValueError
-    
+
 # Check if the API is valid and refresh the UI accordingly.
 # Check if the API is valid and refresh the UI accordingly.
+
+
 def check_valid_api():
 def check_valid_api():
-    
+
     global user, user_info, params
     global user, user_info, params
 
 
     user = ElevenLabsUser(params['api_key'])
     user = ElevenLabsUser(params['api_key'])
     user_info = user._get_subscription_data()
     user_info = user._get_subscription_data()
     print('checking api')
     print('checking api')
-    if params['activate'] == False:
+    if not params['activate']:
         return gr.update(value='Disconnected')
         return gr.update(value='Disconnected')
     elif user_info is None:
     elif user_info is None:
         print('Incorrect API Key')
         print('Incorrect API Key')
@@ -37,24 +39,28 @@ def check_valid_api():
     else:
     else:
         print('Got an API Key!')
         print('Got an API Key!')
         return gr.update(value='Connected')
         return gr.update(value='Connected')
-    
+
 # Once the API is verified, get the available voices and update the dropdown list
 # Once the API is verified, get the available voices and update the dropdown list
+
+
 def refresh_voices():
 def refresh_voices():
-    
+
     global user, user_info
     global user, user_info
-    
+
     your_voices = [None]
     your_voices = [None]
     if user_info is not None:
     if user_info is not None:
         for voice in user.get_available_voices():
         for voice in user.get_available_voices():
             your_voices.append(voice.initialName)
             your_voices.append(voice.initialName)
-        return  gr.Dropdown.update(choices=your_voices)
+        return gr.Dropdown.update(choices=your_voices)
     else:
     else:
         return
         return
 
 
+
 def remove_surrounded_chars(string):
 def remove_surrounded_chars(string):
     # this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
     # this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
     # 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
     # 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
-    return re.sub('\*[^\*]*?(\*|$)','',string)
+    return re.sub('\*[^\*]*?(\*|$)', '', string)
+
 
 
 def input_modifier(string):
 def input_modifier(string):
     """
     """
@@ -64,16 +70,17 @@ def input_modifier(string):
 
 
     return string
     return string
 
 
+
 def output_modifier(string):
 def output_modifier(string):
     """
     """
     This function is applied to the model outputs.
     This function is applied to the model outputs.
     """
     """
 
 
     global params, wav_idx, user, user_info
     global params, wav_idx, user, user_info
-    
-    if params['activate'] == False:
+
+    if not params['activate']:
         return string
         return string
-    elif user_info == None:
+    elif user_info is None:
         return string
         return string
 
 
     string = remove_surrounded_chars(string)
     string = remove_surrounded_chars(string)
@@ -84,7 +91,7 @@ def output_modifier(string):
 
 
     if string == '':
     if string == '':
         string = 'empty reply, try regenerating'
         string = 'empty reply, try regenerating'
-        
+
     output_file = Path(f'extensions/elevenlabs_tts/outputs/{wav_idx:06d}.wav'.format(wav_idx))
     output_file = Path(f'extensions/elevenlabs_tts/outputs/{wav_idx:06d}.wav'.format(wav_idx))
     voice = user.get_voices_by_name(params['selected_voice'])[0]
     voice = user.get_voices_by_name(params['selected_voice'])[0]
     audio_data = voice.generate_audio_bytes(string)
     audio_data = voice.generate_audio_bytes(string)
@@ -94,6 +101,7 @@ def output_modifier(string):
     wav_idx += 1
     wav_idx += 1
     return string
     return string
 
 
+
 def ui():
 def ui():
 
 
     # Gradio elements
     # Gradio elements
@@ -110,4 +118,4 @@ def ui():
     voice.change(lambda x: params.update({'selected_voice': x}), voice, None)
     voice.change(lambda x: params.update({'selected_voice': x}), voice, None)
     api_key.change(lambda x: params.update({'api_key': x}), api_key, None)
     api_key.change(lambda x: params.update({'api_key': x}), api_key, None)
     connect.click(check_valid_api, [], connection_status)
     connect.click(check_valid_api, [], connection_status)
-    connect.click(refresh_voices, [], voice)
+    connect.click(refresh_voices, [], voice)

+ 2 - 2
extensions/gallery/script.py

@@ -85,7 +85,7 @@ def select_character(evt: gr.SelectData):
 def ui():
 def ui():
     with gr.Accordion("Character gallery", open=False):
     with gr.Accordion("Character gallery", open=False):
         update = gr.Button("Refresh")
         update = gr.Button("Refresh")
-        gr.HTML(value="<style>"+generate_css()+"</style>")
+        gr.HTML(value="<style>" + generate_css() + "</style>")
         gallery = gr.Dataset(components=[gr.HTML(visible=False)],
         gallery = gr.Dataset(components=[gr.HTML(visible=False)],
             label="",
             label="",
             samples=generate_html(),
             samples=generate_html(),
@@ -93,4 +93,4 @@ def ui():
             samples_per_page=50
             samples_per_page=50
         )
         )
     update.click(generate_html, [], gallery)
     update.click(generate_html, [], gallery)
-    gallery.select(select_character, None, gradio['character_menu'])
+    gallery.select(select_character, None, gradio['character_menu'])

+ 5 - 1
extensions/google_translate/script.py

@@ -7,14 +7,16 @@ params = {
 
 
 language_codes = {'Afrikaans': 'af', 'Albanian': 'sq', 'Amharic': 'am', 'Arabic': 'ar', 'Armenian': 'hy', 'Azerbaijani': 'az', 'Basque': 'eu', 'Belarusian': 'be', 'Bengali': 'bn', 'Bosnian': 'bs', 'Bulgarian': 'bg', 'Catalan': 'ca', 'Cebuano': 'ceb', 'Chinese (Simplified)': 'zh-CN', 'Chinese (Traditional)': 'zh-TW', 'Corsican': 'co', 'Croatian': 'hr', 'Czech': 'cs', 'Danish': 'da', 'Dutch': 'nl', 'English': 'en', 'Esperanto': 'eo', 'Estonian': 'et', 'Finnish': 'fi', 'French': 'fr', 'Frisian': 'fy', 'Galician': 'gl', 'Georgian': 'ka', 'German': 'de', 'Greek': 'el', 'Gujarati': 'gu', 'Haitian Creole': 'ht', 'Hausa': 'ha', 'Hawaiian': 'haw', 'Hebrew': 'iw', 'Hindi': 'hi', 'Hmong': 'hmn', 'Hungarian': 'hu', 'Icelandic': 'is', 'Igbo': 'ig', 'Indonesian': 'id', 'Irish': 'ga', 'Italian': 'it', 'Japanese': 'ja', 'Javanese': 'jw', 'Kannada': 'kn', 'Kazakh': 'kk', 'Khmer': 'km', 'Korean': 'ko', 'Kurdish': 'ku', 'Kyrgyz': 'ky', 'Lao': 'lo', 'Latin': 'la', 'Latvian': 'lv', 'Lithuanian': 'lt', 'Luxembourgish': 'lb', 'Macedonian': 'mk', 'Malagasy': 'mg', 'Malay': 'ms', 'Malayalam': 'ml', 'Maltese': 'mt', 'Maori': 'mi', 'Marathi': 'mr', 'Mongolian': 'mn', 'Myanmar (Burmese)': 'my', 'Nepali': 'ne', 'Norwegian': 'no', 'Nyanja (Chichewa)': 'ny', 'Pashto': 'ps', 'Persian': 'fa', 'Polish': 'pl', 'Portuguese (Portugal, Brazil)': 'pt', 'Punjabi': 'pa', 'Romanian': 'ro', 'Russian': 'ru', 'Samoan': 'sm', 'Scots Gaelic': 'gd', 'Serbian': 'sr', 'Sesotho': 'st', 'Shona': 'sn', 'Sindhi': 'sd', 'Sinhala (Sinhalese)': 'si', 'Slovak': 'sk', 'Slovenian': 'sl', 'Somali': 'so', 'Spanish': 'es', 'Sundanese': 'su', 'Swahili': 'sw', 'Swedish': 'sv', 'Tagalog (Filipino)': 'tl', 'Tajik': 'tg', 'Tamil': 'ta', 'Telugu': 'te', 'Thai': 'th', 'Turkish': 'tr', 'Ukrainian': 'uk', 'Urdu': 'ur', 'Uzbek': 'uz', 'Vietnamese': 'vi', 'Welsh': 'cy', 'Xhosa': 'xh', 'Yiddish': 'yi', 'Yoruba': 'yo', 'Zulu': 'zu'}
 language_codes = {'Afrikaans': 'af', 'Albanian': 'sq', 'Amharic': 'am', 'Arabic': 'ar', 'Armenian': 'hy', 'Azerbaijani': 'az', 'Basque': 'eu', 'Belarusian': 'be', 'Bengali': 'bn', 'Bosnian': 'bs', 'Bulgarian': 'bg', 'Catalan': 'ca', 'Cebuano': 'ceb', 'Chinese (Simplified)': 'zh-CN', 'Chinese (Traditional)': 'zh-TW', 'Corsican': 'co', 'Croatian': 'hr', 'Czech': 'cs', 'Danish': 'da', 'Dutch': 'nl', 'English': 'en', 'Esperanto': 'eo', 'Estonian': 'et', 'Finnish': 'fi', 'French': 'fr', 'Frisian': 'fy', 'Galician': 'gl', 'Georgian': 'ka', 'German': 'de', 'Greek': 'el', 'Gujarati': 'gu', 'Haitian Creole': 'ht', 'Hausa': 'ha', 'Hawaiian': 'haw', 'Hebrew': 'iw', 'Hindi': 'hi', 'Hmong': 'hmn', 'Hungarian': 'hu', 'Icelandic': 'is', 'Igbo': 'ig', 'Indonesian': 'id', 'Irish': 'ga', 'Italian': 'it', 'Japanese': 'ja', 'Javanese': 'jw', 'Kannada': 'kn', 'Kazakh': 'kk', 'Khmer': 'km', 'Korean': 'ko', 'Kurdish': 'ku', 'Kyrgyz': 'ky', 'Lao': 'lo', 'Latin': 'la', 'Latvian': 'lv', 'Lithuanian': 'lt', 'Luxembourgish': 'lb', 'Macedonian': 'mk', 'Malagasy': 'mg', 'Malay': 'ms', 'Malayalam': 'ml', 'Maltese': 'mt', 'Maori': 'mi', 'Marathi': 'mr', 'Mongolian': 'mn', 'Myanmar (Burmese)': 'my', 'Nepali': 'ne', 'Norwegian': 'no', 'Nyanja (Chichewa)': 'ny', 'Pashto': 'ps', 'Persian': 'fa', 'Polish': 'pl', 'Portuguese (Portugal, Brazil)': 'pt', 'Punjabi': 'pa', 'Romanian': 'ro', 'Russian': 'ru', 'Samoan': 'sm', 'Scots Gaelic': 'gd', 'Serbian': 'sr', 'Sesotho': 'st', 'Shona': 'sn', 'Sindhi': 'sd', 'Sinhala (Sinhalese)': 'si', 'Slovak': 'sk', 'Slovenian': 'sl', 'Somali': 'so', 'Spanish': 'es', 'Sundanese': 'su', 'Swahili': 'sw', 'Swedish': 'sv', 'Tagalog (Filipino)': 'tl', 'Tajik': 'tg', 'Tamil': 'ta', 'Telugu': 'te', 'Thai': 'th', 'Turkish': 'tr', 'Ukrainian': 'uk', 'Urdu': 'ur', 'Uzbek': 'uz', 'Vietnamese': 'vi', 'Welsh': 'cy', 'Xhosa': 'xh', 'Yiddish': 'yi', 'Yoruba': 'yo', 'Zulu': 'zu'}
 
 
+
 def input_modifier(string):
 def input_modifier(string):
     """
     """
     This function is applied to your text inputs before
     This function is applied to your text inputs before
     they are fed into the model.
     they are fed into the model.
-    """ 
+    """
 
 
     return GoogleTranslator(source=params['language string'], target='en').translate(string)
     return GoogleTranslator(source=params['language string'], target='en').translate(string)
 
 
+
 def output_modifier(string):
 def output_modifier(string):
     """
     """
     This function is applied to the model outputs.
     This function is applied to the model outputs.
@@ -22,6 +24,7 @@ def output_modifier(string):
 
 
     return GoogleTranslator(source='en', target=params['language string']).translate(string)
     return GoogleTranslator(source='en', target=params['language string']).translate(string)
 
 
+
 def bot_prefix_modifier(string):
 def bot_prefix_modifier(string):
     """
     """
     This function is only applied in chat mode. It modifies
     This function is only applied in chat mode. It modifies
@@ -31,6 +34,7 @@ def bot_prefix_modifier(string):
 
 
     return string
     return string
 
 
+
 def ui():
 def ui():
     # Finding the language name from the language code to use as the default value
     # Finding the language name from the language code to use as the default value
     language_name = list(language_codes.keys())[list(language_codes.values()).index(params['language string'])]
     language_name = list(language_codes.keys())[list(language_codes.values()).index(params['language string'])]

+ 2 - 0
extensions/llama_prompts/script.py

@@ -4,12 +4,14 @@ import pandas as pd
 
 
 df = pd.read_csv("https://raw.githubusercontent.com/devbrones/llama-prompts/main/prompts/prompts.csv")
 df = pd.read_csv("https://raw.githubusercontent.com/devbrones/llama-prompts/main/prompts/prompts.csv")
 
 
+
 def get_prompt_by_name(name):
 def get_prompt_by_name(name):
     if name == 'None':
     if name == 'None':
         return ''
         return ''
     else:
     else:
         return df[df['Prompt name'] == name].iloc[0]['Prompt'].replace('\\n', '\n')
         return df[df['Prompt name'] == name].iloc[0]['Prompt'].replace('\\n', '\n')
 
 
+
 def ui():
 def ui():
     if not shared.is_chat():
     if not shared.is_chat():
         choices = ['None'] + list(df['Prompt name'])
         choices = ['None'] + list(df['Prompt name'])

+ 24 - 14
extensions/sd_api_pictures/script.py

@@ -12,30 +12,33 @@ from PIL import Image
 
 
 torch._C._jit_set_profiling_mode(False)
 torch._C._jit_set_profiling_mode(False)
 
 
-# parameters which can be customized in settings.json of webui  
+# parameters which can be customized in settings.json of webui
 params = {
 params = {
     'enable_SD_api': False,
     'enable_SD_api': False,
     'address': 'http://127.0.0.1:7860',
     'address': 'http://127.0.0.1:7860',
     'save_img': False,
     'save_img': False,
-    'SD_model': 'NeverEndingDream', # not really used right now
+    'SD_model': 'NeverEndingDream',  # not really used right now
     'prompt_prefix': '(Masterpiece:1.1), (solo:1.3), detailed, intricate, colorful',
     'prompt_prefix': '(Masterpiece:1.1), (solo:1.3), detailed, intricate, colorful',
     'negative_prompt': '(worst quality, low quality:1.3)',
     'negative_prompt': '(worst quality, low quality:1.3)',
     'side_length': 512,
     'side_length': 512,
     'restore_faces': False
     'restore_faces': False
 }
 }
 
 
-SD_models = ['NeverEndingDream'] # TODO: get with http://{address}}/sdapi/v1/sd-models and allow user to select
+SD_models = ['NeverEndingDream']  # TODO: get with http://{address}}/sdapi/v1/sd-models and allow user to select
 
 
-streaming_state = shared.args.no_stream # remember if chat streaming was enabled
-picture_response = False # specifies if the next model response should appear as a picture
+streaming_state = shared.args.no_stream  # remember if chat streaming was enabled
+picture_response = False  # specifies if the next model response should appear as a picture
 pic_id = 0
 pic_id = 0
 
 
+
 def remove_surrounded_chars(string):
 def remove_surrounded_chars(string):
     # this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
     # this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
     # 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
     # 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
-    return re.sub('\*[^\*]*?(\*|$)','',string)
+    return re.sub('\*[^\*]*?(\*|$)', '', string)
 
 
 # I don't even need input_hijack for this as visible text will be commited to history as the unmodified string
 # I don't even need input_hijack for this as visible text will be commited to history as the unmodified string
+
+
 def input_modifier(string):
 def input_modifier(string):
     """
     """
     This function is applied to your text inputs before
     This function is applied to your text inputs before
@@ -51,7 +54,7 @@ def input_modifier(string):
     lowstr = string.lower()
     lowstr = string.lower()
 
 
     # TODO: refactor out to separate handler and also replace detection with a regexp
     # TODO: refactor out to separate handler and also replace detection with a regexp
-    if any(command in lowstr for command in commands) and any(case in lowstr for case in mediums): # trigger the generation if a command signature and a medium signature is found
+    if any(command in lowstr for command in commands) and any(case in lowstr for case in mediums):  # trigger the generation if a command signature and a medium signature is found
         picture_response = True
         picture_response = True
         shared.args.no_stream = True                                                               # Disable streaming cause otherwise the SD-generated picture would return as a dud
         shared.args.no_stream = True                                                               # Disable streaming cause otherwise the SD-generated picture would return as a dud
         shared.processing_message = "*Is sending a picture...*"
         shared.processing_message = "*Is sending a picture...*"
@@ -62,6 +65,8 @@ def input_modifier(string):
     return string
     return string
 
 
 # Get and save the Stable Diffusion-generated picture
 # Get and save the Stable Diffusion-generated picture
+
+
 def get_SD_pictures(description):
 def get_SD_pictures(description):
 
 
     global params, pic_id
     global params, pic_id
@@ -77,13 +82,13 @@ def get_SD_pictures(description):
         "restore_faces": params['restore_faces'],
         "restore_faces": params['restore_faces'],
         "negative_prompt": params['negative_prompt']
         "negative_prompt": params['negative_prompt']
     }
     }
-    
+
     response = requests.post(url=f'{params["address"]}/sdapi/v1/txt2img', json=payload)
     response = requests.post(url=f'{params["address"]}/sdapi/v1/txt2img', json=payload)
     r = response.json()
     r = response.json()
 
 
     visible_result = ""
     visible_result = ""
     for img_str in r['images']:
     for img_str in r['images']:
-        image = Image.open(io.BytesIO(base64.b64decode(img_str.split(",",1)[0])))
+        image = Image.open(io.BytesIO(base64.b64decode(img_str.split(",", 1)[0])))
         if params['save_img']:
         if params['save_img']:
             output_file = Path(f'extensions/sd_api_pictures/outputs/{pic_id:06d}.png')
             output_file = Path(f'extensions/sd_api_pictures/outputs/{pic_id:06d}.png')
             image.save(output_file.as_posix())
             image.save(output_file.as_posix())
@@ -96,11 +101,13 @@ def get_SD_pictures(description):
         image_bytes = buffered.getvalue()
         image_bytes = buffered.getvalue()
         img_str = "data:image/jpeg;base64," + base64.b64encode(image_bytes).decode()
         img_str = "data:image/jpeg;base64," + base64.b64encode(image_bytes).decode()
         visible_result = visible_result + f'<img src="{img_str}" alt="{description}">\n'
         visible_result = visible_result + f'<img src="{img_str}" alt="{description}">\n'
-    
+
     return visible_result
     return visible_result
 
 
 # TODO: how do I make the UI history ignore the resulting pictures (I don't want HTML to appear in history)
 # TODO: how do I make the UI history ignore the resulting pictures (I don't want HTML to appear in history)
 # and replace it with 'text' for the purposes of logging?
 # and replace it with 'text' for the purposes of logging?
+
+
 def output_modifier(string):
 def output_modifier(string):
     """
     """
     This function is applied to the model outputs.
     This function is applied to the model outputs.
@@ -130,6 +137,7 @@ def output_modifier(string):
     shared.args.no_stream = streaming_state
     shared.args.no_stream = streaming_state
     return image + "\n" + text
     return image + "\n" + text
 
 
+
 def bot_prefix_modifier(string):
 def bot_prefix_modifier(string):
     """
     """
     This function is only applied in chat mode. It modifies
     This function is only applied in chat mode. It modifies
@@ -139,10 +147,12 @@ def bot_prefix_modifier(string):
 
 
     return string
     return string
 
 
+
 def force_pic():
 def force_pic():
     global picture_response
     global picture_response
     picture_response = True
     picture_response = True
 
 
+
 def ui():
 def ui():
 
 
     # Gradio elements
     # Gradio elements
@@ -153,7 +163,7 @@ def ui():
                 save_img = gr.Checkbox(value=params['save_img'], label='Keep original received images in the outputs subdir')
                 save_img = gr.Checkbox(value=params['save_img'], label='Keep original received images in the outputs subdir')
             with gr.Column():
             with gr.Column():
                 address = gr.Textbox(placeholder=params['address'], value=params['address'], label='Stable Diffusion host address')
                 address = gr.Textbox(placeholder=params['address'], value=params['address'], label='Stable Diffusion host address')
-        
+
         with gr.Row():
         with gr.Row():
             force_btn = gr.Button("Force the next response to be a picture")
             force_btn = gr.Button("Force the next response to be a picture")
             generate_now_btn = gr.Button("Generate an image response to the input")
             generate_now_btn = gr.Button("Generate an image response to the input")
@@ -162,9 +172,9 @@ def ui():
             prompt_prefix = gr.Textbox(placeholder=params['prompt_prefix'], value=params['prompt_prefix'], label='Prompt Prefix (best used to describe the look of the character)')
             prompt_prefix = gr.Textbox(placeholder=params['prompt_prefix'], value=params['prompt_prefix'], label='Prompt Prefix (best used to describe the look of the character)')
             with gr.Row():
             with gr.Row():
                 negative_prompt = gr.Textbox(placeholder=params['negative_prompt'], value=params['negative_prompt'], label='Negative Prompt')
                 negative_prompt = gr.Textbox(placeholder=params['negative_prompt'], value=params['negative_prompt'], label='Negative Prompt')
-                dimensions = gr.Slider(256,702,value=params['side_length'],step=64,label='Image dimensions')
+                dimensions = gr.Slider(256, 702, value=params['side_length'], step=64, label='Image dimensions')
                 # model = gr.Dropdown(value=SD_models[0], choices=SD_models, label='Model')
                 # model = gr.Dropdown(value=SD_models[0], choices=SD_models, label='Model')
-    
+
     # Event functions to update the parameters in the backend
     # Event functions to update the parameters in the backend
     enable.change(lambda x: params.update({"enable_SD_api": x}), enable, None)
     enable.change(lambda x: params.update({"enable_SD_api": x}), enable, None)
     save_img.change(lambda x: params.update({"save_img": x}), save_img, None)
     save_img.change(lambda x: params.update({"save_img": x}), save_img, None)
@@ -176,4 +186,4 @@ def ui():
 
 
     force_btn.click(force_pic)
     force_btn.click(force_pic)
     generate_now_btn.click(force_pic)
     generate_now_btn.click(force_pic)
-    generate_now_btn.click(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
+    generate_now_btn.click(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)

+ 4 - 1
extensions/send_pictures/script.py

@@ -17,11 +17,13 @@ input_hijack = {
 processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
 processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
 model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float32).to("cpu")
 model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float32).to("cpu")
 
 
+
 def caption_image(raw_image):
 def caption_image(raw_image):
     inputs = processor(raw_image.convert('RGB'), return_tensors="pt").to("cpu", torch.float32)
     inputs = processor(raw_image.convert('RGB'), return_tensors="pt").to("cpu", torch.float32)
     out = model.generate(**inputs, max_new_tokens=100)
     out = model.generate(**inputs, max_new_tokens=100)
     return processor.decode(out[0], skip_special_tokens=True)
     return processor.decode(out[0], skip_special_tokens=True)
 
 
+
 def generate_chat_picture(picture, name1, name2):
 def generate_chat_picture(picture, name1, name2):
     text = f'*{name1} sends {name2} a picture that contains the following: "{caption_image(picture)}"*'
     text = f'*{name1} sends {name2} a picture that contains the following: "{caption_image(picture)}"*'
     # lower the resolution of sent images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
     # lower the resolution of sent images for the chat, otherwise the log size gets out of control quickly with all the base64 values in visible history
@@ -32,6 +34,7 @@ def generate_chat_picture(picture, name1, name2):
     visible_text = f'<img src="data:image/jpeg;base64,{img_str}" alt="{text}">'
     visible_text = f'<img src="data:image/jpeg;base64,{img_str}" alt="{text}">'
     return text, visible_text
     return text, visible_text
 
 
+
 def ui():
 def ui():
     picture_select = gr.Image(label='Send a picture', type='pil')
     picture_select = gr.Image(label='Send a picture', type='pil')
 
 
@@ -42,4 +45,4 @@ def ui():
     picture_select.upload(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
     picture_select.upload(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
 
 
     # Clear the picture from the upload field
     # Clear the picture from the upload field
-    picture_select.upload(lambda : None, [], [picture_select], show_progress=False)
+    picture_select.upload(lambda: None, [], [picture_select], show_progress=False)

+ 11 - 8
modules/GPTQ_loader.py

@@ -17,9 +17,11 @@ from quant import make_quant
 
 
 
 
 def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exclude_layers=['lm_head'], kernel_switch_threshold=128):
 def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exclude_layers=['lm_head'], kernel_switch_threshold=128):
-    config = AutoConfig.from_pretrained(model)
+
     def noop(*args, **kwargs):
     def noop(*args, **kwargs):
         pass
         pass
+
+    config = AutoConfig.from_pretrained(model)
     torch.nn.init.kaiming_uniform_ = noop
     torch.nn.init.kaiming_uniform_ = noop
     torch.nn.init.uniform_ = noop
     torch.nn.init.uniform_ = noop
     torch.nn.init.normal_ = noop
     torch.nn.init.normal_ = noop
@@ -34,11 +36,11 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
     for name in exclude_layers:
     for name in exclude_layers:
         if name in layers:
         if name in layers:
             del layers[name]
             del layers[name]
-    
+
     gptq_args = inspect.getfullargspec(make_quant).args
     gptq_args = inspect.getfullargspec(make_quant).args
 
 
     make_quant_kwargs = {
     make_quant_kwargs = {
-        'module': model, 
+        'module': model,
         'names': layers,
         'names': layers,
         'bits': wbits,
         'bits': wbits,
     }
     }
@@ -48,7 +50,7 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
         make_quant_kwargs['faster'] = faster_kernel
         make_quant_kwargs['faster'] = faster_kernel
     if 'kernel_switch_threshold' in gptq_args:
     if 'kernel_switch_threshold' in gptq_args:
         make_quant_kwargs['kernel_switch_threshold'] = kernel_switch_threshold
         make_quant_kwargs['kernel_switch_threshold'] = kernel_switch_threshold
-    
+
     make_quant(**make_quant_kwargs)
     make_quant(**make_quant_kwargs)
 
 
     del layers
     del layers
@@ -56,14 +58,15 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
     print('Loading model ...')
     print('Loading model ...')
     if checkpoint.endswith('.safetensors'):
     if checkpoint.endswith('.safetensors'):
         from safetensors.torch import load_file as safe_load
         from safetensors.torch import load_file as safe_load
-        model.load_state_dict(safe_load(checkpoint), strict = False)
+        model.load_state_dict(safe_load(checkpoint), strict=False)
     else:
     else:
-        model.load_state_dict(torch.load(checkpoint), strict = False)
+        model.load_state_dict(torch.load(checkpoint), strict=False)
     model.seqlen = 2048
     model.seqlen = 2048
     print('Done.')
     print('Done.')
 
 
     return model
     return model
 
 
+
 def load_quantized(model_name):
 def load_quantized(model_name):
     if not shared.args.model_type:
     if not shared.args.model_type:
         # Try to determine model type from model name
         # Try to determine model type from model name
@@ -114,7 +117,7 @@ def load_quantized(model_name):
             pt_model = f'{model_name}-{shared.args.wbits}bit'
             pt_model = f'{model_name}-{shared.args.wbits}bit'
 
 
         # Try to find the .safetensors or .pt both in the model dir and in the subfolder
         # Try to find the .safetensors or .pt both in the model dir and in the subfolder
-        for path in [Path(p+ext) for ext in ['.safetensors', '.pt'] for p in [f"{shared.args.model_dir}/{pt_model}", f"{path_to_model}/{pt_model}"]]:
+        for path in [Path(p + ext) for ext in ['.safetensors', '.pt'] for p in [f"{shared.args.model_dir}/{pt_model}", f"{path_to_model}/{pt_model}"]]:
             if path.exists():
             if path.exists():
                 print(f"Found {path}")
                 print(f"Found {path}")
                 pt_path = path
                 pt_path = path
@@ -133,7 +136,7 @@ def load_quantized(model_name):
 
 
         # accelerate offload (doesn't work properly)
         # accelerate offload (doesn't work properly)
         if shared.args.gpu_memory:
         if shared.args.gpu_memory:
-            memory_map = list(map(lambda x : x.strip(), shared.args.gpu_memory))
+            memory_map = list(map(lambda x: x.strip(), shared.args.gpu_memory))
             max_cpu_memory = shared.args.cpu_memory.strip() if shared.args.cpu_memory is not None else '99GiB'
             max_cpu_memory = shared.args.cpu_memory.strip() if shared.args.cpu_memory is not None else '99GiB'
             max_memory = {}
             max_memory = {}
             for i in range(len(memory_map)):
             for i in range(len(memory_map)):

+ 3 - 2
modules/LoRA.py

@@ -13,6 +13,7 @@ def reload_model():
     clear_torch_cache()
     clear_torch_cache()
     shared.model, shared.tokenizer = load_model(shared.model_name)
     shared.model, shared.tokenizer = load_model(shared.model_name)
 
 
+
 def add_lora_to_model(lora_name):
 def add_lora_to_model(lora_name):
 
 
     # If a LoRA had been previously loaded, or if we want
     # If a LoRA had been previously loaded, or if we want
@@ -27,10 +28,10 @@ def add_lora_to_model(lora_name):
         if not shared.args.cpu:
         if not shared.args.cpu:
             params['dtype'] = shared.model.dtype
             params['dtype'] = shared.model.dtype
             if hasattr(shared.model, "hf_device_map"):
             if hasattr(shared.model, "hf_device_map"):
-                params['device_map'] = {"base_model.model."+k: v for k, v in shared.model.hf_device_map.items()}
+                params['device_map'] = {"base_model.model." + k: v for k, v in shared.model.hf_device_map.items()}
             elif shared.args.load_in_8bit:
             elif shared.args.load_in_8bit:
                 params['device_map'] = {'': 0}
                 params['device_map'] = {'': 0}
-            
+
         shared.model = PeftModel.from_pretrained(shared.model, Path(f"{shared.args.lora_dir}/{lora_name}"), **params)
         shared.model = PeftModel.from_pretrained(shared.model, Path(f"{shared.args.lora_dir}/{lora_name}"), **params)
         if not shared.args.load_in_8bit and not shared.args.cpu:
         if not shared.args.load_in_8bit and not shared.args.cpu:
             shared.model.half()
             shared.model.half()

+ 9 - 8
modules/RWKV.py

@@ -10,7 +10,7 @@ from modules.callbacks import Iteratorize
 np.set_printoptions(precision=4, suppress=True, linewidth=200)
 np.set_printoptions(precision=4, suppress=True, linewidth=200)
 
 
 os.environ['RWKV_JIT_ON'] = '1'
 os.environ['RWKV_JIT_ON'] = '1'
-os.environ["RWKV_CUDA_ON"] = '1' if shared.args.rwkv_cuda_on else '0' # use CUDA kernel for seq mode (much faster)
+os.environ["RWKV_CUDA_ON"] = '1' if shared.args.rwkv_cuda_on else '0'  # use CUDA kernel for seq mode (much faster)
 
 
 from rwkv.model import RWKV
 from rwkv.model import RWKV
 from rwkv.utils import PIPELINE, PIPELINE_ARGS
 from rwkv.utils import PIPELINE, PIPELINE_ARGS
@@ -36,13 +36,13 @@ class RWKVModel:
 
 
     def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=None, alpha_frequency=0.1, alpha_presence=0.1, token_ban=[0], token_stop=[], callback=None):
     def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=None, alpha_frequency=0.1, alpha_presence=0.1, token_ban=[0], token_stop=[], callback=None):
         args = PIPELINE_ARGS(
         args = PIPELINE_ARGS(
-            temperature = temperature,
-            top_p = top_p,
-            top_k = top_k,
-            alpha_frequency = alpha_frequency, # Frequency Penalty (as in GPT-3)
-            alpha_presence = alpha_presence, # Presence Penalty (as in GPT-3)
-            token_ban = token_ban, # ban the generation of some tokens
-            token_stop = token_stop
+            temperature=temperature,
+            top_p=top_p,
+            top_k=top_k,
+            alpha_frequency=alpha_frequency,  # Frequency Penalty (as in GPT-3)
+            alpha_presence=alpha_presence,  # Presence Penalty (as in GPT-3)
+            token_ban=token_ban,  # ban the generation of some tokens
+            token_stop=token_stop
         )
         )
 
 
         return self.pipeline.generate(context, token_count=token_count, args=args, callback=callback)
         return self.pipeline.generate(context, token_count=token_count, args=args, callback=callback)
@@ -54,6 +54,7 @@ class RWKVModel:
                 reply += token
                 reply += token
                 yield reply
                 yield reply
 
 
+
 class RWKVTokenizer:
 class RWKVTokenizer:
     def __init__(self):
     def __init__(self):
         pass
         pass

+ 1 - 0
modules/api.py

@@ -28,6 +28,7 @@ def generate_reply_wrapper(string):
     for i in generate_reply(params[0], generate_params):
     for i in generate_reply(params[0], generate_params):
         yield i
         yield i
 
 
+
 def create_apis():
 def create_apis():
     t1 = gr.Textbox(visible=False)
     t1 = gr.Textbox(visible=False)
     t2 = gr.Textbox(visible=False)
     t2 = gr.Textbox(visible=False)

+ 6 - 3
modules/callbacks.py

@@ -30,6 +30,7 @@ class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
                         return True
                         return True
         return False
         return False
 
 
+
 class Stream(transformers.StoppingCriteria):
 class Stream(transformers.StoppingCriteria):
     def __init__(self, callback_func=None):
     def __init__(self, callback_func=None):
         self.callback_func = callback_func
         self.callback_func = callback_func
@@ -39,6 +40,7 @@ class Stream(transformers.StoppingCriteria):
             self.callback_func(input_ids[0])
             self.callback_func(input_ids[0])
         return False
         return False
 
 
+
 class Iteratorize:
 class Iteratorize:
 
 
     """
     """
@@ -47,8 +49,8 @@ class Iteratorize:
     """
     """
 
 
     def __init__(self, func, kwargs={}, callback=None):
     def __init__(self, func, kwargs={}, callback=None):
-        self.mfunc=func
-        self.c_callback=callback
+        self.mfunc = func
+        self.c_callback = callback
         self.q = Queue()
         self.q = Queue()
         self.sentinel = object()
         self.sentinel = object()
         self.kwargs = kwargs
         self.kwargs = kwargs
@@ -80,7 +82,7 @@ class Iteratorize:
         return self
         return self
 
 
     def __next__(self):
     def __next__(self):
-        obj = self.q.get(True,None)
+        obj = self.q.get(True, None)
         if obj is self.sentinel:
         if obj is self.sentinel:
             raise StopIteration
             raise StopIteration
         else:
         else:
@@ -96,6 +98,7 @@ class Iteratorize:
         self.stop_now = True
         self.stop_now = True
         clear_torch_cache()
         clear_torch_cache()
 
 
+
 def clear_torch_cache():
 def clear_torch_cache():
     gc.collect()
     gc.collect()
     if not shared.args.cpu:
     if not shared.args.cpu:

+ 39 - 18
modules/chat.py

@@ -23,12 +23,11 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
     end_of_turn = kwargs['end_of_turn'] if 'end_of_turn' in kwargs else ''
     end_of_turn = kwargs['end_of_turn'] if 'end_of_turn' in kwargs else ''
     impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
     impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
     also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
     also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
-
     rows = [f"{context.strip()}\n"]
     rows = [f"{context.strip()}\n"]
 
 
     # Finding the maximum prompt size
     # Finding the maximum prompt size
     if shared.soft_prompt:
     if shared.soft_prompt:
-       chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
+        chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
     max_length = min(get_max_prompt_length(max_new_tokens), chat_prompt_size)
     max_length = min(get_max_prompt_length(max_new_tokens), chat_prompt_size)
 
 
     if is_instruct:
     if is_instruct:
@@ -38,7 +37,7 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
         prefix1 = f"{name1}: "
         prefix1 = f"{name1}: "
         prefix2 = f"{name2}: "
         prefix2 = f"{name2}: "
 
 
-    i = len(shared.history['internal'])-1
+    i = len(shared.history['internal']) - 1
     while i >= 0 and len(encode(''.join(rows), max_new_tokens)[0]) < max_length:
     while i >= 0 and len(encode(''.join(rows), max_new_tokens)[0]) < max_length:
         rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{end_of_turn}\n")
         rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{end_of_turn}\n")
         string = shared.history['internal'][i][0]
         string = shared.history['internal'][i][0]
@@ -68,6 +67,7 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
     else:
     else:
         return prompt
         return prompt
 
 
+
 def extract_message_from_reply(reply, name1, name2, stop_at_newline):
 def extract_message_from_reply(reply, name1, name2, stop_at_newline):
     next_character_found = False
     next_character_found = False
 
 
@@ -87,7 +87,7 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline):
         # is completed, trim it
         # is completed, trim it
         if not next_character_found:
         if not next_character_found:
             for string in [f"\n{name1}:", f"\n{name2}:"]:
             for string in [f"\n{name1}:", f"\n{name2}:"]:
-                for j in range(len(string)-1, 0, -1):
+                for j in range(len(string) - 1, 0, -1):
                     if reply[-j:] == string[:j]:
                     if reply[-j:] == string[:j]:
                         reply = reply[:-j]
                         reply = reply[:-j]
                         break
                         break
@@ -98,12 +98,13 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline):
     reply = fix_newlines(reply)
     reply = fix_newlines(reply)
     return reply, next_character_found
     return reply, next_character_found
 
 
+
 def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False):
 def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False):
     if mode == 'instruct':
     if mode == 'instruct':
         stopping_strings = [f"\n{name1}", f"\n{name2}"]
         stopping_strings = [f"\n{name1}", f"\n{name2}"]
     else:
     else:
         stopping_strings = [f"\n{name1}:", f"\n{name2}:"]
         stopping_strings = [f"\n{name1}:", f"\n{name2}:"]
-        
+
     eos_token = '\n' if generate_state['stop_at_newline'] else None
     eos_token = '\n' if generate_state['stop_at_newline'] else None
     name1_original = name1
     name1_original = name1
     if 'pygmalion' in shared.model_name.lower():
     if 'pygmalion' in shared.model_name.lower():
@@ -113,7 +114,7 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
     visible_text = None
     visible_text = None
     custom_generate_chat_prompt = None
     custom_generate_chat_prompt = None
     for extension, _ in extensions_module.iterator():
     for extension, _ in extensions_module.iterator():
-        if hasattr(extension, 'input_hijack') and extension.input_hijack['state'] == True:
+        if hasattr(extension, 'input_hijack') and extension.input_hijack['state']:
             extension.input_hijack['state'] = False
             extension.input_hijack['state'] = False
             text, visible_text = extension.input_hijack['value']
             text, visible_text = extension.input_hijack['value']
         if custom_generate_chat_prompt is None and hasattr(extension, 'custom_generate_chat_prompt'):
         if custom_generate_chat_prompt is None and hasattr(extension, 'custom_generate_chat_prompt'):
@@ -131,7 +132,7 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
 
 
     # Yield *Is typing...*
     # Yield *Is typing...*
     if not regenerate:
     if not regenerate:
-        yield shared.history['visible']+[[visible_text, shared.processing_message]]
+        yield shared.history['visible'] + [[visible_text, shared.processing_message]]
 
 
     # Generate
     # Generate
     cumulative_reply = ''
     cumulative_reply = ''
@@ -167,12 +168,13 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
 
 
     yield shared.history['visible']
     yield shared.history['visible']
 
 
+
 def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
 def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
     if mode == 'instruct':
     if mode == 'instruct':
         stopping_strings = [f"\n{name1}", f"\n{name2}"]
         stopping_strings = [f"\n{name1}", f"\n{name2}"]
     else:
     else:
         stopping_strings = [f"\n{name1}:", f"\n{name2}:"]
         stopping_strings = [f"\n{name1}:", f"\n{name2}:"]
-        
+
     eos_token = '\n' if generate_state['stop_at_newline'] else None
     eos_token = '\n' if generate_state['stop_at_newline'] else None
     if 'pygmalion' in shared.model_name.lower():
     if 'pygmalion' in shared.model_name.lower():
         name1 = "You"
         name1 = "You"
@@ -197,10 +199,12 @@ def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_o
 
 
     yield reply
     yield reply
 
 
+
 def cai_chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
 def cai_chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
     for history in chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
     for history in chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
         yield chat_html_wrapper(history, name1, name2, mode)
         yield chat_html_wrapper(history, name1, name2, mode)
 
 
+
 def regenerate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
 def regenerate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
     if (shared.character != 'None' and len(shared.history['visible']) == 1) or len(shared.history['internal']) == 0:
     if (shared.character != 'None' and len(shared.history['visible']) == 1) or len(shared.history['internal']) == 0:
         yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
         yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
@@ -208,11 +212,12 @@ def regenerate_wrapper(text, generate_state, name1, name2, context, mode, end_of
         last_visible = shared.history['visible'].pop()
         last_visible = shared.history['visible'].pop()
         last_internal = shared.history['internal'].pop()
         last_internal = shared.history['internal'].pop()
         # Yield '*Is typing...*'
         # Yield '*Is typing...*'
-        yield chat_html_wrapper(shared.history['visible']+[[last_visible[0], shared.processing_message]], name1, name2, mode)
+        yield chat_html_wrapper(shared.history['visible'] + [[last_visible[0], shared.processing_message]], name1, name2, mode)
         for history in chatbot_wrapper(last_internal[0], generate_state, name1, name2, context, mode, end_of_turn, regenerate=True):
         for history in chatbot_wrapper(last_internal[0], generate_state, name1, name2, context, mode, end_of_turn, regenerate=True):
             shared.history['visible'][-1] = [last_visible[0], history[-1][1]]
             shared.history['visible'][-1] = [last_visible[0], history[-1][1]]
             yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
             yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
 
 
+
 def remove_last_message(name1, name2, mode):
 def remove_last_message(name1, name2, mode):
     if len(shared.history['visible']) > 0 and shared.history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>':
     if len(shared.history['visible']) > 0 and shared.history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>':
         last = shared.history['visible'].pop()
         last = shared.history['visible'].pop()
@@ -222,12 +227,14 @@ def remove_last_message(name1, name2, mode):
 
 
     return chat_html_wrapper(shared.history['visible'], name1, name2, mode), last[0]
     return chat_html_wrapper(shared.history['visible'], name1, name2, mode), last[0]
 
 
+
 def send_last_reply_to_input():
 def send_last_reply_to_input():
     if len(shared.history['internal']) > 0:
     if len(shared.history['internal']) > 0:
         return shared.history['internal'][-1][1]
         return shared.history['internal'][-1][1]
     else:
     else:
         return ''
         return ''
 
 
+
 def replace_last_reply(text, name1, name2, mode):
 def replace_last_reply(text, name1, name2, mode):
     if len(shared.history['visible']) > 0:
     if len(shared.history['visible']) > 0:
         shared.history['visible'][-1][1] = text
         shared.history['visible'][-1][1] = text
@@ -235,9 +242,11 @@ def replace_last_reply(text, name1, name2, mode):
 
 
     return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
     return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
 
 
+
 def clear_html():
 def clear_html():
     return chat_html_wrapper([], "", "")
     return chat_html_wrapper([], "", "")
 
 
+
 def clear_chat_log(name1, name2, greeting, mode):
 def clear_chat_log(name1, name2, greeting, mode):
     shared.history['visible'] = []
     shared.history['visible'] = []
     shared.history['internal'] = []
     shared.history['internal'] = []
@@ -248,9 +257,11 @@ def clear_chat_log(name1, name2, greeting, mode):
 
 
     return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
     return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
 
 
+
 def redraw_html(name1, name2, mode):
 def redraw_html(name1, name2, mode):
     return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
     return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
 
 
+
 def tokenize_dialogue(dialogue, name1, name2, mode):
 def tokenize_dialogue(dialogue, name1, name2, mode):
     history = []
     history = []
 
 
@@ -263,8 +274,8 @@ def tokenize_dialogue(dialogue, name1, name2, mode):
         return history
         return history
 
 
     messages = []
     messages = []
-    for i in range(len(idx)-1):
-        messages.append(dialogue[idx[i]:idx[i+1]].strip())
+    for i in range(len(idx) - 1):
+        messages.append(dialogue[idx[i]:idx[i + 1]].strip())
     messages.append(dialogue[idx[-1]:].strip())
     messages.append(dialogue[idx[-1]:].strip())
 
 
     entry = ['', '']
     entry = ['', '']
@@ -282,12 +293,13 @@ def tokenize_dialogue(dialogue, name1, name2, mode):
         for column in row:
         for column in row:
             print("\n")
             print("\n")
             for line in column.strip().split('\n'):
             for line in column.strip().split('\n'):
-                print("|  "+line+"\n")
+                print("|  " + line + "\n")
             print("|\n")
             print("|\n")
         print("------------------------------")
         print("------------------------------")
 
 
     return history
     return history
 
 
+
 def save_history(timestamp=True):
 def save_history(timestamp=True):
     if timestamp:
     if timestamp:
         fname = f"{shared.character}_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
         fname = f"{shared.character}_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
@@ -299,6 +311,7 @@ def save_history(timestamp=True):
         f.write(json.dumps({'data': shared.history['internal'], 'data_visible': shared.history['visible']}, indent=2))
         f.write(json.dumps({'data': shared.history['internal'], 'data_visible': shared.history['visible']}, indent=2))
     return Path(f'logs/{fname}')
     return Path(f'logs/{fname}')
 
 
+
 def load_history(file, name1, name2):
 def load_history(file, name1, name2):
     file = file.decode('utf-8')
     file = file.decode('utf-8')
     try:
     try:
@@ -313,20 +326,22 @@ def load_history(file, name1, name2):
         elif 'chat' in j:
         elif 'chat' in j:
             shared.history['internal'] = [':'.join(x.split(':')[1:]).strip() for x in j['chat']]
             shared.history['internal'] = [':'.join(x.split(':')[1:]).strip() for x in j['chat']]
             if len(j['chat']) > 0 and j['chat'][0].startswith(f'{name2}:'):
             if len(j['chat']) > 0 and j['chat'][0].startswith(f'{name2}:'):
-                shared.history['internal'] = [['<|BEGIN-VISIBLE-CHAT|>', shared.history['internal'][0]]] + [[shared.history['internal'][i], shared.history['internal'][i+1]] for i in range(1, len(shared.history['internal'])-1, 2)]
+                shared.history['internal'] = [['<|BEGIN-VISIBLE-CHAT|>', shared.history['internal'][0]]] + [[shared.history['internal'][i], shared.history['internal'][i + 1]] for i in range(1, len(shared.history['internal']) - 1, 2)]
                 shared.history['visible'] = copy.deepcopy(shared.history['internal'])
                 shared.history['visible'] = copy.deepcopy(shared.history['internal'])
                 shared.history['visible'][0][0] = ''
                 shared.history['visible'][0][0] = ''
             else:
             else:
-                shared.history['internal'] = [[shared.history['internal'][i], shared.history['internal'][i+1]] for i in range(0, len(shared.history['internal'])-1, 2)]
+                shared.history['internal'] = [[shared.history['internal'][i], shared.history['internal'][i + 1]] for i in range(0, len(shared.history['internal']) - 1, 2)]
                 shared.history['visible'] = copy.deepcopy(shared.history['internal'])
                 shared.history['visible'] = copy.deepcopy(shared.history['internal'])
     except:
     except:
         shared.history['internal'] = tokenize_dialogue(file, name1, name2)
         shared.history['internal'] = tokenize_dialogue(file, name1, name2)
         shared.history['visible'] = copy.deepcopy(shared.history['internal'])
         shared.history['visible'] = copy.deepcopy(shared.history['internal'])
 
 
+
 def replace_character_names(text, name1, name2):
 def replace_character_names(text, name1, name2):
     text = text.replace('{{user}}', name1).replace('{{char}}', name2)
     text = text.replace('{{user}}', name1).replace('{{char}}', name2)
     return text.replace('<USER>', name1).replace('<BOT>', name2)
     return text.replace('<USER>', name1).replace('<BOT>', name2)
 
 
+
 def build_pygmalion_style_context(data):
 def build_pygmalion_style_context(data):
     context = ""
     context = ""
     if 'char_persona' in data and data['char_persona'] != '':
     if 'char_persona' in data and data['char_persona'] != '':
@@ -336,6 +351,7 @@ def build_pygmalion_style_context(data):
     context = f"{context.strip()}\n<START>\n"
     context = f"{context.strip()}\n<START>\n"
     return context
     return context
 
 
+
 def generate_pfp_cache(character):
 def generate_pfp_cache(character):
     cache_folder = Path("cache")
     cache_folder = Path("cache")
     if not cache_folder.exists():
     if not cache_folder.exists():
@@ -348,6 +364,7 @@ def generate_pfp_cache(character):
             return img
             return img
     return None
     return None
 
 
+
 def load_character(character, name1, name2, mode):
 def load_character(character, name1, name2, mode):
     shared.character = character
     shared.character = character
     shared.history['internal'] = []
     shared.history['internal'] = []
@@ -387,13 +404,13 @@ def load_character(character, name1, name2, mode):
         if 'example_dialogue' in data:
         if 'example_dialogue' in data:
             context += f"{data['example_dialogue'].strip()}\n"
             context += f"{data['example_dialogue'].strip()}\n"
         if greeting_field in data:
         if greeting_field in data:
-            greeting = data[greeting_field]  
+            greeting = data[greeting_field]
         if 'end_of_turn' in data:
         if 'end_of_turn' in data:
-            end_of_turn = data['end_of_turn']  
+            end_of_turn = data['end_of_turn']
     else:
     else:
         context = shared.settings['context']
         context = shared.settings['context']
         name2 = shared.settings['name2']
         name2 = shared.settings['name2']
-        greeting = shared.settings['greeting'] 
+        greeting = shared.settings['greeting']
         end_of_turn = shared.settings['end_of_turn']
         end_of_turn = shared.settings['end_of_turn']
 
 
     if Path(f'logs/{shared.character}_persistent.json').exists():
     if Path(f'logs/{shared.character}_persistent.json').exists():
@@ -404,9 +421,11 @@ def load_character(character, name1, name2, mode):
 
 
     return name1, name2, picture, greeting, context, end_of_turn, chat_html_wrapper(shared.history['visible'], name1, name2, mode, reset_cache=True)
     return name1, name2, picture, greeting, context, end_of_turn, chat_html_wrapper(shared.history['visible'], name1, name2, mode, reset_cache=True)
 
 
+
 def load_default_history(name1, name2):
 def load_default_history(name1, name2):
     load_character("None", name1, name2, "chat")
     load_character("None", name1, name2, "chat")
 
 
+
 def upload_character(json_file, img, tavern=False):
 def upload_character(json_file, img, tavern=False):
     json_file = json_file if type(json_file) == str else json_file.decode('utf-8')
     json_file = json_file if type(json_file) == str else json_file.decode('utf-8')
     data = json.loads(json_file)
     data = json.loads(json_file)
@@ -425,6 +444,7 @@ def upload_character(json_file, img, tavern=False):
     print(f'New character saved to "characters/{outfile_name}.json".')
     print(f'New character saved to "characters/{outfile_name}.json".')
     return outfile_name
     return outfile_name
 
 
+
 def upload_tavern_character(img, name1, name2):
 def upload_tavern_character(img, name1, name2):
     _img = Image.open(io.BytesIO(img))
     _img = Image.open(io.BytesIO(img))
     _img.getexif()
     _img.getexif()
@@ -433,12 +453,13 @@ def upload_tavern_character(img, name1, name2):
     _json = {"char_name": _json['name'], "char_persona": _json['description'], "char_greeting": _json["first_mes"], "example_dialogue": _json['mes_example'], "world_scenario": _json['scenario']}
     _json = {"char_name": _json['name'], "char_persona": _json['description'], "char_greeting": _json["first_mes"], "example_dialogue": _json['mes_example'], "world_scenario": _json['scenario']}
     return upload_character(json.dumps(_json), img, tavern=True)
     return upload_character(json.dumps(_json), img, tavern=True)
 
 
+
 def upload_your_profile_picture(img, name1, name2, mode):
 def upload_your_profile_picture(img, name1, name2, mode):
     cache_folder = Path("cache")
     cache_folder = Path("cache")
     if not cache_folder.exists():
     if not cache_folder.exists():
         cache_folder.mkdir()
         cache_folder.mkdir()
 
 
-    if img == None:
+    if img is None:
         if Path("cache/pfp_me.png").exists():
         if Path("cache/pfp_me.png").exists():
             Path("cache/pfp_me.png").unlink()
             Path("cache/pfp_me.png").unlink()
     else:
     else:

+ 7 - 1
modules/extensions.py

@@ -9,6 +9,7 @@ state = {}
 available_extensions = []
 available_extensions = []
 setup_called = set()
 setup_called = set()
 
 
+
 def load_extensions():
 def load_extensions():
     global state
     global state
     for i, name in enumerate(shared.args.extensions):
     for i, name in enumerate(shared.args.extensions):
@@ -23,12 +24,16 @@ def load_extensions():
                 traceback.print_exc()
                 traceback.print_exc()
 
 
 # This iterator returns the extensions in the order specified in the command-line
 # This iterator returns the extensions in the order specified in the command-line
+
+
 def iterator():
 def iterator():
-    for name in sorted(state, key=lambda x : state[x][1]):
+    for name in sorted(state, key=lambda x: state[x][1]):
         if state[name][0] == True:
         if state[name][0] == True:
             yield eval(f"extensions.{name}.script"), name
             yield eval(f"extensions.{name}.script"), name
 
 
 # Extension functions that map string -> string
 # Extension functions that map string -> string
+
+
 def apply_extensions(text, typ):
 def apply_extensions(text, typ):
     for extension, _ in iterator():
     for extension, _ in iterator():
         if typ == "input" and hasattr(extension, "input_modifier"):
         if typ == "input" and hasattr(extension, "input_modifier"):
@@ -39,6 +44,7 @@ def apply_extensions(text, typ):
             text = extension.bot_prefix_modifier(text)
             text = extension.bot_prefix_modifier(text)
     return text
     return text
 
 
+
 def create_extensions_block():
 def create_extensions_block():
     global setup_called
     global setup_called
 
 

+ 19 - 7
modules/html_generator.py

@@ -24,6 +24,7 @@ with open(Path(__file__).resolve().parent / '../css/html_cai_style.css', 'r') as
 with open(Path(__file__).resolve().parent / '../css/html_instruct_style.css', 'r') as f:
 with open(Path(__file__).resolve().parent / '../css/html_instruct_style.css', 'r') as f:
     instruct_css = f.read()
     instruct_css = f.read()
 
 
+
 def fix_newlines(string):
 def fix_newlines(string):
     string = string.replace('\n', '\n\n')
     string = string.replace('\n', '\n\n')
     string = re.sub(r"\n{3,}", "\n\n", string)
     string = re.sub(r"\n{3,}", "\n\n", string)
@@ -31,6 +32,8 @@ def fix_newlines(string):
     return string
     return string
 
 
 # This could probably be generalized and improved
 # This could probably be generalized and improved
+
+
 def convert_to_markdown(string):
 def convert_to_markdown(string):
     string = string.replace('\\begin{code}', '```')
     string = string.replace('\\begin{code}', '```')
     string = string.replace('\\end{code}', '```')
     string = string.replace('\\end{code}', '```')
@@ -38,13 +41,15 @@ def convert_to_markdown(string):
     string = string.replace('\\end{blockquote}', '')
     string = string.replace('\\end{blockquote}', '')
     string = re.sub(r"(.)```", r"\1\n```", string)
     string = re.sub(r"(.)```", r"\1\n```", string)
     string = fix_newlines(string)
     string = fix_newlines(string)
-    return markdown.markdown(string, extensions=['fenced_code']) 
+    return markdown.markdown(string, extensions=['fenced_code'])
+
 
 
 def generate_basic_html(string):
 def generate_basic_html(string):
     string = convert_to_markdown(string)
     string = convert_to_markdown(string)
     string = f'<style>{readable_css}</style><div class="container">{string}</div>'
     string = f'<style>{readable_css}</style><div class="container">{string}</div>'
     return string
     return string
 
 
+
 def process_post(post, c):
 def process_post(post, c):
     t = post.split('\n')
     t = post.split('\n')
     number = t[0].split(' ')[1]
     number = t[0].split(' ')[1]
@@ -59,6 +64,7 @@ def process_post(post, c):
     src = f'<span class="name">Anonymous </span> <span class="number">No.{number}</span>\n{src}'
     src = f'<span class="name">Anonymous </span> <span class="number">No.{number}</span>\n{src}'
     return src
     return src
 
 
+
 def generate_4chan_html(f):
 def generate_4chan_html(f):
     posts = []
     posts = []
     post = ''
     post = ''
@@ -84,7 +90,7 @@ def generate_4chan_html(f):
             posts[i] = f'<div class="op">{posts[i]}</div>\n'
             posts[i] = f'<div class="op">{posts[i]}</div>\n'
         else:
         else:
             posts[i] = f'<div class="reply">{posts[i]}</div>\n'
             posts[i] = f'<div class="reply">{posts[i]}</div>\n'
-    
+
     output = ''
     output = ''
     output += f'<style>{_4chan_css}</style><div id="parent"><div id="container">'
     output += f'<style>{_4chan_css}</style><div id="parent"><div id="container">'
     for post in posts:
     for post in posts:
@@ -98,13 +104,15 @@ def generate_4chan_html(f):
 
 
     return output
     return output
 
 
+
 def make_thumbnail(image):
 def make_thumbnail(image):
-    image = image.resize((350, round(image.size[1]/image.size[0]*350)), Image.Resampling.LANCZOS)
+    image = image.resize((350, round(image.size[1] / image.size[0] * 350)), Image.Resampling.LANCZOS)
     if image.size[1] > 470:
     if image.size[1] > 470:
         image = ImageOps.fit(image, (350, 470), Image.ANTIALIAS)
         image = ImageOps.fit(image, (350, 470), Image.ANTIALIAS)
 
 
     return image
     return image
 
 
+
 def get_image_cache(path):
 def get_image_cache(path):
     cache_folder = Path("cache")
     cache_folder = Path("cache")
     if not cache_folder.exists():
     if not cache_folder.exists():
@@ -119,9 +127,10 @@ def get_image_cache(path):
 
 
     return image_cache[path][1]
     return image_cache[path][1]
 
 
+
 def generate_instruct_html(history):
 def generate_instruct_html(history):
     output = f'<style>{instruct_css}</style><div class="chat" id="chat">'
     output = f'<style>{instruct_css}</style><div class="chat" id="chat">'
-    for i,_row in enumerate(history[::-1]):
+    for i, _row in enumerate(history[::-1]):
         row = [convert_to_markdown(entry) for entry in _row]
         row = [convert_to_markdown(entry) for entry in _row]
 
 
         output += f"""
         output += f"""
@@ -134,7 +143,7 @@ def generate_instruct_html(history):
               </div>
               </div>
             """
             """
 
 
-        if len(row[0]) == 0: # don't display empty user messages
+        if len(row[0]) == 0:  # don't display empty user messages
             continue
             continue
 
 
         output += f"""
         output += f"""
@@ -151,6 +160,7 @@ def generate_instruct_html(history):
 
 
     return output
     return output
 
 
+
 def generate_cai_chat_html(history, name1, name2, reset_cache=False):
 def generate_cai_chat_html(history, name1, name2, reset_cache=False):
     output = f'<style>{cai_css}</style><div class="chat" id="chat">'
     output = f'<style>{cai_css}</style><div class="chat" id="chat">'
 
 
@@ -159,7 +169,7 @@ def generate_cai_chat_html(history, name1, name2, reset_cache=False):
     img_bot = f'<img src="file/cache/pfp_character.png{suffix}">' if Path("cache/pfp_character.png").exists() else ''
     img_bot = f'<img src="file/cache/pfp_character.png{suffix}">' if Path("cache/pfp_character.png").exists() else ''
     img_me = f'<img src="file/cache/pfp_me.png{suffix}">' if Path("cache/pfp_me.png").exists() else ''
     img_me = f'<img src="file/cache/pfp_me.png{suffix}">' if Path("cache/pfp_me.png").exists() else ''
 
 
-    for i,_row in enumerate(history[::-1]):
+    for i, _row in enumerate(history[::-1]):
         row = [convert_to_markdown(entry) for entry in _row]
         row = [convert_to_markdown(entry) for entry in _row]
 
 
         output += f"""
         output += f"""
@@ -178,7 +188,7 @@ def generate_cai_chat_html(history, name1, name2, reset_cache=False):
               </div>
               </div>
             """
             """
 
 
-        if len(row[0]) == 0: # don't display empty user messages
+        if len(row[0]) == 0:  # don't display empty user messages
             continue
             continue
 
 
         output += f"""
         output += f"""
@@ -200,9 +210,11 @@ def generate_cai_chat_html(history, name1, name2, reset_cache=False):
     output += "</div>"
     output += "</div>"
     return output
     return output
 
 
+
 def generate_chat_html(history, name1, name2):
 def generate_chat_html(history, name1, name2):
     return generate_cai_chat_html(history, name1, name2)
     return generate_cai_chat_html(history, name1, name2)
 
 
+
 def chat_html_wrapper(history, name1, name2, mode, reset_cache=False):
 def chat_html_wrapper(history, name1, name2, mode, reset_cache=False):
     if mode == "cai-chat":
     if mode == "cai-chat":
         return generate_cai_chat_html(history, name1, name2, reset_cache)
         return generate_cai_chat_html(history, name1, name2, reset_cache)

+ 2 - 2
modules/llamacpp_model.py

@@ -50,9 +50,9 @@ class LlamaCppModel:
         params.top_k = top_k
         params.top_k = top_k
         params.temp = temperature
         params.temp = temperature
         params.repeat_penalty = repetition_penalty
         params.repeat_penalty = repetition_penalty
-        #params.repeat_last_n = repeat_last_n
+        # params.repeat_last_n = repeat_last_n
 
 
-        #self.model.params = params
+        # self.model.params = params
         self.model.add_bos()
         self.model.add_bos()
         self.model.update_input(context)
         self.model.update_input(context)
 
 

+ 2 - 4
modules/llamacpp_model_alternative.py

@@ -1,13 +1,11 @@
 '''
 '''
-Based on 
+Based on
 https://github.com/abetlen/llama-cpp-python
 https://github.com/abetlen/llama-cpp-python
 
 
 Documentation:
 Documentation:
 https://abetlen.github.io/llama-cpp-python/
 https://abetlen.github.io/llama-cpp-python/
 '''
 '''
 
 
-import multiprocessing
-
 from llama_cpp import Llama
 from llama_cpp import Llama
 
 
 from modules import shared
 from modules import shared
@@ -31,7 +29,7 @@ class LlamaCppModel:
         self.model = Llama(**params)
         self.model = Llama(**params)
 
 
         # This is ugly, but the model and the tokenizer are the same object in this library.
         # This is ugly, but the model and the tokenizer are the same object in this library.
-        return result, result 
+        return result, result
 
 
     def encode(self, string):
     def encode(self, string):
         if type(string) is str:
         if type(string) is str:

+ 11 - 10
modules/models.py

@@ -34,7 +34,7 @@ if shared.args.deepspeed:
     torch.cuda.set_device(local_rank)
     torch.cuda.set_device(local_rank)
     deepspeed.init_distributed()
     deepspeed.init_distributed()
     ds_config = generate_ds_config(shared.args.bf16, 1 * world_size, shared.args.nvme_offload_dir)
     ds_config = generate_ds_config(shared.args.bf16, 1 * world_size, shared.args.nvme_offload_dir)
-    dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration
+    dschf = HfDeepSpeedConfig(ds_config)  # Keep this object alive for the Transformers integration
 
 
 
 
 def load_model(model_name):
 def load_model(model_name):
@@ -83,7 +83,7 @@ def load_model(model_name):
     elif shared.args.deepspeed:
     elif shared.args.deepspeed:
         model = AutoModelForCausalLM.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}"), torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16)
         model = AutoModelForCausalLM.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}"), torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16)
         model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0]
         model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0]
-        model.module.eval() # Inference
+        model.module.eval()  # Inference
         print(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}")
         print(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}")
 
 
     # RMKV model (not on HuggingFace)
     # RMKV model (not on HuggingFace)
@@ -132,7 +132,7 @@ def load_model(model_name):
                 params["torch_dtype"] = torch.float16
                 params["torch_dtype"] = torch.float16
 
 
             if shared.args.gpu_memory:
             if shared.args.gpu_memory:
-                memory_map = list(map(lambda x : x.strip(), shared.args.gpu_memory))
+                memory_map = list(map(lambda x: x.strip(), shared.args.gpu_memory))
                 max_cpu_memory = shared.args.cpu_memory.strip() if shared.args.cpu_memory is not None else '99GiB'
                 max_cpu_memory = shared.args.cpu_memory.strip() if shared.args.cpu_memory is not None else '99GiB'
                 max_memory = {}
                 max_memory = {}
                 for i in range(len(memory_map)):
                 for i in range(len(memory_map)):
@@ -140,13 +140,13 @@ def load_model(model_name):
                 max_memory['cpu'] = max_cpu_memory
                 max_memory['cpu'] = max_cpu_memory
                 params['max_memory'] = max_memory
                 params['max_memory'] = max_memory
             elif shared.args.auto_devices:
             elif shared.args.auto_devices:
-                total_mem = (torch.cuda.get_device_properties(0).total_memory / (1024*1024))
-                suggestion = round((total_mem-1000) / 1000) * 1000
+                total_mem = (torch.cuda.get_device_properties(0).total_memory / (1024 * 1024))
+                suggestion = round((total_mem - 1000) / 1000) * 1000
                 if total_mem - suggestion < 800:
                 if total_mem - suggestion < 800:
                     suggestion -= 1000
                     suggestion -= 1000
-                suggestion = int(round(suggestion/1000))
+                suggestion = int(round(suggestion / 1000))
                 print(f"\033[1;32;1mAuto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors.\nYou can manually set other values.\033[0;37;0m")
                 print(f"\033[1;32;1mAuto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors.\nYou can manually set other values.\033[0;37;0m")
-                
+
                 max_memory = {0: f'{suggestion}GiB', 'cpu': f'{shared.args.cpu_memory or 99}GiB'}
                 max_memory = {0: f'{suggestion}GiB', 'cpu': f'{shared.args.cpu_memory or 99}GiB'}
                 params['max_memory'] = max_memory
                 params['max_memory'] = max_memory
 
 
@@ -161,10 +161,10 @@ def load_model(model_name):
                 model = AutoModelForCausalLM.from_config(config)
                 model = AutoModelForCausalLM.from_config(config)
             model.tie_weights()
             model.tie_weights()
             params['device_map'] = infer_auto_device_map(
             params['device_map'] = infer_auto_device_map(
-                model, 
-                dtype=torch.int8, 
+                model,
+                dtype=torch.int8,
                 max_memory=params['max_memory'],
                 max_memory=params['max_memory'],
-                no_split_module_classes = model._no_split_modules
+                no_split_module_classes=model._no_split_modules
             )
             )
 
 
         model = AutoModelForCausalLM.from_pretrained(checkpoint, **params)
         model = AutoModelForCausalLM.from_pretrained(checkpoint, **params)
@@ -181,6 +181,7 @@ def load_model(model_name):
     print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
     print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
     return model, tokenizer
     return model, tokenizer
 
 
+
 def load_soft_prompt(name):
 def load_soft_prompt(name):
     if name == 'None':
     if name == 'None':
         shared.soft_prompt = False
         shared.soft_prompt = False

+ 4 - 1
modules/shared.py

@@ -61,6 +61,7 @@ settings = {
     }
     }
 }
 }
 
 
+
 def str2bool(v):
 def str2bool(v):
     if isinstance(v, bool):
     if isinstance(v, bool):
         return v
         return v
@@ -71,7 +72,8 @@ def str2bool(v):
     else:
     else:
         raise argparse.ArgumentTypeError('Boolean value expected.')
         raise argparse.ArgumentTypeError('Boolean value expected.')
 
 
-parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54))
+
+parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=54))
 
 
 # Basic settings
 # Basic settings
 parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.')
 parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.')
@@ -145,5 +147,6 @@ if args.cai_chat:
     print("Warning: --cai-chat is deprecated. Use --chat instead.")
     print("Warning: --cai-chat is deprecated. Use --chat instead.")
     args.chat = True
     args.chat = True
 
 
+
 def is_chat():
 def is_chat():
     return args.chat
     return args.chat

+ 19 - 7
modules/text_generation.py

@@ -16,11 +16,12 @@ from modules.models import local_rank
 
 
 
 
 def get_max_prompt_length(tokens):
 def get_max_prompt_length(tokens):
-    max_length = 2048-tokens
+    max_length = 2048 - tokens
     if shared.soft_prompt:
     if shared.soft_prompt:
         max_length -= shared.soft_prompt_tensor.shape[1]
         max_length -= shared.soft_prompt_tensor.shape[1]
     return max_length
     return max_length
 
 
+
 def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
 def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
     if any((shared.is_RWKV, shared.is_llamacpp)):
     if any((shared.is_RWKV, shared.is_llamacpp)):
         input_ids = shared.tokenizer.encode(str(prompt))
         input_ids = shared.tokenizer.encode(str(prompt))
@@ -30,7 +31,7 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
         input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens)
         input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens)
 
 
         if type(shared.tokenizer) is transformers.LlamaTokenizer and input_ids[0][0] == 29871:
         if type(shared.tokenizer) is transformers.LlamaTokenizer and input_ids[0][0] == 29871:
-            input_ids = input_ids[:,1:]
+            input_ids = input_ids[:, 1:]
 
 
         if shared.args.cpu:
         if shared.args.cpu:
             return input_ids
             return input_ids
@@ -44,6 +45,7 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
         else:
         else:
             return input_ids.cuda()
             return input_ids.cuda()
 
 
+
 def decode(output_ids):
 def decode(output_ids):
     # Open Assistant relies on special tokens like <|endoftext|>
     # Open Assistant relies on special tokens like <|endoftext|>
     if re.match('.*(oasst|galactica)-*', shared.model_name.lower()):
     if re.match('.*(oasst|galactica)-*', shared.model_name.lower()):
@@ -53,14 +55,17 @@ def decode(output_ids):
         reply = reply.replace(r'<|endoftext|>', '')
         reply = reply.replace(r'<|endoftext|>', '')
         return reply
         return reply
 
 
+
 def generate_softprompt_input_tensors(input_ids):
 def generate_softprompt_input_tensors(input_ids):
     inputs_embeds = shared.model.transformer.wte(input_ids)
     inputs_embeds = shared.model.transformer.wte(input_ids)
     inputs_embeds = torch.cat((shared.soft_prompt_tensor, inputs_embeds), dim=1)
     inputs_embeds = torch.cat((shared.soft_prompt_tensor, inputs_embeds), dim=1)
     filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(shared.model.device)
     filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(shared.model.device)
-    #filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
+    # filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
     return inputs_embeds, filler_input_ids
     return inputs_embeds, filler_input_ids
 
 
 # Removes empty replies from gpt4chan outputs
 # Removes empty replies from gpt4chan outputs
+
+
 def fix_gpt4chan(s):
 def fix_gpt4chan(s):
     for i in range(10):
     for i in range(10):
         s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s)
         s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s)
@@ -69,6 +74,8 @@ def fix_gpt4chan(s):
     return s
     return s
 
 
 # Fix the LaTeX equations in galactica
 # Fix the LaTeX equations in galactica
+
+
 def fix_galactica(s):
 def fix_galactica(s):
     s = s.replace(r'\[', r'$')
     s = s.replace(r'\[', r'$')
     s = s.replace(r'\]', r'$')
     s = s.replace(r'\]', r'$')
@@ -79,6 +86,7 @@ def fix_galactica(s):
     s = re.sub(r"\n{3,}", "\n\n", s)
     s = re.sub(r"\n{3,}", "\n\n", s)
     return s
     return s
 
 
+
 def formatted_outputs(reply, model_name):
 def formatted_outputs(reply, model_name):
     if not shared.is_chat():
     if not shared.is_chat():
         if 'galactica' in model_name.lower():
         if 'galactica' in model_name.lower():
@@ -92,20 +100,24 @@ def formatted_outputs(reply, model_name):
     else:
     else:
         return reply
         return reply
 
 
+
 def clear_torch_cache():
 def clear_torch_cache():
     gc.collect()
     gc.collect()
     if not shared.args.cpu:
     if not shared.args.cpu:
         torch.cuda.empty_cache()
         torch.cuda.empty_cache()
 
 
+
 def set_manual_seed(seed):
 def set_manual_seed(seed):
     if seed != -1:
     if seed != -1:
         torch.manual_seed(seed)
         torch.manual_seed(seed)
         if torch.cuda.is_available():
         if torch.cuda.is_available():
             torch.cuda.manual_seed_all(seed)
             torch.cuda.manual_seed_all(seed)
 
 
+
 def stop_everything_event():
 def stop_everything_event():
     shared.stop_everything = True
     shared.stop_everything = True
 
 
+
 def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]):
 def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]):
     clear_torch_cache()
     clear_torch_cache()
     set_manual_seed(generate_state['seed'])
     set_manual_seed(generate_state['seed'])
@@ -128,7 +140,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
         try:
         try:
             if shared.args.no_stream:
             if shared.args.no_stream:
                 reply = shared.model.generate(context=question, **generate_params)
                 reply = shared.model.generate(context=question, **generate_params)
-                output = original_question+reply
+                output = original_question + reply
                 if not shared.is_chat():
                 if not shared.is_chat():
                     reply = original_question + apply_extensions(reply, "output")
                     reply = original_question + apply_extensions(reply, "output")
                 yield formatted_outputs(reply, shared.model_name)
                 yield formatted_outputs(reply, shared.model_name)
@@ -139,7 +151,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
                 # RWKV has proper streaming, which is very nice.
                 # RWKV has proper streaming, which is very nice.
                 # No need to generate 8 tokens at a time.
                 # No need to generate 8 tokens at a time.
                 for reply in shared.model.generate_with_streaming(context=question, **generate_params):
                 for reply in shared.model.generate_with_streaming(context=question, **generate_params):
-                    output = original_question+reply
+                    output = original_question + reply
                     if not shared.is_chat():
                     if not shared.is_chat():
                         reply = original_question + apply_extensions(reply, "output")
                         reply = original_question + apply_extensions(reply, "output")
                     yield formatted_outputs(reply, shared.model_name)
                     yield formatted_outputs(reply, shared.model_name)
@@ -240,7 +252,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
 
 
         # Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
         # Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
         else:
         else:
-            for i in range(generate_state['max_new_tokens']//8+1):
+            for i in range(generate_state['max_new_tokens'] // 8 + 1):
                 clear_torch_cache()
                 clear_torch_cache()
                 with torch.no_grad():
                 with torch.no_grad():
                     output = shared.model.generate(**generate_params)[0]
                     output = shared.model.generate(**generate_params)[0]
@@ -271,6 +283,6 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
     finally:
     finally:
         t1 = time.time()
         t1 = time.time()
         original_tokens = len(original_input_ids[0])
         original_tokens = len(original_input_ids[0])
-        new_tokens = len(output)-original_tokens
+        new_tokens = len(output) - original_tokens
         print(f"Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})")
         print(f"Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})")
         return
         return

+ 20 - 10
modules/training.py

@@ -19,9 +19,11 @@ CURRENT_STEPS = 0
 MAX_STEPS = 0
 MAX_STEPS = 0
 CURRENT_GRADIENT_ACCUM = 1
 CURRENT_GRADIENT_ACCUM = 1
 
 
+
 def get_dataset(path: str, ext: str):
 def get_dataset(path: str, ext: str):
     return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=str.lower)
     return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=str.lower)
 
 
+
 def create_train_interface():
 def create_train_interface():
     with gr.Tab('Train LoRA', elem_id='lora-train-tab'):
     with gr.Tab('Train LoRA', elem_id='lora-train-tab'):
         lora_name = gr.Textbox(label="Name", info="The name of your new LoRA file")
         lora_name = gr.Textbox(label="Name", info="The name of your new LoRA file")
@@ -44,16 +46,16 @@ def create_train_interface():
         with gr.Tab(label="Formatted Dataset"):
         with gr.Tab(label="Formatted Dataset"):
             with gr.Row():
             with gr.Row():
                 dataset = gr.Dropdown(choices=get_dataset('training/datasets', 'json'), value='None', label='Dataset', info='The dataset file to use for training.')
                 dataset = gr.Dropdown(choices=get_dataset('training/datasets', 'json'), value='None', label='Dataset', info='The dataset file to use for training.')
-                ui.create_refresh_button(dataset, lambda : None, lambda : {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button')
+                ui.create_refresh_button(dataset, lambda: None, lambda: {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button')
                 eval_dataset = gr.Dropdown(choices=get_dataset('training/datasets', 'json'), value='None', label='Evaluation Dataset', info='The (optional) dataset file used to evaluate the model after training.')
                 eval_dataset = gr.Dropdown(choices=get_dataset('training/datasets', 'json'), value='None', label='Evaluation Dataset', info='The (optional) dataset file used to evaluate the model after training.')
-                ui.create_refresh_button(eval_dataset, lambda : None, lambda : {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button')
+                ui.create_refresh_button(eval_dataset, lambda: None, lambda: {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button')
                 format = gr.Dropdown(choices=get_dataset('training/formats', 'json'), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.')
                 format = gr.Dropdown(choices=get_dataset('training/formats', 'json'), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.')
-                ui.create_refresh_button(format, lambda : None, lambda : {'choices': get_dataset('training/formats', 'json')}, 'refresh-button')
+                ui.create_refresh_button(format, lambda: None, lambda: {'choices': get_dataset('training/formats', 'json')}, 'refresh-button')
 
 
         with gr.Tab(label="Raw Text File"):
         with gr.Tab(label="Raw Text File"):
             with gr.Row():
             with gr.Row():
                 raw_text_file = gr.Dropdown(choices=get_dataset('training/datasets', 'txt'), value='None', label='Text File', info='The raw text file to use for training.')
                 raw_text_file = gr.Dropdown(choices=get_dataset('training/datasets', 'txt'), value='None', label='Text File', info='The raw text file to use for training.')
-                ui.create_refresh_button(raw_text_file, lambda : None, lambda : {'choices': get_dataset('training/datasets', 'txt')}, 'refresh-button')
+                ui.create_refresh_button(raw_text_file, lambda: None, lambda: {'choices': get_dataset('training/datasets', 'txt')}, 'refresh-button')
             with gr.Row():
             with gr.Row():
                 overlap_len = gr.Slider(label='Overlap Length', minimum=0, maximum=512, value=128, step=16, info='Overlap length - ie how many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length below). Setting overlap to exactly half the cutoff length may be ideal.')
                 overlap_len = gr.Slider(label='Overlap Length', minimum=0, maximum=512, value=128, step=16, info='Overlap length - ie how many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length below). Setting overlap to exactly half the cutoff length may be ideal.')
                 newline_favor_len = gr.Slider(label='Prefer Newline Cut Length', minimum=0, maximum=512, value=128, step=16, info='Length (in characters, not tokens) of the maximum distance to shift an overlap cut by to ensure chunks cut at newlines. If too low, cuts may occur in the middle of lines.')
                 newline_favor_len = gr.Slider(label='Prefer Newline Cut Length', minimum=0, maximum=512, value=128, step=16, info='Length (in characters, not tokens) of the maximum distance to shift an overlap cut by to ensure chunks cut at newlines. If too low, cuts may occur in the middle of lines.')
@@ -67,10 +69,12 @@ def create_train_interface():
                                       cutoff_len, dataset, eval_dataset, format, raw_text_file, overlap_len, newline_favor_len], [output])
                                       cutoff_len, dataset, eval_dataset, format, raw_text_file, overlap_len, newline_favor_len], [output])
         stop_button.click(do_interrupt, [], [], cancels=[], queue=False)
         stop_button.click(do_interrupt, [], [], cancels=[], queue=False)
 
 
+
 def do_interrupt():
 def do_interrupt():
     global WANT_INTERRUPT
     global WANT_INTERRUPT
     WANT_INTERRUPT = True
     WANT_INTERRUPT = True
 
 
+
 class Callbacks(transformers.TrainerCallback):
 class Callbacks(transformers.TrainerCallback):
     def on_step_begin(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
     def on_step_begin(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
         global CURRENT_STEPS, MAX_STEPS
         global CURRENT_STEPS, MAX_STEPS
@@ -79,6 +83,7 @@ class Callbacks(transformers.TrainerCallback):
         if WANT_INTERRUPT:
         if WANT_INTERRUPT:
             control.should_epoch_stop = True
             control.should_epoch_stop = True
             control.should_training_stop = True
             control.should_training_stop = True
+
     def on_substep_end(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
     def on_substep_end(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
         global CURRENT_STEPS
         global CURRENT_STEPS
         CURRENT_STEPS += 1
         CURRENT_STEPS += 1
@@ -86,6 +91,7 @@ class Callbacks(transformers.TrainerCallback):
             control.should_epoch_stop = True
             control.should_epoch_stop = True
             control.should_training_stop = True
             control.should_training_stop = True
 
 
+
 def clean_path(base_path: str, path: str):
 def clean_path(base_path: str, path: str):
     """"Strips unusual symbols and forcibly builds a path as relative to the intended directory."""
     """"Strips unusual symbols and forcibly builds a path as relative to the intended directory."""
     # TODO: Probably could do with a security audit to guarantee there's no ways this can be bypassed to target an unwanted path.
     # TODO: Probably could do with a security audit to guarantee there's no ways this can be bypassed to target an unwanted path.
@@ -95,6 +101,7 @@ def clean_path(base_path: str, path: str):
         return path
         return path
     return f'{Path(base_path).absolute()}/{path}'
     return f'{Path(base_path).absolute()}/{path}'
 
 
+
 def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lora_rank: int, lora_alpha: int, lora_dropout: float,
 def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lora_rank: int, lora_alpha: int, lora_dropout: float,
              cutoff_len: int, dataset: str, eval_dataset: str, format: str, raw_text_file: str, overlap_len: int, newline_favor_len: int):
              cutoff_len: int, dataset: str, eval_dataset: str, format: str, raw_text_file: str, overlap_len: int, newline_favor_len: int):
     global WANT_INTERRUPT, CURRENT_STEPS, MAX_STEPS, CURRENT_GRADIENT_ACCUM
     global WANT_INTERRUPT, CURRENT_STEPS, MAX_STEPS, CURRENT_GRADIENT_ACCUM
@@ -124,7 +131,7 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
     elif not shared.args.load_in_8bit:
     elif not shared.args.load_in_8bit:
         yield "It is highly recommended you use `--load-in-8bit` for LoRA training. *(Will continue anyway in 2 seconds, press `Interrupt` to stop.)*"
         yield "It is highly recommended you use `--load-in-8bit` for LoRA training. *(Will continue anyway in 2 seconds, press `Interrupt` to stop.)*"
         print("Warning: It is highly recommended you use `--load-in-8bit` for LoRA training.")
         print("Warning: It is highly recommended you use `--load-in-8bit` for LoRA training.")
-        time.sleep(2) # Give it a moment for the message to show in UI before continuing
+        time.sleep(2)  # Give it a moment for the message to show in UI before continuing
 
 
     if cutoff_len <= 0 or micro_batch_size <= 0 or batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0:
     if cutoff_len <= 0 or micro_batch_size <= 0 or batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0:
         yield "Cannot input zeroes."
         yield "Cannot input zeroes."
@@ -148,7 +155,7 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
         with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r') as file:
         with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r') as file:
             raw_text = file.read()
             raw_text = file.read()
         tokens = shared.tokenizer.encode(raw_text)
         tokens = shared.tokenizer.encode(raw_text)
-        del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM
+        del raw_text  # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM
 
 
         tokens = list(split_chunks(tokens, cutoff_len - overlap_len))
         tokens = list(split_chunks(tokens, cutoff_len - overlap_len))
         for i in range(1, len(tokens)):
         for i in range(1, len(tokens)):
@@ -197,18 +204,18 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
         else:
         else:
             eval_data = load_dataset("json", data_files=clean_path('training/datasets', f'{eval_dataset}.json'))
             eval_data = load_dataset("json", data_files=clean_path('training/datasets', f'{eval_dataset}.json'))
             eval_data = eval_data['train'].shuffle().map(generate_and_tokenize_prompt)
             eval_data = eval_data['train'].shuffle().map(generate_and_tokenize_prompt)
-    
+
     # == Start prepping the model itself ==
     # == Start prepping the model itself ==
     if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
     if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
         print("Getting model ready...")
         print("Getting model ready...")
         prepare_model_for_int8_training(shared.model)
         prepare_model_for_int8_training(shared.model)
-    
+
     print("Prepping for training...")
     print("Prepping for training...")
     config = LoraConfig(
     config = LoraConfig(
         r=lora_rank,
         r=lora_rank,
         lora_alpha=lora_alpha,
         lora_alpha=lora_alpha,
         # TODO: Should target_modules be configurable?
         # TODO: Should target_modules be configurable?
-        target_modules=[ "q_proj", "v_proj" ],
+        target_modules=["q_proj", "v_proj"],
         lora_dropout=lora_dropout,
         lora_dropout=lora_dropout,
         bias="none",
         bias="none",
         task_type="CAUSAL_LM"
         task_type="CAUSAL_LM"
@@ -289,7 +296,7 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
                     timer_info = f"`{its:.2f}` it/s"
                     timer_info = f"`{its:.2f}` it/s"
                 else:
                 else:
                     timer_info = f"`{1.0/its:.2f}` s/it"
                     timer_info = f"`{1.0/its:.2f}` s/it"
-                total_time_estimate = (1.0/its) * (MAX_STEPS)
+                total_time_estimate = (1.0 / its) * (MAX_STEPS)
             yield f"Running... **{CURRENT_STEPS}** / **{MAX_STEPS}** ... {timer_info}, {format_time(time_elapsed)} / {format_time(total_time_estimate)} ... {format_time(total_time_estimate - time_elapsed)} remaining"
             yield f"Running... **{CURRENT_STEPS}** / **{MAX_STEPS}** ... {timer_info}, {format_time(time_elapsed)} / {format_time(total_time_estimate)} ... {format_time(total_time_estimate - time_elapsed)} remaining"
 
 
     print("Training complete, saving...")
     print("Training complete, saving...")
@@ -302,10 +309,12 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
         print("Training complete!")
         print("Training complete!")
         yield f"Done! LoRA saved to `{lora_name}`"
         yield f"Done! LoRA saved to `{lora_name}`"
 
 
+
 def split_chunks(arr, step):
 def split_chunks(arr, step):
     for i in range(0, len(arr), step):
     for i in range(0, len(arr), step):
         yield arr[i:i + step]
         yield arr[i:i + step]
 
 
+
 def cut_chunk_for_newline(chunk: str, max_length: int):
 def cut_chunk_for_newline(chunk: str, max_length: int):
     if '\n' not in chunk:
     if '\n' not in chunk:
         return chunk
         return chunk
@@ -319,6 +328,7 @@ def cut_chunk_for_newline(chunk: str, max_length: int):
         chunk = chunk[:last_newline]
         chunk = chunk[:last_newline]
     return chunk
     return chunk
 
 
+
 def format_time(seconds: float):
 def format_time(seconds: float):
     if seconds < 120:
     if seconds < 120:
         return f"`{seconds:.0f}` seconds"
         return f"`{seconds:.0f}` seconds"

+ 2 - 0
modules/ui.py

@@ -13,6 +13,7 @@ with open(Path(__file__).resolve().parent / '../css/main.js', 'r') as f:
 with open(Path(__file__).resolve().parent / '../css/chat.js', 'r') as f:
 with open(Path(__file__).resolve().parent / '../css/chat.js', 'r') as f:
     chat_js = f.read()
     chat_js = f.read()
 
 
+
 class ToolButton(gr.Button, gr.components.FormComponent):
 class ToolButton(gr.Button, gr.components.FormComponent):
     """Small button with single emoji as text, fits inside gradio forms"""
     """Small button with single emoji as text, fits inside gradio forms"""
 
 
@@ -22,6 +23,7 @@ class ToolButton(gr.Button, gr.components.FormComponent):
     def get_block_name(self):
     def get_block_name(self):
         return "button"
         return "button"
 
 
+
 def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
 def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
     def refresh():
     def refresh():
         refresh_method()
         refresh_method()

+ 46 - 27
server.py

@@ -34,15 +34,18 @@ if settings_file is not None:
     for item in new_settings:
     for item in new_settings:
         shared.settings[item] = new_settings[item]
         shared.settings[item] = new_settings[item]
 
 
+
 def get_available_models():
 def get_available_models():
     if shared.args.flexgen:
     if shared.args.flexgen:
         return sorted([re.sub('-np$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if item.name.endswith('-np')], key=str.lower)
         return sorted([re.sub('-np$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if item.name.endswith('-np')], key=str.lower)
     else:
     else:
         return sorted([re.sub('.pth$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
         return sorted([re.sub('.pth$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
 
 
+
 def get_available_presets():
 def get_available_presets():
     return sorted(set((k.stem for k in Path('presets').glob('*.txt'))), key=str.lower)
     return sorted(set((k.stem for k in Path('presets').glob('*.txt'))), key=str.lower)
 
 
+
 def get_available_prompts():
 def get_available_prompts():
     prompts = []
     prompts = []
     prompts += sorted(set((k.stem for k in Path('prompts').glob('[0-9]*.txt'))), key=str.lower, reverse=True)
     prompts += sorted(set((k.stem for k in Path('prompts').glob('[0-9]*.txt'))), key=str.lower, reverse=True)
@@ -50,10 +53,12 @@ def get_available_prompts():
     prompts += ['None']
     prompts += ['None']
     return prompts
     return prompts
 
 
+
 def get_available_characters():
 def get_available_characters():
     paths = (x for x in Path('characters').iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
     paths = (x for x in Path('characters').iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
     return ['None'] + sorted(set((k.stem for k in paths if k.stem != "instruction-following")), key=str.lower)
     return ['None'] + sorted(set((k.stem for k in paths if k.stem != "instruction-following")), key=str.lower)
 
 
+
 def get_available_instruction_templates():
 def get_available_instruction_templates():
     path = "characters/instruction-following"
     path = "characters/instruction-following"
     paths = []
     paths = []
@@ -61,19 +66,24 @@ def get_available_instruction_templates():
         paths = (x for x in Path(path).iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
         paths = (x for x in Path(path).iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
     return ['None'] + sorted(set((k.stem for k in paths)), key=str.lower)
     return ['None'] + sorted(set((k.stem for k in paths)), key=str.lower)
 
 
+
 def get_available_extensions():
 def get_available_extensions():
-    return sorted(set(map(lambda x : x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower)
+    return sorted(set(map(lambda x: x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower)
+
 
 
 def get_available_softprompts():
 def get_available_softprompts():
     return ['None'] + sorted(set((k.stem for k in Path('softprompts').glob('*.zip'))), key=str.lower)
     return ['None'] + sorted(set((k.stem for k in Path('softprompts').glob('*.zip'))), key=str.lower)
 
 
+
 def get_available_loras():
 def get_available_loras():
     return ['None'] + sorted([item.name for item in list(Path(shared.args.lora_dir).glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
     return ['None'] + sorted([item.name for item in list(Path(shared.args.lora_dir).glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
 
 
+
 def unload_model():
 def unload_model():
     shared.model = shared.tokenizer = None
     shared.model = shared.tokenizer = None
     clear_torch_cache()
     clear_torch_cache()
 
 
+
 def load_model_wrapper(selected_model):
 def load_model_wrapper(selected_model):
     if selected_model != shared.model_name:
     if selected_model != shared.model_name:
         shared.model_name = selected_model
         shared.model_name = selected_model
@@ -84,10 +94,12 @@ def load_model_wrapper(selected_model):
 
 
     return selected_model
     return selected_model
 
 
+
 def load_lora_wrapper(selected_lora):
 def load_lora_wrapper(selected_lora):
     add_lora_to_model(selected_lora)
     add_lora_to_model(selected_lora)
     return selected_lora
     return selected_lora
 
 
+
 def load_preset_values(preset_menu, state, return_dict=False):
 def load_preset_values(preset_menu, state, return_dict=False):
     generate_params = {
     generate_params = {
         'do_sample': True,
         'do_sample': True,
@@ -118,6 +130,7 @@ def load_preset_values(preset_menu, state, return_dict=False):
         state.update(generate_params)
         state.update(generate_params)
         return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
         return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
 
 
+
 def upload_soft_prompt(file):
 def upload_soft_prompt(file):
     with zipfile.ZipFile(io.BytesIO(file)) as zf:
     with zipfile.ZipFile(io.BytesIO(file)) as zf:
         zf.extract('meta.json')
         zf.extract('meta.json')
@@ -130,12 +143,14 @@ def upload_soft_prompt(file):
 
 
     return name
     return name
 
 
+
 def save_prompt(text):
 def save_prompt(text):
     fname = f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}.txt"
     fname = f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}.txt"
     with open(Path(f'prompts/{fname}'), 'w', encoding='utf-8') as f:
     with open(Path(f'prompts/{fname}'), 'w', encoding='utf-8') as f:
         f.write(text)
         f.write(text)
     return f"Saved to prompts/{fname}"
     return f"Saved to prompts/{fname}"
 
 
+
 def load_prompt(fname):
 def load_prompt(fname):
     if fname in ['None', '']:
     if fname in ['None', '']:
         return ''
         return ''
@@ -146,12 +161,13 @@ def load_prompt(fname):
                 text = text[:-1]
                 text = text[:-1]
             return text
             return text
 
 
+
 def create_prompt_menus():
 def create_prompt_menus():
     with gr.Row():
     with gr.Row():
         with gr.Column():
         with gr.Column():
             with gr.Row():
             with gr.Row():
                 shared.gradio['prompt_menu'] = gr.Dropdown(choices=get_available_prompts(), value='None', label='Prompt')
                 shared.gradio['prompt_menu'] = gr.Dropdown(choices=get_available_prompts(), value='None', label='Prompt')
-                ui.create_refresh_button(shared.gradio['prompt_menu'], lambda : None, lambda : {'choices': get_available_prompts()}, 'refresh-button')
+                ui.create_refresh_button(shared.gradio['prompt_menu'], lambda: None, lambda: {'choices': get_available_prompts()}, 'refresh-button')
 
 
         with gr.Column():
         with gr.Column():
             with gr.Column():
             with gr.Column():
@@ -161,20 +177,22 @@ def create_prompt_menus():
     shared.gradio['prompt_menu'].change(load_prompt, [shared.gradio['prompt_menu']], [shared.gradio['textbox']], show_progress=False)
     shared.gradio['prompt_menu'].change(load_prompt, [shared.gradio['prompt_menu']], [shared.gradio['textbox']], show_progress=False)
     shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False)
     shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False)
 
 
+
 def create_model_menus():
 def create_model_menus():
     with gr.Row():
     with gr.Row():
         with gr.Column():
         with gr.Column():
             with gr.Row():
             with gr.Row():
                 shared.gradio['model_menu'] = gr.Dropdown(choices=available_models, value=shared.model_name, label='Model')
                 shared.gradio['model_menu'] = gr.Dropdown(choices=available_models, value=shared.model_name, label='Model')
-                ui.create_refresh_button(shared.gradio['model_menu'], lambda : None, lambda : {'choices': get_available_models()}, 'refresh-button')
+                ui.create_refresh_button(shared.gradio['model_menu'], lambda: None, lambda: {'choices': get_available_models()}, 'refresh-button')
         with gr.Column():
         with gr.Column():
             with gr.Row():
             with gr.Row():
                 shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA')
                 shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA')
-                ui.create_refresh_button(shared.gradio['lora_menu'], lambda : None, lambda : {'choices': get_available_loras()}, 'refresh-button')
+                ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': get_available_loras()}, 'refresh-button')
 
 
     shared.gradio['model_menu'].change(load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_menu'], show_progress=True)
     shared.gradio['model_menu'].change(load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_menu'], show_progress=True)
     shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True)
     shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True)
 
 
+
 def create_settings_menus(default_preset):
 def create_settings_menus(default_preset):
     generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', {}, return_dict=True)
     generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', {}, return_dict=True)
     for k in ['max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size', 'chat_generation_attempts']:
     for k in ['max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size', 'chat_generation_attempts']:
@@ -185,7 +203,7 @@ def create_settings_menus(default_preset):
         with gr.Column():
         with gr.Column():
             with gr.Row():
             with gr.Row():
                 shared.gradio['preset_menu'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset')
                 shared.gradio['preset_menu'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset')
-                ui.create_refresh_button(shared.gradio['preset_menu'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button')
+                ui.create_refresh_button(shared.gradio['preset_menu'], lambda: None, lambda: {'choices': get_available_presets()}, 'refresh-button')
         with gr.Column():
         with gr.Column():
             shared.gradio['seed'] = gr.Number(value=shared.settings['seed'], label='Seed (-1 for random)')
             shared.gradio['seed'] = gr.Number(value=shared.settings['seed'], label='Seed (-1 for random)')
 
 
@@ -196,12 +214,12 @@ def create_settings_menus(default_preset):
                 with gr.Row():
                 with gr.Row():
                     with gr.Column():
                     with gr.Column():
                         shared.gradio['temperature'] = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label='temperature')
                         shared.gradio['temperature'] = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label='temperature')
-                        shared.gradio['top_p'] = gr.Slider(0.0,1.0,value=generate_params['top_p'],step=0.01,label='top_p')
-                        shared.gradio['top_k'] = gr.Slider(0,200,value=generate_params['top_k'],step=1,label='top_k')
-                        shared.gradio['typical_p'] = gr.Slider(0.0,1.0,value=generate_params['typical_p'],step=0.01,label='typical_p')
+                        shared.gradio['top_p'] = gr.Slider(0.0, 1.0, value=generate_params['top_p'], step=0.01, label='top_p')
+                        shared.gradio['top_k'] = gr.Slider(0, 200, value=generate_params['top_k'], step=1, label='top_k')
+                        shared.gradio['typical_p'] = gr.Slider(0.0, 1.0, value=generate_params['typical_p'], step=0.01, label='typical_p')
                     with gr.Column():
                     with gr.Column():
-                        shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=generate_params['repetition_penalty'],step=0.01,label='repetition_penalty')
-                        shared.gradio['encoder_repetition_penalty'] = gr.Slider(0.8, 1.5, value=generate_params['encoder_repetition_penalty'],step=0.01,label='encoder_repetition_penalty')
+                        shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=generate_params['repetition_penalty'], step=0.01, label='repetition_penalty')
+                        shared.gradio['encoder_repetition_penalty'] = gr.Slider(0.8, 1.5, value=generate_params['encoder_repetition_penalty'], step=0.01, label='encoder_repetition_penalty')
                         shared.gradio['no_repeat_ngram_size'] = gr.Slider(0, 20, step=1, value=generate_params['no_repeat_ngram_size'], label='no_repeat_ngram_size')
                         shared.gradio['no_repeat_ngram_size'] = gr.Slider(0, 20, step=1, value=generate_params['no_repeat_ngram_size'], label='no_repeat_ngram_size')
                         shared.gradio['min_length'] = gr.Slider(0, 2000, step=1, value=generate_params['min_length'] if shared.args.no_stream else 0, label='min_length', interactive=shared.args.no_stream)
                         shared.gradio['min_length'] = gr.Slider(0, 2000, step=1, value=generate_params['min_length'] if shared.args.no_stream else 0, label='min_length', interactive=shared.args.no_stream)
                 shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
                 shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
@@ -209,7 +227,6 @@ def create_settings_menus(default_preset):
             with gr.Box():
             with gr.Box():
                 gr.Markdown('Contrastive search')
                 gr.Markdown('Contrastive search')
                 shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha')
                 shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha')
-
             with gr.Box():
             with gr.Box():
                 gr.Markdown('Beam search (uses a lot of VRAM)')
                 gr.Markdown('Beam search (uses a lot of VRAM)')
                 with gr.Row():
                 with gr.Row():
@@ -219,11 +236,10 @@ def create_settings_menus(default_preset):
                         shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
                         shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
                 shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
                 shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
 
 
-
     with gr.Accordion('Soft prompt', open=False):
     with gr.Accordion('Soft prompt', open=False):
         with gr.Row():
         with gr.Row():
             shared.gradio['softprompts_menu'] = gr.Dropdown(choices=available_softprompts, value='None', label='Soft prompt')
             shared.gradio['softprompts_menu'] = gr.Dropdown(choices=available_softprompts, value='None', label='Soft prompt')
-            ui.create_refresh_button(shared.gradio['softprompts_menu'], lambda : None, lambda : {'choices': get_available_softprompts()}, 'refresh-button')
+            ui.create_refresh_button(shared.gradio['softprompts_menu'], lambda: None, lambda: {'choices': get_available_softprompts()}, 'refresh-button')
 
 
         gr.Markdown('Upload a soft prompt (.zip format):')
         gr.Markdown('Upload a soft prompt (.zip format):')
         with gr.Row():
         with gr.Row():
@@ -233,6 +249,7 @@ def create_settings_menus(default_preset):
     shared.gradio['softprompts_menu'].change(load_soft_prompt, shared.gradio['softprompts_menu'], shared.gradio['softprompts_menu'], show_progress=True)
     shared.gradio['softprompts_menu'].change(load_soft_prompt, shared.gradio['softprompts_menu'], shared.gradio['softprompts_menu'], show_progress=True)
     shared.gradio['upload_softprompt'].upload(upload_soft_prompt, shared.gradio['upload_softprompt'], shared.gradio['softprompts_menu'])
     shared.gradio['upload_softprompt'].upload(upload_soft_prompt, shared.gradio['upload_softprompt'], shared.gradio['softprompts_menu'])
 
 
+
 def set_interface_arguments(interface_mode, extensions, bool_active):
 def set_interface_arguments(interface_mode, extensions, bool_active):
     modes = ["default", "notebook", "chat", "cai_chat"]
     modes = ["default", "notebook", "chat", "cai_chat"]
     cmd_list = vars(shared.args)
     cmd_list = vars(shared.args)
@@ -251,6 +268,7 @@ def set_interface_arguments(interface_mode, extensions, bool_active):
 
 
     shared.need_restart = True
     shared.need_restart = True
 
 
+
 available_models = get_available_models()
 available_models = get_available_models()
 available_presets = get_available_presets()
 available_presets = get_available_presets()
 available_characters = get_available_characters()
 available_characters = get_available_characters()
@@ -284,7 +302,7 @@ else:
         for i, model in enumerate(available_models):
         for i, model in enumerate(available_models):
             print(f'{i+1}. {model}')
             print(f'{i+1}. {model}')
         print(f'\nWhich one do you want to load? 1-{len(available_models)}\n')
         print(f'\nWhich one do you want to load? 1-{len(available_models)}\n')
-        i = int(input())-1
+        i = int(input()) - 1
         print()
         print()
     shared.model_name = available_models[i]
     shared.model_name = available_models[i]
 shared.model, shared.tokenizer = load_model(shared.model_name)
 shared.model, shared.tokenizer = load_model(shared.model_name)
@@ -297,15 +315,15 @@ if shared.lora_name != "None":
     default_text = load_prompt(shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')])
     default_text = load_prompt(shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')])
 else:
 else:
     default_text = load_prompt(shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')])
     default_text = load_prompt(shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')])
-title ='Text generation web UI'
+title = 'Text generation web UI'
 
 
-def create_interface():
 
 
+def create_interface():
     gen_events = []
     gen_events = []
     if shared.args.extensions is not None and len(shared.args.extensions) > 0:
     if shared.args.extensions is not None and len(shared.args.extensions) > 0:
         extensions_module.load_extensions()
         extensions_module.load_extensions()
 
 
-    with gr.Blocks(css=ui.css if not shared.is_chat() else ui.css+ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']:
+    with gr.Blocks(css=ui.css if not shared.is_chat() else ui.css + ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']:
         if shared.is_chat():
         if shared.is_chat():
             shared.gradio['Chat input'] = gr.State()
             shared.gradio['Chat input'] = gr.State()
             with gr.Tab("Text generation", elem_id="main"):
             with gr.Tab("Text generation", elem_id="main"):
@@ -342,7 +360,7 @@ def create_interface():
                         shared.gradio['your_picture'] = gr.Image(label='Your picture', type="pil", value=Image.open(Path("cache/pfp_me.png")) if Path("cache/pfp_me.png").exists() else None)
                         shared.gradio['your_picture'] = gr.Image(label='Your picture', type="pil", value=Image.open(Path("cache/pfp_me.png")) if Path("cache/pfp_me.png").exists() else None)
                 with gr.Row():
                 with gr.Row():
                     shared.gradio['character_menu'] = gr.Dropdown(choices=available_characters, value='None', label='Character', elem_id='character-menu')
                     shared.gradio['character_menu'] = gr.Dropdown(choices=available_characters, value='None', label='Character', elem_id='character-menu')
-                    ui.create_refresh_button(shared.gradio['character_menu'], lambda : None, lambda : {'choices': get_available_characters()}, 'refresh-button')
+                    ui.create_refresh_button(shared.gradio['character_menu'], lambda: None, lambda: {'choices': get_available_characters()}, 'refresh-button')
 
 
                 with gr.Row():
                 with gr.Row():
                     with gr.Tab('Chat history'):
                     with gr.Tab('Chat history'):
@@ -399,11 +417,11 @@ def create_interface():
 
 
             # Clear history with confirmation
             # Clear history with confirmation
             clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']]
             clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']]
-            shared.gradio['Clear history'].click(lambda :[gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
-            shared.gradio['Clear history-confirm'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
+            shared.gradio['Clear history'].click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
+            shared.gradio['Clear history-confirm'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
             shared.gradio['Clear history-confirm'].click(chat.clear_chat_log, [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'Chat mode']], shared.gradio['display'])
             shared.gradio['Clear history-confirm'].click(chat.clear_chat_log, [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'Chat mode']], shared.gradio['display'])
-            shared.gradio['Clear history-cancel'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
-            shared.gradio['Chat mode'].change(lambda x : gr.update(visible= x=='instruct'), shared.gradio['Chat mode'], shared.gradio['Instruction templates'])
+            shared.gradio['Clear history-cancel'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
+            shared.gradio['Chat mode'].change(lambda x: gr.update(visible=x == 'instruct'), shared.gradio['Chat mode'], shared.gradio['Instruction templates'])
 
 
             shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
             shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
             shared.gradio['download_button'].click(chat.save_history, inputs=[], outputs=[shared.gradio['download']])
             shared.gradio['download_button'].click(chat.save_history, inputs=[], outputs=[shared.gradio['download']])
@@ -412,10 +430,10 @@ def create_interface():
             # Clearing stuff and saving the history
             # Clearing stuff and saving the history
             for i in ['Generate', 'Regenerate', 'Replace last reply']:
             for i in ['Generate', 'Regenerate', 'Replace last reply']:
                 shared.gradio[i].click(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False)
                 shared.gradio[i].click(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False)
-                shared.gradio[i].click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
-            shared.gradio['Clear history-confirm'].click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
+                shared.gradio[i].click(lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
+            shared.gradio['Clear history-confirm'].click(lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
             shared.gradio['textbox'].submit(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False)
             shared.gradio['textbox'].submit(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False)
-            shared.gradio['textbox'].submit(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
+            shared.gradio['textbox'].submit(lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
 
 
             shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
             shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
             shared.gradio['Instruction templates'].change(lambda character, name1, name2, mode: chat.load_character(character, name1, name2, mode), [shared.gradio[k] for k in ['Instruction templates', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
             shared.gradio['Instruction templates'].change(lambda character, name1, name2, mode: chat.load_character(character, name1, name2, mode), [shared.gradio[k] for k in ['Instruction templates', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
@@ -430,7 +448,7 @@ def create_interface():
             shared.gradio['Chat mode'].change(chat.redraw_html, reload_inputs, [shared.gradio['display']])
             shared.gradio['Chat mode'].change(chat.redraw_html, reload_inputs, [shared.gradio['display']])
 
 
             shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}")
             shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}")
-            shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings['name1'], shared.settings['name2']), None, None)
+            shared.gradio['interface'].load(lambda: chat.load_default_history(shared.settings['name1'], shared.settings['name2']), None, None)
             shared.gradio['interface'].load(chat.redraw_html, reload_inputs, [shared.gradio['display']], show_progress=True)
             shared.gradio['interface'].load(chat.redraw_html, reload_inputs, [shared.gradio['display']], show_progress=True)
 
 
         elif shared.args.notebook:
         elif shared.args.notebook:
@@ -526,7 +544,7 @@ def create_interface():
             shared.gradio['reset_interface'] = gr.Button("Apply and restart the interface", type="primary")
             shared.gradio['reset_interface'] = gr.Button("Apply and restart the interface", type="primary")
 
 
             shared.gradio['reset_interface'].click(set_interface_arguments, [shared.gradio[k] for k in ['interface_modes_menu', 'extensions_menu', 'bool_menu']], None)
             shared.gradio['reset_interface'].click(set_interface_arguments, [shared.gradio[k] for k in ['interface_modes_menu', 'extensions_menu', 'bool_menu']], None)
-            shared.gradio['reset_interface'].click(lambda : None, None, None, _js='() => {document.body.innerHTML=\'<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>\'; setTimeout(function(){location.reload()},2500); return []}')
+            shared.gradio['reset_interface'].click(lambda: None, None, None, _js='() => {document.body.innerHTML=\'<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>\'; setTimeout(function(){location.reload()},2500); return []}')
 
 
         if shared.args.extensions is not None:
         if shared.args.extensions is not None:
             extensions_module.create_extensions_block()
             extensions_module.create_extensions_block()
@@ -562,6 +580,7 @@ def create_interface():
     else:
     else:
         shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch, auth=auth)
         shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch, auth=auth)
 
 
+
 create_interface()
 create_interface()
 
 
 while True:
 while True: