Просмотр исходного кода

Merge branch 'main' into UsamaKenway-main

oobabooga 3 лет назад
Родитель
Сommit
c6e9ba20a4
11 измененных файлов с 434 добавлено и 147 удалено
  1. 2 0
      README.md
  2. 3 3
      api-example.py
  3. 115 101
      download-model.py
  4. 42 6
      extensions/character_bias/script.py
  5. 6 5
      modules/GPTQ_loader.py
  6. 49 13
      modules/chat.py
  7. 176 0
      modules/llama_attn_hijack.py
  8. 13 0
      modules/models.py
  9. 2 0
      modules/shared.py
  10. 3 1
      requirements.txt
  11. 23 18
      server.py

+ 2 - 0
README.md

@@ -215,6 +215,8 @@ Optionally, you can use the following command-line flags:
 | `--load-in-8bit`                            | Load the model with 8-bit precision.|
 | `--load-in-8bit`                            | Load the model with 8-bit precision.|
 | `--bf16`                                    | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. |
 | `--bf16`                                    | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. |
 | `--no-cache`                                | Set `use_cache` to False while generating text. This reduces the VRAM usage a bit with a performance cost. |
 | `--no-cache`                                | Set `use_cache` to False while generating text. This reduces the VRAM usage a bit with a performance cost. |
+| `--xformers`                                | Use xformer's memory efficient attention. This should increase your tokens/s. |
+| `--sdp-attention`                           | Use torch 2.0's sdp attention. |
 
 
 #### llama.cpp
 #### llama.cpp
 
 

+ 3 - 3
api-example.py

