|
|
@@ -10,10 +10,12 @@ import re
|
|
|
import sys
|
|
|
import time
|
|
|
import warnings
|
|
|
+import zipfile
|
|
|
from datetime import datetime
|
|
|
from pathlib import Path
|
|
|
|
|
|
import gradio as gr
|
|
|
+import numpy as np
|
|
|
import torch
|
|
|
import transformers
|
|
|
from PIL import Image
|
|
|
@@ -157,6 +159,37 @@ def load_model(model_name):
|
|
|
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
|
|
|
return model, tokenizer
|
|
|
|
|
|
+def load_soft_prompt(name):
|
|
|
+ global soft_prompt, soft_prompt_tensor
|
|
|
+
|
|
|
+ if name == 'None':
|
|
|
+ soft_prompt = False
|
|
|
+ soft_prompt_tensor = None
|
|
|
+ else:
|
|
|
+ with zipfile.ZipFile(Path(f'softprompts/{name}.zip')) as zf:
|
|
|
+ zf.extract('tensor.npy')
|
|
|
+ tensor = np.load('tensor.npy')
|
|
|
+ tensor = torch.Tensor(tensor).to(device=model.device, dtype=model.dtype)
|
|
|
+ tensor = torch.reshape(tensor, (1, tensor.shape[0], tensor.shape[1]))
|
|
|
+
|
|
|
+ soft_prompt = True
|
|
|
+ soft_prompt_tensor = tensor
|
|
|
+
|
|
|
+ return name
|
|
|
+
|
|
|
+def upload_softprompt_event(file):
|
|
|
+ with zipfile.ZipFile(io.BytesIO(file)) as zf:
|
|
|
+ zf.extract('meta.json')
|
|
|
+ j = json.loads(open('meta.json', 'r').read())
|
|
|
+ name = j['name']
|
|
|
+
|
|
|
+ with open(Path(f'softprompts/{name}.zip'), 'wb') as f:
|
|
|
+ f.write(file)
|
|
|
+
|
|
|
+ load_soft_prompt(name)
|
|
|
+
|
|
|
+ return name
|
|
|
+
|
|
|
def load_model_wrapper(selected_model):
|
|
|
global model_name, model, tokenizer
|
|
|
|
|
|
@@ -244,7 +277,7 @@ def formatted_outputs(reply, model_name):
|
|
|
return reply
|
|
|
|
|
|
def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None):
|
|
|
- global model_name, model, tokenizer
|
|
|
+ global model_name, model, tokenizer, soft_prompt, soft_prompt_tensor
|
|
|
|
|
|
original_question = question
|
|
|
if not (args.chat or args.cai_chat):
|
|
|
@@ -292,14 +325,29 @@ def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top
|
|
|
else:
|
|
|
generate_params.append(f"max_new_tokens=8")
|
|
|
|
|
|
+ if soft_prompt:
|
|
|
+ inputs_embeds = model.transformer.wte(input_ids)
|
|
|
+ inputs_embeds = torch.cat((soft_prompt_tensor, inputs_embeds), dim=1)
|
|
|
+ filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(model.device)
|
|
|
+ filler_input_ids += model.config.bos_token_id # setting dummy input_ids to bos tokens
|
|
|
+ generate_params.insert(0, "inputs_embeds=inputs_embeds")
|
|
|
+ generate_params.insert(0, "filler_input_ids")
|
|
|
+ else:
|
|
|
+ filler_input_ids = None
|
|
|
+ generate_params.insert(0, "input_ids")
|
|
|
+
|
|
|
# Generate the entire reply at once
|
|
|
if args.no_stream:
|
|
|
t0 = time.time()
|
|
|
with torch.no_grad():
|
|
|
- output = eval(f"model.generate(input_ids, {','.join(generate_params)}){cuda}")
|
|
|
- reply = decode(output[0])
|
|
|
+ output = eval(f"model.generate({','.join(generate_params)}){cuda}")
|
|
|
+ if soft_prompt:
|
|
|
+ output = torch.cat((input_ids[0], output[0][filler_input_ids.shape[1]:]))
|
|
|
+ else:
|
|
|
+ output = output[0]
|
|
|
+ reply = decode(output)
|
|
|
t1 = time.time()
|
|
|
- print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output[0])-len(input_ids[0]))/(t1-t0)/8:.2f} it/s, {len(output[0])-len(input_ids[0])} tokens)")
|
|
|
+ print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0)/8:.2f} it/s, {len(output)-len(input_ids[0])} tokens)")
|
|
|
if not (args.chat or args.cai_chat):
|
|
|
reply = original_question + apply_extensions(reply[len(question):], "output")
|
|
|
yield formatted_outputs(reply, model_name)
|
|
|
@@ -309,13 +357,26 @@ def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top
|
|
|
yield formatted_outputs(original_question, model_name)
|
|
|
for i in tqdm(range(tokens//8+1)):
|
|
|
with torch.no_grad():
|
|
|
- output = eval(f"model.generate(input_ids, {','.join(generate_params)}){cuda}")
|
|
|
- reply = decode(output[0])
|
|
|
+ output = eval(f"model.generate({','.join(generate_params)}){cuda}")
|
|
|
+
|
|
|
+ if soft_prompt:
|
|
|
+ output = torch.cat((input_ids[0], output[0][filler_input_ids.shape[1]:]))
|
|
|
+ else:
|
|
|
+ output = output[0]
|
|
|
+
|
|
|
+ reply = decode(output)
|
|
|
if not (args.chat or args.cai_chat):
|
|
|
reply = original_question + apply_extensions(reply[len(question):], "output")
|
|
|
yield formatted_outputs(reply, model_name)
|
|
|
- input_ids = output
|
|
|
- if output[0][-1] == n:
|
|
|
+
|
|
|
+ input_ids = torch.reshape(output, (1, output.shape[0]))
|
|
|
+ if soft_prompt:
|
|
|
+ inputs_embeds = model.transformer.wte(input_ids)
|
|
|
+ inputs_embeds = torch.cat((soft_prompt_tensor, inputs_embeds), dim=1)
|
|
|
+ filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(model.device)
|
|
|
+ filler_input_ids += model.config.bos_token_id # setting dummy input_ids to bos tokens
|
|
|
+
|
|
|
+ if output[-1] == n:
|
|
|
break
|
|
|
|
|
|
def apply_extensions(text, typ):
|
|
|
@@ -353,6 +414,9 @@ def get_available_characters():
|
|
|
def get_available_extensions():
|
|
|
return sorted(set(map(lambda x : x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower)
|
|
|
|
|
|
+def get_available_softprompts():
|
|
|
+ return ["None"] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('softprompts').glob('*.zip'))), key=str.lower)
|
|
|
+
|
|
|
def create_extensions_block():
|
|
|
extensions_ui_elements = []
|
|
|
default_values = []
|
|
|
@@ -410,8 +474,19 @@ def create_settings_menus():
|
|
|
min_length = gr.Slider(0, 2000, step=1, value=generate_params["min_length"] if args.no_stream else 0, label="min_length", interactive=args.no_stream)
|
|
|
early_stopping = gr.Checkbox(value=generate_params["early_stopping"], label="early_stopping")
|
|
|
|
|
|
+ with gr.Accordion("Soft prompt", open=False):
|
|
|
+ with gr.Row():
|
|
|
+ softprompts_menu = gr.Dropdown(choices=available_softprompts, value="None", label='Soft prompt')
|
|
|
+ create_refresh_button(softprompts_menu, lambda : None, lambda : {"choices": get_available_softprompts()}, "refresh-button")
|
|
|
+
|
|
|
+ gr.Markdown('Upload a soft prompt:')
|
|
|
+ with gr.Row():
|
|
|
+ upload_softprompt = gr.File(type='binary')
|
|
|
+
|
|
|
model_menu.change(load_model_wrapper, [model_menu], [model_menu], show_progress=True)
|
|
|
preset_menu.change(load_preset_values, [preset_menu], [do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping])
|
|
|
+ softprompts_menu.change(load_soft_prompt, [softprompts_menu], [softprompts_menu], show_progress=True)
|
|
|
+ upload_softprompt.upload(upload_softprompt_event, [upload_softprompt], [softprompts_menu])
|
|
|
return preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping
|
|
|
|
|
|
# This gets the new line characters right.
|
|
|
@@ -718,6 +793,7 @@ available_models = get_available_models()
|
|
|
available_presets = get_available_presets()
|
|
|
available_characters = get_available_characters()
|
|
|
available_extensions = get_available_extensions()
|
|
|
+available_softprompts = get_available_softprompts()
|
|
|
extension_state = {}
|
|
|
if args.extensions is not None:
|
|
|
for i,ext in enumerate(args.extensions.split(',')):
|
|
|
@@ -746,7 +822,8 @@ else:
|
|
|
print()
|
|
|
model_name = available_models[i]
|
|
|
model, tokenizer = load_model(model_name)
|
|
|
-loaded_preset = None
|
|
|
+loaded_preset = soft_prompt_tensor = None
|
|
|
+soft_prompt = False
|
|
|
|
|
|
# UI settings
|
|
|
if model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')):
|