فهرست منبع

Merge pull request #366 from oobabooga/lora

Add LoRA support
oobabooga 2 سال پیش
والد
کامیت
3bda907727
10فایلهای تغییر یافته به همراه82 افزوده شده و 12 حذف شده
  1. 10 1
      css/main.css
  2. 11 6
      download-model.py
  3. 0 0
      loras/place-your-loras-here.txt
  4. 17 0
      modules/LoRA.py
  5. 1 0
      modules/callbacks.py
  6. 2 1
      modules/chat.py
  7. 7 1
      modules/shared.py
  8. 1 0
      requirements.txt
  9. 28 1
      server.py
  10. 5 2
      settings-template.json

+ 10 - 1
css/main.css

@@ -1,12 +1,15 @@
 .tabs.svelte-710i53 {
 .tabs.svelte-710i53 {
     margin-top: 0
     margin-top: 0
 }
 }
+
 .py-6 {
 .py-6 {
     padding-top: 2.5rem
     padding-top: 2.5rem
 }
 }
+
 .dark #refresh-button {
 .dark #refresh-button {
     background-color: #ffffff1f;
     background-color: #ffffff1f;
 }
 }
+
 #refresh-button {
 #refresh-button {
   flex: none;
   flex: none;
   margin: 0;
   margin: 0;
@@ -17,22 +20,28 @@
   border-radius: 10px;
   border-radius: 10px;
   background-color: #0000000d;
   background-color: #0000000d;
 }
 }
+
 #download-label, #upload-label {
 #download-label, #upload-label {
   min-height: 0
   min-height: 0
 }
 }
+
 #accordion {
 #accordion {
 }
 }
+
 .dark svg {
 .dark svg {
   fill: white;
   fill: white;
 }
 }
+
 svg {
 svg {
   display: unset !important;
   display: unset !important;
   vertical-align: middle !important;
   vertical-align: middle !important;
   margin: 5px;
   margin: 5px;
 }
 }
+
 ol li p, ul li p {
 ol li p, ul li p {
     display: inline-block;
     display: inline-block;
 }
 }
-#main, #parameters, #chat-settings, #interface-mode {
+
+#main, #parameters, #chat-settings, #interface-mode, #lora {
   border: 0;
   border: 0;
 }
 }

+ 11 - 6
download-model.py

@@ -101,6 +101,7 @@ def get_download_links_from_huggingface(model, branch):
     classifications = []
     classifications = []
     has_pytorch = False
     has_pytorch = False
     has_safetensors = False
     has_safetensors = False
+    is_lora = False
     while True:
     while True:
         content = requests.get(f"{base}{page}{cursor.decode()}").content
         content = requests.get(f"{base}{page}{cursor.decode()}").content
 
 
@@ -110,8 +111,10 @@ def get_download_links_from_huggingface(model, branch):
 
 
         for i in range(len(dict)):
         for i in range(len(dict)):
             fname = dict[i]['path']
             fname = dict[i]['path']
+            if not is_lora and fname.endswith(('adapter_config.json', 'adapter_model.bin')):
+                is_lora = True
 
 
-            is_pytorch = re.match("pytorch_model.*\.bin", fname)
+            is_pytorch = re.match("(pytorch|adapter)_model.*\.bin", fname)
             is_safetensors = re.match("model.*\.safetensors", fname)
             is_safetensors = re.match("model.*\.safetensors", fname)
             is_tokenizer = re.match("tokenizer.*\.model", fname)
             is_tokenizer = re.match("tokenizer.*\.model", fname)
             is_text = re.match(".*\.(txt|json)", fname) or is_tokenizer
             is_text = re.match(".*\.(txt|json)", fname) or is_tokenizer
@@ -130,6 +133,7 @@ def get_download_links_from_huggingface(model, branch):
                         has_pytorch = True
                         has_pytorch = True
                         classifications.append('pytorch')
                         classifications.append('pytorch')
 
 
+
         cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50'
         cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50'
         cursor = base64.b64encode(cursor)
         cursor = base64.b64encode(cursor)
         cursor = cursor.replace(b'=', b'%3D')
         cursor = cursor.replace(b'=', b'%3D')
@@ -140,7 +144,7 @@ def get_download_links_from_huggingface(model, branch):
             if classifications[i] == 'pytorch':
             if classifications[i] == 'pytorch':
                 links.pop(i)
                 links.pop(i)
 
 
-    return links
+    return links, is_lora
 
 
 if __name__ == '__main__':
 if __name__ == '__main__':
     model = args.MODEL
     model = args.MODEL