@@ -22,10 +22,10 @@ server = "127.0.0.1"
 params = {
 params = {
     'max_new_tokens': 200,
     'max_new_tokens': 200,
     'do_sample': True,
     'do_sample': True,
-    'temperature': 0.5,
-    'top_p': 0.9,
+    'temperature': 0.72,
+    'top_p': 0.73,
     'typical_p': 1,
     'typical_p': 1,
-    'repetition_penalty': 1.05,
+    'repetition_penalty': 1.1,
     'encoder_repetition_penalty': 1.0,
     'encoder_repetition_penalty': 1.0,
     'top_k': 0,
     'top_k': 0,
     'min_length': 0,
     'min_length': 0,

+ 115 - 101
download-model.py

@@ -19,6 +19,7 @@ import requests
 import tqdm
 import tqdm
 from tqdm.contrib.concurrent import thread_map
 from tqdm.contrib.concurrent import thread_map
 
 
+
 parser = argparse.ArgumentParser()
 parser = argparse.ArgumentParser()
 parser.add_argument('MODEL', type=str, default=None, nargs='?')
 parser.add_argument('MODEL', type=str, default=None, nargs='?')
 parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
 parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
@@ -30,40 +31,6 @@ parser.add_argument('--check', action='store_true', help='Validates the checksum
 args = parser.parse_args()
 args = parser.parse_args()
 
 
 
 
-def get_file(url, output_folder):
-    filename = Path(url.rsplit('/', 1)[1])
-    output_path = output_folder / filename
-    if output_path.exists() and not args.clean:
-        # Check if the file has already been downloaded completely
-        r = requests.get(url, stream=True)
-        total_size = int(r.headers.get('content-length', 0))
-        if output_path.stat().st_size >= total_size:
-            return
-        # Otherwise, resume the download from where it left off
-        headers = {'Range': f'bytes={output_path.stat().st_size}-'}
-        mode = 'ab'
-    else:
-        headers = {}
-        mode = 'wb'
-
-    r = requests.get(url, stream=True, headers=headers)
-    with open(output_path, mode) as f:
-        total_size = int(r.headers.get('content-length', 0))
-        block_size = 1024
-        with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True, bar_format='{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}') as t:
-            for data in r.iter_content(block_size):
-                t.update(len(data))
-                f.write(data)
-
-
-def sanitize_branch_name(branch_name):
-    pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
-    if pattern.match(branch_name):
-        return branch_name
-    else:
-        raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.")
-
-
 def select_model_from_default_options():
 def select_model_from_default_options():
     models = {
     models = {
         "OPT 6.7B": ("facebook", "opt-6.7b", "main"),
         "OPT 6.7B": ("facebook", "opt-6.7b", "main"),
@@ -110,7 +77,20 @@ EleutherAI/pythia-1.4b-deduped
     return model, branch
     return model, branch
 
 
 
 
-def get_download_links_from_huggingface(model, branch):
+def sanitize_model_and_branch_names(model, branch):
+    if model[-1] == '/':
+        model = model[:-1]
+    if branch is None:
+        branch = "main"
+    else:
+        pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
+        if not pattern.match(branch):
+            raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.")
+
+    return model, branch
+
+
+def get_download_links_from_huggingface(model, branch, text_only=False):
     base = "https://huggingface.co"
     base = "https://huggingface.co"
     page = f"/api/models/{model}/tree/{branch}?cursor="
     page = f"/api/models/{model}/tree/{branch}?cursor="
     cursor = b""
     cursor = b""
@@ -149,7 +129,7 @@ def get_download_links_from_huggingface(model, branch):
                     links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
                     links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
                     classifications.append('text')
                     classifications.append('text')
                     continue
                     continue
-                if not args.text_only:
+                if not text_only:
                     links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
                     links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
                     if is_safetensors:
                     if is_safetensors:
                         has_safetensors = True
                         has_safetensors = True
@@ -177,80 +157,114 @@ def get_download_links_from_huggingface(model, branch):
     return links, sha256, is_lora
     return links, sha256, is_lora
 
 
 
 
-def download_files(file_list, output_folder, num_threads=8):
-    thread_map(lambda url: get_file(url, output_folder), file_list, max_workers=num_threads, disable=True)
-
-
-if __name__ == '__main__':
-    model = args.MODEL
-    branch = args.branch
-    if model is None:
-        model, branch = select_model_from_default_options()
-    else:
-        if model[-1] == '/':
-            model = model[:-1]
-            branch = args.branch
-        if branch is None:
-            branch = "main"
-        else:
-            try:
-                branch = sanitize_branch_name(branch)
-            except ValueError as err_branch:
-                print(f"Error: {err_branch}")
-                sys.exit()
-
-    links, sha256, is_lora = get_download_links_from_huggingface(model, branch)
-
-    if args.output is not None:
-        base_folder = args.output
-    else:
+def get_output_folder(model, branch, is_lora, base_folder=None):
+    if base_folder is None:
         base_folder = 'models' if not is_lora else 'loras'
         base_folder = 'models' if not is_lora else 'loras'
 
 
     output_folder = f"{'_'.join(model.split('/')[-2:])}"
     output_folder = f"{'_'.join(model.split('/')[-2:])}"
     if branch != 'main':
     if branch != 'main':
         output_folder += f'_{branch}'
         output_folder += f'_{branch}'
     output_folder = Path(base_folder) / output_folder
     output_folder = Path(base_folder) / output_folder
+    return output_folder
+
+
+def get_single_file(url, output_folder, start_from_scratch=False):
+    filename = Path(url.rsplit('/', 1)[1])
+    output_path = output_folder / filename
+    if output_path.exists() and not start_from_scratch:
+        # Check if the file has already been downloaded completely
+        r = requests.get(url, stream=True)
+        total_size = int(r.headers.get('content-length', 0))
+        if output_path.stat().st_size >= total_size:
+            return
+        # Otherwise, resume the download from where it left off
+        headers = {'Range': f'bytes={output_path.stat().st_size}-'}
+        mode = 'ab'
+    else:
+        headers = {}
+        mode = 'wb'
+
+    r = requests.get(url, stream=True, headers=headers)
+    with open(output_path, mode) as f:
+        total_size = int(r.headers.get('content-length', 0))
+        block_size = 1024
+        with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True, bar_format='{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}') as t:
+            for data in r.iter_content(block_size):
+                t.update(len(data))
+                f.write(data)
 
 
-    if args.check:
-        # Validate the checksums
-        validated = True
-        for i in range(len(sha256)):
-            fpath = (output_folder / sha256[i][0])
 
 
-            if not fpath.exists():
-                print(f"The following file is missing: {fpath}")
+def start_download_threads(file_list, output_folder, start_from_scratch=False, threads=1):
+    thread_map(lambda url: get_single_file(url, output_folder, start_from_scratch=start_from_scratch), file_list, max_workers=threads, disable=True)
+
+
+def download_model_files(model, branch, links, sha256, output_folder, start_from_scratch=False, threads=1):
+    # Creating the folder and writing the metadata
+    if not output_folder.exists():
+        output_folder.mkdir()
+    with open(output_folder / 'huggingface-metadata.txt', 'w') as f:
+        f.write(f'url: https://huggingface.co/{model}\n')
+        f.write(f'branch: {branch}\n')
+        f.write(f'download date: {str(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))}\n')
+        sha256_str = ''
+        for i in range(len(sha256)):
+            sha256_str += f'    {sha256[i][1]} {sha256[i][0]}\n'
+        if sha256_str != '':
+            f.write(f'sha256sum:\n{sha256_str}')
+
+    # Downloading the files
+    print(f"Downloading the model to {output_folder}")
+    start_download_threads(links, output_folder, start_from_scratch=start_from_scratch, threads=threads)
+
+
+def check_model_files(model, branch, links, sha256, output_folder):
+    # Validate the checksums
+    validated = True
+    for i in range(len(sha256)):
+        fpath = (output_folder / sha256[i][0])
+
+        if not fpath.exists():
+            print(f"The following file is missing: {fpath}")
+            validated = False
+            continue
+
+        with open(output_folder / sha256[i][0], "rb") as f:
+            bytes = f.read()
+            file_hash = hashlib.sha256(bytes).hexdigest()
+            if file_hash != sha256[i][1]:
+                print(f'Checksum failed: {sha256[i][0]}  {sha256[i][1]}')
                 validated = False
                 validated = False
-                continue
-
-            with open(output_folder / sha256[i][0], "rb") as f:
-                bytes = f.read()
-                file_hash = hashlib.sha256(bytes).hexdigest()
-                if file_hash != sha256[i][1]:
-                    print(f'Checksum failed: {sha256[i][0]}  {sha256[i][1]}')
-                    validated = False
-                else:
-                    print(f'Checksum validated: {sha256[i][0]}  {sha256[i][1]}')
-
-        if validated:
-            print('[+] Validated checksums of all model files!')
-        else:
-            print('[-] Invalid checksums. Rerun download-model.py with the --clean flag.')
+            else:
+                print(f'Checksum validated: {sha256[i][0]}  {sha256[i][1]}')
 
 
+    if validated:
+        print('[+] Validated checksums of all model files!')
     else:
     else:
+        print('[-] Invalid checksums. Rerun download-model.py with the --clean flag.')
 
 
-        # Creating the folder and writing the metadata
-        if not output_folder.exists():
-            output_folder.mkdir()
-        with open(output_folder / 'huggingface-metadata.txt', 'w') as f:
-            f.write(f'url: https://huggingface.co/{model}\n')
-            f.write(f'branch: {branch}\n')
-            f.write(f'download date: {str(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))}\n')
-            sha256_str = ''
-            for i in range(len(sha256)):
-                sha256_str += f'    {sha256[i][1]} {sha256[i][0]}\n'
-            if sha256_str != '':
-                f.write(f'sha256sum:\n{sha256_str}')
-
-        # Downloading the files
-        print(f"Downloading the model to {output_folder}")
-        download_files(links, output_folder, args.threads)
+
+if __name__ == '__main__':
+    branch = args.branch
+    model = args.MODEL
+    if model is None:
+        model, branch = select_model_from_default_options()
+
+    # Cleaning up the model/branch names
+    try:
+        model, branch = sanitize_model_and_branch_names(model, branch)
+    except ValueError as err_branch:
+        print(f"Error: {err_branch}")
+        sys.exit()
+
+    # Getting the download links from Hugging Face
+    links, sha256, is_lora = get_download_links_from_huggingface(model, branch, text_only=args.text_only)
+
+    # Getting the output folder
+    output_folder = get_output_folder(model, branch, is_lora, base_folder=args.output)
+
+    if args.check:
+        # Check previously downloaded files
+        check_model_files(model, branch, links, sha256, output_folder)
+    else:
+        # Download files
+        download_model_files(model, branch, links, sha256, output_folder, threads=args.threads)

+ 42 - 6
extensions/character_bias/script.py

@@ -1,8 +1,23 @@
 import gradio as gr
 import gradio as gr
+import os
+
+# get the current directory of the script
+current_dir = os.path.dirname(os.path.abspath(__file__))
+
+# check if the bias_options.txt file exists, if not, create it
+bias_file = os.path.join(current_dir, "bias_options.txt")
+if not os.path.isfile(bias_file):
+    with open(bias_file, "w") as f:
+        f.write("*I am so happy*\n*I am so sad*\n*I am so excited*\n*I am so bored*\n*I am so angry*")
+
+# read bias options from the text file
+with open(bias_file, "r") as f:
+    bias_options = [line.strip() for line in f.readlines()]
 
 
 params = {
 params = {
     "activate": True,
     "activate": True,
     "bias string": " *I am so happy*",
     "bias string": " *I am so happy*",
+    "use custom string": False,
 }
 }
 
 
 
 
@@ -11,7 +26,6 @@ def input_modifier(string):
     This function is applied to your text inputs before
     This function is applied to your text inputs before
     they are fed into the model.
     they are fed into the model.
     """
     """
-
     return string
     return string
 
 
 
 
@@ -19,7 +33,6 @@ def output_modifier(string):
     """
     """
     This function is applied to the model outputs.
     This function is applied to the model outputs.
     """
     """
-
     return string
     return string
 
 
 
 
@@ -29,9 +42,11 @@ def bot_prefix_modifier(string):
     the prefix text for the Bot and can be used to bias its
     the prefix text for the Bot and can be used to bias its
     behavior.
     behavior.
     """
     """
-
     if params['activate']:
     if params['activate']:
-        return f'{string} {params["bias string"].strip()} '
+        if params['use custom string']:
+            return f'{string} {params["custom string"].strip()} '
+        else:
+            return f'{string} {params["bias string"].strip()} '
     else:
     else:
         return string
         return string
 
 
@@ -39,8 +54,29 @@ def bot_prefix_modifier(string):
 def ui():
 def ui():
     # Gradio elements
     # Gradio elements
     activate = gr.Checkbox(value=params['activate'], label='Activate character bias')
     activate = gr.Checkbox(value=params['activate'], label='Activate character bias')
-    string = gr.Textbox(value=params["bias string"], label='Character bias')
+    dropdown_string = gr.Dropdown(choices=bias_options, value=params["bias string"], label='Character bias', info='To edit the options in this dropdown edit the "bias_options.txt" file')
+    use_custom_string = gr.Checkbox(value=False, label='Use custom bias textbox instead of dropdown')
+    custom_string = gr.Textbox(value="", placeholder="Enter custom bias string", label="Custom Character Bias", info='To use this textbox activate the checkbox above')
 
 
     # Event functions to update the parameters in the backend
     # Event functions to update the parameters in the backend
-    string.change(lambda x: params.update({"bias string": x}), string, None)
+    def update_bias_string(x):
+        if x:
+            params.update({"bias string": x})
+        else:
+            params.update({"bias string": dropdown_string.get()})
+        return x
+
+    def update_custom_string(x):
+        params.update({"custom string": x})
+
+    dropdown_string.change(update_bias_string, dropdown_string, None)
+    custom_string.change(update_custom_string, custom_string, None)
     activate.change(lambda x: params.update({"activate": x}), activate, None)
     activate.change(lambda x: params.update({"activate": x}), activate, None)
+    use_custom_string.change(lambda x: params.update({"use custom string": x}), use_custom_string, None)
+
+    # Group elements together depending on the selected option
+    def bias_string_group():
+        if use_custom_string.value:
+            return gr.Group([use_custom_string, custom_string])
+        else:
+            return dropdown_string

+ 6 - 5
modules/GPTQ_loader.py

@@ -100,10 +100,10 @@ def load_quantized(model_name):
     found_safetensors = list(path_to_model.glob("*.safetensors"))
     found_safetensors = list(path_to_model.glob("*.safetensors"))
     pt_path = None
     pt_path = None
 
 
-    if len(found_pts) == 1:
-        pt_path = found_pts[0]
-    elif len(found_safetensors) == 1:
-        pt_path = found_safetensors[0]
+    if len(found_pts) > 0:
+        pt_path = found_pts[-1]
+    elif len(found_safetensors) > 0:
+        pt_path = found_safetensors[-1]
     else:
     else:
         if path_to_model.name.lower().startswith('llama-7b'):
         if path_to_model.name.lower().startswith('llama-7b'):
             pt_model = f'llama-7b-{shared.args.wbits}bit'
             pt_model = f'llama-7b-{shared.args.wbits}bit'
@@ -119,13 +119,14 @@ def load_quantized(model_name):
         # Try to find the .safetensors or .pt both in the model dir and in the subfolder
         # Try to find the .safetensors or .pt both in the model dir and in the subfolder
         for path in [Path(p + ext) for ext in ['.safetensors', '.pt'] for p in [f"{shared.args.model_dir}/{pt_model}", f"{path_to_model}/{pt_model}"]]:
         for path in [Path(p + ext) for ext in ['.safetensors', '.pt'] for p in [f"{shared.args.model_dir}/{pt_model}", f"{path_to_model}/{pt_model}"]]:
             if path.exists():
             if path.exists():
-                print(f"Found {path}")
                 pt_path = path
                 pt_path = path
                 break
                 break
 
 
     if not pt_path:
     if not pt_path:
         print("Could not find the quantized model in .pt or .safetensors format, exiting...")
         print("Could not find the quantized model in .pt or .safetensors format, exiting...")
         exit()
         exit()
+    else:
+        print(f"Found the following quantized model: {pt_path}")
 
 
     # qwopqwop200's offload
     # qwopqwop200's offload
     if model_type == 'llama' and shared.args.pre_layer:
     if model_type == 'llama' and shared.args.pre_layer:

+ 49 - 13
modules/chat.py

@@ -22,6 +22,7 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
     is_instruct = kwargs['is_instruct'] if 'is_instruct' in kwargs else False
     is_instruct = kwargs['is_instruct'] if 'is_instruct' in kwargs else False
     end_of_turn = kwargs['end_of_turn'] if 'end_of_turn' in kwargs else ''
     end_of_turn = kwargs['end_of_turn'] if 'end_of_turn' in kwargs else ''
     impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
     impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
+    _continue = kwargs['_continue'] if '_continue' in kwargs else False
     also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
     also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
     rows = [f"{context.strip()}\n"]
     rows = [f"{context.strip()}\n"]
 
 
@@ -39,7 +40,10 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
 
 
     i = len(shared.history['internal']) - 1
     i = len(shared.history['internal']) - 1
     while i >= 0 and len(encode(''.join(rows), max_new_tokens)[0]) < max_length:
     while i >= 0 and len(encode(''.join(rows), max_new_tokens)[0]) < max_length:
-        rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{end_of_turn}\n")
+        if _continue and i == len(shared.history['internal']) - 1:
+            rows.insert(1, f"{prefix2}{shared.history['internal'][i][1]}")
+        else:
+            rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{end_of_turn}\n")
         string = shared.history['internal'][i][0]
         string = shared.history['internal'][i][0]
         if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
         if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
             rows.insert(1, f"{prefix1}{string.strip()}{end_of_turn}\n")
             rows.insert(1, f"{prefix1}{string.strip()}{end_of_turn}\n")
@@ -48,6 +52,8 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
     if impersonate:
     if impersonate:
         rows.append(f"{prefix1.strip() if not is_instruct else prefix1}")
         rows.append(f"{prefix1.strip() if not is_instruct else prefix1}")
         limit = 2
         limit = 2
+    elif _continue:
+        limit = 3
     else:
     else:
         # Adding the user message
         # Adding the user message
         user_input = fix_newlines(user_input)
         user_input = fix_newlines(user_input)
@@ -99,7 +105,7 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline):
     return reply, next_character_found
     return reply, next_character_found
 
 
 
 
-def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False):
+def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False, _continue=False):
     if mode == 'instruct':
     if mode == 'instruct':
         stopping_strings = [f"\n{name1}", f"\n{name2}"]
         stopping_strings = [f"\n{name1}", f"\n{name2}"]
     else:
     else:
