Ver Fonte

Add softprompt support (for real this time)

Is this too much voodoo for our purposes?
oobabooga há 3 anos atrás
pai
commit
3277b751f5

+ 2 - 2
extensions/softprompt/script.py → extensions/character_bias/script.py

@@ -1,5 +1,5 @@
 params = {
-    "soft prompt": " *I speak in an annoyingly cute way*",
+    "bias string": " *I speak in an annoyingly cute way*",
 }
 
 def input_modifier(string):
@@ -24,4 +24,4 @@ def bot_prefix_modifier(string):
     behavior.
     """
 
-    return string + params["soft prompt"]
+    return string + params["bias string"]

+ 3 - 2
requirements.txt

@@ -1,6 +1,7 @@
 accelerate==0.15.0
+beautifulsoup4
 bitsandbytes==0.37.0
 gradio==3.15.0
-transformers==4.25.1
+numpy
 safetensors==0.2.8
-beautifulsoup4
+git+https://github.com/huggingface/transformers

+ 86 - 9
server.py

@@ -10,10 +10,12 @@ import re
 import sys
 import time
 import warnings
+import zipfile
 from datetime import datetime
 from pathlib import Path
 
 import gradio as gr
+import numpy as np
 import torch
 import transformers
 from PIL import Image
@@ -157,6 +159,37 @@ def load_model(model_name):
     print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
     return model, tokenizer
 
+def load_soft_prompt(name):
+    global soft_prompt, soft_prompt_tensor
+
+    if name == 'None':
+        soft_prompt = False
+        soft_prompt_tensor = None
+    else:
+        with zipfile.ZipFile(Path(f'softprompts/{name}.zip')) as zf:
+            zf.extract('tensor.npy')
+            tensor = np.load('tensor.npy')
+        tensor = torch.Tensor(tensor).to(device=model.device, dtype=model.dtype)
+        tensor = torch.reshape(tensor, (1, tensor.shape[0], tensor.shape[1]))
+
+        soft_prompt = True
+        soft_prompt_tensor = tensor
+
+    return name
+
+def upload_softprompt_event(file):
+    with zipfile.ZipFile(io.BytesIO(file)) as zf:
+        zf.extract('meta.json')
+        j = json.loads(open('meta.json', 'r').read())
+        name = j['name']
+
+    with open(Path(f'softprompts/{name}.zip'), 'wb') as f:
+        f.write(file)
+
+    load_soft_prompt(name)
+
+    return name
+
 def load_model_wrapper(selected_model):
     global model_name, model, tokenizer
 
@@ -244,7 +277,7 @@ def formatted_outputs(reply, model_name):
         return reply
 
 def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None):
-    global model_name, model, tokenizer
+    global model_name, model, tokenizer, soft_prompt, soft_prompt_tensor
 
     original_question = question
     if not (args.chat or args.cai_chat):
@@ -292,14 +325,29 @@ def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top
     else:
         generate_params.append(f"max_new_tokens=8")
 
+    if soft_prompt:
+        inputs_embeds = model.transformer.wte(input_ids)
+        inputs_embeds = torch.cat((soft_prompt_tensor, inputs_embeds), dim=1)
+        filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(model.device)
+        filler_input_ids += model.config.bos_token_id # setting dummy input_ids to bos tokens
+        generate_params.insert(0, "inputs_embeds=inputs_embeds")
+        generate_params.insert(0, "filler_input_ids")
+    else:
+        filler_input_ids = None
+        generate_params.insert(0, "input_ids")
+
     # Generate the entire reply at once
     if args.no_stream:
         t0 = time.time()
         with torch.no_grad():
-            output = eval(f"model.generate(input_ids, {','.join(generate_params)}){cuda}")
-        reply = decode(output[0])
+            output = eval(f"model.generate({','.join(generate_params)}){cuda}")
+        if soft_prompt:
+            output = torch.cat((input_ids[0], output[0][filler_input_ids.shape[1]:]))
+        else:
+            output = output[0]
+        reply = decode(output)
         t1 = time.time()
-        print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output[0])-len(input_ids[0]))/(t1-t0)/8:.2f} it/s, {len(output[0])-len(input_ids[0])} tokens)")
+        print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0)/8:.2f} it/s, {len(output)-len(input_ids[0])} tokens)")
         if not (args.chat or args.cai_chat):
             reply = original_question + apply_extensions(reply[len(question):], "output")
         yield formatted_outputs(reply, model_name)
@@ -309,13 +357,26 @@ def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top
         yield formatted_outputs(original_question, model_name)
         for i in tqdm(range(tokens//8+1)):
             with torch.no_grad():
-                output = eval(f"model.generate(input_ids, {','.join(generate_params)}){cuda}")
-            reply = decode(output[0])
+                output = eval(f"model.generate({','.join(generate_params)}){cuda}")
+
+            if soft_prompt:
+                output = torch.cat((input_ids[0], output[0][filler_input_ids.shape[1]:]))
+            else:
+                output = output[0]
+
+            reply = decode(output)
             if not (args.chat or args.cai_chat):
                 reply = original_question + apply_extensions(reply[len(question):], "output")
             yield formatted_outputs(reply, model_name)
-            input_ids = output
-            if output[0][-1] == n:
+
+            input_ids = torch.reshape(output, (1, output.shape[0]))
+            if soft_prompt:
+                inputs_embeds = model.transformer.wte(input_ids)
+                inputs_embeds = torch.cat((soft_prompt_tensor, inputs_embeds), dim=1)
+                filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(model.device)
+                filler_input_ids += model.config.bos_token_id # setting dummy input_ids to bos tokens
+
+            if output[-1] == n:
                 break
 
 def apply_extensions(text, typ):
@@ -353,6 +414,9 @@ def get_available_characters():
 def get_available_extensions():
     return sorted(set(map(lambda x : x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower)
 
+def get_available_softprompts():
+    return ["None"] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('softprompts').glob('*.zip'))), key=str.lower)
+
 def create_extensions_block():
     extensions_ui_elements = []
     default_values = []
@@ -410,8 +474,19 @@ def create_settings_menus():
                 min_length = gr.Slider(0, 2000, step=1, value=generate_params["min_length"] if args.no_stream else 0, label="min_length", interactive=args.no_stream)
                 early_stopping = gr.Checkbox(value=generate_params["early_stopping"], label="early_stopping")
 
+    with gr.Accordion("Soft prompt", open=False):
+        with gr.Row():
+            softprompts_menu = gr.Dropdown(choices=available_softprompts, value="None", label='Soft prompt')
+            create_refresh_button(softprompts_menu, lambda : None, lambda : {"choices": get_available_softprompts()}, "refresh-button")
+
+        gr.Markdown('Upload a soft prompt:')
+        with gr.Row():
+            upload_softprompt = gr.File(type='binary')
+
     model_menu.change(load_model_wrapper, [model_menu], [model_menu], show_progress=True)
     preset_menu.change(load_preset_values, [preset_menu], [do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping])
+    softprompts_menu.change(load_soft_prompt, [softprompts_menu], [softprompts_menu], show_progress=True)
+    upload_softprompt.upload(upload_softprompt_event, [upload_softprompt], [softprompts_menu])
     return preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping
 
 # This gets the new line characters right.
@@ -718,6 +793,7 @@ available_models = get_available_models()
 available_presets = get_available_presets()
 available_characters = get_available_characters()
 available_extensions = get_available_extensions()
+available_softprompts = get_available_softprompts()
 extension_state = {}
 if args.extensions is not None:
     for i,ext in enumerate(args.extensions.split(',')):
@@ -746,7 +822,8 @@ else:
         print()
     model_name = available_models[i]
 model, tokenizer = load_model(model_name)
-loaded_preset = None
+loaded_preset = soft_prompt_tensor = None
+soft_prompt = False
 
 # UI settings
 if model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')):

+ 0 - 0
softprompts/place-your-softprompts-here.txt