Bladeren bron

Add prompt loading/saving menus + reorganize interface

oobabooga 2 jaren geleden
bovenliggende
commit
57345b8f30
4 gewijzigde bestanden met toevoegingen van 63 en 15 verwijderingen
  1. 6 0
      prompts/Alpaca.txt
  2. 1 0
      prompts/Open Assistant.txt
  3. 4 0
      prompts/QA.txt
  4. 52 15
      server.py

+ 6 - 0
prompts/Alpaca.txt

@@ -0,0 +1,6 @@
+Below is an instruction that describes a task. Write a response that appropriately completes the request.
+### Instruction:
+Write a poem about the transformers Python library. 
+Mention the word "large language models" in that poem.
+### Response:
+

+ 1 - 0
prompts/Open Assistant.txt

@@ -0,0 +1 @@
+<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>

+ 4 - 0
prompts/QA.txt

@@ -0,0 +1,4 @@
+Common sense questions and answers
+
+Question: 
+Factual answer:

+ 52 - 15
server.py

@@ -4,6 +4,7 @@ import re
 import sys
 import time
 import zipfile
+from datetime import datetime
 from pathlib import Path
 
 import gradio as gr
@@ -38,6 +39,13 @@ def get_available_models():
 def get_available_presets():
     return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower)
 
+def get_available_prompts():
+    prompts = []
+    prompts += sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('prompts').glob('[0-9]*.txt'))), key=str.lower, reverse=True)
+    prompts += sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('prompts').glob('*.txt'))), key=str.lower)
+    prompts += ['None']
+    return prompts
+
 def get_available_characters():
     return ['None'] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('characters').glob('*.json'))), key=str.lower)
 
@@ -98,7 +106,7 @@ def load_preset_values(preset_menu, return_dict=False):
     if return_dict:
         return generate_params
     else:
-        return preset_menu, generate_params['do_sample'], generate_params['temperature'], generate_params['top_p'], generate_params['typical_p'], generate_params['repetition_penalty'], generate_params['encoder_repetition_penalty'], generate_params['top_k'], generate_params['min_length'], generate_params['no_repeat_ngram_size'], generate_params['num_beams'], generate_params['penalty_alpha'], generate_params['length_penalty'], generate_params['early_stopping']
+        return generate_params['do_sample'], generate_params['temperature'], generate_params['top_p'], generate_params['typical_p'], generate_params['repetition_penalty'], generate_params['encoder_repetition_penalty'], generate_params['top_k'], generate_params['min_length'], generate_params['no_repeat_ngram_size'], generate_params['num_beams'], generate_params['penalty_alpha'], generate_params['length_penalty'], generate_params['early_stopping']
 
 def upload_soft_prompt(file):
     with zipfile.ZipFile(io.BytesIO(file)) as zf:
@@ -123,11 +131,45 @@ def create_model_and_preset_menus():
                 shared.gradio['preset_menu'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset')
                 ui.create_refresh_button(shared.gradio['preset_menu'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button')
 
+def save_prompt(text):
+    fname = f"{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}.txt"
+    with open(Path(f'prompts/{fname}'), 'w', encoding='utf-8') as f:
+        f.write(text)
+    return f"Saved prompt to prompts/{fname}"
+
+def load_prompt(fname):
+    if fname in ['None', '']:
+        return ''
+    else:
+        with open(Path(f'prompts/{fname}.txt'), 'r', encoding='utf-8') as f:
+            return f.read()
+        
+def create_prompt_menus():
+    with gr.Row():
+        with gr.Column():
+            with gr.Row():
+                shared.gradio['prompt_menu'] = gr.Dropdown(choices=get_available_prompts(), value='None', label='Prompt')
+                ui.create_refresh_button(shared.gradio['prompt_menu'], lambda : None, lambda : {'choices': get_available_prompts()}, 'refresh-button')
+
+        with gr.Column():
+            with gr.Column():
+                shared.gradio['save_prompt'] = gr.Button('Save prompt')
+                shared.gradio['status'] = gr.Markdown('Ready')
+
+    shared.gradio['prompt_menu'].change(load_prompt, [shared.gradio['prompt_menu']], [shared.gradio['textbox']], show_progress=True)
+    shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False)
+
 def create_settings_menus(default_preset):
     generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', return_dict=True)
 
     with gr.Row():
         with gr.Column():
+            create_model_and_preset_menus()
+        with gr.Column():
+            shared.gradio['seed'] = gr.Number(value=-1, label='Seed (-1 for random)')
+
+    with gr.Row():
+        with gr.Column():
             with gr.Box():
                 gr.Markdown('Custom generation parameters ([reference](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig))')
                 with gr.Row():
@@ -156,12 +198,6 @@ def create_settings_menus(default_preset):
                         shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
                 shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
 
-            shared.gradio['seed'] = gr.Number(value=-1, label='Seed (-1 for random)')
-
-    with gr.Row():
-        shared.gradio['preset_menu_mirror'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset')
-        ui.create_refresh_button(shared.gradio['preset_menu_mirror'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button')
-
     with gr.Row():
         shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA')
         ui.create_refresh_button(shared.gradio['lora_menu'], lambda : None, lambda : {'choices': get_available_loras()}, 'refresh-button')
@@ -176,8 +212,7 @@ def create_settings_menus(default_preset):
             shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip'])
 
     shared.gradio['model_menu'].change(load_model_wrapper, [shared.gradio['model_menu']], [shared.gradio['model_menu']], show_progress=True)
-    shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio[k] for k in ['preset_menu_mirror', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']])
-    shared.gradio['preset_menu_mirror'].change(load_preset_values, [shared.gradio['preset_menu_mirror']], [shared.gradio[k] for k in ['preset_menu', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']])
+    shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']])
     shared.gradio['lora_menu'].change(load_lora_wrapper, [shared.gradio['lora_menu']], [shared.gradio['lora_menu'], shared.gradio['textbox']], show_progress=True)
     shared.gradio['softprompts_menu'].change(load_soft_prompt, [shared.gradio['softprompts_menu']], [shared.gradio['softprompts_menu']], show_progress=True)
     shared.gradio['upload_softprompt'].upload(upload_soft_prompt, [shared.gradio['upload_softprompt']], [shared.gradio['softprompts_menu']])
@@ -265,8 +300,8 @@ def create_interface():
                     shared.gradio['display'] = gr.Chatbot(value=shared.history['visible']).style(color_map=("#326efd", "#212528"))
                 shared.gradio['textbox'] = gr.Textbox(label='Input')
                 with gr.Row():
-                    shared.gradio['Stop'] = gr.Button('Stop', elem_id="stop")
                     shared.gradio['Generate'] = gr.Button('Generate')
+                    shared.gradio['Stop'] = gr.Button('Stop', elem_id="stop")
                 with gr.Row():
                     shared.gradio['Impersonate'] = gr.Button('Impersonate')
                     shared.gradio['Regenerate'] = gr.Button('Regenerate')
@@ -279,8 +314,6 @@ def create_interface():
                     shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False)
                     shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
 
-                create_model_and_preset_menus()
-
             with gr.Tab("Character", elem_id="chat-settings"):
                 shared.gradio['name1'] = gr.Textbox(value=shared.settings[f'name1{suffix}'], lines=1, label='Your name')
                 shared.gradio['name2'] = gr.Textbox(value=shared.settings[f'name2{suffix}'], lines=1, label='Bot\'s name')
@@ -384,12 +417,15 @@ def create_interface():
                             shared.gradio['html'] = gr.HTML()
 
                         with gr.Row():
-                            shared.gradio['Stop'] = gr.Button('Stop')
                             shared.gradio['Generate'] = gr.Button('Generate')
+                            shared.gradio['Stop'] = gr.Button('Stop')
+
                     with gr.Column(scale=1):
+                        gr.Markdown("\n")
                         shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
 
-                        create_model_and_preset_menus()
+                        create_prompt_menus()
+
             with gr.Tab("Parameters", elem_id="parameters"):
                 create_settings_menus(default_preset)
 
@@ -413,7 +449,7 @@ def create_interface():
                             with gr.Column():
                                 shared.gradio['Stop'] = gr.Button('Stop')
 
-                        create_model_and_preset_menus()
+                        create_prompt_menus()
 
                     with gr.Column():
                         with gr.Tab('Raw'):
@@ -422,6 +458,7 @@ def create_interface():
                             shared.gradio['markdown'] = gr.Markdown()
                         with gr.Tab('HTML'):
                             shared.gradio['html'] = gr.HTML()
+
             with gr.Tab("Parameters", elem_id="parameters"):
                 create_settings_menus(default_preset)