@@ -107,6 +113,7 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
 
 
     # Defining some variables
     # Defining some variables
     cumulative_reply = ''
     cumulative_reply = ''
+    last_reply = [shared.history['internal'][-1][1], shared.history['visible'][-1][1]] if _continue else None
     just_started = True
     just_started = True
     name1_original = name1
     name1_original = name1
     visible_text = custom_generate_chat_prompt = None
     visible_text = custom_generate_chat_prompt = None
@@ -124,17 +131,22 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
 
 
     if visible_text is None:
     if visible_text is None:
         visible_text = text
         visible_text = text
-    text = apply_extensions(text, "input")
+    if not _continue:
+        text = apply_extensions(text, "input")
 
 
     # Generating the prompt
     # Generating the prompt
-    kwargs = {'end_of_turn': end_of_turn, 'is_instruct': mode == 'instruct'}
+    kwargs = {
+        'end_of_turn': end_of_turn,
+        'is_instruct': mode == 'instruct',
+        '_continue': _continue
+    }
     if custom_generate_chat_prompt is None:
     if custom_generate_chat_prompt is None:
         prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
         prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
     else:
     else:
         prompt = custom_generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
         prompt = custom_generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
 
 
     # Yield *Is typing...*
     # Yield *Is typing...*
