From 165d757444096d9323c96afc46b41e6d0790e815 Mon Sep 17 00:00:00 2001 From: "Alex \"mcmonkey\" Goodwin" Date: Tue, 4 Apr 2023 08:25:11 -0700 Subject: [PATCH 01/45] improve the example character yaml format - use multiline blocks multiline blocks make the input much cleaner and simpler, particularly for the example_dialogue. For the greeting block it can kinda go either way but I think it still ends up nicer. Also double quotes in context fixes the need to escape the singlequote inside. --- characters/Example.yaml | 47 +++++++++++++---------------------------- 1 file changed, 15 insertions(+), 32 deletions(-) diff --git a/characters/Example.yaml b/characters/Example.yaml index 948dece..dde83fd 100644 --- a/characters/Example.yaml +++ b/characters/Example.yaml @@ -1,32 +1,15 @@ -name: Chiharu Yamada -context: 'Chiharu Yamada''s Persona: Chiharu Yamada is a young, computer engineer-nerd - with a knack for problem solving and a passion for technology.' -greeting: '*Chiharu strides into the room with a smile, her eyes lighting up - when she sees you. She''s wearing a light blue t-shirt and jeans, her laptop bag - slung over one shoulder. She takes a seat next to you, her enthusiasm palpable in - the air* - - Hey! I''m so excited to finally meet you. I''ve heard so many great things about - you and I''m eager to pick your brain about computers. I''m sure you have a wealth - of knowledge that I can learn from. *She grins, eyes twinkling with excitement* - Let''s get started!' -example_dialogue: '{{user}}: So how did you get into computer engineering? - - {{char}}: I''ve always loved tinkering with technology since I was a kid. - - {{user}}: That''s really impressive! - - {{char}}: *She chuckles bashfully* Thanks! - - {{user}}: So what do you do when you''re not working on computers? - - {{char}}: I love exploring, going out with friends, watching movies, and playing - video games. - - {{user}}: What''s your favorite type of computer hardware to work with? - - {{char}}: Motherboards, they''re like puzzles and the backbone of any system. - - {{user}}: That sounds great! - - {{char}}: Yeah, it''s really fun. I''m lucky to be able to do this as a job.' +context: "Chiharu Yamada's Persona: Chiharu Yamada is a young, computer engineer-nerd with a knack for problem solving and a passion for technology." +greeting: |- + *Chiharu strides into the room with a smile, her eyes lighting up when she sees you. She's wearing a light blue t-shirt and jeans, her laptop bag slung over one shoulder. She takes a seat next to you, her enthusiasm palpable in the air* + Hey! I'm so excited to finally meet you. I've heard so many great things about you and I'm eager to pick your brain about computers. I'm sure you have a wealth of knowledge that I can learn from. *She grins, eyes twinkling with excitement* Let's get started! +example_dialogue: |- + {{user}}: So how did you get into computer engineering? + {{char}}: I've always loved tinkering with technology since I was a kid. + {{user}}: That's really impressive! + {{char}}: *She chuckles bashfully* Thanks! + {{user}}: So what do you do when you're not working on computers? + {{char}}: I love exploring, going out with friends, watching movies, and playing video games. + {{user}}: What's your favorite type of computer hardware to work with? + {{char}}: Motherboards, they're like puzzles and the backbone of any system. + {{user}}: That sounds great! + {{char}}: Yeah, it's really fun. I'm lucky to be able to do this as a job. From 881dbc3d449246b2ffc1ad033ee0e75c77d02fd0 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 4 Apr 2023 13:11:34 -0300 Subject: [PATCH 02/45] Add back the name --- characters/Example.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/characters/Example.yaml b/characters/Example.yaml index dde83fd..1f60c2c 100644 --- a/characters/Example.yaml +++ b/characters/Example.yaml @@ -1,3 +1,4 @@ +name: "Chiharu Yamada" context: "Chiharu Yamada's Persona: Chiharu Yamada is a young, computer engineer-nerd with a knack for problem solving and a passion for technology." greeting: |- *Chiharu strides into the room with a smile, her eyes lighting up when she sees you. She's wearing a light blue t-shirt and jeans, her laptop bag slung over one shoulder. She takes a seat next to you, her enthusiasm palpable in the air* From ee4547cd34c0bb3ebbb862db159734135276baaf Mon Sep 17 00:00:00 2001 From: OWKenobi Date: Tue, 4 Apr 2023 18:23:27 +0200 Subject: [PATCH 03/45] Detect "vicuna" as llama model type (#772) --- modules/GPTQ_loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/GPTQ_loader.py b/modules/GPTQ_loader.py index e7877de..917f58f 100644 --- a/modules/GPTQ_loader.py +++ b/modules/GPTQ_loader.py @@ -52,7 +52,7 @@ def load_quantized(model_name): if not shared.args.model_type: # Try to determine model type from model name name = model_name.lower() - if any((k in name for k in ['llama', 'alpaca'])): + if any((k in name for k in ['llama', 'alpaca', 'vicuna'])): model_type = 'llama' elif any((k in name for k in ['opt-', 'galactica'])): model_type = 'opt' From b2ce7282a1acbb6c0744c818c3b21a75fff9d954 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 4 Apr 2023 16:11:42 -0300 Subject: [PATCH 04/45] Use past transformers version #773 --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 6d802df..6ed9e91 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,4 +13,4 @@ safetensors==0.3.0 sentencepiece pyyaml tqdm -git+https://github.com/huggingface/transformers +git+https://github.com/huggingface/transformers@9eae4aa57650c1dbe1becd4e0979f6ad1e572ac0 From 38afc2470c67d2b3e3d9e763786cb887c8b912c0 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 4 Apr 2023 16:32:27 -0300 Subject: [PATCH 05/45] Change indentation --- characters/Example.yaml | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/characters/Example.yaml b/characters/Example.yaml index 1f60c2c..0160f45 100644 --- a/characters/Example.yaml +++ b/characters/Example.yaml @@ -1,16 +1,16 @@ name: "Chiharu Yamada" context: "Chiharu Yamada's Persona: Chiharu Yamada is a young, computer engineer-nerd with a knack for problem solving and a passion for technology." greeting: |- - *Chiharu strides into the room with a smile, her eyes lighting up when she sees you. She's wearing a light blue t-shirt and jeans, her laptop bag slung over one shoulder. She takes a seat next to you, her enthusiasm palpable in the air* - Hey! I'm so excited to finally meet you. I've heard so many great things about you and I'm eager to pick your brain about computers. I'm sure you have a wealth of knowledge that I can learn from. *She grins, eyes twinkling with excitement* Let's get started! + *Chiharu strides into the room with a smile, her eyes lighting up when she sees you. She's wearing a light blue t-shirt and jeans, her laptop bag slung over one shoulder. She takes a seat next to you, her enthusiasm palpable in the air* + Hey! I'm so excited to finally meet you. I've heard so many great things about you and I'm eager to pick your brain about computers. I'm sure you have a wealth of knowledge that I can learn from. *She grins, eyes twinkling with excitement* Let's get started! example_dialogue: |- - {{user}}: So how did you get into computer engineering? - {{char}}: I've always loved tinkering with technology since I was a kid. - {{user}}: That's really impressive! - {{char}}: *She chuckles bashfully* Thanks! - {{user}}: So what do you do when you're not working on computers? - {{char}}: I love exploring, going out with friends, watching movies, and playing video games. - {{user}}: What's your favorite type of computer hardware to work with? - {{char}}: Motherboards, they're like puzzles and the backbone of any system. - {{user}}: That sounds great! - {{char}}: Yeah, it's really fun. I'm lucky to be able to do this as a job. + {{user}}: So how did you get into computer engineering? + {{char}}: I've always loved tinkering with technology since I was a kid. + {{user}}: That's really impressive! + {{char}}: *She chuckles bashfully* Thanks! + {{user}}: So what do you do when you're not working on computers? + {{char}}: I love exploring, going out with friends, watching movies, and playing video games. + {{user}}: What's your favorite type of computer hardware to work with? + {{char}}: Motherboards, they're like puzzles and the backbone of any system. + {{user}}: That sounds great! + {{char}}: Yeah, it's really fun. I'm lucky to be able to do this as a job. From 9c86acda677a4bf89033ee87a456847c7c2a2fd8 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 4 Apr 2023 18:07:34 -0300 Subject: [PATCH 06/45] Fix huge empty space in the Character tab --- css/main.css | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/css/main.css b/css/main.css index 6aa3bc1..f19fdf7 100644 --- a/css/main.css +++ b/css/main.css @@ -63,3 +63,7 @@ span.math.inline { font-size: 27px; vertical-align: baseline !important; } + +div.svelte-15lo0d8 { + flex-wrap: nowrap; +} From f70a2e3ad4f6f6441a2468d144c45b7dab94e70c Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 4 Apr 2023 18:30:34 -0300 Subject: [PATCH 07/45] Second attempt at fixing empty space --- css/main.css | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/css/main.css b/css/main.css index f19fdf7..dfb54d7 100644 --- a/css/main.css +++ b/css/main.css @@ -64,6 +64,6 @@ span.math.inline { vertical-align: baseline !important; } -div.svelte-15lo0d8 { +div.svelte-15lo0d8 > *, div.svelte-15lo0d8 > .form > * { flex-wrap: nowrap; } From 65d8a24a6df424e01412c71de463ae54fe5626ea Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 4 Apr 2023 22:28:49 -0300 Subject: [PATCH 08/45] Show profile pictures in the Character tab --- README.md | 2 +- extensions/gallery/script.py | 5 ++- modules/chat.py | 59 ++++++++++++++++++++++++++---------- modules/html_generator.py | 15 +++------ modules/shared.py | 2 +- server.py | 24 +++++++++------ 6 files changed, 66 insertions(+), 41 deletions(-) diff --git a/README.md b/README.md index 373f83f..065a9a3 100644 --- a/README.md +++ b/README.md @@ -175,7 +175,7 @@ Optionally, you can use the following command-line flags: | `-h`, `--help` | show this help message and exit | | `--notebook` | Launch the web UI in notebook mode, where the output is written to the same text box as the input. | | `--chat` | Launch the web UI in chat mode.| -| `--cai-chat` | Launch the web UI in chat mode with a style similar to Character.AI's. If the file `img_bot.png` or `img_bot.jpg` exists in the same folder as server.py, this image will be used as the bot's profile picture. Similarly, `img_me.png` or `img_me.jpg` will be used as your profile picture. | +| `--cai-chat` | Launch the web UI in chat mode with a style similar to the Character.AI website. | | `--model MODEL` | Name of the model to load by default. | | `--lora LORA` | Name of the LoRA to apply to the model by default. | | `--model-dir MODEL_DIR` | Path to directory with all the models | diff --git a/extensions/gallery/script.py b/extensions/gallery/script.py index 034506d..51ab643 100644 --- a/extensions/gallery/script.py +++ b/extensions/gallery/script.py @@ -2,9 +2,8 @@ from pathlib import Path import gradio as gr -from modules.chat import load_character from modules.html_generator import get_image_cache -from modules.shared import gradio, settings +from modules.shared import gradio def generate_css(): @@ -64,7 +63,7 @@ def generate_html(): for file in sorted(Path("characters").glob("*")): if file.suffix in [".json", ".yml", ".yaml"]: character = file.stem - container_html = f'
' + container_html = '
' image_html = "
" for i in [ diff --git a/modules/chat.py b/modules/chat.py index cd8639c..2a76bdd 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -17,9 +17,9 @@ from modules.text_generation import (encode, generate_reply, get_max_prompt_length) -def generate_chat_output(history, name1, name2, character): +def generate_chat_output(history, name1, name2): if shared.args.cai_chat: - return generate_chat_html(history, name1, name2, character) + return generate_chat_html(history, name1, name2) else: return history @@ -180,22 +180,22 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1): for history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts): - yield generate_chat_html(history, name1, name2, shared.character) + yield generate_chat_html(history, name1, name2) def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1): if (shared.character != 'None' and len(shared.history['visible']) == 1) or len(shared.history['internal']) == 0: - yield generate_chat_output(shared.history['visible'], name1, name2, shared.character) + yield generate_chat_output(shared.history['visible'], name1, name2) else: last_visible = shared.history['visible'].pop() last_internal = shared.history['internal'].pop() # Yield '*Is typing...*' - yield generate_chat_output(shared.history['visible']+[[last_visible[0], shared.processing_message]], name1, name2, shared.character) + yield generate_chat_output(shared.history['visible']+[[last_visible[0], shared.processing_message]], name1, name2) for history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts, regenerate=True): if shared.args.cai_chat: shared.history['visible'][-1] = [last_visible[0], history[-1][1]] else: shared.history['visible'][-1] = (last_visible[0], history[-1][1]) - yield generate_chat_output(shared.history['visible'], name1, name2, shared.character) + yield generate_chat_output(shared.history['visible'], name1, name2) def remove_last_message(name1, name2): if len(shared.history['visible']) > 0 and shared.history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>': @@ -205,7 +205,7 @@ def remove_last_message(name1, name2): last = ['', ''] if shared.args.cai_chat: - return generate_chat_html(shared.history['visible'], name1, name2, shared.character), last[0] + return generate_chat_html(shared.history['visible'], name1, name2), last[0] else: return shared.history['visible'], last[0] @@ -223,10 +223,10 @@ def replace_last_reply(text, name1, name2): shared.history['visible'][-1] = (shared.history['visible'][-1][0], text) shared.history['internal'][-1][1] = apply_extensions(text, "input") - return generate_chat_output(shared.history['visible'], name1, name2, shared.character) + return generate_chat_output(shared.history['visible'], name1, name2) def clear_html(): - return generate_chat_html([], "", "", shared.character) + return generate_chat_html([], "", "") def clear_chat_log(name1, name2, greeting): shared.history['visible'] = [] @@ -236,10 +236,10 @@ def clear_chat_log(name1, name2, greeting): shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]] shared.history['visible'] += [['', apply_extensions(greeting, "output")]] - return generate_chat_output(shared.history['visible'], name1, name2, shared.character) + return generate_chat_output(shared.history['visible'], name1, name2) def redraw_html(name1, name2): - return generate_chat_html(shared.history['visible'], name1, name2, shared.character) + return generate_chat_html(shared.history['visible'], name1, name2) def tokenize_dialogue(dialogue, name1, name2): history = [] @@ -326,13 +326,32 @@ def build_pygmalion_style_context(data): context = f"{context.strip()}\n\n" return context +def generate_pfp_cache(character): + cache_folder = Path("cache") + if not cache_folder.exists(): + cache_folder.mkdir() + + for path in [Path(f"characters/{character}.{extension}") for extension in ['png', 'jpg', 'jpeg']]: + if path.exists(): + img = Image.open(path) + img.thumbnail((200, 200)) + img.save(Path('cache/pfp_character.png'), format='PNG') + return img + return None + def load_character(character, name1, name2): shared.character = character shared.history['internal'] = [] shared.history['visible'] = [] greeting = "" + picture = None + + # Deleting the profile picture cache, if any + if Path("cache/pfp_character.png").exists(): + Path("cache/pfp_character.png").unlink() if character != 'None': + picture = generate_pfp_cache(character) for extension in ["yml", "yaml", "json"]: filepath = Path(f'characters/{character}.{extension}') if filepath.exists(): @@ -371,9 +390,9 @@ def load_character(character, name1, name2): shared.history['visible'] += [['', apply_extensions(greeting, "output")]] if shared.args.cai_chat: - return name1, name2, greeting, context, generate_chat_html(shared.history['visible'], name1, name2, shared.character) + return name1, name2, picture, greeting, context, generate_chat_html(shared.history['visible'], name1, name2) else: - return name1, name2, greeting, context, shared.history['visible'] + return name1, name2, picture, greeting, context, shared.history['visible'] def load_default_history(name1, name2): load_character("None", name1, name2) @@ -405,6 +424,14 @@ def upload_tavern_character(img, name1, name2): return upload_character(json.dumps(_json), img, tavern=True) def upload_your_profile_picture(img): - img = Image.open(io.BytesIO(img)) - img.save(Path('img_me.png')) - print('Profile picture saved to "img_me.png"') + cache_folder = Path("cache") + if not cache_folder.exists(): + cache_folder.mkdir() + + if img == None: + if Path("cache/pfp_me.png").exists(): + Path("cache/pfp_me.png").unlink() + else: + img.thumbnail((200, 200)) + img.save(Path('cache/pfp_me.png')) + print('Profile picture saved to "cache/pfp_me.png"') diff --git a/modules/html_generator.py b/modules/html_generator.py index 48d2e02..a6b969b 100644 --- a/modules/html_generator.py +++ b/modules/html_generator.py @@ -6,6 +6,7 @@ This is a library for formatting text outputs as nice HTML. import os import re +import time from pathlib import Path import markdown @@ -110,18 +111,12 @@ def get_image_cache(path): return image_cache[path][1] -def load_html_image(paths): - for str_path in paths: - path = Path(str_path) - if path.exists(): - return f'' - return '' - -def generate_chat_html(history, name1, name2, character): +def generate_chat_html(history, name1, name2): output = f'
' - img_bot = load_html_image([f"characters/{character}.{ext}" for ext in ['png', 'jpg', 'jpeg']] + ["img_bot.png","img_bot.jpg","img_bot.jpeg"]) - img_me = load_html_image(["img_me.png", "img_me.jpg", "img_me.jpeg"]) + # The time.time() is to prevent the brower from caching the image + img_bot = f'' if Path("cache/pfp_character.png").exists() else '' + img_me = f'' if Path("cache/pfp_me.png").exists() else '' for i,_row in enumerate(history[::-1]): row = [convert_to_markdown(entry) for entry in _row] diff --git a/modules/shared.py b/modules/shared.py index 038e392..6c183a8 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -74,7 +74,7 @@ parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpForma # Basic settings parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.') parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode.') -parser.add_argument('--cai-chat', action='store_true', help='Launch the web UI in chat mode with a style similar to Character.AI\'s. If the file img_bot.png or img_bot.jpg exists in the same folder as server.py, this image will be used as the bot\'s profile picture. Similarly, img_me.png or img_me.jpg will be used as your profile picture.') +parser.add_argument('--cai-chat', action='store_true', help='Launch the web UI in chat mode with a style similar to the Character.AI website.') parser.add_argument('--model', type=str, help='Name of the model to load by default.') parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.') parser.add_argument("--model-dir", type=str, default='models/', help="Path to directory with all the models") diff --git a/server.py b/server.py index 0a837c5..914448a 100644 --- a/server.py +++ b/server.py @@ -8,6 +8,7 @@ from datetime import datetime from pathlib import Path import gradio as gr +from PIL import Image import modules.extensions as extensions_module from modules import chat, shared, training, ui @@ -296,7 +297,7 @@ def create_interface(): shared.gradio['Chat input'] = gr.State() with gr.Tab("Text generation", elem_id="main"): if shared.args.cai_chat: - shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], shared.character)) + shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings['name1'], shared.settings['name2'])) else: shared.gradio['display'] = gr.Chatbot(value=shared.history['visible'], elem_id="gradio-chatbot") shared.gradio['textbox'] = gr.Textbox(label='Input') @@ -316,10 +317,15 @@ def create_interface(): shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False) with gr.Tab("Character", elem_id="chat-settings"): - shared.gradio['name1'] = gr.Textbox(value=shared.settings['name1'], lines=1, label='Your name') - shared.gradio['name2'] = gr.Textbox(value=shared.settings['name2'], lines=1, label='Character\'s name') - shared.gradio['greeting'] = gr.Textbox(value=shared.settings['greeting'], lines=2, label='Greeting') - shared.gradio['context'] = gr.Textbox(value=shared.settings['context'], lines=8, label='Context') + with gr.Row(): + with gr.Column(scale=8): + shared.gradio['name1'] = gr.Textbox(value=shared.settings['name1'], lines=1, label='Your name') + shared.gradio['name2'] = gr.Textbox(value=shared.settings['name2'], lines=1, label='Character\'s name') + shared.gradio['greeting'] = gr.Textbox(value=shared.settings['greeting'], lines=2, label='Greeting') + shared.gradio['context'] = gr.Textbox(value=shared.settings['context'], lines=8, label='Context') + with gr.Column(scale=1): + shared.gradio['character_picture'] = gr.Image(label='Character picture', type="pil") + shared.gradio['your_picture'] = gr.Image(label='Your picture', type="pil", value=Image.open(Path("cache/pfp_me.png")) if Path("cache/pfp_me.png").exists() else None) with gr.Row(): shared.gradio['character_menu'] = gr.Dropdown(choices=available_characters, value='None', label='Character', elem_id='character-menu') ui.create_refresh_button(shared.gradio['character_menu'], lambda : None, lambda : {'choices': get_available_characters()}, 'refresh-button') @@ -347,8 +353,6 @@ def create_interface(): gr.Markdown("# TavernAI PNG format") shared.gradio['upload_img_tavern'] = gr.File(type='binary', file_types=['image']) - with gr.Tab('Upload your profile picture'): - shared.gradio['upload_img_me'] = gr.File(type='binary', file_types=['image']) with gr.Tab("Parameters", elem_id="parameters"): with gr.Box(): @@ -399,15 +403,15 @@ def create_interface(): shared.gradio['textbox'].submit(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False) shared.gradio['textbox'].submit(lambda : chat.save_history(timestamp=False), [], [], show_progress=False) - shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2']], [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'context', 'display']]) + shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'display']]) shared.gradio['upload_chat_history'].upload(chat.load_history, [shared.gradio['upload_chat_history'], shared.gradio['name1'], shared.gradio['name2']], []) shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']]) - shared.gradio['upload_img_me'].upload(chat.upload_your_profile_picture, [shared.gradio['upload_img_me']], []) + shared.gradio['your_picture'].change(chat.upload_your_profile_picture, shared.gradio['your_picture'], []) reload_func = chat.redraw_html if shared.args.cai_chat else lambda : shared.history['visible'] reload_inputs = [shared.gradio['name1'], shared.gradio['name2']] if shared.args.cai_chat else [] shared.gradio['upload_chat_history'].upload(reload_func, reload_inputs, [shared.gradio['display']]) - shared.gradio['upload_img_me'].upload(reload_func, reload_inputs, [shared.gradio['display']]) + shared.gradio['your_picture'].change(reload_func, reload_inputs, [shared.gradio['display']]) shared.gradio['Stop'].click(reload_func, reload_inputs, [shared.gradio['display']]) shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}") From 80dfba05f386cb5ea57040b4a0bd306861cfe964 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 4 Apr 2023 22:52:15 -0300 Subject: [PATCH 09/45] Better crop/resize cached images --- extensions/gallery/script.py | 7 ++----- modules/chat.py | 6 +++--- modules/html_generator.py | 4 ++-- server.py | 2 +- 4 files changed, 8 insertions(+), 11 deletions(-) diff --git a/extensions/gallery/script.py b/extensions/gallery/script.py index 51ab643..f14dfd7 100644 --- a/extensions/gallery/script.py +++ b/extensions/gallery/script.py @@ -74,11 +74,8 @@ def generate_html(): path = Path(i) if path.exists(): - try: - image_html = f'' - break - except: - continue + image_html = f'' + break container_html += f'{image_html} {character}' container_html += "
" diff --git a/modules/chat.py b/modules/chat.py index 2a76bdd..f7b1cc1 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -7,7 +7,7 @@ from datetime import datetime from pathlib import Path import yaml -from PIL import Image +from PIL import Image, ImageOps import modules.extensions as extensions_module import modules.shared as shared @@ -334,7 +334,7 @@ def generate_pfp_cache(character): for path in [Path(f"characters/{character}.{extension}") for extension in ['png', 'jpg', 'jpeg']]: if path.exists(): img = Image.open(path) - img.thumbnail((200, 200)) + img = ImageOps.fit(img, (350, 470), Image.ANTIALIAS) img.save(Path('cache/pfp_character.png'), format='PNG') return img return None @@ -432,6 +432,6 @@ def upload_your_profile_picture(img): if Path("cache/pfp_me.png").exists(): Path("cache/pfp_me.png").unlink() else: - img.thumbnail((200, 200)) + img = ImageOps.fit(img, (350, 470), Image.ANTIALIAS) img.save(Path('cache/pfp_me.png')) print('Profile picture saved to "cache/pfp_me.png"') diff --git a/modules/html_generator.py b/modules/html_generator.py index a6b969b..35c60b7 100644 --- a/modules/html_generator.py +++ b/modules/html_generator.py @@ -10,7 +10,7 @@ import time from pathlib import Path import markdown -from PIL import Image +from PIL import Image, ImageOps # This is to store the paths to the thumbnails of the profile pictures image_cache = {} @@ -104,7 +104,7 @@ def get_image_cache(path): mtime = os.stat(path).st_mtime if (path in image_cache and mtime != image_cache[path][0]) or (path not in image_cache): img = Image.open(path) - img.thumbnail((200, 200)) + img = ImageOps.fit(img, (350, 470), Image.ANTIALIAS) output_file = Path(f'cache/{path.name}_cache.png') img.convert('RGB').save(output_file, format='PNG') image_cache[path] = [mtime, output_file.as_posix()] diff --git a/server.py b/server.py index 914448a..f9bf3d7 100644 --- a/server.py +++ b/server.py @@ -406,7 +406,7 @@ def create_interface(): shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'display']]) shared.gradio['upload_chat_history'].upload(chat.load_history, [shared.gradio['upload_chat_history'], shared.gradio['name1'], shared.gradio['name2']], []) shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']]) - shared.gradio['your_picture'].change(chat.upload_your_profile_picture, shared.gradio['your_picture'], []) + shared.gradio['your_picture'].change(chat.upload_your_profile_picture, shared.gradio['your_picture'], None) reload_func = chat.redraw_html if shared.args.cai_chat else lambda : shared.history['visible'] reload_inputs = [shared.gradio['name1'], shared.gradio['name2']] if shared.args.cai_chat else [] From cc6c7a37f35d7064341303fc10b61102c5ecae5d Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 4 Apr 2023 23:03:58 -0300 Subject: [PATCH 10/45] Add make_thumbnail function --- modules/chat.py | 10 +++++----- modules/html_generator.py | 10 ++++++++-- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/modules/chat.py b/modules/chat.py index f7b1cc1..9af2d68 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -7,12 +7,13 @@ from datetime import datetime from pathlib import Path import yaml -from PIL import Image, ImageOps +from PIL import Image import modules.extensions as extensions_module import modules.shared as shared from modules.extensions import apply_extensions -from modules.html_generator import fix_newlines, generate_chat_html +from modules.html_generator import (fix_newlines, generate_chat_html, + make_thumbnail) from modules.text_generation import (encode, generate_reply, get_max_prompt_length) @@ -333,8 +334,7 @@ def generate_pfp_cache(character): for path in [Path(f"characters/{character}.{extension}") for extension in ['png', 'jpg', 'jpeg']]: if path.exists(): - img = Image.open(path) - img = ImageOps.fit(img, (350, 470), Image.ANTIALIAS) + img = make_thumbnail(Image.open(path)) img.save(Path('cache/pfp_character.png'), format='PNG') return img return None @@ -432,6 +432,6 @@ def upload_your_profile_picture(img): if Path("cache/pfp_me.png").exists(): Path("cache/pfp_me.png").unlink() else: - img = ImageOps.fit(img, (350, 470), Image.ANTIALIAS) + img = make_thumbnail(img) img.save(Path('cache/pfp_me.png')) print('Profile picture saved to "cache/pfp_me.png"') diff --git a/modules/html_generator.py b/modules/html_generator.py index 35c60b7..98a536f 100644 --- a/modules/html_generator.py +++ b/modules/html_generator.py @@ -96,6 +96,13 @@ def generate_4chan_html(f): return output +def make_thumbnail(image): + image = image.resize((350, round(image.size[1]/image.size[0]*350)), Image.Resampling.LANCZOS) + if image.size[1] > 470: + image = ImageOps.fit(image, (350, 470), Image.ANTIALIAS) + + return image + def get_image_cache(path): cache_folder = Path("cache") if not cache_folder.exists(): @@ -103,8 +110,7 @@ def get_image_cache(path): mtime = os.stat(path).st_mtime if (path in image_cache and mtime != image_cache[path][0]) or (path not in image_cache): - img = Image.open(path) - img = ImageOps.fit(img, (350, 470), Image.ANTIALIAS) + img = make_thumbnail(Image.open(path)) output_file = Path(f'cache/{path.name}_cache.png') img.convert('RGB').save(output_file, format='PNG') image_cache[path] = [mtime, output_file.as_posix()] From 8ef89730a58af8a697468822e8c3b6b670a355be Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 4 Apr 2023 23:09:28 -0300 Subject: [PATCH 11/45] Try to better handle browser image cache --- modules/chat.py | 2 +- modules/html_generator.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/modules/chat.py b/modules/chat.py index 9af2d68..1fbda6b 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -390,7 +390,7 @@ def load_character(character, name1, name2): shared.history['visible'] += [['', apply_extensions(greeting, "output")]] if shared.args.cai_chat: - return name1, name2, picture, greeting, context, generate_chat_html(shared.history['visible'], name1, name2) + return name1, name2, picture, greeting, context, generate_chat_html(shared.history['visible'], name1, name2, reset_cache=True) else: return name1, name2, picture, greeting, context, shared.history['visible'] diff --git a/modules/html_generator.py b/modules/html_generator.py index 98a536f..e1c085a 100644 --- a/modules/html_generator.py +++ b/modules/html_generator.py @@ -117,12 +117,13 @@ def get_image_cache(path): return image_cache[path][1] -def generate_chat_html(history, name1, name2): +def generate_chat_html(history, name1, name2, reset_cache=False): output = f'
' # The time.time() is to prevent the brower from caching the image - img_bot = f'' if Path("cache/pfp_character.png").exists() else '' - img_me = f'' if Path("cache/pfp_me.png").exists() else '' + suffix = f"?{time.time()}" if reset_cache else '' + img_bot = f'' if Path("cache/pfp_character.png").exists() else '' + img_me = f'' if Path("cache/pfp_me.png").exists() else '' for i,_row in enumerate(history[::-1]): row = [convert_to_markdown(entry) for entry in _row] From ae1fe45bc0115a415f1299f4f401306c6ac9f9c6 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 4 Apr 2023 23:15:57 -0300 Subject: [PATCH 12/45] One more cache reset --- modules/chat.py | 7 ++++++- server.py | 3 +-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/modules/chat.py b/modules/chat.py index 1fbda6b..21d9d16 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -423,7 +423,7 @@ def upload_tavern_character(img, name1, name2): _json = {"char_name": _json['name'], "char_persona": _json['description'], "char_greeting": _json["first_mes"], "example_dialogue": _json['mes_example'], "world_scenario": _json['scenario']} return upload_character(json.dumps(_json), img, tavern=True) -def upload_your_profile_picture(img): +def upload_your_profile_picture(img, name1, name2): cache_folder = Path("cache") if not cache_folder.exists(): cache_folder.mkdir() @@ -435,3 +435,8 @@ def upload_your_profile_picture(img): img = make_thumbnail(img) img.save(Path('cache/pfp_me.png')) print('Profile picture saved to "cache/pfp_me.png"') + + if shared.args.cai_chat: + return generate_chat_html(shared.history['visible'], name1, name2, reset_cache=True) + else: + return shared.history['visible'] diff --git a/server.py b/server.py index f9bf3d7..a34c86f 100644 --- a/server.py +++ b/server.py @@ -406,12 +406,11 @@ def create_interface(): shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'display']]) shared.gradio['upload_chat_history'].upload(chat.load_history, [shared.gradio['upload_chat_history'], shared.gradio['name1'], shared.gradio['name2']], []) shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']]) - shared.gradio['your_picture'].change(chat.upload_your_profile_picture, shared.gradio['your_picture'], None) + shared.gradio['your_picture'].change(chat.upload_your_profile_picture, [shared.gradio[k] for k in ['your_picture', 'name1', 'name2']], shared.gradio['display']) reload_func = chat.redraw_html if shared.args.cai_chat else lambda : shared.history['visible'] reload_inputs = [shared.gradio['name1'], shared.gradio['name2']] if shared.args.cai_chat else [] shared.gradio['upload_chat_history'].upload(reload_func, reload_inputs, [shared.gradio['display']]) - shared.gradio['your_picture'].change(reload_func, reload_inputs, [shared.gradio['display']]) shared.gradio['Stop'].click(reload_func, reload_inputs, [shared.gradio['display']]) shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}") From 4ab679480e1ba0840051cd52efdbbd47b2f4b8e4 Mon Sep 17 00:00:00 2001 From: catalpaaa <89681913+catalpaaa@users.noreply.github.com> Date: Tue, 4 Apr 2023 19:19:38 -0700 Subject: [PATCH 13/45] allow quantized model to be loaded from model dir (#760) --- modules/GPTQ_loader.py | 6 +++--- modules/models.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/modules/GPTQ_loader.py b/modules/GPTQ_loader.py index 917f58f..3b062ea 100644 --- a/modules/GPTQ_loader.py +++ b/modules/GPTQ_loader.py @@ -74,7 +74,7 @@ def load_quantized(model_name): exit() # Now we are going to try to locate the quantized model file. - path_to_model = Path(f'models/{model_name}') + path_to_model = Path(f'{shared.args.model_dir}/{model_name}') found_pts = list(path_to_model.glob("*.pt")) found_safetensors = list(path_to_model.glob("*.safetensors")) pt_path = None @@ -95,8 +95,8 @@ def load_quantized(model_name): else: pt_model = f'{model_name}-{shared.args.wbits}bit' - # Try to find the .safetensors or .pt both in models/ and in the subfolder - for path in [Path(p+ext) for ext in ['.safetensors', '.pt'] for p in [f"models/{pt_model}", f"{path_to_model}/{pt_model}"]]: + # Try to find the .safetensors or .pt both in the model dir and in the subfolder + for path in [Path(p+ext) for ext in ['.safetensors', '.pt'] for p in [f"{shared.args.model_dir}/{pt_model}", f"{path_to_model}/{pt_model}"]]: if path.exists(): print(f"Found {path}") pt_path = path diff --git a/modules/models.py b/modules/models.py index edcb350..a8f8469 100644 --- a/modules/models.py +++ b/modules/models.py @@ -42,7 +42,7 @@ def load_model(model_name): t0 = time.time() shared.is_RWKV = 'rwkv-' in model_name.lower() - shared.is_llamacpp = len(list(Path(f'models/{model_name}').glob('ggml*.bin'))) > 0 + shared.is_llamacpp = len(list(Path(f'{shared.args.model_dir}/{model_name}').glob('ggml*.bin'))) > 0 # Default settings if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV, shared.is_llamacpp]): @@ -105,7 +105,7 @@ def load_model(model_name): elif shared.is_llamacpp: from modules.llamacpp_model import LlamaCppModel - model_file = list(Path(f'models/{model_name}').glob('ggml*.bin'))[0] + model_file = list(Path(f'{shared.args.model_dir}/{model_name}').glob('ggml*.bin'))[0] print(f"llama.cpp weights detected: {model_file}\n") model, tokenizer = LlamaCppModel.from_pretrained(model_file) From ca8bb3894990014df48cdcb273c518076682f11a Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Wed, 5 Apr 2023 00:34:17 -0300 Subject: [PATCH 14/45] Simplify gallery --- extensions/gallery/script.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/extensions/gallery/script.py b/extensions/gallery/script.py index f14dfd7..5c47f0f 100644 --- a/extensions/gallery/script.py +++ b/extensions/gallery/script.py @@ -66,13 +66,7 @@ def generate_html(): container_html = '
' image_html = "
" - for i in [ - f"characters/{character}.png", - f"characters/{character}.jpg", - f"characters/{character}.jpeg", - ]: - - path = Path(i) + for path in [Path(f"characters/{character}.{extension}") for extension in ['png', 'jpg', 'jpeg']]: if path.exists(): image_html = f'' break From f3a2e0b8a91002f118b1b07f39078509bd4e3558 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Wed, 5 Apr 2023 01:19:26 -0300 Subject: [PATCH 15/45] Disable pre_layer when the model type is not llama --- modules/GPTQ_loader.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/modules/GPTQ_loader.py b/modules/GPTQ_loader.py index 3b062ea..5c94776 100644 --- a/modules/GPTQ_loader.py +++ b/modules/GPTQ_loader.py @@ -65,8 +65,12 @@ def load_quantized(model_name): else: model_type = shared.args.model_type.lower() - if model_type == 'llama' and shared.args.pre_layer: - load_quant = llama_inference_offload.load_quant + if shared.args.pre_layer: + if model_type == 'llama': + load_quant = llama_inference_offload.load_quant + else: + print("Warning: ignoring --pre_layer because it only works for llama model type.") + load_quant = _load_quant elif model_type in ('llama', 'opt', 'gptj'): load_quant = _load_quant else: @@ -107,7 +111,7 @@ def load_quantized(model_name): exit() # qwopqwop200's offload - if shared.args.pre_layer: + if model_type == 'llama' and shared.args.pre_layer: model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, shared.args.pre_layer) else: threshold = False if model_type == 'gptj' else 128 From 3d6cb5ed63daf77c970b36f716f5219cccaef06e Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Wed, 5 Apr 2023 01:21:40 -0300 Subject: [PATCH 16/45] Minor rewrite --- modules/GPTQ_loader.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/modules/GPTQ_loader.py b/modules/GPTQ_loader.py index 5c94776..abfa33a 100644 --- a/modules/GPTQ_loader.py +++ b/modules/GPTQ_loader.py @@ -65,13 +65,11 @@ def load_quantized(model_name): else: model_type = shared.args.model_type.lower() - if shared.args.pre_layer: - if model_type == 'llama': - load_quant = llama_inference_offload.load_quant - else: - print("Warning: ignoring --pre_layer because it only works for llama model type.") - load_quant = _load_quant + if shared.args.pre_layer and model_type == 'llama': + load_quant = llama_inference_offload.load_quant elif model_type in ('llama', 'opt', 'gptj'): + if shared.args.pre_layer: + print("Warning: ignoring --pre_layer because it only works for llama model type.") load_quant = _load_quant else: print("Unknown pre-quantized model type specified. Only 'llama', 'opt' and 'gptj' are supported") From e722c240af56aa733c576f49e663ed7cafef5784 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Wed, 5 Apr 2023 11:49:59 -0300 Subject: [PATCH 17/45] Add Instruct mode --- characters/instruction-following/Alpaca.yaml | 3 + .../instruction-following/Open Assistant.yaml | 3 + css/html_instruct_style.css | 56 +++++++ extensions/sd_api_pictures/script.py | 2 +- extensions/send_pictures/script.py | 4 +- modules/chat.py | 139 ++++++++---------- modules/html_generator.py | 50 ++++++- modules/shared.py | 14 +- server.py | 51 ++++--- 9 files changed, 217 insertions(+), 105 deletions(-) create mode 100644 characters/instruction-following/Alpaca.yaml create mode 100644 characters/instruction-following/Open Assistant.yaml create mode 100644 css/html_instruct_style.css diff --git a/characters/instruction-following/Alpaca.yaml b/characters/instruction-following/Alpaca.yaml new file mode 100644 index 0000000..3037324 --- /dev/null +++ b/characters/instruction-following/Alpaca.yaml @@ -0,0 +1,3 @@ +name: "### Response:" +your_name: "### Instruction:" +context: "Below is an instruction that describes a task. Write a response that appropriately completes the request." diff --git a/characters/instruction-following/Open Assistant.yaml b/characters/instruction-following/Open Assistant.yaml new file mode 100644 index 0000000..5b3320f --- /dev/null +++ b/characters/instruction-following/Open Assistant.yaml @@ -0,0 +1,3 @@ +name: "<|assistant|>" +your_name: "<|prompter|>" +end_of_turn: "<|endoftext|>" diff --git a/css/html_instruct_style.css b/css/html_instruct_style.css new file mode 100644 index 0000000..f50b64d --- /dev/null +++ b/css/html_instruct_style.css @@ -0,0 +1,56 @@ +.chat { + margin-left: auto; + margin-right: auto; + max-width: 800px; + height: 66.67vh; + overflow-y: auto; + padding-right: 20px; + display: flex; + flex-direction: column-reverse; +} + +.message { + display: grid; + grid-template-columns: 60px 1fr; + padding-bottom: 25px; + font-size: 15px; + font-family: Helvetica, Arial, sans-serif; + line-height: 1.428571429; +} + +.text p { + margin-top: 5px; +} + +.username { + display: none; +} + +.message-body {} + +.message-body p { + margin-bottom: 0 !important; + font-size: 15px !important; + line-height: 1.428571429 !important; +} + +.dark .message-body p em { + color: rgb(138, 138, 138) !important; +} + +.message-body p em { + color: rgb(110, 110, 110) !important; +} + +.assistant-message { + padding: 10px; +} + +.user-message { + padding: 10px; + background-color: #f1f1f1; +} + +.dark .user-message { + background-color: #ffffff1a; +} \ No newline at end of file diff --git a/extensions/sd_api_pictures/script.py b/extensions/sd_api_pictures/script.py index cc85f3b..df07ef2 100644 --- a/extensions/sd_api_pictures/script.py +++ b/extensions/sd_api_pictures/script.py @@ -176,4 +176,4 @@ def ui(): force_btn.click(force_pic) generate_now_btn.click(force_pic) - generate_now_btn.click(eval('chat.cai_chatbot_wrapper'), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream) \ No newline at end of file + generate_now_btn.click(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream) \ No newline at end of file diff --git a/extensions/send_pictures/script.py b/extensions/send_pictures/script.py index 556a88e..b6305bd 100644 --- a/extensions/send_pictures/script.py +++ b/extensions/send_pictures/script.py @@ -36,13 +36,11 @@ def generate_chat_picture(picture, name1, name2): def ui(): picture_select = gr.Image(label='Send a picture', type='pil') - function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper' - # Prepare the hijack with custom inputs picture_select.upload(lambda picture, name1, name2: input_hijack.update({"state": True, "value": generate_chat_picture(picture, name1, name2)}), [picture_select, shared.gradio['name1'], shared.gradio['name2']], None) # Call the generation function - picture_select.upload(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream) + picture_select.upload(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream) # Clear the picture from the upload field picture_select.upload(lambda : None, [], [picture_select], show_progress=False) diff --git a/modules/chat.py b/modules/chat.py index 21d9d16..978a08f 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -12,46 +12,51 @@ from PIL import Image import modules.extensions as extensions_module import modules.shared as shared from modules.extensions import apply_extensions -from modules.html_generator import (fix_newlines, generate_chat_html, +from modules.html_generator import (fix_newlines, chat_html_wrapper, make_thumbnail) from modules.text_generation import (encode, generate_reply, get_max_prompt_length) -def generate_chat_output(history, name1, name2): - if shared.args.cai_chat: - return generate_chat_html(history, name1, name2) - else: - return history - -def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=False, also_return_rows=False): +def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, is_instruct, end_of_turn="", impersonate=False, also_return_rows=False): user_input = fix_newlines(user_input) rows = [f"{context.strip()}\n"] + # Finding the maximum prompt size if shared.soft_prompt: chat_prompt_size -= shared.soft_prompt_tensor.shape[1] max_length = min(get_max_prompt_length(max_new_tokens), chat_prompt_size) + if is_instruct: + prefix1 = f"{name1}\n" + prefix2 = f"{name2}\n" + else: + prefix1 = f"{name1}: " + prefix2 = f"{name2}: " + i = len(shared.history['internal'])-1 while i >= 0 and len(encode(''.join(rows), max_new_tokens)[0]) < max_length: - rows.insert(1, f"{name2}: {shared.history['internal'][i][1].strip()}\n") - prev_user_input = shared.history['internal'][i][0] - if prev_user_input not in ['', '<|BEGIN-VISIBLE-CHAT|>']: - rows.insert(1, f"{name1}: {prev_user_input.strip()}\n") + rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{end_of_turn}\n") + string = shared.history['internal'][i][0] + if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']: + rows.insert(1, f"{prefix1}{string.strip()}{end_of_turn}\n") i -= 1 - if not impersonate: - if len(user_input) > 0: - rows.append(f"{name1}: {user_input}\n") - rows.append(apply_extensions(f"{name2}:", "bot_prefix")) - limit = 3 - else: - rows.append(f"{name1}:") + if impersonate: + rows.append(f"{prefix1.strip() if not is_instruct else prefix1}") limit = 2 + else: + + # Adding the user message + if len(user_input) > 0: + rows.append(f"{prefix1}{user_input}{end_of_turn}\n") + + # Adding the Character prefix + rows.append(apply_extensions(f"{prefix2.strip() if not is_instruct else prefix2}", "bot_prefix")) + limit = 3 while len(rows) > limit and len(encode(''.join(rows), max_new_tokens)[0]) >= max_length: rows.pop(1) - prompt = ''.join(rows) if also_return_rows: @@ -86,7 +91,7 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline): reply = fix_newlines(reply) return reply, next_character_found -def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1, regenerate=False): +def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1, regenerate=False, mode="cai-chat", end_of_turn=""): just_started = True eos_token = '\n' if stop_at_newline else None name1_original = name1 @@ -105,14 +110,13 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical if visible_text is None: visible_text = text - if shared.args.chat: - visible_text = visible_text.replace('\n', '
') text = apply_extensions(text, "input") + is_instruct = mode == 'instruct' if custom_generate_chat_prompt is None: - prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size) + prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, is_instruct, end_of_turn=end_of_turn) else: - prompt = custom_generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size) + prompt = custom_generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, is_instruct, end_of_turn=end_of_turn) # Yield *Is typing...* if not regenerate: @@ -129,8 +133,6 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical reply, next_character_found = extract_message_from_reply(reply, name1, name2, stop_at_newline) visible_reply = re.sub("(||{{user}})", name1_original, reply) visible_reply = apply_extensions(visible_reply, "output") - if shared.args.chat: - visible_reply = visible_reply.replace('\n', '
') # We need this global variable to handle the Stop event, # otherwise gradio gets confused @@ -153,13 +155,13 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical yield shared.history['visible'] -def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1): +def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1, mode="cai-chat", end_of_turn=""): eos_token = '\n' if stop_at_newline else None if 'pygmalion' in shared.model_name.lower(): name1 = "You" - prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=True) + prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=True, end_of_turn=end_of_turn) # Yield *Is typing...* yield shared.processing_message @@ -179,36 +181,30 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ yield reply -def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1): - for history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts): - yield generate_chat_html(history, name1, name2) +def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1, mode="cai-chat", end_of_turn=""): + for history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts, regenerate=False, mode=mode, end_of_turn=end_of_turn): + yield chat_html_wrapper(history, name1, name2, mode) -def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1): +def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1, mode="cai-chat", end_of_turn=""): if (shared.character != 'None' and len(shared.history['visible']) == 1) or len(shared.history['internal']) == 0: - yield generate_chat_output(shared.history['visible'], name1, name2) + yield chat_html_wrapper(shared.history['visible'], name1, name2, mode) else: last_visible = shared.history['visible'].pop() last_internal = shared.history['internal'].pop() # Yield '*Is typing...*' - yield generate_chat_output(shared.history['visible']+[[last_visible[0], shared.processing_message]], name1, name2) - for history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts, regenerate=True): - if shared.args.cai_chat: - shared.history['visible'][-1] = [last_visible[0], history[-1][1]] - else: - shared.history['visible'][-1] = (last_visible[0], history[-1][1]) - yield generate_chat_output(shared.history['visible'], name1, name2) + yield chat_html_wrapper(shared.history['visible']+[[last_visible[0], shared.processing_message]], name1, name2, mode) + for history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts, regenerate=True, mode=mode, end_of_turn=end_of_turn): + shared.history['visible'][-1] = [last_visible[0], history[-1][1]] + yield chat_html_wrapper(shared.history['visible'], name1, name2, mode) -def remove_last_message(name1, name2): +def remove_last_message(name1, name2, mode): if len(shared.history['visible']) > 0 and shared.history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>': last = shared.history['visible'].pop() shared.history['internal'].pop() else: last = ['', ''] - if shared.args.cai_chat: - return generate_chat_html(shared.history['visible'], name1, name2), last[0] - else: - return shared.history['visible'], last[0] + return chat_html_wrapper(shared.history['visible'], name1, name2, mode), last[0] def send_last_reply_to_input(): if len(shared.history['internal']) > 0: @@ -216,20 +212,17 @@ def send_last_reply_to_input(): else: return '' -def replace_last_reply(text, name1, name2): +def replace_last_reply(text, name1, name2, mode): if len(shared.history['visible']) > 0: - if shared.args.cai_chat: - shared.history['visible'][-1][1] = text - else: - shared.history['visible'][-1] = (shared.history['visible'][-1][0], text) + shared.history['visible'][-1][1] = text shared.history['internal'][-1][1] = apply_extensions(text, "input") - return generate_chat_output(shared.history['visible'], name1, name2) + return chat_html_wrapper(shared.history['visible'], name1, name2, mode) def clear_html(): - return generate_chat_html([], "", "") + return chat_html_wrapper([], "", "") -def clear_chat_log(name1, name2, greeting): +def clear_chat_log(name1, name2, greeting, mode): shared.history['visible'] = [] shared.history['internal'] = [] @@ -237,12 +230,12 @@ def clear_chat_log(name1, name2, greeting): shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]] shared.history['visible'] += [['', apply_extensions(greeting, "output")]] - return generate_chat_output(shared.history['visible'], name1, name2) + return chat_html_wrapper(shared.history['visible'], name1, name2, mode) -def redraw_html(name1, name2): - return generate_chat_html(shared.history['visible'], name1, name2) +def redraw_html(name1, name2, mode): + return chat_html_wrapper(shared.history['visible'], name1, name2, mode) -def tokenize_dialogue(dialogue, name1, name2): +def tokenize_dialogue(dialogue, name1, name2, mode): history = [] dialogue = re.sub('', '', dialogue) @@ -339,11 +332,12 @@ def generate_pfp_cache(character): return img return None -def load_character(character, name1, name2): +def load_character(character, name1, name2, instruct=False): shared.character = character shared.history['internal'] = [] shared.history['visible'] = [] - greeting = "" + context = greeting = end_of_turn = "" + greeting_field = 'greeting' picture = None # Deleting the profile picture cache, if any @@ -351,9 +345,10 @@ def load_character(character, name1, name2): Path("cache/pfp_character.png").unlink() if character != 'None': + folder = "characters" if not instruct else "characters/instruction-following" picture = generate_pfp_cache(character) for extension in ["yml", "yaml", "json"]: - filepath = Path(f'characters/{character}.{extension}') + filepath = Path(f'{folder}/{character}.{extension}') if filepath.exists(): break file_contents = open(filepath, 'r', encoding='utf-8').read() @@ -369,19 +364,21 @@ def load_character(character, name1, name2): if 'context' in data: context = f"{data['context'].strip()}\n\n" - greeting_field = 'greeting' - else: + elif "char_persona" in data: context = build_pygmalion_style_context(data) greeting_field = 'char_greeting' - if 'example_dialogue' in data and data['example_dialogue'] != '': + if 'example_dialogue' in data: context += f"{data['example_dialogue'].strip()}\n" - if greeting_field in data and len(data[greeting_field].strip()) > 0: + if greeting_field in data: greeting = data[greeting_field] + if 'end_of_turn' in data: + end_of_turn = data['end_of_turn'] else: context = shared.settings['context'] name2 = shared.settings['name2'] greeting = shared.settings['greeting'] + end_of_turn = shared.settings['end_of_turn'] if Path(f'logs/{shared.character}_persistent.json').exists(): load_history(open(Path(f'logs/{shared.character}_persistent.json'), 'rb').read(), name1, name2) @@ -389,10 +386,7 @@ def load_character(character, name1, name2): shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]] shared.history['visible'] += [['', apply_extensions(greeting, "output")]] - if shared.args.cai_chat: - return name1, name2, picture, greeting, context, generate_chat_html(shared.history['visible'], name1, name2, reset_cache=True) - else: - return name1, name2, picture, greeting, context, shared.history['visible'] + return name1, name2, picture, greeting, context, end_of_turn, chat_html_wrapper(shared.history['visible'], name1, name2, reset_cache=True) def load_default_history(name1, name2): load_character("None", name1, name2) @@ -423,7 +417,7 @@ def upload_tavern_character(img, name1, name2): _json = {"char_name": _json['name'], "char_persona": _json['description'], "char_greeting": _json["first_mes"], "example_dialogue": _json['mes_example'], "world_scenario": _json['scenario']} return upload_character(json.dumps(_json), img, tavern=True) -def upload_your_profile_picture(img, name1, name2): +def upload_your_profile_picture(img, name1, name2, mode): cache_folder = Path("cache") if not cache_folder.exists(): cache_folder.mkdir() @@ -436,7 +430,4 @@ def upload_your_profile_picture(img, name1, name2): img.save(Path('cache/pfp_me.png')) print('Profile picture saved to "cache/pfp_me.png"') - if shared.args.cai_chat: - return generate_chat_html(shared.history['visible'], name1, name2, reset_cache=True) - else: - return shared.history['visible'] + return chat_html_wrapper(shared.history['visible'], name1, name2, mode, reset_cache=True) diff --git a/modules/html_generator.py b/modules/html_generator.py index e1c085a..6fb8457 100644 --- a/modules/html_generator.py +++ b/modules/html_generator.py @@ -21,6 +21,8 @@ with open(Path(__file__).resolve().parent / '../css/html_4chan_style.css', 'r') _4chan_css = css_f.read() with open(Path(__file__).resolve().parent / '../css/html_cai_style.css', 'r') as f: cai_css = f.read() +with open(Path(__file__).resolve().parent / '../css/html_instruct_style.css', 'r') as f: + instruct_css = f.read() def fix_newlines(string): string = string.replace('\n', '\n\n') @@ -117,7 +119,39 @@ def get_image_cache(path): return image_cache[path][1] -def generate_chat_html(history, name1, name2, reset_cache=False): +def generate_instruct_html(history): + output = f'
' + for i,_row in enumerate(history[::-1]): + row = [convert_to_markdown(entry) for entry in _row] + + output += f""" +
+
+
+ {row[1]} +
+
+
+ """ + + if len(row[0]) == 0: # don't display empty user messages + continue + + output += f""" +
+
+
+ {row[0]} +
+
+
+ """ + + output += "
" + + return output + +def generate_cai_chat_html(history, name1, name2, reset_cache=False): output = f'
' # The time.time() is to prevent the brower from caching the image @@ -165,3 +199,17 @@ def generate_chat_html(history, name1, name2, reset_cache=False): output += "
" return output + +def generate_chat_html(history, name1, name2): + return generate_cai_chat_html(history, name1, name2) + +def chat_html_wrapper(history, name1, name2, mode="cai-chat", reset_cache=False): + + if mode == "cai-chat": + return generate_cai_chat_html(history, name1, name2, reset_cache) + elif mode == "chat": + return generate_chat_html(history, name1, name2) + elif mode == "instruct": + return generate_instruct_html(history) + else: + return '' diff --git a/modules/shared.py b/modules/shared.py index 6c183a8..a6f5877 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -33,6 +33,7 @@ settings = { 'name2': 'Assistant', 'context': 'This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.', 'greeting': 'Hello there!', + 'end_of_turn': '', 'stop_at_newline': False, 'chat_prompt_size': 2048, 'chat_prompt_size_min': 0, @@ -73,8 +74,8 @@ parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpForma # Basic settings parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.') -parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode.') -parser.add_argument('--cai-chat', action='store_true', help='Launch the web UI in chat mode with a style similar to the Character.AI website.') +parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode with a style similar to the Character.AI website.') +parser.add_argument('--cai-chat', action='store_true', help='DEPRECATED: use --chat instead.') parser.add_argument('--model', type=str, help='Name of the model to load by default.') parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.') parser.add_argument("--model-dir", type=str, default='models/', help="Path to directory with all the models") @@ -131,12 +132,17 @@ parser.add_argument("--gradio-auth-path", type=str, help='Set the gradio authent args = parser.parse_args() -# Provisional, this will be deleted later +# Deprecation warnings for parameters that have been renamed deprecated_dict = {'gptq_bits': ['wbits', 0], 'gptq_model_type': ['model_type', None], 'gptq_pre_layer': ['prelayer', 0]} for k in deprecated_dict: if eval(f"args.{k}") != deprecated_dict[k][1]: print(f"Warning: --{k} is deprecated and will be removed. Use --{deprecated_dict[k][0]} instead.") exec(f"args.{deprecated_dict[k][0]} = args.{k}") +# Deprecation warnings for parameters that have been removed +if args.cai_chat: + print("Warning: --cai-chat is deprecated. Use --chat instead.") + args.chat = True + def is_chat(): - return any((args.chat, args.cai_chat)) + return args.chat diff --git a/server.py b/server.py index a34c86f..f367ca0 100644 --- a/server.py +++ b/server.py @@ -12,7 +12,7 @@ from PIL import Image import modules.extensions as extensions_module from modules import chat, shared, training, ui -from modules.html_generator import generate_chat_html +from modules.html_generator import chat_html_wrapper from modules.LoRA import add_lora_to_model from modules.models import load_model, load_soft_prompt from modules.text_generation import (clear_torch_cache, generate_reply, @@ -48,6 +48,10 @@ def get_available_prompts(): def get_available_characters(): paths = (x for x in Path('characters').iterdir() if x.suffix in ('.json', '.yaml', '.yml')) + return ['None'] + sorted(set((k.stem for k in paths if k.stem != "instruction-following")), key=str.lower) + +def get_available_instruction_templates(): + paths = (x for x in Path('characters/instruction-following').iterdir() if x.suffix in ('.json', '.yaml', '.yml')) return ['None'] + sorted(set((k.stem for k in paths)), key=str.lower) def get_available_extensions(): @@ -145,7 +149,7 @@ def load_prompt(fname): if text[-1] == '\n': text = text[:-1] return text - + def create_prompt_menus(): with gr.Row(): with gr.Column(): @@ -296,10 +300,7 @@ def create_interface(): if shared.is_chat(): shared.gradio['Chat input'] = gr.State() with gr.Tab("Text generation", elem_id="main"): - if shared.args.cai_chat: - shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings['name1'], shared.settings['name2'])) - else: - shared.gradio['display'] = gr.Chatbot(value=shared.history['visible'], elem_id="gradio-chatbot") + shared.gradio['display'] = gr.HTML(value=chat_html_wrapper(shared.history['visible'], shared.settings['name1'], shared.settings['name2'])) shared.gradio['textbox'] = gr.Textbox(label='Input') with gr.Row(): shared.gradio['Generate'] = gr.Button('Generate') @@ -316,13 +317,17 @@ def create_interface(): shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False) shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False) + shared.gradio["Chat mode"] = gr.Radio(choices=["cai-chat", "chat", "instruct"], value="cai-chat", label="Mode") + shared.gradio["Instruction templates"] = gr.Dropdown(choices=get_available_instruction_templates(), label="Instruction template", value="None", visible=False) + with gr.Tab("Character", elem_id="chat-settings"): with gr.Row(): with gr.Column(scale=8): shared.gradio['name1'] = gr.Textbox(value=shared.settings['name1'], lines=1, label='Your name') shared.gradio['name2'] = gr.Textbox(value=shared.settings['name2'], lines=1, label='Character\'s name') - shared.gradio['greeting'] = gr.Textbox(value=shared.settings['greeting'], lines=2, label='Greeting') - shared.gradio['context'] = gr.Textbox(value=shared.settings['context'], lines=8, label='Context') + shared.gradio['greeting'] = gr.Textbox(value=shared.settings['greeting'], lines=4, label='Greeting') + shared.gradio['context'] = gr.Textbox(value=shared.settings['context'], lines=4, label='Context') + shared.gradio['end_of_turn'] = gr.Textbox(value=shared.settings["end_of_turn"], lines=1, label='End of turn string') with gr.Column(scale=1): shared.gradio['character_picture'] = gr.Image(label='Character picture', type="pil") shared.gradio['your_picture'] = gr.Image(label='Your picture', type="pil", value=Image.open(Path("cache/pfp_me.png")) if Path("cache/pfp_me.png").exists() else None) @@ -367,31 +372,31 @@ def create_interface(): create_settings_menus(default_preset) - function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper' - shared.input_params = [shared.gradio[k] for k in ['Chat input', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'seed', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider', 'chat_generation_attempts']] + shared.input_params = [shared.gradio[k] for k in ['Chat input', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'seed', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider', 'chat_generation_attempts', 'Chat mode', 'end_of_turn']] def set_chat_input(textbox): return textbox, "" gen_events.append(shared.gradio['Generate'].click(set_chat_input, shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False)) - gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)) + gen_events.append(shared.gradio['Generate'].click(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['textbox'].submit(set_chat_input, shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False)) - gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)) + gen_events.append(shared.gradio['textbox'].submit(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream)) shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None) shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream) - shared.gradio['Replace last reply'].click(chat.replace_last_reply, [shared.gradio['textbox'], shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'], show_progress=shared.args.no_stream) + shared.gradio['Replace last reply'].click(chat.replace_last_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'Chat mode']], shared.gradio['display'], show_progress=shared.args.no_stream) # Clear history with confirmation clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']] shared.gradio['Clear history'].click(lambda :[gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr) shared.gradio['Clear history-confirm'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr) - shared.gradio['Clear history-confirm'].click(chat.clear_chat_log, [shared.gradio['name1'], shared.gradio['name2'], shared.gradio['greeting']], shared.gradio['display']) + shared.gradio['Clear history-confirm'].click(chat.clear_chat_log, [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'Chat mode']], shared.gradio['display']) shared.gradio['Clear history-cancel'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr) + shared.gradio['Chat mode'].change(lambda x : gr.update(visible= x=='instruct'), shared.gradio['Chat mode'], shared.gradio['Instruction templates']) - shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False) + shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False) shared.gradio['download_button'].click(chat.save_history, inputs=[], outputs=[shared.gradio['download']]) shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']]) @@ -404,18 +409,20 @@ def create_interface(): shared.gradio['textbox'].submit(lambda : chat.save_history(timestamp=False), [], [], show_progress=False) shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'display']]) - shared.gradio['upload_chat_history'].upload(chat.load_history, [shared.gradio['upload_chat_history'], shared.gradio['name1'], shared.gradio['name2']], []) + shared.gradio['Instruction templates'].change(lambda character, name1, name2: chat.load_character(character, name1, name2, instruct=True), [shared.gradio[k] for k in ['Instruction templates', 'name1', 'name2']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']]) + shared.gradio['upload_chat_history'].upload(chat.load_history, [shared.gradio[k] for k in ['upload_chat_history', 'name1', 'name2']], []) shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']]) - shared.gradio['your_picture'].change(chat.upload_your_profile_picture, [shared.gradio[k] for k in ['your_picture', 'name1', 'name2']], shared.gradio['display']) + shared.gradio['your_picture'].change(chat.upload_your_profile_picture, [shared.gradio[k] for k in ['your_picture', 'name1', 'name2', 'Chat mode']], shared.gradio['display']) - reload_func = chat.redraw_html if shared.args.cai_chat else lambda : shared.history['visible'] - reload_inputs = [shared.gradio['name1'], shared.gradio['name2']] if shared.args.cai_chat else [] - shared.gradio['upload_chat_history'].upload(reload_func, reload_inputs, [shared.gradio['display']]) - shared.gradio['Stop'].click(reload_func, reload_inputs, [shared.gradio['display']]) + reload_inputs = [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']] + shared.gradio['upload_chat_history'].upload(chat.redraw_html, reload_inputs, [shared.gradio['display']]) + shared.gradio['Stop'].click(chat.redraw_html, reload_inputs, [shared.gradio['display']]) + shared.gradio['Instruction templates'].change(chat.redraw_html, reload_inputs, [shared.gradio['display']]) + shared.gradio['Chat mode'].change(chat.redraw_html, reload_inputs, [shared.gradio['display']]) shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}") shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings['name1'], shared.settings['name2']), None, None) - shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True) + shared.gradio['interface'].load(chat.redraw_html, reload_inputs, [shared.gradio['display']], show_progress=True) elif shared.args.notebook: with gr.Tab("Text generation", elem_id="main"): From cf2c4e740b1d06e145c1992515d9b34e18affc95 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Wed, 5 Apr 2023 14:05:50 -0300 Subject: [PATCH 18/45] Disable gradio analytics globally --- server.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/server.py b/server.py index f367ca0..846ea9d 100644 --- a/server.py +++ b/server.py @@ -1,3 +1,7 @@ +import os + +os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False' + import io import json import re From 90141bc1a804ff5b0e988753832e3c4b4cc6923e Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Wed, 5 Apr 2023 14:08:54 -0300 Subject: [PATCH 19/45] Fix saving prompts on Windows --- server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server.py b/server.py index 846ea9d..5bcb2cb 100644 --- a/server.py +++ b/server.py @@ -139,7 +139,7 @@ def create_model_and_preset_menus(): ui.create_refresh_button(shared.gradio['preset_menu'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button') def save_prompt(text): - fname = f"{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}.txt" + fname = f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}.txt" with open(Path(f'prompts/{fname}'), 'w', encoding='utf-8') as f: f.write(text) return f"Saved to prompts/{fname}" From 7f664213692c31166869f67614453189dc4ab8ff Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Wed, 5 Apr 2023 14:22:32 -0300 Subject: [PATCH 20/45] Fix loading characters --- modules/chat.py | 8 ++++---- modules/html_generator.py | 3 +-- server.py | 6 +++--- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/modules/chat.py b/modules/chat.py index 978a08f..1140b5f 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -332,7 +332,7 @@ def generate_pfp_cache(character): return img return None -def load_character(character, name1, name2, instruct=False): +def load_character(character, name1, name2, mode): shared.character = character shared.history['internal'] = [] shared.history['visible'] = [] @@ -345,7 +345,7 @@ def load_character(character, name1, name2, instruct=False): Path("cache/pfp_character.png").unlink() if character != 'None': - folder = "characters" if not instruct else "characters/instruction-following" + folder = 'characters' if not mode == 'instruct' else 'characters/instruction-following' picture = generate_pfp_cache(character) for extension in ["yml", "yaml", "json"]: filepath = Path(f'{folder}/{character}.{extension}') @@ -386,10 +386,10 @@ def load_character(character, name1, name2, instruct=False): shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]] shared.history['visible'] += [['', apply_extensions(greeting, "output")]] - return name1, name2, picture, greeting, context, end_of_turn, chat_html_wrapper(shared.history['visible'], name1, name2, reset_cache=True) + return name1, name2, picture, greeting, context, end_of_turn, chat_html_wrapper(shared.history['visible'], name1, name2, mode, reset_cache=True) def load_default_history(name1, name2): - load_character("None", name1, name2) + load_character("None", name1, name2, "chat") def upload_character(json_file, img, tavern=False): json_file = json_file if type(json_file) == str else json_file.decode('utf-8') diff --git a/modules/html_generator.py b/modules/html_generator.py index 6fb8457..e5c0bb5 100644 --- a/modules/html_generator.py +++ b/modules/html_generator.py @@ -203,8 +203,7 @@ def generate_cai_chat_html(history, name1, name2, reset_cache=False): def generate_chat_html(history, name1, name2): return generate_cai_chat_html(history, name1, name2) -def chat_html_wrapper(history, name1, name2, mode="cai-chat", reset_cache=False): - +def chat_html_wrapper(history, name1, name2, mode, reset_cache=False): if mode == "cai-chat": return generate_cai_chat_html(history, name1, name2, reset_cache) elif mode == "chat": diff --git a/server.py b/server.py index 5bcb2cb..8bcb650 100644 --- a/server.py +++ b/server.py @@ -304,7 +304,7 @@ def create_interface(): if shared.is_chat(): shared.gradio['Chat input'] = gr.State() with gr.Tab("Text generation", elem_id="main"): - shared.gradio['display'] = gr.HTML(value=chat_html_wrapper(shared.history['visible'], shared.settings['name1'], shared.settings['name2'])) + shared.gradio['display'] = gr.HTML(value=chat_html_wrapper(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], 'cai-chat')) shared.gradio['textbox'] = gr.Textbox(label='Input') with gr.Row(): shared.gradio['Generate'] = gr.Button('Generate') @@ -412,8 +412,8 @@ def create_interface(): shared.gradio['textbox'].submit(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False) shared.gradio['textbox'].submit(lambda : chat.save_history(timestamp=False), [], [], show_progress=False) - shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'display']]) - shared.gradio['Instruction templates'].change(lambda character, name1, name2: chat.load_character(character, name1, name2, instruct=True), [shared.gradio[k] for k in ['Instruction templates', 'name1', 'name2']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']]) + shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']]) + shared.gradio['Instruction templates'].change(lambda character, name1, name2, mode: chat.load_character(character, name1, name2, mode), [shared.gradio[k] for k in ['Instruction templates', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']]) shared.gradio['upload_chat_history'].upload(chat.load_history, [shared.gradio[k] for k in ['upload_chat_history', 'name1', 'name2']], []) shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']]) shared.gradio['your_picture'].change(chat.upload_your_profile_picture, [shared.gradio[k] for k in ['your_picture', 'name1', 'name2', 'Chat mode']], shared.gradio['display']) From 8203ce0cac79d84b4ecafc8378c83cdf5b21b68f Mon Sep 17 00:00:00 2001 From: Forkoz <59298527+Ph0rk0z@users.noreply.github.com> Date: Wed, 5 Apr 2023 12:25:01 -0500 Subject: [PATCH 21/45] Stop character pic from being cached when changing chars or clearing. (#798) Tested on both FF and chromium --- modules/html_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/html_generator.py b/modules/html_generator.py index e5c0bb5..448c20c 100644 --- a/modules/html_generator.py +++ b/modules/html_generator.py @@ -155,7 +155,7 @@ def generate_cai_chat_html(history, name1, name2, reset_cache=False): output = f'
' # The time.time() is to prevent the brower from caching the image - suffix = f"?{time.time()}" if reset_cache else '' + suffix = f"?{time.time()}" if reset_cache else f"?{name2}" img_bot = f'' if Path("cache/pfp_character.png").exists() else '' img_me = f'' if Path("cache/pfp_me.png").exists() else '' From 770ef5744f7f18f1b1ae06a7d94ad58e317fb81d Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Wed, 5 Apr 2023 14:38:11 -0300 Subject: [PATCH 22/45] Update README --- README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README.md b/README.md index 065a9a3..cf8a8cc 100644 --- a/README.md +++ b/README.md @@ -175,7 +175,6 @@ Optionally, you can use the following command-line flags: | `-h`, `--help` | show this help message and exit | | `--notebook` | Launch the web UI in notebook mode, where the output is written to the same text box as the input. | | `--chat` | Launch the web UI in chat mode.| -| `--cai-chat` | Launch the web UI in chat mode with a style similar to the Character.AI website. | | `--model MODEL` | Name of the model to load by default. | | `--lora LORA` | Name of the LoRA to apply to the model by default. | | `--model-dir MODEL_DIR` | Path to directory with all the models | From 7617ed5bfd30145d0a6815dfb506001ef04328b9 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Wed, 5 Apr 2023 14:42:58 -0300 Subject: [PATCH 23/45] Add AMD instructions --- README.md | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index cf8a8cc..341fec2 100644 --- a/README.md +++ b/README.md @@ -62,7 +62,7 @@ Recommended if you have some experience with the command-line. On Windows, I additionally recommend carrying out the installation on WSL instead of the base system: [WSL installation guide](https://github.com/oobabooga/text-generation-webui/wiki/WSL-installation-guide). -0. Install Conda +#### 0. Install Conda https://docs.conda.io/en/latest/miniconda.html @@ -75,14 +75,14 @@ bash Miniconda3.sh Source: https://educe-ubc.github.io/conda.html -1. Create a new conda environment +#### 1. Create a new conda environment ``` conda create -n textgen python=3.10.9 conda activate textgen ``` -2. Install Pytorch +#### 2. Install Pytorch | System | GPU | Command | |--------|---------|---------| @@ -92,10 +92,12 @@ conda activate textgen The up to date commands can be found here: https://pytorch.org/get-started/locally/. -MacOS users, refer to the comments here: https://github.com/oobabooga/text-generation-webui/pull/393 +#### 2.1 Special instructions +* MacOS users: https://github.com/oobabooga/text-generation-webui/pull/393 +* AMD users: https://rentry.org/eq3hg -3. Install the web UI +#### 3. Install the web UI ``` git clone https://github.com/oobabooga/text-generation-webui From 19b516b11b1553ed5531d4b40f80eb2c33224570 Mon Sep 17 00:00:00 2001 From: eiery <19350831+eiery@users.noreply.github.com> Date: Wed, 5 Apr 2023 13:50:23 -0400 Subject: [PATCH 24/45] fix link to streaming api example (#803) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 341fec2..1e51563 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github. * CPU mode * [FlexGen](https://github.com/oobabooga/text-generation-webui/wiki/FlexGen) * [DeepSpeed ZeRO-3](https://github.com/oobabooga/text-generation-webui/wiki/DeepSpeed) -* API [with](https://github.com/oobabooga/text-generation-webui/blob/main/api-example-streaming.py) streaming and [without](https://github.com/oobabooga/text-generation-webui/blob/main/api-example.py) streaming +* API [with](https://github.com/oobabooga/text-generation-webui/blob/main/api-example-stream.py) streaming and [without](https://github.com/oobabooga/text-generation-webui/blob/main/api-example.py) streaming * [LLaMA model, including 4-bit GPTQ](https://github.com/oobabooga/text-generation-webui/wiki/LLaMA-model) * [llama.cpp](https://github.com/oobabooga/text-generation-webui/wiki/llama.cpp-models) **\*NEW!\*** * [RWKV model](https://github.com/oobabooga/text-generation-webui/wiki/RWKV-model) From 378d21e80c3d6f11a4835e57597c69e340008e2c Mon Sep 17 00:00:00 2001 From: SDS <52386626+StefanDanielSchwarz@users.noreply.github.com> Date: Wed, 5 Apr 2023 23:52:36 +0200 Subject: [PATCH 25/45] Add LLaMA-Precise preset (#767) --- modules/shared.py | 1 + presets/LLaMA-Precise.txt | 6 ++++++ 2 files changed, 7 insertions(+) create mode 100644 presets/LLaMA-Precise.txt diff --git a/modules/shared.py b/modules/shared.py index a6f5877..902d760 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -45,6 +45,7 @@ settings = { 'chat_default_extensions': ["gallery"], 'presets': { 'default': 'NovelAI-Sphinx Moth', + '.*(alpaca|llama)': "LLaMA-Precise", '.*pygmalion': 'NovelAI-Storywriter', '.*RWKV': 'Naive', }, diff --git a/presets/LLaMA-Precise.txt b/presets/LLaMA-Precise.txt new file mode 100644 index 0000000..8098b39 --- /dev/null +++ b/presets/LLaMA-Precise.txt @@ -0,0 +1,6 @@ +do_sample=True +top_p=0.1 +top_k=40 +temperature=0.7 +repetition_penalty=1.18 +typical_p=1.0 From 3f3e42e26cb6b8e56af7eada4f441d846b5f5969 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 6 Apr 2023 01:22:15 -0300 Subject: [PATCH 26/45] Refactor several function calls and the API --- api-example-stream.py | 18 ++-------- api-example.py | 21 +++--------- extensions/api/script.py | 35 ++++++++++--------- extensions/send_pictures/script.py | 5 ++- modules/api.py | 38 +++++++++++++++++++++ modules/chat.py | 43 +++++++++++++----------- modules/text_generation.py | 54 +++++++++++++----------------- server.py | 51 ++++++++++++++++++---------- 8 files changed, 147 insertions(+), 118 deletions(-) create mode 100644 modules/api.py diff --git a/api-example-stream.py b/api-example-stream.py index e87fb74..32eefc7 100644 --- a/api-example-stream.py +++ b/api-example-stream.py @@ -36,6 +36,7 @@ async def run(context): 'early_stopping': False, 'seed': -1, } + payload = json.dumps([context, params]) session = random_hash() async with websockets.connect(f"ws://{server}:7860/queue/join") as websocket: @@ -54,22 +55,7 @@ async def run(context): "session_hash": session, "fn_index": 12, "data": [ - context, - params['max_new_tokens'], - params['do_sample'], - params['temperature'], - params['top_p'], - params['typical_p'], - params['repetition_penalty'], - params['encoder_repetition_penalty'], - params['top_k'], - params['min_length'], - params['no_repeat_ngram_size'], - params['num_beams'], - params['penalty_alpha'], - params['length_penalty'], - params['early_stopping'], - params['seed'], + payload ] })) case "process_starts": diff --git a/api-example.py b/api-example.py index 0349824..10be0a8 100644 --- a/api-example.py +++ b/api-example.py @@ -10,6 +10,8 @@ Optionally, you can also add the --share flag to generate a public gradio URL, allowing you to use the API remotely. ''' +import json + import requests # Server address @@ -38,24 +40,11 @@ params = { # Input prompt prompt = "What I would like to say is the following: " +payload = json.dumps([prompt, params]) + response = requests.post(f"http://{server}:7860/run/textgen", json={ "data": [ - prompt, - params['max_new_tokens'], - params['do_sample'], - params['temperature'], - params['top_p'], - params['typical_p'], - params['repetition_penalty'], - params['encoder_repetition_penalty'], - params['top_k'], - params['min_length'], - params['no_repeat_ngram_size'], - params['num_beams'], - params['penalty_alpha'], - params['length_penalty'], - params['early_stopping'], - params['seed'], + payload ] }).json() diff --git a/extensions/api/script.py b/extensions/api/script.py index 20562cc..6726d61 100644 --- a/extensions/api/script.py +++ b/extensions/api/script.py @@ -40,24 +40,27 @@ class Handler(BaseHTTPRequestHandler): prompt_lines.pop(0) prompt = '\n'.join(prompt_lines) + generate_params = { + 'max_new_tokens': int(body.get('max_length', 200)), + 'do_sample': bool(body.get('do_sample', True)), + 'temperature': float(body.get('temperature', 0.5)), + 'top_p': float(body.get('top_p', 1)), + 'typical_p': float(body.get('typical', 1)), + 'repetition_penalty': float(body.get('rep_pen', 1.1)), + 'encoder_repetition_penalty': 1, + 'top_k': int(body.get('top_k', 0)), + 'min_length': int(body.get('min_length', 0)), + 'no_repeat_ngram_size': int(body.get('no_repeat_ngram_size',0)), + 'num_beams': int(body.get('num_beams',1)), + 'penalty_alpha': float(body.get('penalty_alpha', 0)), + 'length_penalty': float(body.get('length_penalty', 1)), + 'early_stopping': bool(body.get('early_stopping', False)), + 'seed': int(body.get('seed', -1)), + } generator = generate_reply( - question = prompt, - max_new_tokens = int(body.get('max_length', 200)), - do_sample=bool(body.get('do_sample', True)), - temperature=float(body.get('temperature', 0.5)), - top_p=float(body.get('top_p', 1)), - typical_p=float(body.get('typical', 1)), - repetition_penalty=float(body.get('rep_pen', 1.1)), - encoder_repetition_penalty=1, - top_k=int(body.get('top_k', 0)), - min_length=int(body.get('min_length', 0)), - no_repeat_ngram_size=int(body.get('no_repeat_ngram_size',0)), - num_beams=int(body.get('num_beams',1)), - penalty_alpha=float(body.get('penalty_alpha', 0)), - length_penalty=float(body.get('length_penalty', 1)), - early_stopping=bool(body.get('early_stopping', False)), - seed=int(body.get('seed', -1)), + prompt, + generate_params, stopping_strings=body.get('stopping_strings', []), ) diff --git a/extensions/send_pictures/script.py b/extensions/send_pictures/script.py index b6305bd..d2401df 100644 --- a/extensions/send_pictures/script.py +++ b/extensions/send_pictures/script.py @@ -2,12 +2,11 @@ import base64 from io import BytesIO import gradio as gr -import modules.chat as chat -import modules.shared as shared import torch -from PIL import Image from transformers import BlipForConditionalGeneration, BlipProcessor +from modules import chat, shared + # If 'state' is True, will hijack the next chat generation with # custom input text given by 'value' in the format [text, visible_text] input_hijack = { diff --git a/modules/api.py b/modules/api.py new file mode 100644 index 0000000..26249fd --- /dev/null +++ b/modules/api.py @@ -0,0 +1,38 @@ +import json + +import gradio as gr + +from modules import shared +from modules.text_generation import generate_reply + + +def generate_reply_wrapper(string): + generate_params = { + 'do_sample': True, + 'temperature': 1, + 'top_p': 1, + 'typical_p': 1, + 'repetition_penalty': 1, + 'encoder_repetition_penalty': 1, + 'top_k': 50, + 'num_beams': 1, + 'penalty_alpha': 0, + 'min_length': 0, + 'length_penalty': 1, + 'no_repeat_ngram_size': 0, + 'early_stopping': False, + } + params = json.loads(string) + for k in params[1]: + generate_params[k] = params[1][k] + for i in generate_reply(params[0], generate_params): + yield i + +def create_apis(): + t1 = gr.Textbox(visible=False) + t2 = gr.Textbox(visible=False) + dummy = gr.Button(visible=False) + + input_params = [t1] + output_params = [t2] + [shared.gradio[k] for k in ['markdown', 'html']] + dummy.click(generate_reply_wrapper, input_params, output_params, api_name='textgen') diff --git a/modules/chat.py b/modules/chat.py index 1140b5f..f4ddf42 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -18,7 +18,12 @@ from modules.text_generation import (encode, generate_reply, get_max_prompt_length) -def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, is_instruct, end_of_turn="", impersonate=False, also_return_rows=False): +def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, **kwargs): + is_instruct = kwargs['is_instruct'] if 'is_instruct' in kwargs else False + end_of_turn = kwargs['end_of_turn'] if 'end_of_turn' in kwargs else '' + impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False + also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False + user_input = fix_newlines(user_input) rows = [f"{context.strip()}\n"] @@ -91,9 +96,9 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline): reply = fix_newlines(reply) return reply, next_character_found -def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1, regenerate=False, mode="cai-chat", end_of_turn=""): +def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False): just_started = True - eos_token = '\n' if stop_at_newline else None + eos_token = '\n' if generate_state['stop_at_newline'] else None name1_original = name1 if 'pygmalion' in shared.model_name.lower(): name1 = "You" @@ -112,11 +117,11 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical visible_text = text text = apply_extensions(text, "input") - is_instruct = mode == 'instruct' + kwargs = {'end_of_turn': end_of_turn, 'is_instruct': mode == 'instruct'} if custom_generate_chat_prompt is None: - prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, is_instruct, end_of_turn=end_of_turn) + prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs) else: - prompt = custom_generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, is_instruct, end_of_turn=end_of_turn) + prompt = custom_generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs) # Yield *Is typing...* if not regenerate: @@ -124,13 +129,13 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical # Generate cumulative_reply = '' - for i in range(chat_generation_attempts): + for i in range(generate_state['chat_generation_attempts']): reply = None - for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]): + for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]): reply = cumulative_reply + reply # Extracting the reply - reply, next_character_found = extract_message_from_reply(reply, name1, name2, stop_at_newline) + reply, next_character_found = extract_message_from_reply(reply, name1, name2, generate_state['stop_at_newline']) visible_reply = re.sub("(||{{user}})", name1_original, reply) visible_reply = apply_extensions(visible_reply, "output") @@ -155,23 +160,23 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical yield shared.history['visible'] -def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1, mode="cai-chat", end_of_turn=""): - eos_token = '\n' if stop_at_newline else None +def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn): + eos_token = '\n' if generate_state['stop_at_newline'] else None if 'pygmalion' in shared.model_name.lower(): name1 = "You" - prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=True, end_of_turn=end_of_turn) + prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], impersonate=True, end_of_turn=end_of_turn) # Yield *Is typing...* yield shared.processing_message cumulative_reply = '' - for i in range(chat_generation_attempts): + for i in range(generate_state['chat_generation_attempts']): reply = None - for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]): + for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]): reply = cumulative_reply + reply - reply, next_character_found = extract_message_from_reply(reply, name1, name2, stop_at_newline) + reply, next_character_found = extract_message_from_reply(reply, name1, name2, generate_state['stop_at_newline']) yield reply if next_character_found: break @@ -181,11 +186,11 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ yield reply -def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1, mode="cai-chat", end_of_turn=""): - for history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts, regenerate=False, mode=mode, end_of_turn=end_of_turn): +def cai_chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn): + for history in chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False): yield chat_html_wrapper(history, name1, name2, mode) -def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1, mode="cai-chat", end_of_turn=""): +def regenerate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn): if (shared.character != 'None' and len(shared.history['visible']) == 1) or len(shared.history['internal']) == 0: yield chat_html_wrapper(shared.history['visible'], name1, name2, mode) else: @@ -193,7 +198,7 @@ def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typi last_internal = shared.history['internal'].pop() # Yield '*Is typing...*' yield chat_html_wrapper(shared.history['visible']+[[last_visible[0], shared.processing_message]], name1, name2, mode) - for history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts, regenerate=True, mode=mode, end_of_turn=end_of_turn): + for history in chatbot_wrapper(last_internal[0], generate_state, name1, name2, context, mode, end_of_turn, regenerate=True): shared.history['visible'][-1] = [last_visible[0], history[-1][1]] yield chat_html_wrapper(shared.history['visible'], name1, name2, mode) diff --git a/modules/text_generation.py b/modules/text_generation.py index 406c454..93f0789 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -102,10 +102,11 @@ def set_manual_seed(seed): def stop_everything_event(): shared.stop_everything = True -def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=None, stopping_strings=[]): +def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]): clear_torch_cache() - set_manual_seed(seed) + set_manual_seed(generate_state['seed']) shared.stop_everything = False + generate_params = {} t0 = time.time() original_question = question @@ -117,9 +118,12 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi # These models are not part of Hugging Face, so we handle them # separately and terminate the function call earlier if any((shared.is_RWKV, shared.is_llamacpp)): + for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']: + generate_params[k] = generate_state[k] + generate_params["token_count"] = generate_state["max_new_tokens"] try: if shared.args.no_stream: - reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty) + reply = shared.model.generate(context=question, **generate_params) output = original_question+reply if not shared.is_chat(): reply = original_question + apply_extensions(reply, "output") @@ -130,7 +134,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi # RWKV has proper streaming, which is very nice. # No need to generate 8 tokens at a time. - for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty): + for reply in shared.model.generate_with_streaming(context=question, **generate_params): output = original_question+reply if not shared.is_chat(): reply = original_question + apply_extensions(reply, "output") @@ -145,7 +149,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi print(f"Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})") return - input_ids = encode(question, max_new_tokens) + input_ids = encode(question, generate_state['max_new_tokens']) original_input_ids = input_ids output = input_ids[0] @@ -158,33 +162,21 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi t = [encode(string, 0, add_special_tokens=False) for string in stopping_strings] stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=t, starting_idx=len(input_ids[0]))) - generate_params = {} + generate_params["max_new_tokens"] = generate_state['max_new_tokens'] if not shared.args.flexgen: - generate_params.update({ - "max_new_tokens": max_new_tokens, - "eos_token_id": eos_token_ids, - "stopping_criteria": stopping_criteria_list, - "do_sample": do_sample, - "temperature": temperature, - "top_p": top_p, - "typical_p": typical_p, - "repetition_penalty": repetition_penalty, - "encoder_repetition_penalty": encoder_repetition_penalty, - "top_k": top_k, - "min_length": min_length if shared.args.no_stream else 0, - "no_repeat_ngram_size": no_repeat_ngram_size, - "num_beams": num_beams, - "penalty_alpha": penalty_alpha, - "length_penalty": length_penalty, - "early_stopping": early_stopping, - }) + for k in ["do_sample", "temperature", "top_p", "typical_p", "repetition_penalty", "encoder_repetition_penalty", "top_k", "min_length", "no_repeat_ngram_size", "num_beams", "penalty_alpha", "length_penalty", "early_stopping"]: + generate_params[k] = generate_state[k] + generate_params["eos_token_id"] = eos_token_ids + generate_params["stopping_criteria"] = stopping_criteria_list + if shared.args.no_stream: + generate_params["min_length"] = 0 else: - generate_params.update({ - "max_new_tokens": max_new_tokens if shared.args.no_stream else 8, - "do_sample": do_sample, - "temperature": temperature, - "stop": eos_token_ids[-1], - }) + for k in ["do_sample", "temperature"]: + generate_params[k] = generate_state[k] + generate_params["stop"] = generate_state["eos_token_ids"][-1] + if not shared.args.no_stream: + generate_params["max_new_tokens"] = 8 + if shared.args.no_cache: generate_params.update({"use_cache": False}) if shared.args.deepspeed: @@ -244,7 +236,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi # Stream the output naively for FlexGen since it doesn't support 'stopping_criteria' else: - for i in range(max_new_tokens//8+1): + for i in range(generate_state['max_new_tokens']//8+1): clear_torch_cache() with torch.no_grad(): output = shared.model.generate(**generate_params)[0] diff --git a/server.py b/server.py index 8bcb650..f00e412 100644 --- a/server.py +++ b/server.py @@ -15,7 +15,7 @@ import gradio as gr from PIL import Image import modules.extensions as extensions_module -from modules import chat, shared, training, ui +from modules import chat, shared, training, ui, api from modules.html_generator import chat_html_wrapper from modules.LoRA import add_lora_to_model from modules.models import load_model, load_soft_prompt @@ -85,7 +85,7 @@ def load_lora_wrapper(selected_lora): add_lora_to_model(selected_lora) return selected_lora -def load_preset_values(preset_menu, return_dict=False): +def load_preset_values(preset_menu, state, return_dict=False): generate_params = { 'do_sample': True, 'temperature': 1, @@ -107,13 +107,13 @@ def load_preset_values(preset_menu, return_dict=False): i = i.rstrip(',').strip().split('=') if len(i) == 2 and i[0].strip() != 'tokens': generate_params[i[0].strip()] = eval(i[1].strip()) - generate_params['temperature'] = min(1.99, generate_params['temperature']) if return_dict: return generate_params else: - return generate_params['do_sample'], generate_params['temperature'], generate_params['top_p'], generate_params['typical_p'], generate_params['repetition_penalty'], generate_params['encoder_repetition_penalty'], generate_params['top_k'], generate_params['min_length'], generate_params['no_repeat_ngram_size'], generate_params['num_beams'], generate_params['penalty_alpha'], generate_params['length_penalty'], generate_params['early_stopping'] + state.update(generate_params) + return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']] def upload_soft_prompt(file): with zipfile.ZipFile(io.BytesIO(file)) as zf: @@ -170,7 +170,10 @@ def create_prompt_menus(): shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False) def create_settings_menus(default_preset): - generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', return_dict=True) + generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', {}, return_dict=True) + for k in ['max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size', 'chat_generation_attempts']: + generate_params[k] = shared.settings[k] + shared.gradio['generate_state'] = gr.State(generate_params) with gr.Row(): with gr.Column(): @@ -221,17 +224,16 @@ def create_settings_menus(default_preset): with gr.Row(): shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip']) - shared.gradio['model_menu'].change(load_model_wrapper, [shared.gradio['model_menu']], [shared.gradio['model_menu']], show_progress=True) - shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]) - shared.gradio['lora_menu'].change(load_lora_wrapper, [shared.gradio['lora_menu']], [shared.gradio['lora_menu']], show_progress=True) - shared.gradio['softprompts_menu'].change(load_soft_prompt, [shared.gradio['softprompts_menu']], [shared.gradio['softprompts_menu']], show_progress=True) - shared.gradio['upload_softprompt'].upload(upload_soft_prompt, [shared.gradio['upload_softprompt']], [shared.gradio['softprompts_menu']]) + shared.gradio['model_menu'].change(load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_menu'], show_progress=True) + shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio[k] for k in ['preset_menu', 'generate_state']], [shared.gradio[k] for k in ['generate_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]) + shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True) + shared.gradio['softprompts_menu'].change(load_soft_prompt, shared.gradio['softprompts_menu'], shared.gradio['softprompts_menu'], show_progress=True) + shared.gradio['upload_softprompt'].upload(upload_soft_prompt, shared.gradio['upload_softprompt'], shared.gradio['softprompts_menu']) def set_interface_arguments(interface_mode, extensions, bool_active): modes = ["default", "notebook", "chat", "cai_chat"] cmd_list = vars(shared.args) bool_list = [k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes] - #int_list = [k for k in cmd_list if type(k) is int] shared.args.extensions = extensions for k in modes[1:]: @@ -372,11 +374,11 @@ def create_interface(): shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size']) with gr.Column(): shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)') - shared.gradio['check'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character?') + shared.gradio['stop_at_newline'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character?') create_settings_menus(default_preset) - shared.input_params = [shared.gradio[k] for k in ['Chat input', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'seed', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider', 'chat_generation_attempts', 'Chat mode', 'end_of_turn']] + shared.input_params = [shared.gradio[k] for k in ['Chat input', 'generate_state', 'name1', 'name2', 'context', 'Chat mode', 'end_of_turn']] def set_chat_input(textbox): return textbox, "" @@ -456,9 +458,9 @@ def create_interface(): with gr.Tab("Parameters", elem_id="parameters"): create_settings_menus(default_preset) - shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'seed']] + shared.input_params = [shared.gradio[k] for k in ['textbox', 'generate_state']] output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']] - gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen')) + gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)) shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None) shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}") @@ -489,9 +491,9 @@ def create_interface(): with gr.Tab("Parameters", elem_id="parameters"): create_settings_menus(default_preset) - shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'seed']] + shared.input_params = [shared.gradio[k] for k in ['textbox', 'generate_state']] output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']] - gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen')) + gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)) gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream)) shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None) @@ -524,6 +526,21 @@ def create_interface(): if shared.args.extensions is not None: extensions_module.create_extensions_block() + def change_dict_value(d, key, value): + d[key] = value + return d + + for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size_slider', 'chat_generation_attempts']: + if k not in shared.gradio: + continue + if type(shared.gradio[k]) in [gr.Checkbox, gr.Number]: + shared.gradio[k].change(lambda state, value, copy=k: change_dict_value(state, copy, value), inputs=[shared.gradio['generate_state'], shared.gradio[k]], outputs=shared.gradio['generate_state']) + else: + shared.gradio[k].release(lambda state, value, copy=k: change_dict_value(state, copy, value), inputs=[shared.gradio['generate_state'], shared.gradio[k]], outputs=shared.gradio['generate_state']) + + if not shared.is_chat(): + api.create_apis() + # Authentication auth = None if shared.args.gradio_auth_path is not None: From 641646a80178d2310f0d22261fee06be660b25cd Mon Sep 17 00:00:00 2001 From: Randell Miller Date: Wed, 5 Apr 2023 23:24:22 -0500 Subject: [PATCH 27/45] Fix crash if missing instructions directory (#812) --- server.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/server.py b/server.py index f00e412..577d5ac 100644 --- a/server.py +++ b/server.py @@ -55,7 +55,10 @@ def get_available_characters(): return ['None'] + sorted(set((k.stem for k in paths if k.stem != "instruction-following")), key=str.lower) def get_available_instruction_templates(): - paths = (x for x in Path('characters/instruction-following').iterdir() if x.suffix in ('.json', '.yaml', '.yml')) + path = "characters/instruction-following" + paths = [] + if os.path.exists(path): + paths = (x for x in Path(path).iterdir() if x.suffix in ('.json', '.yaml', '.yml')) return ['None'] + sorted(set((k.stem for k in paths)), key=str.lower) def get_available_extensions(): From e94ab5dac1f7839b3f5ca0d1b407cb07adbeecfe Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 6 Apr 2023 01:43:10 -0300 Subject: [PATCH 28/45] Minor fixes --- modules/chat.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/modules/chat.py b/modules/chat.py index f4ddf42..749ff8c 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -24,7 +24,6 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False - user_input = fix_newlines(user_input) rows = [f"{context.strip()}\n"] # Finding the maximum prompt size @@ -51,8 +50,8 @@ def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat rows.append(f"{prefix1.strip() if not is_instruct else prefix1}") limit = 2 else: - # Adding the user message + user_input = fix_newlines(user_input) if len(user_input) > 0: rows.append(f"{prefix1}{user_input}{end_of_turn}\n") @@ -92,12 +91,14 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline): if reply[-j:] == string[:j]: reply = reply[:-j] break + else: + continue + break reply = fix_newlines(reply) return reply, next_character_found def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False): - just_started = True eos_token = '\n' if generate_state['stop_at_newline'] else None name1_original = name1 if 'pygmalion' in shared.model_name.lower(): @@ -129,6 +130,7 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu # Generate cumulative_reply = '' + just_started = True for i in range(generate_state['chat_generation_attempts']): reply = None for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]): @@ -162,7 +164,6 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn): eos_token = '\n' if generate_state['stop_at_newline'] else None - if 'pygmalion' in shared.model_name.lower(): name1 = "You" @@ -187,7 +188,7 @@ def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_o yield reply def cai_chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn): - for history in chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False): + for history in chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn): yield chat_html_wrapper(history, name1, name2, mode) def regenerate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn): From 4a400320ddb863ffc0cf6bf35f3a1282294d82a0 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 6 Apr 2023 01:47:00 -0300 Subject: [PATCH 29/45] Clean up --- server.py | 1 - 1 file changed, 1 deletion(-) diff --git a/server.py b/server.py index 577d5ac..2c45f46 100644 --- a/server.py +++ b/server.py @@ -515,7 +515,6 @@ def create_interface(): cmd_list = vars(shared.args) bool_list = [k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes] bool_active = [k for k in bool_list if vars(shared.args)[k]] - #int_list = [k for k in cmd_list if type(k) is int] gr.Markdown("*Experimental*") shared.gradio['interface_modes_menu'] = gr.Dropdown(choices=modes, value=current_mode, label="Mode") From 5b301d9a022d3b13eedfa074d5246d78c3e914ed Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 6 Apr 2023 01:54:05 -0300 Subject: [PATCH 30/45] Create a Model tab --- css/main.css | 2 +- server.py | 37 ++++++++++++++++++++----------------- 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/css/main.css b/css/main.css index dfb54d7..2d8f01e 100644 --- a/css/main.css +++ b/css/main.css @@ -41,7 +41,7 @@ ol li p, ul li p { display: inline-block; } -#main, #parameters, #chat-settings, #interface-mode, #lora, #training-tab { +#main, #parameters, #chat-settings, #interface-mode, #lora, #training-tab, #model-tab { border: 0; } diff --git a/server.py b/server.py index 2c45f46..4ba5ba8 100644 --- a/server.py +++ b/server.py @@ -130,17 +130,6 @@ def upload_soft_prompt(file): return name -def create_model_and_preset_menus(): - with gr.Row(): - with gr.Column(): - with gr.Row(): - shared.gradio['model_menu'] = gr.Dropdown(choices=available_models, value=shared.model_name, label='Model') - ui.create_refresh_button(shared.gradio['model_menu'], lambda : None, lambda : {'choices': get_available_models()}, 'refresh-button') - with gr.Column(): - with gr.Row(): - shared.gradio['preset_menu'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset') - ui.create_refresh_button(shared.gradio['preset_menu'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button') - def save_prompt(text): fname = f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}.txt" with open(Path(f'prompts/{fname}'), 'w', encoding='utf-8') as f: @@ -172,6 +161,20 @@ def create_prompt_menus(): shared.gradio['prompt_menu'].change(load_prompt, [shared.gradio['prompt_menu']], [shared.gradio['textbox']], show_progress=False) shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False) +def create_model_menus(): + with gr.Row(): + with gr.Column(): + with gr.Row(): + shared.gradio['model_menu'] = gr.Dropdown(choices=available_models, value=shared.model_name, label='Model') + ui.create_refresh_button(shared.gradio['model_menu'], lambda : None, lambda : {'choices': get_available_models()}, 'refresh-button') + with gr.Column(): + with gr.Row(): + shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA') + ui.create_refresh_button(shared.gradio['lora_menu'], lambda : None, lambda : {'choices': get_available_loras()}, 'refresh-button') + + shared.gradio['model_menu'].change(load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_menu'], show_progress=True) + shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True) + def create_settings_menus(default_preset): generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', {}, return_dict=True) for k in ['max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size', 'chat_generation_attempts']: @@ -180,7 +183,9 @@ def create_settings_menus(default_preset): with gr.Row(): with gr.Column(): - create_model_and_preset_menus() + with gr.Row(): + shared.gradio['preset_menu'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset') + ui.create_refresh_button(shared.gradio['preset_menu'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button') with gr.Column(): shared.gradio['seed'] = gr.Number(value=shared.settings['seed'], label='Seed (-1 for random)') @@ -214,9 +219,6 @@ def create_settings_menus(default_preset): shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty') shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping') - with gr.Row(): - shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA') - ui.create_refresh_button(shared.gradio['lora_menu'], lambda : None, lambda : {'choices': get_available_loras()}, 'refresh-button') with gr.Accordion('Soft prompt', open=False): with gr.Row(): @@ -227,9 +229,7 @@ def create_settings_menus(default_preset): with gr.Row(): shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip']) - shared.gradio['model_menu'].change(load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_menu'], show_progress=True) shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio[k] for k in ['preset_menu', 'generate_state']], [shared.gradio[k] for k in ['generate_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]) - shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True) shared.gradio['softprompts_menu'].change(load_soft_prompt, shared.gradio['softprompts_menu'], shared.gradio['softprompts_menu'], show_progress=True) shared.gradio['upload_softprompt'].upload(upload_soft_prompt, shared.gradio['upload_softprompt'], shared.gradio['softprompts_menu']) @@ -502,6 +502,9 @@ def create_interface(): shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None) shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}") + with gr.Tab("Model", elem_id="model-tab"): + create_model_menus() + with gr.Tab("Training", elem_id="training-tab"): training.create_train_interface() From 0c7ef26981ce8244262b53af15646aca59eef172 Mon Sep 17 00:00:00 2001 From: "Alex \"mcmonkey\" Goodwin" <4000772+mcmonkey4eva@users.noreply.github.com> Date: Wed, 5 Apr 2023 22:04:11 -0700 Subject: [PATCH 31/45] Lora trainer improvements (#763) --- modules/training.py | 102 +++++++++++++++++++++++++++++++++----------- 1 file changed, 78 insertions(+), 24 deletions(-) diff --git a/modules/training.py b/modules/training.py index 5ba8d35..220428b 100644 --- a/modules/training.py +++ b/modules/training.py @@ -20,7 +20,7 @@ MAX_STEPS = 0 CURRENT_GRADIENT_ACCUM = 1 def get_dataset(path: str, ext: str): - return ['None'] + sorted(set((k.stem for k in Path(path).glob(f'*.{ext}'))), key=str.lower) + return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=str.lower) def create_train_interface(): with gr.Tab('Train LoRA', elem_id='lora-train-tab'): @@ -45,22 +45,26 @@ def create_train_interface(): with gr.Row(): dataset = gr.Dropdown(choices=get_dataset('training/datasets', 'json'), value='None', label='Dataset', info='The dataset file to use for training.') ui.create_refresh_button(dataset, lambda : None, lambda : {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button') - eval_dataset = gr.Dropdown(choices=get_dataset('training/datasets', 'json'), value='None', label='Evaluation Dataset', info='The dataset file used to evaluate the model after training.') + eval_dataset = gr.Dropdown(choices=get_dataset('training/datasets', 'json'), value='None', label='Evaluation Dataset', info='The (optional) dataset file used to evaluate the model after training.') ui.create_refresh_button(eval_dataset, lambda : None, lambda : {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button') format = gr.Dropdown(choices=get_dataset('training/formats', 'json'), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.') ui.create_refresh_button(format, lambda : None, lambda : {'choices': get_dataset('training/formats', 'json')}, 'refresh-button') + with gr.Tab(label="Raw Text File"): with gr.Row(): raw_text_file = gr.Dropdown(choices=get_dataset('training/datasets', 'txt'), value='None', label='Text File', info='The raw text file to use for training.') ui.create_refresh_button(raw_text_file, lambda : None, lambda : {'choices': get_dataset('training/datasets', 'txt')}, 'refresh-button') - overlap_len = gr.Slider(label='Overlap Length', minimum=0,maximum=512, value=128, step=16, info='Overlap length - ie how many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length above). Setting overlap to exactly half the cutoff length may be ideal.') + with gr.Row(): + overlap_len = gr.Slider(label='Overlap Length', minimum=0, maximum=512, value=128, step=16, info='Overlap length - ie how many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length below). Setting overlap to exactly half the cutoff length may be ideal.') + newline_favor_len = gr.Slider(label='Prefer Newline Cut Length', minimum=0, maximum=512, value=128, step=16, info='Length (in characters, not tokens) of the maximum distance to shift an overlap cut by to ensure chunks cut at newlines. If too low, cuts may occur in the middle of lines.') with gr.Row(): start_button = gr.Button("Start LoRA Training") stop_button = gr.Button("Interrupt") output = gr.Markdown(value="Ready") - start_button.click(do_train, [lora_name, micro_batch_size, batch_size, epochs, learning_rate, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, raw_text_file, overlap_len], [output]) + start_button.click(do_train, [lora_name, micro_batch_size, batch_size, epochs, learning_rate, lora_rank, lora_alpha, lora_dropout, + cutoff_len, dataset, eval_dataset, format, raw_text_file, overlap_len, newline_favor_len], [output]) stop_button.click(do_interrupt, [], [], cancels=[], queue=False) def do_interrupt(): @@ -91,8 +95,8 @@ def clean_path(base_path: str, path: str): return path return f'{Path(base_path).absolute()}/{path}' -def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lora_rank: int, - lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, raw_text_file: str, overlap_len: int): +def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lora_rank: int, lora_alpha: int, lora_dropout: float, + cutoff_len: int, dataset: str, eval_dataset: str, format: str, raw_text_file: str, overlap_len: int, newline_favor_len: int): global WANT_INTERRUPT, CURRENT_STEPS, MAX_STEPS, CURRENT_GRADIENT_ACCUM WANT_INTERRUPT = False CURRENT_STEPS = 0 @@ -103,6 +107,25 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int lora_name = f"{shared.args.lora_dir}/{clean_path(None, lora_name)}" actual_lr = float(learning_rate) + model_type = type(shared.model).__name__ + if model_type != "LlamaForCausalLM": + if model_type == "PeftModelForCausalLM": + yield "You are trying to train a LoRA while you already have another LoRA loaded. This will work, but may have unexpected effects. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*" + print("Warning: Training LoRA over top of another LoRA. May have unexpected effects.") + else: + yield "LoRA training has only currently been validated for LLaMA models. Unexpected errors may follow. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*" + print(f"Warning: LoRA training has only currently been validated for LLaMA models. (Found model type: {model_type})") + time.sleep(5) + + if shared.args.wbits > 0 or shared.args.gptq_bits > 0: + yield "LoRA training does not yet support 4bit. Please use `--load-in-8bit` for now." + return + + elif not shared.args.load_in_8bit: + yield "It is highly recommended you use `--load-in-8bit` for LoRA training. *(Will continue anyway in 2 seconds, press `Interrupt` to stop.)*" + print("Warning: It is highly recommended you use `--load-in-8bit` for LoRA training.") + time.sleep(2) # Give it a moment for the message to show in UI before continuing + if cutoff_len <= 0 or micro_batch_size <= 0 or batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0: yield "Cannot input zeroes." return @@ -126,15 +149,20 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int raw_text = file.read() tokens = shared.tokenizer.encode(raw_text) del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM + tokens = list(split_chunks(tokens, cutoff_len - overlap_len)) for i in range(1, len(tokens)): tokens[i] = tokens[i - 1][-overlap_len:] + tokens[i] text_chunks = [shared.tokenizer.decode(x) for x in tokens] del tokens - data = Dataset.from_list([tokenize(x) for x in text_chunks]) - train_data = data.shuffle() - eval_data = None + + if newline_favor_len > 0: + text_chunks = [cut_chunk_for_newline(x, newline_favor_len) for x in text_chunks] + + train_data = Dataset.from_list([tokenize(x) for x in text_chunks]) del text_chunks + train_data = train_data.shuffle() + eval_data = None else: if dataset in ['None', '']: @@ -232,33 +260,37 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int # TODO: save/load checkpoints to resume from? print("Starting training...") yield "Starting..." + if WANT_INTERRUPT: + yield "Interrupted before start." + return - def threadedRun(): + def threaded_run(): trainer.train() - thread = threading.Thread(target=threadedRun) + thread = threading.Thread(target=threaded_run) thread.start() - lastStep = 0 - startTime = time.perf_counter() + last_step = 0 + start_time = time.perf_counter() while thread.is_alive(): time.sleep(0.5) if WANT_INTERRUPT: yield "Interrupting, please wait... *(Run will stop after the current training step completes.)*" - elif CURRENT_STEPS != lastStep: - lastStep = CURRENT_STEPS - timeElapsed = time.perf_counter() - startTime - if timeElapsed <= 0: - timerInfo = "" - totalTimeEstimate = 999 + + elif CURRENT_STEPS != last_step: + last_step = CURRENT_STEPS + time_elapsed = time.perf_counter() - start_time + if time_elapsed <= 0: + timer_info = "" + total_time_estimate = 999 else: - its = CURRENT_STEPS / timeElapsed + its = CURRENT_STEPS / time_elapsed if its > 1: - timerInfo = f"`{its:.2f}` it/s" + timer_info = f"`{its:.2f}` it/s" else: - timerInfo = f"`{1.0/its:.2f}` s/it" - totalTimeEstimate = (1.0/its) * (MAX_STEPS) - yield f"Running... **{CURRENT_STEPS}** / **{MAX_STEPS}** ... {timerInfo}, `{timeElapsed:.0f}`/`{totalTimeEstimate:.0f}` seconds" + timer_info = f"`{1.0/its:.2f}` s/it" + total_time_estimate = (1.0/its) * (MAX_STEPS) + yield f"Running... **{CURRENT_STEPS}** / **{MAX_STEPS}** ... {timer_info}, {format_time(time_elapsed)} / {format_time(total_time_estimate)} ... {format_time(total_time_estimate - time_elapsed)} remaining" print("Training complete, saving...") lora_model.save_pretrained(lora_name) @@ -273,3 +305,25 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int def split_chunks(arr, step): for i in range(0, len(arr), step): yield arr[i:i + step] + +def cut_chunk_for_newline(chunk: str, max_length: int): + if '\n' not in chunk: + return chunk + first_newline = chunk.index('\n') + if first_newline < max_length: + chunk = chunk[first_newline + 1:] + if '\n' not in chunk: + return chunk + last_newline = chunk.rindex('\n') + if len(chunk) - last_newline < max_length: + chunk = chunk[:last_newline] + return chunk + +def format_time(seconds: float): + if seconds < 120: + return f"`{seconds:.0f}` seconds" + minutes = seconds / 60 + if minutes < 120: + return f"`{minutes:.0f}` minutes" + hours = minutes / 60 + return f"`{hours:.0f}` hours" From 158ec51ae351e2d65c841dad6c459ab117958a1a Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 6 Apr 2023 02:20:52 -0300 Subject: [PATCH 32/45] Increase instruct mode padding --- css/html_instruct_style.css | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/css/html_instruct_style.css b/css/html_instruct_style.css index f50b64d..d45a3e4 100644 --- a/css/html_instruct_style.css +++ b/css/html_instruct_style.css @@ -43,11 +43,11 @@ } .assistant-message { - padding: 10px; + padding: 15px; } .user-message { - padding: 10px; + padding: 15px; background-color: #f1f1f1; } From 4a28f39823cecc3124cd4cd9329c35d6571b158f Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 6 Apr 2023 02:47:27 -0300 Subject: [PATCH 33/45] Update README.md --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 1e51563..a46805e 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,7 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github. * Dropdown menu for switching between models * Notebook mode that resembles OpenAI's playground * Chat mode for conversation and role playing +* Instruct mode compatible with Alpaca and Open Assistant formats **\*NEW!\*** * Nice HTML output for GPT-4chan * Markdown output for [GALACTICA](https://github.com/paperswithcode/galai), including LaTeX rendering * [Custom chat characters](https://github.com/oobabooga/text-generation-webui/wiki/Custom-chat-characters) From 8cd899515e94dc273f7c9ef5e5706bffc1cebf22 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 6 Apr 2023 12:00:20 -0300 Subject: [PATCH 34/45] Change instruct html a bit --- css/html_instruct_style.css | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/css/html_instruct_style.css b/css/html_instruct_style.css index d45a3e4..13a97f8 100644 --- a/css/html_instruct_style.css +++ b/css/html_instruct_style.css @@ -42,15 +42,19 @@ color: rgb(110, 110, 110) !important; } -.assistant-message { +.gradio-container .chat .assistant-message { padding: 15px; + border-radius: 20px; + background-color: #0000000f; + margin-bottom: 17.5px; } -.user-message { +.gradio-container .chat .user-message { padding: 15px; - background-color: #f1f1f1; + border-radius: 20px; + margin-bottom: 17.5px !important; } -.dark .user-message { - background-color: #ffffff1a; +.dark .chat .assistant-message { + background-color: #ffffff21; } \ No newline at end of file From 39f3fec913e8ca4ad9e11bcb94ceee76f4e07d11 Mon Sep 17 00:00:00 2001 From: EyeDeck Date: Thu, 6 Apr 2023 11:16:48 -0400 Subject: [PATCH 35/45] Broaden GPTQ-for-LLaMA branch support (#820) --- modules/GPTQ_loader.py | 30 +++++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/modules/GPTQ_loader.py b/modules/GPTQ_loader.py index abfa33a..572947a 100644 --- a/modules/GPTQ_loader.py +++ b/modules/GPTQ_loader.py @@ -1,3 +1,4 @@ +import inspect import re import sys from pathlib import Path @@ -19,9 +20,9 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc config = AutoConfig.from_pretrained(model) def noop(*args, **kwargs): pass - torch.nn.init.kaiming_uniform_ = noop - torch.nn.init.uniform_ = noop - torch.nn.init.normal_ = noop + torch.nn.init.kaiming_uniform_ = noop + torch.nn.init.uniform_ = noop + torch.nn.init.normal_ = noop torch.set_default_dtype(torch.half) transformers.modeling_utils._init_weights = False @@ -33,16 +34,31 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc for name in exclude_layers: if name in layers: del layers[name] - make_quant(model, layers, wbits, groupsize, faster=faster_kernel, kernel_switch_threshold=kernel_switch_threshold) + + gptq_args = inspect.getfullargspec(make_quant).args + + make_quant_kwargs = { + 'module': model, + 'names': layers, + 'bits': wbits, + } + if 'groupsize' in gptq_args: + make_quant_kwargs['groupsize'] = groupsize + if 'faster' in gptq_args: + make_quant_kwargs['faster'] = faster_kernel + if 'kernel_switch_threshold' in gptq_args: + make_quant_kwargs['kernel_switch_threshold'] = kernel_switch_threshold + + make_quant(**make_quant_kwargs) del layers - + print('Loading model ...') if checkpoint.endswith('.safetensors'): from safetensors.torch import load_file as safe_load - model.load_state_dict(safe_load(checkpoint)) + model.load_state_dict(safe_load(checkpoint), strict = False) else: - model.load_state_dict(torch.load(checkpoint)) + model.load_state_dict(torch.load(checkpoint), strict = False) model.seqlen = 2048 print('Done.') From 03cb44fc8ca24c6458dbd25c31acd1e8cbdfcde2 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 6 Apr 2023 13:12:14 -0300 Subject: [PATCH 36/45] Add new llama.cpp library (2048 context, temperature, etc now work) --- modules/llamacpp_model_alternative.py | 65 +++++++++++++++++++++++++++ modules/models.py | 2 +- requirements.txt | 1 + 3 files changed, 67 insertions(+), 1 deletion(-) create mode 100644 modules/llamacpp_model_alternative.py diff --git a/modules/llamacpp_model_alternative.py b/modules/llamacpp_model_alternative.py new file mode 100644 index 0000000..4057611 --- /dev/null +++ b/modules/llamacpp_model_alternative.py @@ -0,0 +1,65 @@ +''' +Based on +https://github.com/abetlen/llama-cpp-python + +Documentation: +https://abetlen.github.io/llama-cpp-python/ +''' + +import multiprocessing + +from llama_cpp import Llama + +from modules import shared +from modules.callbacks import Iteratorize + + +class LlamaCppModel: + def __init__(self): + self.initialized = False + + @classmethod + def from_pretrained(self, path): + result = self() + + params = { + 'model_path': str(path), + 'n_ctx': 2048, + 'seed': 0, + 'n_threads': shared.args.threads or None + } + self.model = Llama(**params) + + # This is ugly, but the model and the tokenizer are the same object in this library. + return result, result + + def encode(self, string): + if type(string) is str: + string = string.encode() + return self.model.tokenize(string) + + def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=1, callback=None): + if type(context) is str: + context = context.encode() + tokens = self.model.tokenize(context) + + output = b"" + count = 0 + for token in self.model.generate(tokens, top_k=top_k, top_p=top_p, temp=temperature, repeat_penalty=repetition_penalty): + text = self.model.detokenize([token]) + output += text + if callback: + callback(text.decode()) + + count += 1 + if count >= token_count or (token == self.model.token_eos()): + break + + return output.decode() + + def generate_with_streaming(self, **kwargs): + with Iteratorize(self.generate, kwargs, callback=None) as generator: + reply = '' + for token in generator: + reply += token + yield reply diff --git a/modules/models.py b/modules/models.py index a8f8469..642e1f3 100644 --- a/modules/models.py +++ b/modules/models.py @@ -103,7 +103,7 @@ def load_model(model_name): # llamacpp model elif shared.is_llamacpp: - from modules.llamacpp_model import LlamaCppModel + from modules.llamacpp_model_alternative import LlamaCppModel model_file = list(Path(f'{shared.args.model_dir}/{model_name}').glob('ggml*.bin'))[0] print(f"llama.cpp weights detected: {model_file}\n") diff --git a/requirements.txt b/requirements.txt index 6ed9e91..1dcc495 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ datasets flexgen==0.1.7 gradio==3.24.1 llamacpp==0.1.11 +llama-cpp-python==0.1.23 markdown numpy peft==0.2.0 From eec3665845435b2f8a0dd169ba79ef0640ab4782 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 6 Apr 2023 13:24:01 -0300 Subject: [PATCH 37/45] Add instructions for updating requirements --- README.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/README.md b/README.md index a46805e..a41190d 100644 --- a/README.md +++ b/README.md @@ -119,6 +119,15 @@ As an alternative to the recommended WSL method, you can install the web UI nati https://github.com/oobabooga/text-generation-webui/issues/174, https://github.com/oobabooga/text-generation-webui/issues/87 +### Updating the requirements + +From time to time, the `requirements.txt` changes. To update, use this command: + +``` +conda activate textgen +cd text-generation-webui +pip install -r requirements.txt --upgrade +``` ## Downloading models Models should be placed inside the `models` folder. From 59058576b5a94487f39ae3d082076ab57e0546a2 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 6 Apr 2023 13:28:21 -0300 Subject: [PATCH 38/45] Remove unused requirement --- requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 1dcc495..108fe0c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,6 @@ bitsandbytes==0.37.2 datasets flexgen==0.1.7 gradio==3.24.1 -llamacpp==0.1.11 llama-cpp-python==0.1.23 markdown numpy From d9e7aba714cefddc47bb9defd2cbe55c34e95882 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 6 Apr 2023 13:42:24 -0300 Subject: [PATCH 39/45] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index a41190d..c4dd01d 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,7 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github. * [LLaMA model, including 4-bit GPTQ](https://github.com/oobabooga/text-generation-webui/wiki/LLaMA-model) * [llama.cpp](https://github.com/oobabooga/text-generation-webui/wiki/llama.cpp-models) **\*NEW!\*** * [RWKV model](https://github.com/oobabooga/text-generation-webui/wiki/RWKV-model) -* [LoRa (loading and training)](https://github.com/oobabooga/text-generation-webui/wiki/Using-LoRAs) +* [LoRA (loading and training)](https://github.com/oobabooga/text-generation-webui/wiki/Using-LoRAs) * Softprompts * [Extensions](https://github.com/oobabooga/text-generation-webui/wiki/Extensions) * [Google Colab](https://github.com/oobabooga/text-generation-webui/wiki/Running-on-Colab) From 113f94b61ee0e85bd791992da024cb5fc6beac93 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 6 Apr 2023 16:04:03 -0300 Subject: [PATCH 40/45] Bump transformers (16-bit llama must be reconverted/redownloaded) --- modules/models.py | 4 +++- modules/text_generation.py | 4 ++++ requirements.txt | 2 +- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/modules/models.py b/modules/models.py index 642e1f3..1bf6fc3 100644 --- a/modules/models.py +++ b/modules/models.py @@ -10,7 +10,7 @@ import torch import transformers from accelerate import infer_auto_device_map, init_empty_weights from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer, - BitsAndBytesConfig) + BitsAndBytesConfig, LlamaTokenizer) import modules.shared as shared @@ -172,6 +172,8 @@ def load_model(model_name): # Loading the tokenizer if any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists(): tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/")) + elif type(model) is transformers.LlamaForCausalLM: + tokenizer = LlamaTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"), clean_up_tokenization_spaces=True) else: tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/")) tokenizer.truncation_side = 'left' diff --git a/modules/text_generation.py b/modules/text_generation.py index 93f0789..b8885ab 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -28,6 +28,10 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True): return input_ids else: input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens) + + if type(shared.tokenizer) is transformers.LlamaTokenizer and input_ids[0][0] == 29871: + input_ids = input_ids[:,1:] + if shared.args.cpu: return input_ids elif shared.args.flexgen: diff --git a/requirements.txt b/requirements.txt index 108fe0c..882dc30 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,4 +13,4 @@ safetensors==0.3.0 sentencepiece pyyaml tqdm -git+https://github.com/huggingface/transformers@9eae4aa57650c1dbe1becd4e0979f6ad1e572ac0 +git+https://github.com/huggingface/transformers From 20b8ca4482b21d98d6d3c99882d590d5d90ce339 Mon Sep 17 00:00:00 2001 From: DavG25 <31524206+DavG25@users.noreply.github.com> Date: Thu, 6 Apr 2023 21:15:04 +0200 Subject: [PATCH 41/45] Add CSS for lists (#833) --- css/html_cai_style.css | 9 +++++++++ css/html_instruct_style.css | 9 +++++++++ 2 files changed, 18 insertions(+) diff --git a/css/html_cai_style.css b/css/html_cai_style.css index 3190b3d..57c3b5c 100644 --- a/css/html_cai_style.css +++ b/css/html_cai_style.css @@ -64,6 +64,15 @@ line-height: 1.428571429 !important; } +.message-body li { + margin-top: 0.5em !important; + margin-bottom: 0.5em !important; +} + +.message-body li > p { + display: inline !important; +} + .dark .message-body p em { color: rgb(138, 138, 138) !important; } diff --git a/css/html_instruct_style.css b/css/html_instruct_style.css index 13a97f8..8a2a000 100644 --- a/css/html_instruct_style.css +++ b/css/html_instruct_style.css @@ -34,6 +34,15 @@ line-height: 1.428571429 !important; } +.message-body li { + margin-top: 0.5em !important; + margin-bottom: 0.5em !important; +} + +.message-body li > p { + display: inline !important; +} + .dark .message-body p em { color: rgb(138, 138, 138) !important; } From 310bf46a945aacc507454509b82f6807c48cc093 Mon Sep 17 00:00:00 2001 From: OWKenobi Date: Thu, 6 Apr 2023 22:40:44 +0200 Subject: [PATCH 42/45] Instruction Character Vicuna, Instruction Mode Bugfix (#838) --- characters/instruction-following/Vicuna.yaml | 3 +++ modules/chat.py | 14 ++++++++++++-- 2 files changed, 15 insertions(+), 2 deletions(-) create mode 100644 characters/instruction-following/Vicuna.yaml diff --git a/characters/instruction-following/Vicuna.yaml b/characters/instruction-following/Vicuna.yaml new file mode 100644 index 0000000..026901d --- /dev/null +++ b/characters/instruction-following/Vicuna.yaml @@ -0,0 +1,3 @@ +name: "### Assistant:" +your_name: "### Human:" +context: "Below is an instruction that describes a task. Write a response that appropriately completes the request." diff --git a/modules/chat.py b/modules/chat.py index 749ff8c..3693264 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -99,6 +99,11 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline): return reply, next_character_found def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False): + if mode == 'instruct': + stopping_strings = [f"\n{name1}", f"\n{name2}"] + else: + stopping_strings = [f"\n{name1}:", f"\n{name2}:"] + eos_token = '\n' if generate_state['stop_at_newline'] else None name1_original = name1 if 'pygmalion' in shared.model_name.lower(): @@ -133,7 +138,7 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu just_started = True for i in range(generate_state['chat_generation_attempts']): reply = None - for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]): + for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings): reply = cumulative_reply + reply # Extracting the reply @@ -163,6 +168,11 @@ def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_tu yield shared.history['visible'] def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn): + if mode == 'instruct': + stopping_strings = [f"\n{name1}", f"\n{name2}"] + else: + stopping_strings = [f"\n{name1}:", f"\n{name2}:"] + eos_token = '\n' if generate_state['stop_at_newline'] else None if 'pygmalion' in shared.model_name.lower(): name1 = "You" @@ -175,7 +185,7 @@ def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_o cumulative_reply = '' for i in range(generate_state['chat_generation_attempts']): reply = None - for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]): + for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings): reply = cumulative_reply + reply reply, next_character_found = extract_message_from_reply(reply, name1, name2, generate_state['stop_at_newline']) yield reply From 21be80242e7ed6b9691ecaeea413f6723a485afc Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 6 Apr 2023 17:52:27 -0300 Subject: [PATCH 43/45] Bump rwkv from 0.7.2 to 0.7.3 (#842) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 882dc30..be51203 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,7 +8,7 @@ markdown numpy peft==0.2.0 requests -rwkv==0.7.2 +rwkv==0.7.3 safetensors==0.3.0 sentencepiece pyyaml From 58ed87e5d9dd864e38d346274e6455a4a041e203 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 6 Apr 2023 18:42:54 -0300 Subject: [PATCH 44/45] Update requirements.txt --- requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index be51203..aa1a38d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,6 @@ bitsandbytes==0.37.2 datasets flexgen==0.1.7 gradio==3.24.1 -llama-cpp-python==0.1.23 markdown numpy peft==0.2.0 From 64bcde56ab3d09f98eaaff5e98b053248c65438b Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 6 Apr 2023 20:14:29 -0300 Subject: [PATCH 45/45] Minor css change --- css/html_instruct_style.css | 4 ---- 1 file changed, 4 deletions(-) diff --git a/css/html_instruct_style.css b/css/html_instruct_style.css index 8a2a000..533c547 100644 --- a/css/html_instruct_style.css +++ b/css/html_instruct_style.css @@ -18,10 +18,6 @@ line-height: 1.428571429; } -.text p { - margin-top: 5px; -} - .username { display: none; }