Forráskód Böngészése

Merge branch 'main' into UsamaKenway-main

oobabooga 2 éve
szülő
commit
c6e9ba20a4

+ 2 - 0
README.md

@@ -215,6 +215,8 @@ Optionally, you can use the following command-line flags:
 | `--load-in-8bit`                            | Load the model with 8-bit precision.|
 | `--bf16`                                    | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. |
 | `--no-cache`                                | Set `use_cache` to False while generating text. This reduces the VRAM usage a bit with a performance cost. |
+| `--xformers`                                | Use xformer's memory efficient attention. This should increase your tokens/s. |
+| `--sdp-attention`                           | Use torch 2.0's sdp attention. |
 
 #### llama.cpp
 

+ 3 - 3
api-example.py

@@ -22,10 +22,10 @@ server = "127.0.0.1"
 params = {
     'max_new_tokens': 200,
     'do_sample': True,
-    'temperature': 0.5,
-    'top_p': 0.9,
+    'temperature': 0.72,
+    'top_p': 0.73,
     'typical_p': 1,
-    'repetition_penalty': 1.05,
+    'repetition_penalty': 1.1,
     'encoder_repetition_penalty': 1.0,
     'top_k': 0,
     'min_length': 0,

+ 115 - 101
download-model.py

@@ -19,6 +19,7 @@ import requests
 import tqdm
 from tqdm.contrib.concurrent import thread_map
 
+
 parser = argparse.ArgumentParser()
 parser.add_argument('MODEL', type=str, default=None, nargs='?')
 parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.')
@@ -30,40 +31,6 @@ parser.add_argument('--check', action='store_true', help='Validates the checksum
 args = parser.parse_args()
 
 
-def get_file(url, output_folder):
-    filename = Path(url.rsplit('/', 1)[1])
-    output_path = output_folder / filename
-    if output_path.exists() and not args.clean:
-        # Check if the file has already been downloaded completely
-        r = requests.get(url, stream=True)
-        total_size = int(r.headers.get('content-length', 0))
-        if output_path.stat().st_size >= total_size:
-            return
-        # Otherwise, resume the download from where it left off
-        headers = {'Range': f'bytes={output_path.stat().st_size}-'}
-        mode = 'ab'
-    else:
-        headers = {}
-        mode = 'wb'
-
-    r = requests.get(url, stream=True, headers=headers)
-    with open(output_path, mode) as f:
-        total_size = int(r.headers.get('content-length', 0))
-        block_size = 1024
-        with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True, bar_format='{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}') as t:
-            for data in r.iter_content(block_size):
-                t.update(len(data))
-                f.write(data)
-
-
-def sanitize_branch_name(branch_name):
-    pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
-    if pattern.match(branch_name):
-        return branch_name
-    else:
-        raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.")
-
-
 def select_model_from_default_options():
     models = {
         "OPT 6.7B": ("facebook", "opt-6.7b", "main"),
@@ -110,7 +77,20 @@ EleutherAI/pythia-1.4b-deduped
     return model, branch
 
 
-def get_download_links_from_huggingface(model, branch):
+def sanitize_model_and_branch_names(model, branch):
+    if model[-1] == '/':
+        model = model[:-1]
+    if branch is None:
+        branch = "main"
+    else:
+        pattern = re.compile(r"^[a-zA-Z0-9._-]+$")
+        if not pattern.match(branch):
+            raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.")
+
+    return model, branch
+
+
+def get_download_links_from_huggingface(model, branch, text_only=False):
     base = "https://huggingface.co"
     page = f"/api/models/{model}/tree/{branch}?cursor="
     cursor = b""
@@ -149,7 +129,7 @@ def get_download_links_from_huggingface(model, branch):
                     links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
                     classifications.append('text')
                     continue
-                if not args.text_only:
+                if not text_only:
                     links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
                     if is_safetensors:
                         has_safetensors = True
@@ -177,80 +157,114 @@ def get_download_links_from_huggingface(model, branch):
     return links, sha256, is_lora
 
 
-def download_files(file_list, output_folder, num_threads=8):
-    thread_map(lambda url: get_file(url, output_folder), file_list, max_workers=num_threads, disable=True)
-
-
-if __name__ == '__main__':
-    model = args.MODEL
-    branch = args.branch
-    if model is None:
-        model, branch = select_model_from_default_options()
-    else:
-        if model[-1] == '/':
-            model = model[:-1]
-            branch = args.branch
-        if branch is None:
-            branch = "main"
-        else:
-            try:
-                branch = sanitize_branch_name(branch)
-            except ValueError as err_branch:
-                print(f"Error: {err_branch}")
-                sys.exit()
-
-    links, sha256, is_lora = get_download_links_from_huggingface(model, branch)
-
-    if args.output is not None:
-        base_folder = args.output
-    else:
+def get_output_folder(model, branch, is_lora, base_folder=None):
+    if base_folder is None:
         base_folder = 'models' if not is_lora else 'loras'
 
     output_folder = f"{'_'.join(model.split('/')[-2:])}"
     if branch != 'main':
         output_folder += f'_{branch}'
     output_folder = Path(base_folder) / output_folder
+    return output_folder
+
+
+def get_single_file(url, output_folder, start_from_scratch=False):
+    filename = Path(url.rsplit('/', 1)[1])
+    output_path = output_folder / filename
+    if output_path.exists() and not start_from_scratch:
+        # Check if the file has already been downloaded completely
+        r = requests.get(url, stream=True)
+        total_size = int(r.headers.get('content-length', 0))
+        if output_path.stat().st_size >= total_size:
+            return
+        # Otherwise, resume the download from where it left off
+        headers = {'Range': f'bytes={output_path.stat().st_size}-'}
+        mode = 'ab'
+    else:
+        headers = {}
+        mode = 'wb'
+
+    r = requests.get(url, stream=True, headers=headers)
+    with open(output_path, mode) as f:
+        total_size = int(r.headers.get('content-length', 0))
+        block_size = 1024
+        with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True, bar_format='{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}') as t:
+            for data in r.iter_content(block_size):
+                t.update(len(data))
+                f.write(data)
 
-    if args.check:
-        # Validate the checksums
-        validated = True
-        for i in range(len(sha256)):
-            fpath = (output_folder / sha256[i][0])
 
-            if not fpath.exists():
-                print(f"The following file is missing: {fpath}")
+def start_download_threads(file_list, output_folder, start_from_scratch=False, threads=1):
+    thread_map(lambda url: get_single_file(url, output_folder, start_from_scratch=start_from_scratch), file_list, max_workers=threads, disable=True)
+
+
+def download_model_files(model, branch, links, sha256, output_folder, start_from_scratch=False, threads=1):
+    # Creating the folder and writing the metadata
+    if not output_folder.exists():
+        output_folder.mkdir()
+    with open(output_folder / 'huggingface-metadata.txt', 'w') as f:
+        f.write(f'url: https://huggingface.co/{model}\n')
+        f.write(f'branch: {branch}\n')
+        f.write(f'download date: {str(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))}\n')
+        sha256_str = ''
+        for i in range(len(sha256)):
+            sha256_str += f'    {sha256[i][1]} {sha256[i][0]}\n'
+        if sha256_str != '':
+            f.write(f'sha256sum:\n{sha256_str}')
+
+    # Downloading the files
+    print(f"Downloading the model to {output_folder}")
+    start_download_threads(links, output_folder, start_from_scratch=start_from_scratch, threads=threads)
+
+
+def check_model_files(model, branch, links, sha256, output_folder):
+    # Validate the checksums
+    validated = True
+    for i in range(len(sha256)):
+        fpath = (output_folder / sha256[i][0])
+
+        if not fpath.exists():
+            print(f"The following file is missing: {fpath}")
+            validated = False
+            continue
+
+        with open(output_folder / sha256[i][0], "rb") as f:
+            bytes = f.read()
+            file_hash = hashlib.sha256(bytes).hexdigest()
+            if file_hash != sha256[i][1]:
+                print(f'Checksum failed: {sha256[i][0]}  {sha256[i][1]}')
                 validated = False
-                continue
-
-            with open(output_folder / sha256[i][0], "rb") as f:
-                bytes = f.read()
-                file_hash = hashlib.sha256(bytes).hexdigest()
-                if file_hash != sha256[i][1]:
-                    print(f'Checksum failed: {sha256[i][0]}  {sha256[i][1]}')
-                    validated = False
-                else:
-                    print(f'Checksum validated: {sha256[i][0]}  {sha256[i][1]}')
-
-        if validated:
-            print('[+] Validated checksums of all model files!')
-        else:
-            print('[-] Invalid checksums. Rerun download-model.py with the --clean flag.')
+            else:
+                print(f'Checksum validated: {sha256[i][0]}  {sha256[i][1]}')
 
+    if validated:
+        print('[+] Validated checksums of all model files!')
     else:
+        print('[-] Invalid checksums. Rerun download-model.py with the --clean flag.')
 
-        # Creating the folder and writing the metadata
-        if not output_folder.exists():
-            output_folder.mkdir()
-        with open(output_folder / 'huggingface-metadata.txt', 'w') as f:
-            f.write(f'url: https://huggingface.co/{model}\n')
-            f.write(f'branch: {branch}\n')
-            f.write(f'download date: {str(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))}\n')
-            sha256_str = ''
-            for i in range(len(sha256)):
-                sha256_str += f'    {sha256[i][1]} {sha256[i][0]}\n'
-            if sha256_str != '':
-                f.write(f'sha256sum:\n{sha256_str}')
-
-        # Downloading the files
-        print(f"Downloading the model to {output_folder}")
-        download_files(links, output_folder, args.threads)
+
+if __name__ == '__main__':
+    branch = args.branch
+    model = args.MODEL
+    if model is None:
+        model, branch = select_model_from_default_options()
+
+    # Cleaning up the model/branch names
+    try:
+        model, branch = sanitize_model_and_branch_names(model, branch)
+    except ValueError as err_branch:
+        print(f"Error: {err_branch}")
+        sys.exit()
+
+    # Getting the download links from Hugging Face
+    links, sha256, is_lora = get_download_links_from_huggingface(model, branch, text_only=args.text_only)
+
+    # Getting the output folder
+    output_folder = get_output_folder(model, branch, is_lora, base_folder=args.output)
+
+    if args.check:
+        # Check previously downloaded files
+        check_model_files(model, branch, links, sha256, output_folder)
+    else:
+        # Download files
+        download_model_files(model, branch, links, sha256, output_folder, threads=args.threads)

+ 42 - 6
extensions/character_bias/script.py

@@ -1,8 +1,23 @@
 import gradio as gr
+import os
+
+# get the current directory of the script
+current_dir = os.path.dirname(os.path.abspath(__file__))
+
+# check if the bias_options.txt file exists, if not, create it
+bias_file = os.path.join(current_dir, "bias_options.txt")
+if not os.path.isfile(bias_file):
+    with open(bias_file, "w") as f:
+        f.write("*I am so happy*\n*I am so sad*\n*I am so excited*\n*I am so bored*\n*I am so angry*")
+
+# read bias options from the text file
+with open(bias_file, "r") as f:
+    bias_options = [line.strip() for line in f.readlines()]
 
 params = {
     "activate": True,
     "bias string": " *I am so happy*",
+    "use custom string": False,
 }
 
 
@@ -11,7 +26,6 @@ def input_modifier(string):
     This function is applied to your text inputs before
     they are fed into the model.
     """
-
     return string
 
 
@@ -19,7 +33,6 @@ def output_modifier(string):
     """
     This function is applied to the model outputs.
     """
-
     return string
 
 
@@ -29,9 +42,11 @@ def bot_prefix_modifier(string):
     the prefix text for the Bot and can be used to bias its
     behavior.
     """
-
     if params['activate']:
-        return f'{string} {params["bias string"].strip()} '
+        if params['use custom string']:
+            return f'{string} {params["custom string"].strip()} '
+        else:
+            return f'{string} {params["bias string"].strip()} '
     else:
         return string
 
@@ -39,8 +54,29 @@ def bot_prefix_modifier(string):
 def ui():
     # Gradio elements
     activate = gr.Checkbox(value=params['activate'], label='Activate character bias')
-    string = gr.Textbox(value=params["bias string"], label='Character bias')
+    dropdown_string = gr.Dropdown(choices=bias_options, value=params["bias string"], label='Character bias', info='To edit the options in this dropdown edit the "bias_options.txt" file')
+    use_custom_string = gr.Checkbox(value=False, label='Use custom bias textbox instead of dropdown')
+    custom_string = gr.Textbox(value="", placeholder="Enter custom bias string", label="Custom Character Bias", info='To use this textbox activate the checkbox above')
 
     # Event functions to update the parameters in the backend
-    string.change(lambda x: params.update({"bias string": x}), string, None)
+    def update_bias_string(x):
+        if x:
+            params.update({"bias string": x})
+        else:
+            params.update({"bias string": dropdown_string.get()})
+        return x
+
+    def update_custom_string(x):
+        params.update({"custom string": x})
+
+    dropdown_string.change(update_bias_string, dropdown_string, None)
+    custom_string.change(update_custom_string, custom_string, None)
     activate.change(lambda x: params.update({"activate": x}), activate, None)
+    use_custom_string.change(lambda x: params.update({"use custom string": x}), use_custom_string, None)
+
+    # Group elements together depending on the selected option
+    def bias_string_group():
+        if use_custom_string.value:
+            return gr.Group([use_custom_string, custom_string])
+        else:
+            return dropdown_string

+ 6 - 5
modules/GPTQ_loader.py

@@ -100,10 +100,10 @@ def load_quantized(model_name):
     found_safetensors = list(path_to_model.glob("*.safetensors"))
     pt_path = None
 
-    if len(found_pts) == 1:
-        pt_path = found_pts[0]
-    elif len(found_safetensors) == 1:
-        pt_path = found_safetensors[0]
+    if len(found_pts) > 0:
+        pt_path = found_pts[-1]
+    elif len(found_safetensors) > 0:
+        pt_path = found_safetensors[-1]
     else:
         if path_to_model.name.lower().startswith('llama-7b'):
             pt_model = f'llama-7b-{shared.args.wbits}bit'
@@ -119,13 +119,14 @@ def load_quantized(model_name):
         # Try to find the .safetensors or .pt both in the model dir and in the subfolder
         for path in [Path(p + ext) for ext in ['.safetensors', '.pt'] for p in [f"{shared.args.model_dir}/{pt_model}", f"{path_to_model}/{pt_model}"]]:
             if path.exists():
-                print(f"Found {path}")
                 pt_path = path
                 break
 
     if not pt_path:
         print("Could not find the quantized model in .pt or .safetensors format, exiting...")
         exit()
+    else:
+        print(f"Found the following quantized model: {pt_path}")
 
     # qwopqwop200's offload
     if model_type == 'llama' and shared.args.pre_layer:

+ 49 - 13
modules/chat.py

@@ -22,6 +22,7 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
     is_instruct = kwargs['is_instruct'] if 'is_instruct' in kwargs else False
     end_of_turn = kwargs['end_of_turn'] if 'end_of_turn' in kwargs else ''
     impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
+    _continue = kwargs['_continue'] if '_continue' in kwargs else False
     also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
     rows = [f"{context.strip()}\n"]
 
@@ -39,7 +40,10 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
 
     i = len(shared.history['internal']) - 1
     while i >= 0 and len(encode(''.join(rows), max_new_tokens)[0]) < max_length:
-        rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{end_of_turn}\n")
+        if _continue and i == len(shared.history['internal']) - 1:
+            rows.insert(1, f"{prefix2}{shared.history['internal'][i][1]}")
+        else:
+            rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{end_of_turn}\n")
         string = shared.history['internal'][i][0]
         if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
             rows.insert(1, f"{prefix1}{string.strip()}{end_of_turn}\n")
@@ -48,6 +52,8 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat
     if impersonate:
         rows.append(f"{prefix1.strip() if not is_instruct else prefix1}")
         limit = 2
+    elif _continue:
+        limit = 3
     else:
         # Adding the user message
         user_input = fix_newlines(user_input)
@@ -99,7 +105,7 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline):
     return reply, next_character_found
 
 