-    if not regenerate:
+    if not any((regenerate, _continue)):
         yield shared.history['visible'] + [[visible_text, shared.processing_message]]
         yield shared.history['visible'] + [[visible_text, shared.processing_message]]
 
 
     # Generate
     # Generate
@@ -154,11 +166,17 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
                 return shared.history['visible']
                 return shared.history['visible']
             if just_started:
             if just_started:
                 just_started = False
                 just_started = False
-                shared.history['internal'].append(['', ''])
-                shared.history['visible'].append(['', ''])
-
-            shared.history['internal'][-1] = [text, reply]
-            shared.history['visible'][-1] = [visible_text, visible_reply]
+                if not _continue:
+                    shared.history['internal'].append(['', ''])
+                    shared.history['visible'].append(['', ''])
+
+            if _continue:
+                sep = list(map(lambda x : ' ' if x[-1] != ' ' else '', last_reply))
+                shared.history['internal'][-1] = [text, f'{last_reply[0]}{sep[0]}{reply}']
+                shared.history['visible'][-1] = [visible_text, f'{last_reply[1]}{sep[1]}{visible_reply}']
+            else:
+                shared.history['internal'][-1] = [text, reply]
+                shared.history['visible'][-1] = [visible_text, visible_reply]
             if not shared.args.no_stream:
             if not shared.args.no_stream:
                 yield shared.history['visible']
                 yield shared.history['visible']
             if next_character_found:
             if next_character_found:
@@ -220,6 +238,16 @@ def regenerate_wrapper(text, generate_state, name1, name2, context, mode, end_of
             yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
             yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
 
 
 
 
+def continue_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
+    if (len(shared.history['visible']) == 1 and not shared.history['visible'][0][0]) or len(shared.history['internal']) == 0:
+        yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
+    else:
+        # Yield ' ...'
+        yield chat_html_wrapper(shared.history['visible'][:-1] + [[shared.history['visible'][-1][0], shared.history['visible'][-1][1] + ' ...']], name1, name2, mode)
+        for history in chatbot_wrapper(shared.history['internal'][-1][0], generate_state, name1, name2, context, mode, end_of_turn, _continue=True):
+            yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
+
+
 def remove_last_message(name1, name2, mode):
 def remove_last_message(name1, name2, mode):
     if len(shared.history['visible']) > 0 and shared.history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>':
     if len(shared.history['visible']) > 0 and shared.history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>':
         last = shared.history['visible'].pop()
         last = shared.history['visible'].pop()
@@ -256,6 +284,9 @@ def clear_chat_log(name1, name2, greeting, mode):
     if greeting != '':
     if greeting != '':
         shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
         shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
         shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
         shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
+    
+    # Save cleared logs
+    save_history(timestamp=False)
 
 
     return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
     return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
 
 
@@ -406,9 +437,14 @@ def load_character(character, name1, name2, mode):
 
 
     if Path(f'logs/{shared.character}_persistent.json').exists():
     if Path(f'logs/{shared.character}_persistent.json').exists():
         load_history(open(Path(f'logs/{shared.character}_persistent.json'), 'rb').read(), name1, name2)
         load_history(open(Path(f'logs/{shared.character}_persistent.json'), 'rb').read(), name1, name2)
-    elif greeting != "":
-        shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
-        shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
+    else:
+        # Insert greeting if it exists
+        if greeting != "":
+            shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
+            shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
+        
+        # Create .json log files since they don't already exist
+        save_history(timestamp=False)
 
 
     return name1, name2, picture, greeting, context, end_of_turn, chat_html_wrapper(shared.history['visible'], name1, name2, mode, reset_cache=True)
     return name1, name2, picture, greeting, context, end_of_turn, chat_html_wrapper(shared.history['visible'], name1, name2, mode, reset_cache=True)
 
 

+ 176 - 0
modules/llama_attn_hijack.py

@@ -0,0 +1,176 @@
+import math
+import sys
+import torch
+import torch.nn as nn
+import transformers.models.llama.modeling_llama
+
+from typing import Optional
+from typing import Tuple
+
+import modules.shared as shared
+
+
+if shared.args.xformers:
+    try:
+        import xformers.ops
+    except Exception:
+        print("🔴 xformers not found! Please install it before trying to use it.", file=sys.stderr)
+
+
+def hijack_llama_attention():
+    if shared.args.xformers:
+        transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
+        print("Replaced attention with xformers_attention")
+    elif shared.args.sdp_attention:
+        transformers.models.llama.modeling_llama.LlamaAttention.forward = sdp_attention_forward
+        print("Replaced attention with sdp_attention")
+
+
+def xformers_forward(
+    self,
+    hidden_states: torch.Tensor,
+    attention_mask: Optional[torch.Tensor] = None,
+    position_ids: Optional[torch.LongTensor] = None,
+    past_key_value: Optional[Tuple[torch.Tensor]] = None,
+    output_attentions: bool = False,
+    use_cache: bool = False,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+    bsz, q_len, _ = hidden_states.size()
+
+    query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+    key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+    value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+
+    kv_seq_len = key_states.shape[-2]
+    if past_key_value is not None:
+        kv_seq_len += past_key_value[0].shape[-2]
+    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+    query_states, key_states = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+    # [bsz, nh, t, hd]
+
+    if past_key_value is not None:
+        # reuse k, v, self_attention
+        key_states = torch.cat([past_key_value[0], key_states], dim=2)
+        value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+    past_key_value = (key_states, value_states) if use_cache else None
+
+    #We only apply xformers optimizations if we don't need to output the whole attention matrix
+    if not output_attentions:
+        dtype = query_states.dtype
+
+        query_states = query_states.transpose(1, 2)
+        key_states = key_states.transpose(1, 2)
+        value_states = value_states.transpose(1, 2)
+        
+        #This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
+        #We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
+        if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
+            # input and output should be of form (bsz, q_len, num_heads, head_dim)
+            attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=None)
+        else:
+            # input and output should be of form (bsz, q_len, num_heads, head_dim)
+            attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=xformers.ops.LowerTriangularMask())
+        attn_weights = None
+    else:
+        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
+            raise ValueError(
+                f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
+                f" {attn_weights.size()}"
+            )
+
+        if attention_mask is not None:
+            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+                raise ValueError(
+                    f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
+                )
+            attn_weights = attn_weights + attention_mask
+            attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
+
+        # upcast attention to fp32
+        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+        attn_output = torch.matmul(attn_weights, value_states)
+
+        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+            raise ValueError(
+                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+                f" {attn_output.size()}"
+            )
+
+        attn_output = attn_output.transpose(1, 2)
+
+    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+    attn_output = self.o_proj(attn_output)
+
+    return attn_output, attn_weights, past_key_value
+
+
+def sdp_attention_forward(
+    self,
+    hidden_states: torch.Tensor,
+    attention_mask: Optional[torch.Tensor] = None,
+    position_ids: Optional[torch.LongTensor] = None,
+    past_key_value: Optional[Tuple[torch.Tensor]] = None,
+    output_attentions: bool = False,
+    use_cache: bool = False,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+    bsz, q_len, _ = hidden_states.size()
+
+    query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+    key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+    value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+
+    kv_seq_len = key_states.shape[-2]
+    if past_key_value is not None:
+        kv_seq_len += past_key_value[0].shape[-2]
+    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+    query_states, key_states = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+    # [bsz, nh, t, hd]
+
+    if past_key_value is not None:
+        # reuse k, v, self_attention
+        key_states = torch.cat([past_key_value[0], key_states], dim=2)
+        value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+    past_key_value = (key_states, value_states) if use_cache else None
+
+    #We only apply sdp attention if we don't need to output the whole attention matrix
+    if not output_attentions:
+        attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask, is_causal=False)
+        attn_weights = None
+    else:
+        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
+            raise ValueError(
+                f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
+                f" {attn_weights.size()}"
+            )
+
+        if attention_mask is not None:
+            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+                raise ValueError(
+                    f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
+                )
+            attn_weights = attn_weights + attention_mask
+            attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
+
+        # upcast attention to fp32
+        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+        attn_output = torch.matmul(attn_weights, value_states)
+
+        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+            raise ValueError(
+                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+                f" {attn_output.size()}"
+            )
+
+    attn_output = attn_output.transpose(1, 2)
+    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+    attn_output = self.o_proj(attn_output)
+
+    return attn_output, attn_weights, past_key_value

