Sfoglia il codice sorgente

Adapt to the new model names

oobabooga 2 anni fa
parent
commit
1cb9246160
6 ha cambiato i file con 18 aggiunte e 25 eliminazioni
  1. 4 3
      modules/GPTQ_loader.py
  2. 2 2
      modules/models.py
  3. 0 4
      modules/shared.py
  4. 3 3
      modules/text_generation.py
  5. 6 7
      server.py
  6. 3 6
      settings-template.json

+ 4 - 3
modules/GPTQ_loader.py

@@ -51,11 +51,12 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
 def load_quantized(model_name):
     if not shared.args.model_type:
         # Try to determine model type from model name
-        if model_name.lower().startswith(('llama', 'alpaca')):
+        name = model_name.lower()
+        if any((k in name for k in ['llama', 'alpaca'])):
             model_type = 'llama'
-        elif model_name.lower().startswith(('opt', 'galactica')):
+        elif any((k in name for k in ['opt-', 'galactica'])):
             model_type = 'opt'
-        elif model_name.lower().startswith(('gpt-j', 'pygmalion-6b')):
+        elif any((k in name for k in ['gpt-j', 'pygmalion-6b'])):
             model_type = 'gptj'
         else:
             print("Can't determine model type from model name. Please specify it manually using --model_type "

+ 2 - 2
modules/models.py

@@ -41,7 +41,7 @@ def load_model(model_name):
     print(f"Loading {model_name}...")
     t0 = time.time()
 
-    shared.is_RWKV = model_name.lower().startswith('rwkv-')
+    shared.is_RWKV = 'rwkv-' in model_name.lower()
 
     # Default settings
     if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV]):
@@ -159,7 +159,7 @@ def load_model(model_name):
         model = AutoModelForCausalLM.from_pretrained(checkpoint, **params)
 
     # Loading the tokenizer
-    if shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
+    if any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
         tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
     else:
         tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"))

+ 0 - 4
modules/shared.py

@@ -37,10 +37,6 @@ settings = {
     'chat_generation_attempts': 1,
     'chat_generation_attempts_min': 1,
     'chat_generation_attempts_max': 5,
-    'name1_pygmalion': 'You',
-    'name2_pygmalion': 'Kawaii',
-    'context_pygmalion': "Kawaii's persona: Kawaii is a cheerful person who loves to make others smile. She is an optimist who loves to spread happiness and positivity wherever she goes.\n<START>",
-    'stop_at_newline_pygmalion': False,
     'default_extensions': [],
     'chat_default_extensions': ["gallery"],
     'presets': {

+ 3 - 3
modules/text_generation.py

@@ -42,7 +42,7 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
 
 def decode(output_ids):
     # Open Assistant relies on special tokens like <|endoftext|>
-    if re.match('(oasst|galactica)-*', shared.model_name.lower()):
+    if re.match('.*(oasst|galactica)-*', shared.model_name.lower()):
         return shared.tokenizer.decode(output_ids, skip_special_tokens=False)
     else:
         reply = shared.tokenizer.decode(output_ids, skip_special_tokens=True)
@@ -77,10 +77,10 @@ def fix_galactica(s):
 
 def formatted_outputs(reply, model_name):
     if not (shared.args.chat or shared.args.cai_chat):
-        if model_name.lower().startswith('galactica'):
+        if 'galactica' in model_name.lower():
             reply = fix_galactica(reply)
             return reply, reply, generate_basic_html(reply)
-        elif model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')):
+        elif any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])):
             reply = fix_gpt4chan(reply)
             return reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply)
         else:

+ 6 - 7
server.py

@@ -282,7 +282,6 @@ else:
     default_text = shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')]
 title ='Text generation web UI'
 description = '\n\n# Text generation lab\nGenerate text using Large Language Models.\n'
-suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else ''
 
 def create_interface():
 
@@ -294,7 +293,7 @@ def create_interface():
         if shared.args.chat or shared.args.cai_chat:
             with gr.Tab("Text generation", elem_id="main"):
                 if shared.args.cai_chat:
-                    shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}'], shared.character))
+                    shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], shared.character))
                 else:
                     shared.gradio['display'] = gr.Chatbot(value=shared.history['visible']).style(color_map=("#326efd", "#212528"))
                 shared.gradio['textbox'] = gr.Textbox(label='Input')
@@ -314,9 +313,9 @@ def create_interface():
                     shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
 
             with gr.Tab("Character", elem_id="chat-settings"):
-                shared.gradio['name1'] = gr.Textbox(value=shared.settings[f'name1{suffix}'], lines=1, label='Your name')
-                shared.gradio['name2'] = gr.Textbox(value=shared.settings[f'name2{suffix}'], lines=1, label='Bot\'s name')
-                shared.gradio['context'] = gr.Textbox(value=shared.settings[f'context{suffix}'], lines=5, label='Context')
+                shared.gradio['name1'] = gr.Textbox(value=shared.settings['name1'], lines=1, label='Your name')
+                shared.gradio['name2'] = gr.Textbox(value=shared.settings['name2'], lines=1, label='Bot\'s name')
+                shared.gradio['context'] = gr.Textbox(value=shared.settings['context'], lines=5, label='Context')
                 with gr.Row():
                     shared.gradio['character_menu'] = gr.Dropdown(choices=available_characters, value='None', label='Character', elem_id='character-menu')
                     ui.create_refresh_button(shared.gradio['character_menu'], lambda : None, lambda : {'choices': get_available_characters()}, 'refresh-button')
@@ -354,7 +353,7 @@ def create_interface():
                             shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size'])
                         with gr.Column():
                             shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)')
-                            shared.gradio['check'] = gr.Checkbox(value=shared.settings[f'stop_at_newline{suffix}'], label='Stop generating at new line character?')
+                            shared.gradio['check'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character?')
 
                 create_settings_menus(default_preset)
 
@@ -401,7 +400,7 @@ def create_interface():
             shared.gradio['Stop'].click(reload_func, reload_inputs, [shared.gradio['display']])
 
             shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}")
-            shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}']), None, None)
+            shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings['name1'], shared.settings['name2']), None, None)
             shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True)
 
         elif shared.args.notebook:

+ 3 - 6
settings-template.json

@@ -12,10 +12,6 @@
     "chat_generation_attempts": 1,
     "chat_generation_attempts_min": 1,
     "chat_generation_attempts_max": 5,
-    "name1_pygmalion": "You",
-    "name2_pygmalion": "Kawaii",
-    "context_pygmalion": "Kawaii's persona: Kawaii is a cheerful person who loves to make others smile. She is an optimist who loves to spread happiness and positivity wherever she goes.\n<START>",
-    "stop_at_newline_pygmalion": false,
     "default_extensions": [],
     "chat_default_extensions": [
         "gallery"
@@ -29,10 +25,11 @@
         "default": "Common sense questions and answers\n\nQuestion: \nFactual answer:",
         "^(gpt4chan|gpt-4chan|4chan)": "-----\n--- 865467536\nInput text\n--- 865467537\n",
         "(rosey|chip|joi)_.*_instruct.*": "User: \n",
-        "oasst-*": "<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>"
+        "oasst-*": "<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>",
+        "alpaca-*": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\nWrite a poem about the transformers Python library. \nMention the word \"large language models\" in that poem.\n### Response:\n"
     },
     "lora_prompts": {
         "default": "Common sense questions and answers\n\nQuestion: \nFactual answer:",
-        "alpaca-lora-7b": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\nWrite a poem about the transformers Python library. \nMention the word \"large language models\" in that poem.\n### Response:\n"
+        "(alpaca-lora-7b|alpaca-lora-13b|alpaca-lora-30b)": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\nWrite a poem about the transformers Python library. \nMention the word \"large language models\" in that poem.\n### Response:\n"
     }
 }