-def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False):
+def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False, _continue=False):
     if mode == 'instruct':
         stopping_strings = [f"\n{name1}", f"\n{name2}"]
     else:
@@ -107,6 +113,7 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
 
     # Defining some variables
     cumulative_reply = ''
+    last_reply = [shared.history['internal'][-1][1], shared.history['visible'][-1][1]] if _continue else None
     just_started = True
     name1_original = name1
     visible_text = custom_generate_chat_prompt = None
@@ -124,17 +131,22 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
 
     if visible_text is None:
         visible_text = text
-    text = apply_extensions(text, "input")
+    if not _continue:
+        text = apply_extensions(text, "input")
 
     # Generating the prompt
-    kwargs = {'end_of_turn': end_of_turn, 'is_instruct': mode == 'instruct'}
+    kwargs = {
+        'end_of_turn': end_of_turn,
+        'is_instruct': mode == 'instruct',
+        '_continue': _continue
+    }
     if custom_generate_chat_prompt is None:
         prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
     else:
         prompt = custom_generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
 
     # Yield *Is typing...*
-    if not regenerate:
+    if not any((regenerate, _continue)):
         yield shared.history['visible'] + [[visible_text, shared.processing_message]]
 
     # Generate
@@ -154,11 +166,17 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu
                 return shared.history['visible']
             if just_started:
                 just_started = False