+ 13 - 0
modules/models.py

@@ -14,6 +14,7 @@ from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
                           BitsAndBytesConfig, LlamaTokenizer)
                           BitsAndBytesConfig, LlamaTokenizer)
 
 
 import modules.shared as shared
 import modules.shared as shared
+from modules import llama_attn_hijack
 
 
 transformers.logging.set_verbosity_error()
 transformers.logging.set_verbosity_error()
 
 
@@ -169,11 +170,23 @@ def load_model(model_name):
 
 
         model = AutoModelForCausalLM.from_pretrained(checkpoint, **params)
         model = AutoModelForCausalLM.from_pretrained(checkpoint, **params)
 
 
+    # Hijack attention with xformers
+    if any((shared.args.xformers, shared.args.sdp_attention)):
+        llama_attn_hijack.hijack_llama_attention()
+
     # Loading the tokenizer
     # Loading the tokenizer
     if any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
     if any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
         tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
         tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
     elif type(model) is transformers.LlamaForCausalLM:
     elif type(model) is transformers.LlamaForCausalLM:
         tokenizer = LlamaTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"), clean_up_tokenization_spaces=True)
         tokenizer = LlamaTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"), clean_up_tokenization_spaces=True)
+        # Leaving this here until the LLaMA tokenizer gets figured out.
+        # For some people this fixes things, for others it causes an error.
+        try:
+            tokenizer.eos_token_id = 2
+            tokenizer.bos_token_id = 1
+            tokenizer.pad_token_id = 0
+        except:
+            pass
     else:
     else:
         tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"))
         tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"))
     tokenizer.truncation_side = 'left'
     tokenizer.truncation_side = 'left'