@@ -159,15 +163,16 @@ if __name__ == '__main__':
             except ValueError as err_branch:
             except ValueError as err_branch:
                 print(f"Error: {err_branch}")
                 print(f"Error: {err_branch}")
                 sys.exit()
                 sys.exit()
+
+    links, is_lora = get_download_links_from_huggingface(model, branch)
+    base_folder = 'models' if not is_lora else 'loras'
     if branch != 'main':
     if branch != 'main':
-        output_folder = Path("models") / (model.split('/')[-1] + f'_{branch}')
+        output_folder = Path(base_folder) / (model.split('/')[-1] + f'_{branch}')
     else:
     else:
-        output_folder = Path("models") / model.split('/')[-1]
+        output_folder = Path(base_folder) / model.split('/')[-1]
     if not output_folder.exists():
     if not output_folder.exists():
         output_folder.mkdir()
         output_folder.mkdir()
 
 
-    links = get_download_links_from_huggingface(model, branch)
-
     # Downloading the files
     # Downloading the files
     print(f"Downloading the model to {output_folder}")
     print(f"Downloading the model to {output_folder}")
     pool = multiprocessing.Pool(processes=args.threads)
     pool = multiprocessing.Pool(processes=args.threads)

+ 0 - 0
loras/place-your-loras-here.txt


+ 17 - 0
modules/LoRA.py

@@ -0,0 +1,17 @@
+from pathlib import Path
+
+from peft import PeftModel
+
+import modules.shared as shared
+from modules.models import load_model
+
+
+def add_lora_to_model(lora_name):
+
+    # Is there a more efficient way of returning to the base model?
+    if lora_name == "None":
+        print("Reloading the model to remove the LoRA...")
+        shared.model, shared.tokenizer = load_model(shared.model_name)
+    else:
+        print(f"Adding the LoRA {lora_name} to the model...")
+        shared.model = PeftModel.from_pretrained(shared.model, Path(f"loras/{lora_name}"))

+ 1 - 0
modules/callbacks.py

@@ -7,6 +7,7 @@ import transformers
 
 
 import modules.shared as shared
 import modules.shared as shared
 
 
+
 # Copied from https://github.com/PygmalionAI/gradio-ui/
 # Copied from https://github.com/PygmalionAI/gradio-ui/
 class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
 class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
 
 

+ 2 - 1
modules/chat.py

@@ -12,7 +12,8 @@ import modules.extensions as extensions_module
 import modules.shared as shared
 import modules.shared as shared
 from modules.extensions import apply_extensions
 from modules.extensions import apply_extensions
 from modules.html_generator import generate_chat_html
 from modules.html_generator import generate_chat_html
-from modules.text_generation import encode, generate_reply, get_max_prompt_length
+from modules.text_generation import (encode, generate_reply,
+                                     get_max_prompt_length)
 
 
 
 
 # This gets the new line characters right.
 # This gets the new line characters right.

+ 7 - 1
modules/shared.py

@@ -2,7 +2,8 @@ import argparse
 
 
 model = None
 model = None
 tokenizer = None
 tokenizer = None
-model_name = ""
+model_name = "None"
+lora_name = "None"
 soft_prompt_tensor = None
 soft_prompt_tensor = None
 soft_prompt = False
 soft_prompt = False
 is_RWKV = False
 is_RWKV = False
@@ -52,6 +53,10 @@ settings = {
         '^(gpt4chan|gpt-4chan|4chan)': '-----\n--- 865467536\nInput text\n--- 865467537\n',
         '^(gpt4chan|gpt-4chan|4chan)': '-----\n--- 865467536\nInput text\n--- 865467537\n',
         '(rosey|chip|joi)_.*_instruct.*': 'User: \n',
         '(rosey|chip|joi)_.*_instruct.*': 'User: \n',
         'oasst-*': '<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>'
         'oasst-*': '<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>'
+    },
+    'lora_prompts': {
+        'default': 'Common sense questions and answers\n\nQuestion: \nFactual answer:',
+        'alpaca-lora-7b': "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\nWrite a poem about the transformers Python library. \nMention the word \"large language models\" in that poem.\n### Response:\n"
     }
     }
 }
 }
 
 
@@ -67,6 +72,7 @@ def str2bool(v):
 
 
 parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54))
 parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54))
 parser.add_argument('--model', type=str, help='Name of the model to load by default.')
 parser.add_argument('--model', type=str, help='Name of the model to load by default.')
+parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.')
 parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.')
 parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.')
 parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode.')
 parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode.')
 parser.add_argument('--cai-chat', action='store_true', help='Launch the web UI in chat mode with a style similar to Character.AI\'s. If the file img_bot.png or img_bot.jpg exists in the same folder as server.py, this image will be used as the bot\'s profile picture. Similarly, img_me.png or img_me.jpg will be used as your profile picture.')
 parser.add_argument('--cai-chat', action='store_true', help='Launch the web UI in chat mode with a style similar to Character.AI\'s. If the file img_bot.png or img_bot.jpg exists in the same folder as server.py, this image will be used as the bot\'s profile picture. Similarly, img_me.png or img_me.jpg will be used as your profile picture.')

+ 1 - 0
requirements.txt

@@ -4,6 +4,7 @@ flexgen==0.1.7
 gradio==3.18.0
 gradio==3.18.0
 markdown
 markdown
 numpy
 numpy
+peft==0.2.0
 requests
 requests
 rwkv==0.4.2
 rwkv==0.4.2
 safetensors==0.3.0
 safetensors==0.3.0

+ 28 - 1
server.py

@@ -15,6 +15,7 @@ import modules.extensions as extensions_module
 import modules.shared as shared
 import modules.shared as shared
 import modules.ui as ui
 import modules.ui as ui
 from modules.html_generator import generate_chat_html
 from modules.html_generator import generate_chat_html
+from modules.LoRA import add_lora_to_model
 from modules.models import load_model, load_soft_prompt
 from modules.models import load_model, load_soft_prompt
 from modules.text_generation import generate_reply
 from modules.text_generation import generate_reply
 
 
@@ -48,6 +49,9 @@ def get_available_extensions():
 def get_available_softprompts():
 def get_available_softprompts():
     return ['None'] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('softprompts').glob('*.zip'))), key=str.lower)
     return ['None'] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('softprompts').glob('*.zip'))), key=str.lower)
 
 