-                shared.history['internal'].append(['', ''])
-                shared.history['visible'].append(['', ''])
-
-            shared.history['internal'][-1] = [text, reply]
-            shared.history['visible'][-1] = [visible_text, visible_reply]
+                if not _continue:
+                    shared.history['internal'].append(['', ''])
+                    shared.history['visible'].append(['', ''])
+
+            if _continue:
+                sep = list(map(lambda x : ' ' if x[-1] != ' ' else '', last_reply))
+                shared.history['internal'][-1] = [text, f'{last_reply[0]}{sep[0]}{reply}']
+                shared.history['visible'][-1] = [visible_text, f'{last_reply[1]}{sep[1]}{visible_reply}']
+            else:
+                shared.history['internal'][-1] = [text, reply]
+                shared.history['visible'][-1] = [visible_text, visible_reply]
             if not shared.args.no_stream:
                 yield shared.history['visible']
             if next_character_found:
@@ -220,6 +238,16 @@ def regenerate_wrapper(text, generate_state, name1, name2, context, mode, end_of
             yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
 
 
+def continue_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
+    if (len(shared.history['visible']) == 1 and not shared.history['visible'][0][0]) or len(shared.history['internal']) == 0:
+        yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
+    else:
+        # Yield ' ...'
+        yield chat_html_wrapper(shared.history['visible'][:-1] + [[shared.history['visible'][-1][0], shared.history['visible'][-1][1] + ' ...']], name1, name2, mode)
+        for history in chatbot_wrapper(shared.history['internal'][-1][0], generate_state, name1, name2, context, mode, end_of_turn, _continue=True):
+            yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
+
+
 def remove_last_message(name1, name2, mode):
     if len(shared.history['visible']) > 0 and shared.history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>':
         last = shared.history['visible'].pop()
@@ -256,6 +284,9 @@ def clear_chat_log(name1, name2, greeting, mode):
     if greeting != '':
         shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
         shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
+    
+    # Save cleared logs
+    save_history(timestamp=False)
 
     return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
 
@@ -406,9 +437,14 @@ def load_character(character, name1, name2, mode):
 
     if Path(f'logs/{shared.character}_persistent.json').exists():
         load_history(open(Path(f'logs/{shared.character}_persistent.json'), 'rb').read(), name1, name2)
-    elif greeting != "":
-        shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
-        shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
+    else:
+        # Insert greeting if it exists
+        if greeting != "":
+            shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
+            shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
+        
+        # Create .json log files since they don't already exist
+        save_history(timestamp=False)
 
     return name1, name2, picture, greeting, context, end_of_turn, chat_html_wrapper(shared.history['visible'], name1, name2, mode, reset_cache=True)
 

+ 176 - 0
modules/llama_attn_hijack.py

@@ -0,0 +1,176 @@
+import math
+import sys
+import torch
+import torch.nn as nn
+import transformers.models.llama.modeling_llama
+
+from typing import Optional
+from typing import Tuple
+
+import modules.shared as shared
+
+
+if shared.args.xformers:
+    try:
+        import xformers.ops
+    except Exception:
+        print("🔴 xformers not found! Please install it before trying to use it.", file=sys.stderr)
+
+
+def hijack_llama_attention():
+    if shared.args.xformers:
+        transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
+        print("Replaced attention with xformers_attention")
+    elif shared.args.sdp_attention:
+        transformers.models.llama.modeling_llama.LlamaAttention.forward = sdp_attention_forward
+        print("Replaced attention with sdp_attention")
+
+
+def xformers_forward(
+    self,
+    hidden_states: torch.Tensor,
+    attention_mask: Optional[torch.Tensor] = None,
+    position_ids: Optional[torch.LongTensor] = None,
+    past_key_value: Optional[Tuple[torch.Tensor]] = None,
+    output_attentions: bool = False,
+    use_cache: bool = False,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+    bsz, q_len, _ = hidden_states.size()
+
+    query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+    key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+    value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+
+    kv_seq_len = key_states.shape[-2]
+    if past_key_value is not None:
+        kv_seq_len += past_key_value[0].shape[-2]
+    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+    query_states, key_states = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+    # [bsz, nh, t, hd]
+
+    if past_key_value is not None:
+        # reuse k, v, self_attention
+        key_states = torch.cat([past_key_value[0], key_states], dim=2)
+        value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+    past_key_value = (key_states, value_states) if use_cache else None
+
+    #We only apply xformers optimizations if we don't need to output the whole attention matrix
+    if not output_attentions:
+        dtype = query_states.dtype
+
+        query_states = query_states.transpose(1, 2)
+        key_states = key_states.transpose(1, 2)
+        value_states = value_states.transpose(1, 2)
+        
+        #This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
+        #We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
+        if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
+            # input and output should be of form (bsz, q_len, num_heads, head_dim)
+            attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=None)
+        else:
+            # input and output should be of form (bsz, q_len, num_heads, head_dim)
+            attn_output = xformers.ops.memory_efficient_attention(query_states, key_states, value_states, attn_bias=xformers.ops.LowerTriangularMask())
+        attn_weights = None
+    else:
+        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
+            raise ValueError(
+                f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
+                f" {attn_weights.size()}"
+            )
+
+        if attention_mask is not None:
+            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+                raise ValueError(
+                    f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
+                )
+            attn_weights = attn_weights + attention_mask
+            attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
+
+        # upcast attention to fp32
+        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+        attn_output = torch.matmul(attn_weights, value_states)
+
+        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+            raise ValueError(
+                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+                f" {attn_output.size()}"
+            )
+
+        attn_output = attn_output.transpose(1, 2)
+
+    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+    attn_output = self.o_proj(attn_output)
+
+    return attn_output, attn_weights, past_key_value
+
+
+def sdp_attention_forward(
+    self,
+    hidden_states: torch.Tensor,
+    attention_mask: Optional[torch.Tensor] = None,
+    position_ids: Optional[torch.LongTensor] = None,
+    past_key_value: Optional[Tuple[torch.Tensor]] = None,
+    output_attentions: bool = False,
+    use_cache: bool = False,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+    bsz, q_len, _ = hidden_states.size()
+
+    query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+    key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+    value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+
+    kv_seq_len = key_states.shape[-2]
+    if past_key_value is not None:
+        kv_seq_len += past_key_value[0].shape[-2]
+    cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+    query_states, key_states = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+    # [bsz, nh, t, hd]
+
+    if past_key_value is not None:
+        # reuse k, v, self_attention
+        key_states = torch.cat([past_key_value[0], key_states], dim=2)
+        value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+    past_key_value = (key_states, value_states) if use_cache else None
+
+    #We only apply sdp attention if we don't need to output the whole attention matrix
+    if not output_attentions:
+        attn_output = torch.nn.functional.scaled_dot_product_attention(query_states, key_states, value_states, attn_mask=attention_mask, is_causal=False)
+        attn_weights = None
+    else:
+        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
+            raise ValueError(
+                f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
+                f" {attn_weights.size()}"
+            )
+
+        if attention_mask is not None:
+            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+                raise ValueError(
+                    f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
+                )
+            attn_weights = attn_weights + attention_mask
+            attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
+
+        # upcast attention to fp32
+        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+        attn_output = torch.matmul(attn_weights, value_states)
+
+        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+            raise ValueError(
+                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+                f" {attn_output.size()}"
+            )
+
+    attn_output = attn_output.transpose(1, 2)
+    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+    attn_output = self.o_proj(attn_output)
+
+    return attn_output, attn_weights, past_key_value

+ 13 - 0
modules/models.py

@@ -14,6 +14,7 @@ from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
                           BitsAndBytesConfig, LlamaTokenizer)
 
 import modules.shared as shared
+from modules import llama_attn_hijack
 
 transformers.logging.set_verbosity_error()
 
@@ -169,11 +170,23 @@ def load_model(model_name):
 
         model = AutoModelForCausalLM.from_pretrained(checkpoint, **params)
 
+    # Hijack attention with xformers
+    if any((shared.args.xformers, shared.args.sdp_attention)):
+        llama_attn_hijack.hijack_llama_attention()
+
     # Loading the tokenizer
     if any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
         tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
     elif type(model) is transformers.LlamaForCausalLM:
         tokenizer = LlamaTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"), clean_up_tokenization_spaces=True)
+        # Leaving this here until the LLaMA tokenizer gets figured out.
+        # For some people this fixes things, for others it causes an error.
+        try:
+            tokenizer.eos_token_id = 2
+            tokenizer.bos_token_id = 1
+            tokenizer.pad_token_id = 0
+        except:
+            pass
     else:
         tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"))
     tokenizer.truncation_side = 'left'

+ 2 - 0
modules/shared.py

@@ -98,6 +98,8 @@ parser.add_argument('--disk-cache-dir', type=str, default="cache", help='Directo
 parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
 parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
 parser.add_argument('--no-cache', action='store_true', help='Set use_cache to False while generating text. This reduces the VRAM usage a bit at a performance cost.')
+parser.add_argument('--xformers', action='store_true', help="Use xformer's memory efficient attention. This should increase your tokens/s.")
+parser.add_argument('--sdp-attention', action='store_true', help="Use torch 2.0's sdp attention.")
 
 # llama.cpp
 parser.add_argument('--threads', type=int, default=0, help='Number of threads to use in llama.cpp.')

+ 3 - 1
requirements.txt

@@ -1,5 +1,4 @@
 accelerate==0.18.0
-bitsandbytes==0.37.2
 datasets
 flexgen==0.1.7
 gradio==3.24.1
@@ -14,3 +13,6 @@ sentencepiece
 pyyaml
 tqdm
 git+https://github.com/huggingface/transformers
+bitsandbytes==0.37.2; platform_system != "Windows"
+llama-cpp-python==0.1.30; platform_system != "Windows"
+https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.30/llama_cpp_python-0.1.30-cp310-cp310-win_amd64.whl; platform_system == "Windows"

+ 23 - 18
server.py

@@ -394,8 +394,9 @@ def create_interface():
                     shared.gradio['Generate'] = gr.Button('Generate', elem_id='Generate')
                     shared.gradio['Stop'] = gr.Button('Stop', elem_id="stop")
                 with gr.Row():
-                    shared.gradio['Impersonate'] = gr.Button('Impersonate')
                     shared.gradio['Regenerate'] = gr.Button('Regenerate')
+                    shared.gradio['Continue'] = gr.Button('Continue')
+                    shared.gradio['Impersonate'] = gr.Button('Impersonate')
                 with gr.Row():
                     shared.gradio['Copy last reply'] = gr.Button('Copy last reply')
                     shared.gradio['Replace last reply'] = gr.Button('Replace last reply')
@@ -467,53 +468,57 @@ def create_interface():
             gen_events.append(shared.gradio['Generate'].click(
                 lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
                 chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
-                lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
+                lambda: chat.save_history(timestamp=False), None, None, show_progress=False)
             )
 
             gen_events.append(shared.gradio['textbox'].submit(
                 lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
                 chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
-                lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
+                lambda: chat.save_history(timestamp=False), None, None, show_progress=False)
             )
 
             gen_events.append(shared.gradio['Regenerate'].click(
                 chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
-                lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then(
-                lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
+                lambda: chat.save_history(timestamp=False), None, None, show_progress=False)
+            )
+
+            gen_events.append(shared.gradio['Continue'].click(
+                chat.continue_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
+                lambda: chat.save_history(timestamp=False), None, None, show_progress=False)
             )
 
             shared.gradio['Replace last reply'].click(
                 chat.replace_last_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'Chat mode']], shared.gradio['display'], show_progress=shared.args.no_stream).then(
                 lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then(
-                lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
+                lambda: chat.save_history(timestamp=False), None, None, show_progress=False)
 
             shared.gradio['Clear history-confirm'].click(
                 lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr).then(
                 chat.clear_chat_log, [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'Chat mode']], shared.gradio['display']).then(
-                lambda: chat.save_history(timestamp=False), [], [], show_progress=False)
+                lambda: chat.save_history(timestamp=False), None, None, show_progress=False)
 
             shared.gradio['Stop'].click(
-                stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None).then(
-                chat.redraw_html, reload_inputs, [shared.gradio['display']])
+                stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None).then(
+                chat.redraw_html, reload_inputs, shared.gradio['display'])
 
             shared.gradio['Chat mode'].change(
                 lambda x: gr.update(visible=x == 'instruct'), shared.gradio['Chat mode'], shared.gradio['Instruction templates']).then(
-                chat.redraw_html, reload_inputs, [shared.gradio['display']])
+                chat.redraw_html, reload_inputs, shared.gradio['display'])
 
             shared.gradio['Instruction templates'].change(
                 lambda character, name1, name2, mode: chat.load_character(character, name1, name2, mode), [shared.gradio[k] for k in ['Instruction templates', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']]).then(
-                chat.redraw_html, reload_inputs, [shared.gradio['display']])
+                chat.redraw_html, reload_inputs, shared.gradio['display'])
 
             shared.gradio['upload_chat_history'].upload(
-                chat.load_history, [shared.gradio[k] for k in ['upload_chat_history', 'name1', 'name2']], []).then(
-                chat.redraw_html, reload_inputs, [shared.gradio['display']])
+                chat.load_history, [shared.gradio[k] for k in ['upload_chat_history', 'name1', 'name2']], None).then(
+                chat.redraw_html, reload_inputs, shared.gradio['display'])
 
             gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