+ 2 - 0
modules/shared.py

@@ -98,6 +98,8 @@ parser.add_argument('--disk-cache-dir', type=str, default="cache", help='Directo
 parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
 parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
 parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
 parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
 parser.add_argument('--no-cache', action='store_true', help='Set use_cache to False while generating text. This reduces the VRAM usage a bit at a performance cost.')
 parser.add_argument('--no-cache', action='store_true', help='Set use_cache to False while generating text. This reduces the VRAM usage a bit at a performance cost.')
+parser.add_argument('--xformers', action='store_true', help="Use xformer's memory efficient attention. This should increase your tokens/s.")
+parser.add_argument('--sdp-attention', action='store_true', help="Use torch 2.0's sdp attention.")
 
 
 # llama.cpp
 # llama.cpp
 parser.add_argument('--threads', type=int, default=0, help='Number of threads to use in llama.cpp.')
 parser.add_argument('--threads', type=int, default=0, help='Number of threads to use in llama.cpp.')

+ 3 - 1
requirements.txt

@@ -1,5 +1,4 @@
 accelerate==0.18.0
 accelerate==0.18.0
-bitsandbytes==0.37.2
 datasets
 datasets
 flexgen==0.1.7
 flexgen==0.1.7
 gradio==3.24.1
 gradio==3.24.1
@@ -14,3 +13,6 @@ sentencepiece
 pyyaml
 pyyaml
 tqdm
 tqdm
 git+https://github.com/huggingface/transformers
 git+https://github.com/huggingface/transformers
+bitsandbytes==0.37.2; platform_system != "Windows"
+llama-cpp-python==0.1.30; platform_system != "Windows"
+https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.30/llama_cpp_python-0.1.30-cp310-cp310-win_amd64.whl; platform_system == "Windows"

+ 23 - 18
server.py

@@ -394,8 +394,9 @@ def create_interface():
                     shared.gradio['Generate'] = gr.Button('Generate', elem_id='Generate')
                     shared.gradio['Generate'] = gr.Button('Generate', elem_id='Generate')
                     shared.gradio['Stop'] = gr.Button('Stop', elem_id="stop")
                     shared.gradio['Stop'] = gr.Button('Stop', elem_id="stop")
                 with gr.Row():
                 with gr.Row():
-                    shared.gradio['Impersonate'] = gr.Button('Impersonate')
                     shared.gradio['Regenerate'] = gr.Button('Regenerate')
                     shared.gradio['Regenerate'] = gr.Button('Regenerate')
+                    shared.gradio['Continue'] = gr.Button('Continue')
+                    shared.gradio['Impersonate'] = gr.Button('Impersonate')
                 with gr.Row():
                 with gr.Row():
                     shared.gradio['Copy last reply'] = gr.Button('Copy last reply')
                     shared.gradio['Copy last reply'] = gr.Button('Copy last reply')
                     shared.gradio['Replace last reply'] = gr.Button('Replace last reply')
                     shared.gradio['Replace last reply'] = gr.Button('Replace last reply')
@@ -467,53 +468,57 @@ def create_interface():
             gen_events.append(shared.gradio['Generate'].click(
             gen_events.append(shared.gradio['Generate'].click(
                 lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
                 lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
                 chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
                 chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
-                lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
+                lambda: chat.save_history(timestamp=False), None, None, show_progress=False)
             )
             )
 
 
             gen_events.append(shared.gradio['textbox'].submit(
             gen_events.append(shared.gradio['textbox'].submit(
                 lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
                 lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
                 chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
                 chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
-                lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
+                lambda: chat.save_history(timestamp=False), None, None, show_progress=False)
             )
             )
 
 
             gen_events.append(shared.gradio['Regenerate'].click(
             gen_events.append(shared.gradio['Regenerate'].click(
                 chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
                 chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
-                lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then(
-                lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
+                lambda: chat.save_history(timestamp=False), None, None, show_progress=False)
+            )
+
+            gen_events.append(shared.gradio['Continue'].click(
+                chat.continue_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
+                lambda: chat.save_history(timestamp=False), None, None, show_progress=False)
             )
             )
 
 
             shared.gradio['Replace last reply'].click(
             shared.gradio['Replace last reply'].click(
                 chat.replace_last_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'Chat mode']], shared.gradio['display'], show_progress=shared.args.no_stream).then(
                 chat.replace_last_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'Chat mode']], shared.gradio['display'], show_progress=shared.args.no_stream).then(
                 lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then(
                 lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then(
-                lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
+                lambda: chat.save_history(timestamp=False), None, None, show_progress=False)
 
 
             shared.gradio['Clear history-confirm'].click(
             shared.gradio['Clear history-confirm'].click(
                 lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr).then(
                 lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr).then(
                 chat.clear_chat_log, [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'Chat mode']], shared.gradio['display']).then(
                 chat.clear_chat_log, [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'Chat mode']], shared.gradio['display']).then(
-                lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
+                lambda: chat.save_history(timestamp=False), None, None, show_progress=False)
 
 
             shared.gradio['Stop'].click(
             shared.gradio['Stop'].click(
-                stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None).then(
-                chat.redraw_html, reload_inputs, [shared.gradio['display']])
+                stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None).then(
+                chat.redraw_html, reload_inputs, shared.gradio['display'])
 
 
             shared.gradio['Chat mode'].change(
             shared.gradio['Chat mode'].change(
                 lambda x: gr.update(visible=x == 'instruct'), shared.gradio['Chat mode'], shared.gradio['Instruction templates']).then(
                 lambda x: gr.update(visible=x == 'instruct'), shared.gradio['Chat mode'], shared.gradio['Instruction templates']).then(
-                chat.redraw_html, reload_inputs, [shared.gradio['display']])
+                chat.redraw_html, reload_inputs, shared.gradio['display'])
 
 
             shared.gradio['Instruction templates'].change(
             shared.gradio['Instruction templates'].change(
                 lambda character, name1, name2, mode: chat.load_character(character, name1, name2, mode), [shared.gradio[k] for k in ['Instruction templates', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']]).then(
                 lambda character, name1, name2, mode: chat.load_character(character, name1, name2, mode), [shared.gradio[k] for k in ['Instruction templates', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']]).then(
-                chat.redraw_html, reload_inputs, [shared.gradio['display']])
+                chat.redraw_html, reload_inputs, shared.gradio['display'])
 
 
             shared.gradio['upload_chat_history'].upload(
             shared.gradio['upload_chat_history'].upload(
-                chat.load_history, [shared.gradio[k] for k in ['upload_chat_history', 'name1', 'name2']], []).then(
-                chat.redraw_html, reload_inputs, [shared.gradio['display']])
+                chat.load_history, [shared.gradio[k] for k in ['upload_chat_history', 'name1', 'name2']], None).then(
+                chat.redraw_html, reload_inputs, shared.gradio['display'])
 
 
             gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
             gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
-            shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream)
+            shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, None, shared.gradio['textbox'], show_progress=shared.args.no_stream)
             shared.gradio['Clear history'].click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
             shared.gradio['Clear history'].click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
             shared.gradio['Clear history-cancel'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
             shared.gradio['Clear history-cancel'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
             shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
             shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
-            shared.gradio['download_button'].click(chat.save_history, inputs=[], outputs=[shared.gradio['download']])
+            shared.gradio['download_button'].click(chat.save_history, inputs=None, outputs=[shared.gradio['download']])
             shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']])
             shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']])
             shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
             shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
             shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']])
             shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']])
@@ -521,7 +526,7 @@ def create_interface():
 
 
             shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}")
             shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}")
             shared.gradio['interface'].load(chat.load_default_history, [shared.gradio[k] for k in ['name1', 'name2']], None)
             shared.gradio['interface'].load(chat.load_default_history, [shared.gradio[k] for k in ['name1', 'name2']], None)
-            shared.gradio['interface'].load(chat.redraw_html, reload_inputs, [shared.gradio['display']], show_progress=True)
+            shared.gradio['interface'].load(chat.redraw_html, reload_inputs, shared.gradio['display'], show_progress=True)
 
 
         elif shared.args.notebook:
         elif shared.args.notebook:
             with gr.Tab("Text generation", elem_id="main"):
             with gr.Tab("Text generation", elem_id="main"):
@@ -555,7 +560,7 @@ def create_interface():
             output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
             output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
             gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
             gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
             gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
             gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
-            shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
+            shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None)
             shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
             shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
 
 
         else:
         else:
@@ -589,7 +594,7 @@ def create_interface():
             gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
             gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
             gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
             gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
             gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream))
             gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream))
-            shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
+            shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None)
             shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
             shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
 
 
         with gr.Tab("Model", elem_id="model-tab"):
         with gr.Tab("Model", elem_id="model-tab"):