+def get_available_loras():
+    return ['None'] + sorted([item.name for item in list(Path('loras/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
+
 def load_model_wrapper(selected_model):
 def load_model_wrapper(selected_model):
     if selected_model != shared.model_name:
     if selected_model != shared.model_name:
         shared.model_name = selected_model
         shared.model_name = selected_model
@@ -59,6 +63,17 @@ def load_model_wrapper(selected_model):
 
 
     return selected_model
     return selected_model
 
 
+def load_lora_wrapper(selected_lora):
+    shared.lora_name = selected_lora
+    default_text = shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')]
+
+    if not shared.args.cpu:
+        gc.collect()
+        torch.cuda.empty_cache()
+    add_lora_to_model(selected_lora)
+
+    return selected_lora, default_text
+
 def load_preset_values(preset_menu, return_dict=False):
 def load_preset_values(preset_menu, return_dict=False):
     generate_params = {
     generate_params = {
         'do_sample': True,
         'do_sample': True,
@@ -145,6 +160,10 @@ def create_settings_menus(default_preset):
                         shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
                         shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
                 shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
                 shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
 
 
+    with gr.Row():
+        shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA')
+        ui.create_refresh_button(shared.gradio['lora_menu'], lambda : None, lambda : {'choices': get_available_loras()}, 'refresh-button')
+
     with gr.Accordion('Soft prompt', open=False):
     with gr.Accordion('Soft prompt', open=False):
         with gr.Row():
         with gr.Row():
             shared.gradio['softprompts_menu'] = gr.Dropdown(choices=available_softprompts, value='None', label='Soft prompt')
             shared.gradio['softprompts_menu'] = gr.Dropdown(choices=available_softprompts, value='None', label='Soft prompt')
@@ -156,6 +175,7 @@ def create_settings_menus(default_preset):
 
 
     shared.gradio['model_menu'].change(load_model_wrapper, [shared.gradio['model_menu']], [shared.gradio['model_menu']], show_progress=True)
     shared.gradio['model_menu'].change(load_model_wrapper, [shared.gradio['model_menu']], [shared.gradio['model_menu']], show_progress=True)
     shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio['do_sample'], shared.gradio['temperature'], shared.gradio['top_p'], shared.gradio['typical_p'], shared.gradio['repetition_penalty'], shared.gradio['encoder_repetition_penalty'], shared.gradio['top_k'], shared.gradio['min_length'], shared.gradio['no_repeat_ngram_size'], shared.gradio['num_beams'], shared.gradio['penalty_alpha'], shared.gradio['length_penalty'], shared.gradio['early_stopping']])
     shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio['do_sample'], shared.gradio['temperature'], shared.gradio['top_p'], shared.gradio['typical_p'], shared.gradio['repetition_penalty'], shared.gradio['encoder_repetition_penalty'], shared.gradio['top_k'], shared.gradio['min_length'], shared.gradio['no_repeat_ngram_size'], shared.gradio['num_beams'], shared.gradio['penalty_alpha'], shared.gradio['length_penalty'], shared.gradio['early_stopping']])
+    shared.gradio['lora_menu'].change(load_lora_wrapper, [shared.gradio['lora_menu']], [shared.gradio['lora_menu'], shared.gradio['textbox']], show_progress=True)
     shared.gradio['softprompts_menu'].change(load_soft_prompt, [shared.gradio['softprompts_menu']], [shared.gradio['softprompts_menu']], show_progress=True)
     shared.gradio['softprompts_menu'].change(load_soft_prompt, [shared.gradio['softprompts_menu']], [shared.gradio['softprompts_menu']], show_progress=True)
     shared.gradio['upload_softprompt'].upload(upload_soft_prompt, [shared.gradio['upload_softprompt']], [shared.gradio['softprompts_menu']])
     shared.gradio['upload_softprompt'].upload(upload_soft_prompt, [shared.gradio['upload_softprompt']], [shared.gradio['softprompts_menu']])
 
 
@@ -181,6 +201,7 @@ available_models = get_available_models()
 available_presets = get_available_presets()
 available_presets = get_available_presets()
 available_characters = get_available_characters()
 available_characters = get_available_characters()
 available_softprompts = get_available_softprompts()
 available_softprompts = get_available_softprompts()
+available_loras = get_available_loras()
 
 
 # Default extensions
 # Default extensions
 extensions_module.available_extensions = get_available_extensions()
 extensions_module.available_extensions = get_available_extensions()
@@ -213,10 +234,16 @@ else:
         print()
         print()
     shared.model_name = available_models[i]
     shared.model_name = available_models[i]
 shared.model, shared.tokenizer = load_model(shared.model_name)
 shared.model, shared.tokenizer = load_model(shared.model_name)
+if shared.args.lora:
+    print(shared.args.lora)
+    shared.lora_name = shared.args.lora
+    add_lora_to_model(shared.lora_name)
 
 
 # Default UI settings
 # Default UI settings
 default_preset = shared.settings['presets'][next((k for k in shared.settings['presets'] if re.match(k.lower(), shared.model_name.lower())), 'default')]
 default_preset = shared.settings['presets'][next((k for k in shared.settings['presets'] if re.match(k.lower(), shared.model_name.lower())), 'default')]
-default_text = shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')]
+default_text = shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')]
+if default_text == '':
+    default_text = shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')]
 title ='Text generation web UI'
 title ='Text generation web UI'
 description = '\n\n# Text generation lab\nGenerate text using Large Language Models.\n'
 description = '\n\n# Text generation lab\nGenerate text using Large Language Models.\n'
 suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else ''
 suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else ''

+ 5 - 2
settings-template.json

@@ -23,13 +23,16 @@
     "presets": {
     "presets": {
         "default": "NovelAI-Sphinx Moth",
         "default": "NovelAI-Sphinx Moth",
         "pygmalion-*": "Pygmalion",
         "pygmalion-*": "Pygmalion",
-        "RWKV-*": "Naive",
-        "(rosey|chip|joi)_.*_instruct.*": "Instruct Joi (Contrastive Search)"
+        "RWKV-*": "Naive"
     },
     },
     "prompts": {
     "prompts": {
         "default": "Common sense questions and answers\n\nQuestion: \nFactual answer:",
         "default": "Common sense questions and answers\n\nQuestion: \nFactual answer:",
         "^(gpt4chan|gpt-4chan|4chan)": "-----\n--- 865467536\nInput text\n--- 865467537\n",
         "^(gpt4chan|gpt-4chan|4chan)": "-----\n--- 865467536\nInput text\n--- 865467537\n",
         "(rosey|chip|joi)_.*_instruct.*": "User: \n",
         "(rosey|chip|joi)_.*_instruct.*": "User: \n",
         "oasst-*": "<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>"
         "oasst-*": "<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>"
+    },
+    "lora_prompts": {
+        "default": "Common sense questions and answers\n\nQuestion: \nFactual answer:",
+        "alpaca-lora-7b": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\nWrite a poem about the transformers Python library. \nMention the word \"large language models\" in that poem.\n### Response:\n"
     }
     }
 }
 }