-            shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream)
+            shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, None, shared.gradio['textbox'], show_progress=shared.args.no_stream)
             shared.gradio['Clear history'].click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
             shared.gradio['Clear history-cancel'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
             shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
-            shared.gradio['download_button'].click(chat.save_history, inputs=[], outputs=[shared.gradio['download']])
+            shared.gradio['download_button'].click(chat.save_history, inputs=None, outputs=[shared.gradio['download']])
             shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']])
             shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
             shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']])
@@ -521,7 +526,7 @@ def create_interface():
 
             shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}")
             shared.gradio['interface'].load(chat.load_default_history, [shared.gradio[k] for k in ['name1', 'name2']], None)
-            shared.gradio['interface'].load(chat.redraw_html, reload_inputs, [shared.gradio['display']], show_progress=True)
+            shared.gradio['interface'].load(chat.redraw_html, reload_inputs, shared.gradio['display'], show_progress=True)
 
         elif shared.args.notebook:
             with gr.Tab("Text generation", elem_id="main"):
@@ -555,7 +560,7 @@ def create_interface():
             output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
             gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
             gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
-            shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
+            shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None)
             shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
 
         else:
@@ -589,7 +594,7 @@ def create_interface():
             gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
             gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
             gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream))
-            shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
+            shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None)
             shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
 
         with gr.Tab("Model", elem_id="model-tab"):