Переглянути джерело

Merge branch 'main' into da3dsoul-main

oobabooga 2 роки тому
батько
коміт
7e31bc485c

+ 19 - 8
README.md

@@ -15,6 +15,7 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.
 * Dropdown menu for switching between models
 * Notebook mode that resembles OpenAI's playground
 * Chat mode for conversation and role playing
+* Instruct mode compatible with Alpaca and Open Assistant formats **\*NEW!\***
 * Nice HTML output for GPT-4chan
 * Markdown output for [GALACTICA](https://github.com/paperswithcode/galai), including LaTeX rendering
 * [Custom chat characters](https://github.com/oobabooga/text-generation-webui/wiki/Custom-chat-characters)
@@ -26,11 +27,11 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.
 * CPU mode
 * [FlexGen](https://github.com/oobabooga/text-generation-webui/wiki/FlexGen)
 * [DeepSpeed ZeRO-3](https://github.com/oobabooga/text-generation-webui/wiki/DeepSpeed)
-* API [with](https://github.com/oobabooga/text-generation-webui/blob/main/api-example-streaming.py) streaming and [without](https://github.com/oobabooga/text-generation-webui/blob/main/api-example.py) streaming
+* API [with](https://github.com/oobabooga/text-generation-webui/blob/main/api-example-stream.py) streaming and [without](https://github.com/oobabooga/text-generation-webui/blob/main/api-example.py) streaming
 * [LLaMA model, including 4-bit GPTQ](https://github.com/oobabooga/text-generation-webui/wiki/LLaMA-model)
 * [llama.cpp](https://github.com/oobabooga/text-generation-webui/wiki/llama.cpp-models) **\*NEW!\***
 * [RWKV model](https://github.com/oobabooga/text-generation-webui/wiki/RWKV-model)
-* [LoRa (loading and training)](https://github.com/oobabooga/text-generation-webui/wiki/Using-LoRAs)
+* [LoRA (loading and training)](https://github.com/oobabooga/text-generation-webui/wiki/Using-LoRAs)
 * Softprompts
 * [Extensions](https://github.com/oobabooga/text-generation-webui/wiki/Extensions)
 * [Google Colab](https://github.com/oobabooga/text-generation-webui/wiki/Running-on-Colab)
@@ -62,7 +63,7 @@ Recommended if you have some experience with the command-line.
 
 On Windows, I additionally recommend carrying out the installation on WSL instead of the base system: [WSL installation guide](https://github.com/oobabooga/text-generation-webui/wiki/WSL-installation-guide).
 
-0. Install Conda
+#### 0. Install Conda
 
 https://docs.conda.io/en/latest/miniconda.html
 
@@ -75,14 +76,14 @@ bash Miniconda3.sh
 
 Source: https://educe-ubc.github.io/conda.html
 
-1. Create a new conda environment
+#### 1. Create a new conda environment
 
 ```
 conda create -n textgen python=3.10.9
 conda activate textgen
 ```
 
-2. Install Pytorch
+#### 2. Install Pytorch
 
 | System | GPU | Command |
 |--------|---------|---------|
@@ -92,10 +93,12 @@ conda activate textgen
 
 The up to date commands can be found here: https://pytorch.org/get-started/locally/. 
 
-MacOS users, refer to the comments here: https://github.com/oobabooga/text-generation-webui/pull/393
+#### 2.1 Special instructions
 
+* MacOS users: https://github.com/oobabooga/text-generation-webui/pull/393
+* AMD users: https://rentry.org/eq3hg
 
-3. Install the web UI
+#### 3. Install the web UI
 
 ```
 git clone https://github.com/oobabooga/text-generation-webui
@@ -116,6 +119,15 @@ As an alternative to the recommended WSL method, you can install the web UI nati
 
 https://github.com/oobabooga/text-generation-webui/issues/174, https://github.com/oobabooga/text-generation-webui/issues/87
 
+### Updating the requirements
+
+From time to time, the `requirements.txt` changes. To update, use this command:
+
+```
+conda activate textgen
+cd text-generation-webui
+pip install -r requirements.txt --upgrade
+```
 ## Downloading models
 
 Models should be placed inside the `models` folder.
@@ -175,7 +187,6 @@ Optionally, you can use the following command-line flags:
 | `-h`, `--help`   | show this help message and exit |
 | `--notebook`     | Launch the web UI in notebook mode, where the output is written to the same text box as the input. |
 | `--chat`         | Launch the web UI in chat mode.|
-| `--cai-chat`     | Launch the web UI in chat mode with a style similar to Character.AI's. If the file `img_bot.png` or `img_bot.jpg` exists in the same folder as server.py, this image will be used as the bot's profile picture. Similarly, `img_me.png` or `img_me.jpg` will be used as your profile picture. |
 | `--model MODEL`  | Name of the model to load by default. |
 | `--lora LORA`    | Name of the LoRA to apply to the model by default. |
 |  `--model-dir MODEL_DIR`                    | Path to directory with all the models |

+ 2 - 16
api-example-stream.py

@@ -36,6 +36,7 @@ async def run(context):
         'early_stopping': False,
         'seed': -1,
     }
+    payload = json.dumps([context, params])
     session = random_hash()
 
     async with websockets.connect(f"ws://{server}:7860/queue/join") as websocket:
@@ -54,22 +55,7 @@ async def run(context):
                         "session_hash": session,
                         "fn_index": 12,
                         "data": [
-                            context,
-                            params['max_new_tokens'],
-                            params['do_sample'],
-                            params['temperature'],
-                            params['top_p'],
-                            params['typical_p'],
-                            params['repetition_penalty'],
-                            params['encoder_repetition_penalty'],
-                            params['top_k'],
-                            params['min_length'],
-                            params['no_repeat_ngram_size'],
-                            params['num_beams'],
-                            params['penalty_alpha'],
-                            params['length_penalty'],
-                            params['early_stopping'],
-                            params['seed'],
+                            payload
                         ]
                     }))
                 case "process_starts":

+ 5 - 16
api-example.py

@@ -10,6 +10,8 @@ Optionally, you can also add the --share flag to generate a public gradio URL,
 allowing you to use the API remotely.
 
 '''
+import json
+
 import requests
 
 # Server address
@@ -38,24 +40,11 @@ params = {
 # Input prompt
 prompt = "What I would like to say is the following: "
 
+payload = json.dumps([prompt, params])
+
 response = requests.post(f"http://{server}:7860/run/textgen", json={
     "data": [
-        prompt,
-        params['max_new_tokens'],
-        params['do_sample'],
-        params['temperature'],
-        params['top_p'],
-        params['typical_p'],
-        params['repetition_penalty'],
-        params['encoder_repetition_penalty'],
-        params['top_k'],
-        params['min_length'],
-        params['no_repeat_ngram_size'],
-        params['num_beams'],
-        params['penalty_alpha'],
-        params['length_penalty'],
-        params['early_stopping'],
-        params['seed'],
+        payload
     ]
 }).json()
 

+ 14 - 30
characters/Example.yaml

@@ -1,32 +1,16 @@
-name: Chiharu Yamada
-context: 'Chiharu Yamada''s Persona: Chiharu Yamada is a young, computer engineer-nerd
-  with a knack for problem solving and a passion for technology.'
-greeting: '*Chiharu strides into the room with a smile, her eyes lighting up
-  when she sees you. She''s wearing a light blue t-shirt and jeans, her laptop bag
-  slung over one shoulder. She takes a seat next to you, her enthusiasm palpable in
-  the air*
-
-  Hey! I''m so excited to finally meet you. I''ve heard so many great things about
-  you and I''m eager to pick your brain about computers. I''m sure you have a wealth
-  of knowledge that I can learn from. *She grins, eyes twinkling with excitement*
-  Let''s get started!'
-example_dialogue: '{{user}}: So how did you get into computer engineering?
-
-  {{char}}: I''ve always loved tinkering with technology since I was a kid.
-
-  {{user}}: That''s really impressive!
-
+name: "Chiharu Yamada"
+context: "Chiharu Yamada's Persona: Chiharu Yamada is a young, computer engineer-nerd with a knack for problem solving and a passion for technology."
+greeting: |-
+  *Chiharu strides into the room with a smile, her eyes lighting up when she sees you. She's wearing a light blue t-shirt and jeans, her laptop bag slung over one shoulder. She takes a seat next to you, her enthusiasm palpable in the air*
+  Hey! I'm so excited to finally meet you. I've heard so many great things about you and I'm eager to pick your brain about computers. I'm sure you have a wealth of knowledge that I can learn from. *She grins, eyes twinkling with excitement* Let's get started!
+example_dialogue: |-
+  {{user}}: So how did you get into computer engineering?
+  {{char}}: I've always loved tinkering with technology since I was a kid.
+  {{user}}: That's really impressive!
   {{char}}: *She chuckles bashfully* Thanks!
-
-  {{user}}: So what do you do when you''re not working on computers?
-
-  {{char}}: I love exploring, going out with friends, watching movies, and playing
-  video games.
-
-  {{user}}: What''s your favorite type of computer hardware to work with?
-
-  {{char}}: Motherboards, they''re like puzzles and the backbone of any system.
-
+  {{user}}: So what do you do when you're not working on computers?
+  {{char}}: I love exploring, going out with friends, watching movies, and playing video games.
+  {{user}}: What's your favorite type of computer hardware to work with?
+  {{char}}: Motherboards, they're like puzzles and the backbone of any system.
   {{user}}: That sounds great!
-
-  {{char}}: Yeah, it''s really fun. I''m lucky to be able to do this as a job.'
+  {{char}}: Yeah, it's really fun. I'm lucky to be able to do this as a job.

+ 3 - 0
characters/instruction-following/Alpaca.yaml

@@ -0,0 +1,3 @@
+name: "### Response:"
+your_name: "### Instruction:"
+context: "Below is an instruction that describes a task. Write a response that appropriately completes the request."

+ 3 - 0
characters/instruction-following/Open Assistant.yaml

@@ -0,0 +1,3 @@
+name: "<|assistant|>"
+your_name: "<|prompter|>"
+end_of_turn: "<|endoftext|>"

+ 3 - 0
characters/instruction-following/Vicuna.yaml

@@ -0,0 +1,3 @@
+name: "### Assistant:"
+your_name: "### Human:"
+context: "Below is an instruction that describes a task. Write a response that appropriately completes the request."

+ 9 - 0
css/html_cai_style.css

@@ -64,6 +64,15 @@
     line-height: 1.428571429 !important;
 }
 
+.message-body li {
+    margin-top: 0.5em !important;
+    margin-bottom: 0.5em !important;
+}
+
+.message-body li > p {
+    display: inline !important;
+}
+
 .dark .message-body p em {
     color: rgb(138, 138, 138) !important;
 }

+ 65 - 0
css/html_instruct_style.css

@@ -0,0 +1,65 @@
+.chat {
+    margin-left: auto;
+    margin-right: auto;
+    max-width: 800px;
+    height: 66.67vh;
+    overflow-y: auto;
+    padding-right: 20px;
+    display: flex;
+    flex-direction: column-reverse;
+}
+
+.message {
+    display: grid;
+    grid-template-columns: 60px 1fr;
+    padding-bottom: 25px;
+    font-size: 15px;
+    font-family: Helvetica, Arial, sans-serif;
+    line-height: 1.428571429;
+}
+
+.username {
+    display: none;
+}
+
+.message-body {}
+
+.message-body p {
+    margin-bottom: 0 !important;
+    font-size: 15px !important;
+    line-height: 1.428571429 !important;
+}
+
+.message-body li {
+    margin-top: 0.5em !important;
+    margin-bottom: 0.5em !important;
+}
+
+.message-body li > p {
+    display: inline !important;
+}
+
+.dark .message-body p em {
+    color: rgb(138, 138, 138) !important;
+}
+
+.message-body p em {
+    color: rgb(110, 110, 110) !important;
+}
+
+.gradio-container .chat .assistant-message {
+  padding: 15px;
+  border-radius: 20px;
+  background-color: #0000000f;
+  margin-bottom: 17.5px;
+}
+
+.gradio-container .chat .user-message {
+  padding: 15px;
+  border-radius: 20px;
+  margin-bottom: 17.5px !important;
+}
+
+.dark .chat .assistant-message {
+  background-color: #ffffff21;
+}

+ 5 - 1
css/main.css

@@ -41,7 +41,7 @@ ol li p, ul li p {
     display: inline-block;
 }
 
-#main, #parameters, #chat-settings, #interface-mode, #lora, #training-tab {
+#main, #parameters, #chat-settings, #interface-mode, #lora, #training-tab, #model-tab {
   border: 0;
 }
 
@@ -63,3 +63,7 @@ span.math.inline {
   font-size: 27px;
   vertical-align: baseline !important;
 }
+
+div.svelte-15lo0d8 > *, div.svelte-15lo0d8 > .form > * {
+  flex-wrap: nowrap;
+}

+ 19 - 16
extensions/api/script.py

@@ -40,24 +40,27 @@ class Handler(BaseHTTPRequestHandler):
                 prompt_lines.pop(0)
 
             prompt = '\n'.join(prompt_lines)
+            generate_params =  {
+                'max_new_tokens': int(body.get('max_length', 200)), 
+                'do_sample': bool(body.get('do_sample', True)),
+                'temperature': float(body.get('temperature', 0.5)), 
+                'top_p': float(body.get('top_p', 1)), 
+                'typical_p': float(body.get('typical', 1)), 
+                'repetition_penalty': float(body.get('rep_pen', 1.1)), 
+                'encoder_repetition_penalty': 1,
+                'top_k': int(body.get('top_k', 0)), 
+                'min_length': int(body.get('min_length', 0)),
+                'no_repeat_ngram_size': int(body.get('no_repeat_ngram_size',0)),
+                'num_beams': int(body.get('num_beams',1)),
+                'penalty_alpha': float(body.get('penalty_alpha', 0)),
+                'length_penalty': float(body.get('length_penalty', 1)),
+                'early_stopping': bool(body.get('early_stopping', False)),
+                'seed': int(body.get('seed', -1)),
+            }
 
             generator = generate_reply(
-                question = prompt, 
-                max_new_tokens = int(body.get('max_length', 200)), 
-                do_sample=bool(body.get('do_sample', True)),
-                temperature=float(body.get('temperature', 0.5)), 
-                top_p=float(body.get('top_p', 1)), 
-                typical_p=float(body.get('typical', 1)), 
-                repetition_penalty=float(body.get('rep_pen', 1.1)), 
-                encoder_repetition_penalty=1, 
-                top_k=int(body.get('top_k', 0)), 
-                min_length=int(body.get('min_length', 0)),
-                no_repeat_ngram_size=int(body.get('no_repeat_ngram_size',0)),
-                num_beams=int(body.get('num_beams',1)),
-                penalty_alpha=float(body.get('penalty_alpha', 0)),
-                length_penalty=float(body.get('length_penalty', 1)),
-                early_stopping=bool(body.get('early_stopping', False)),
-                seed=int(body.get('seed', -1)),
+                prompt, 
+                generate_params,
                 stopping_strings=body.get('stopping_strings', []),
             )
 

+ 5 - 15
extensions/gallery/script.py

@@ -2,9 +2,8 @@ from pathlib import Path
 
 import gradio as gr
 
-from modules.chat import load_character
 from modules.html_generator import get_image_cache
-from modules.shared import gradio, settings
+from modules.shared import gradio
 
 
 def generate_css():
@@ -64,22 +63,13 @@ def generate_html():
     for file in sorted(Path("characters").glob("*")):
         if file.suffix in [".json", ".yml", ".yaml"]:
             character = file.stem
-            container_html = f'<div class="character-container">'
+            container_html = '<div class="character-container">'
             image_html = "<div class='placeholder'></div>"
 
-            for i in [
-                    f"characters/{character}.png",
-                    f"characters/{character}.jpg",
-                    f"characters/{character}.jpeg",
-                    ]:
-
-                path = Path(i)
+            for path in [Path(f"characters/{character}.{extension}") for extension in ['png', 'jpg', 'jpeg']]:
                 if path.exists():
-                    try:
-                        image_html = f'<img src="file/{get_image_cache(path)}">'
-                        break
-                    except:
-                        continue
+                    image_html = f'<img src="file/{get_image_cache(path)}">'
+                    break
 
             container_html += f'{image_html} <span class="character-name">{character}</span>'
             container_html += "</div>"

+ 1 - 1
extensions/sd_api_pictures/script.py

@@ -176,4 +176,4 @@ def ui():
 
     force_btn.click(force_pic)
     generate_now_btn.click(force_pic)
-    generate_now_btn.click(eval('chat.cai_chatbot_wrapper'), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
+    generate_now_btn.click(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)

+ 3 - 6
extensions/send_pictures/script.py

@@ -2,12 +2,11 @@ import base64
 from io import BytesIO
 
 import gradio as gr
-import modules.chat as chat
-import modules.shared as shared
 import torch
-from PIL import Image
 from transformers import BlipForConditionalGeneration, BlipProcessor
 
+from modules import chat, shared
+
 # If 'state' is True, will hijack the next chat generation with
 # custom input text given by 'value' in the format [text, visible_text]
 input_hijack = {
@@ -36,13 +35,11 @@ def generate_chat_picture(picture, name1, name2):
 def ui():
     picture_select = gr.Image(label='Send a picture', type='pil')
 
-    function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper'
-
     # Prepare the hijack with custom inputs
     picture_select.upload(lambda picture, name1, name2: input_hijack.update({"state": True, "value": generate_chat_picture(picture, name1, name2)}), [picture_select, shared.gradio['name1'], shared.gradio['name2']], None)
 
     # Call the generation function
-    picture_select.upload(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
+    picture_select.upload(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)
 
     # Clear the picture from the upload field
     picture_select.upload(lambda : None, [], [picture_select], show_progress=False)

+ 31 - 13
modules/GPTQ_loader.py

@@ -1,3 +1,4 @@
+import inspect
 import re
 import sys
 from pathlib import Path
@@ -19,9 +20,9 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
     config = AutoConfig.from_pretrained(model)
     def noop(*args, **kwargs):
         pass
-    torch.nn.init.kaiming_uniform_ = noop 
-    torch.nn.init.uniform_ = noop 
-    torch.nn.init.normal_ = noop 
+    torch.nn.init.kaiming_uniform_ = noop
+    torch.nn.init.uniform_ = noop
+    torch.nn.init.normal_ = noop
 
     torch.set_default_dtype(torch.half)
     transformers.modeling_utils._init_weights = False
@@ -33,16 +34,31 @@ def _load_quant(model, checkpoint, wbits, groupsize=-1, faster_kernel=False, exc
     for name in exclude_layers:
         if name in layers:
             del layers[name]
-    make_quant(model, layers, wbits, groupsize, faster=faster_kernel, kernel_switch_threshold=kernel_switch_threshold)
+    
+    gptq_args = inspect.getfullargspec(make_quant).args
+
+    make_quant_kwargs = {
+        'module': model, 
+        'names': layers,
+        'bits': wbits,
+    }
+    if 'groupsize' in gptq_args:
+        make_quant_kwargs['groupsize'] = groupsize
+    if 'faster' in gptq_args:
+        make_quant_kwargs['faster'] = faster_kernel
+    if 'kernel_switch_threshold' in gptq_args:
+        make_quant_kwargs['kernel_switch_threshold'] = kernel_switch_threshold
+    
+    make_quant(**make_quant_kwargs)
 
     del layers
-    
+
     print('Loading model ...')
     if checkpoint.endswith('.safetensors'):
         from safetensors.torch import load_file as safe_load
-        model.load_state_dict(safe_load(checkpoint))
+        model.load_state_dict(safe_load(checkpoint), strict = False)
     else:
-        model.load_state_dict(torch.load(checkpoint))
+        model.load_state_dict(torch.load(checkpoint), strict = False)
     model.seqlen = 2048
     print('Done.')
 
@@ -52,7 +68,7 @@ def load_quantized(model_name):
     if not shared.args.model_type:
         # Try to determine model type from model name
         name = model_name.lower()
-        if any((k in name for k in ['llama', 'alpaca'])):
+        if any((k in name for k in ['llama', 'alpaca', 'vicuna'])):
             model_type = 'llama'
         elif any((k in name for k in ['opt-', 'galactica'])):
             model_type = 'opt'
@@ -65,16 +81,18 @@ def load_quantized(model_name):
     else:
         model_type = shared.args.model_type.lower()
 
-    if model_type == 'llama' and shared.args.pre_layer:
+    if shared.args.pre_layer and model_type == 'llama':
         load_quant = llama_inference_offload.load_quant
     elif model_type in ('llama', 'opt', 'gptj'):
+        if shared.args.pre_layer:
+            print("Warning: ignoring --pre_layer because it only works for llama model type.")
         load_quant = _load_quant
     else:
         print("Unknown pre-quantized model type specified. Only 'llama', 'opt' and 'gptj' are supported")
         exit()
 
     # Now we are going to try to locate the quantized model file.
-    path_to_model = Path(f'models/{model_name}')
+    path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
     found_pts = list(path_to_model.glob("*.pt"))
     found_safetensors = list(path_to_model.glob("*.safetensors"))
     pt_path = None
@@ -95,8 +113,8 @@ def load_quantized(model_name):
         else:
             pt_model = f'{model_name}-{shared.args.wbits}bit'
 
-        # Try to find the .safetensors or .pt both in models/ and in the subfolder
-        for path in [Path(p+ext) for ext in ['.safetensors', '.pt'] for p in [f"models/{pt_model}", f"{path_to_model}/{pt_model}"]]:
+        # Try to find the .safetensors or .pt both in the model dir and in the subfolder
+        for path in [Path(p+ext) for ext in ['.safetensors', '.pt'] for p in [f"{shared.args.model_dir}/{pt_model}", f"{path_to_model}/{pt_model}"]]:
             if path.exists():
                 print(f"Found {path}")
                 pt_path = path
@@ -107,7 +125,7 @@ def load_quantized(model_name):
         exit()
 
     # qwopqwop200's offload
-    if shared.args.pre_layer:
+    if model_type == 'llama' and shared.args.pre_layer:
         model = load_quant(str(path_to_model), str(pt_path), shared.args.wbits, shared.args.groupsize, shared.args.pre_layer)
     else:
         threshold = False if model_type == 'gptj' else 128

+ 38 - 0
modules/api.py

@@ -0,0 +1,38 @@
+import json
+
+import gradio as gr
+
+from modules import shared
+from modules.text_generation import generate_reply
+
+
+def generate_reply_wrapper(string):
+    generate_params = {
+        'do_sample': True,
+        'temperature': 1,
+        'top_p': 1,
+        'typical_p': 1,
+        'repetition_penalty': 1,
+        'encoder_repetition_penalty': 1,
+        'top_k': 50,
+        'num_beams': 1,
+        'penalty_alpha': 0,
+        'min_length': 0,
+        'length_penalty': 1,
+        'no_repeat_ngram_size': 0,
+        'early_stopping': False,
+    }
+    params = json.loads(string)
+    for k in params[1]:
+        generate_params[k] = params[1][k]
+    for i in generate_reply(params[0], generate_params):
+        yield i
+
+def create_apis():
+    t1 = gr.Textbox(visible=False)
+    t2 = gr.Textbox(visible=False)
+    dummy = gr.Button(visible=False)
+
+    input_params = [t1]
+    output_params = [t2] + [shared.gradio[k] for k in ['markdown', 'html']]
+    dummy.click(generate_reply_wrapper, input_params, output_params, api_name='textgen')

+ 122 - 83
modules/chat.py

@@ -12,45 +12,55 @@ from PIL import Image
 import modules.extensions as extensions_module
 import modules.shared as shared
 from modules.extensions import apply_extensions
-from modules.html_generator import fix_newlines, generate_chat_html
+from modules.html_generator import (fix_newlines, chat_html_wrapper,
+                                    make_thumbnail)
 from modules.text_generation import (encode, generate_reply,
                                      get_max_prompt_length)
 
 
-def generate_chat_output(history, name1, name2, character):
-    if shared.args.cai_chat:
-        return generate_chat_html(history, name1, name2, character)
-    else:
-        return history
+def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, **kwargs):
+    is_instruct = kwargs['is_instruct'] if 'is_instruct' in kwargs else False
+    end_of_turn = kwargs['end_of_turn'] if 'end_of_turn' in kwargs else ''
+    impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
+    also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
 
-def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=False, also_return_rows=False):
-    user_input = fix_newlines(user_input)
     rows = [f"{context.strip()}\n"]
 
+    # Finding the maximum prompt size
     if shared.soft_prompt:
        chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
     max_length = min(get_max_prompt_length(max_new_tokens), chat_prompt_size)
 
+    if is_instruct:
+        prefix1 = f"{name1}\n"
+        prefix2 = f"{name2}\n"
+    else:
+        prefix1 = f"{name1}: "
+        prefix2 = f"{name2}: "
+
     i = len(shared.history['internal'])-1
     while i >= 0 and len(encode(''.join(rows), max_new_tokens)[0]) < max_length:
-        rows.insert(1, f"{name2}: {shared.history['internal'][i][1].strip()}\n")
-        prev_user_input = shared.history['internal'][i][0]
-        if prev_user_input not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
-            rows.insert(1, f"{name1}: {prev_user_input.strip()}\n")
+        rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{end_of_turn}\n")
+        string = shared.history['internal'][i][0]
+        if string not in ['', '<|BEGIN-VISIBLE-CHAT|>']:
+            rows.insert(1, f"{prefix1}{string.strip()}{end_of_turn}\n")
         i -= 1
 
-    if not impersonate:
+    if impersonate:
+        rows.append(f"{prefix1.strip() if not is_instruct else prefix1}")
+        limit = 2
+    else:
+        # Adding the user message
+        user_input = fix_newlines(user_input)
         if len(user_input) > 0:
-            rows.append(f"{name1}: {user_input}\n")
-        rows.append(apply_extensions(f"{name2}:", "bot_prefix"))
+            rows.append(f"{prefix1}{user_input}{end_of_turn}\n")
+
+        # Adding the Character prefix
+        rows.append(apply_extensions(f"{prefix2.strip() if not is_instruct else prefix2}", "bot_prefix"))
         limit = 3
-    else:
-        rows.append(f"{name1}:")
-        limit = 2
 
     while len(rows) > limit and len(encode(''.join(rows), max_new_tokens)[0]) >= max_length:
         rows.pop(1)
-
     prompt = ''.join(rows)
 
     if also_return_rows:
@@ -81,13 +91,20 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline):
                     if reply[-j:] == string[:j]:
                         reply = reply[:-j]
                         break
+                else:
+                    continue
+                break
 
     reply = fix_newlines(reply)
     return reply, next_character_found
 
-def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1, regenerate=False):
-    just_started = True
-    eos_token = '\n' if stop_at_newline else None
+def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False):
+    if mode == 'instruct':
+        stopping_strings = [f"\n{name1}", f"\n{name2}"]
+    else:
+        stopping_strings = [f"\n{name1}:", f"\n{name2}:"]
+        
+    eos_token = '\n' if generate_state['stop_at_newline'] else None
     name1_original = name1
     if 'pygmalion' in shared.model_name.lower():
         name1 = "You"
@@ -104,14 +121,13 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
 
     if visible_text is None:
         visible_text = text
-    if shared.args.chat:
-        visible_text = visible_text.replace('\n', '<br>')
     text = apply_extensions(text, "input")
 
+    kwargs = {'end_of_turn': end_of_turn, 'is_instruct': mode == 'instruct'}
     if custom_generate_chat_prompt is None:
-        prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size)
+        prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
     else:
-        prompt = custom_generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size)
+        prompt = custom_generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
 
     # Yield *Is typing...*
     if not regenerate:
@@ -119,17 +135,16 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
 
     # Generate
     cumulative_reply = ''
-    for i in range(chat_generation_attempts):
+    just_started = True
+    for i in range(generate_state['chat_generation_attempts']):
         reply = None
-        for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]):
+        for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings):
             reply = cumulative_reply + reply
 
             # Extracting the reply
-            reply, next_character_found = extract_message_from_reply(reply, name1, name2, stop_at_newline)
+            reply, next_character_found = extract_message_from_reply(reply, name1, name2, generate_state['stop_at_newline'])
             visible_reply = re.sub("(<USER>|<user>|{{user}})", name1_original, reply)
             visible_reply = apply_extensions(visible_reply, "output")
-            if shared.args.chat:
-                visible_reply = visible_reply.replace('\n', '<br>')
 
             # We need this global variable to handle the Stop event,
             # otherwise gradio gets confused
@@ -152,23 +167,27 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
 
     yield shared.history['visible']
 
-def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1):
-    eos_token = '\n' if stop_at_newline else None
-
+def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
+    if mode == 'instruct':
+        stopping_strings = [f"\n{name1}", f"\n{name2}"]
+    else:
+        stopping_strings = [f"\n{name1}:", f"\n{name2}:"]
+        
+    eos_token = '\n' if generate_state['stop_at_newline'] else None
     if 'pygmalion' in shared.model_name.lower():
         name1 = "You"
 
-    prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=True)
+    prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], impersonate=True, end_of_turn=end_of_turn)
 
     # Yield *Is typing...*
     yield shared.processing_message
 
     cumulative_reply = ''
-    for i in range(chat_generation_attempts):
+    for i in range(generate_state['chat_generation_attempts']):
         reply = None
-        for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]):
+        for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=stopping_strings):
             reply = cumulative_reply + reply
-            reply, next_character_found = extract_message_from_reply(reply, name1, name2, stop_at_newline)
+            reply, next_character_found = extract_message_from_reply(reply, name1, name2, generate_state['stop_at_newline'])
             yield reply
             if next_character_found:
                 break
@@ -178,36 +197,30 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ
 
     yield reply
 
-def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1):
-    for history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts):
-        yield generate_chat_html(history, name1, name2, shared.character)
+def cai_chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
+    for history in chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
+        yield chat_html_wrapper(history, name1, name2, mode)
 
-def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1):
+def regenerate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
     if (shared.character != 'None' and len(shared.history['visible']) == 1) or len(shared.history['internal']) == 0:
-        yield generate_chat_output(shared.history['visible'], name1, name2, shared.character)
+        yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
     else:
         last_visible = shared.history['visible'].pop()
         last_internal = shared.history['internal'].pop()
         # Yield '*Is typing...*'
-        yield generate_chat_output(shared.history['visible']+[[last_visible[0], shared.processing_message]], name1, name2, shared.character)
-        for history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts, regenerate=True):
-            if shared.args.cai_chat:
-                shared.history['visible'][-1] = [last_visible[0], history[-1][1]]
-            else:
-                shared.history['visible'][-1] = (last_visible[0], history[-1][1])
-            yield generate_chat_output(shared.history['visible'], name1, name2, shared.character)
+        yield chat_html_wrapper(shared.history['visible']+[[last_visible[0], shared.processing_message]], name1, name2, mode)
+        for history in chatbot_wrapper(last_internal[0], generate_state, name1, name2, context, mode, end_of_turn, regenerate=True):
+            shared.history['visible'][-1] = [last_visible[0], history[-1][1]]
+            yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
 
-def remove_last_message(name1, name2):
+def remove_last_message(name1, name2, mode):
     if len(shared.history['visible']) > 0 and shared.history['internal'][-1][0] != '<|BEGIN-VISIBLE-CHAT|>':
         last = shared.history['visible'].pop()
         shared.history['internal'].pop()
     else:
         last = ['', '']
 
-    if shared.args.cai_chat:
-        return generate_chat_html(shared.history['visible'], name1, name2, shared.character), last[0]
-    else:
-        return shared.history['visible'], last[0]
+    return chat_html_wrapper(shared.history['visible'], name1, name2, mode), last[0]
 
 def send_last_reply_to_input():
     if len(shared.history['internal']) > 0:
@@ -215,20 +228,17 @@ def send_last_reply_to_input():
     else:
         return ''
 
-def replace_last_reply(text, name1, name2):
+def replace_last_reply(text, name1, name2, mode):
     if len(shared.history['visible']) > 0:
-        if shared.args.cai_chat:
-            shared.history['visible'][-1][1] = text
-        else:
-            shared.history['visible'][-1] = (shared.history['visible'][-1][0], text)
+        shared.history['visible'][-1][1] = text
         shared.history['internal'][-1][1] = apply_extensions(text, "input")
 
-    return generate_chat_output(shared.history['visible'], name1, name2, shared.character)
+    return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
 
 def clear_html():
-    return generate_chat_html([], "", "", shared.character)
+    return chat_html_wrapper([], "", "")
 
-def clear_chat_log(name1, name2, greeting):
+def clear_chat_log(name1, name2, greeting, mode):
     shared.history['visible'] = []
     shared.history['internal'] = []
 
@@ -236,12 +246,12 @@ def clear_chat_log(name1, name2, greeting):
         shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
         shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
 
-    return generate_chat_output(shared.history['visible'], name1, name2, shared.character)
+    return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
 
-def redraw_html(name1, name2):
-    return generate_chat_html(shared.history['visible'], name1, name2, shared.character)
+def redraw_html(name1, name2, mode):
+    return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
 
-def tokenize_dialogue(dialogue, name1, name2):
+def tokenize_dialogue(dialogue, name1, name2, mode):
     history = []
 
     dialogue = re.sub('<START>', '', dialogue)
@@ -326,15 +336,35 @@ def build_pygmalion_style_context(data):
     context = f"{context.strip()}\n<START>\n"
     return context
 
-def load_character(character, name1, name2):
+def generate_pfp_cache(character):
+    cache_folder = Path("cache")
+    if not cache_folder.exists():
+        cache_folder.mkdir()
+
+    for path in [Path(f"characters/{character}.{extension}") for extension in ['png', 'jpg', 'jpeg']]:
+        if path.exists():
+            img = make_thumbnail(Image.open(path))
+            img.save(Path('cache/pfp_character.png'), format='PNG')
+            return img
+    return None
+
+def load_character(character, name1, name2, mode):
     shared.character = character
     shared.history['internal'] = []
     shared.history['visible'] = []
-    greeting = ""
+    context = greeting = end_of_turn = ""
+    greeting_field = 'greeting'
+    picture = None
+
+    # Deleting the profile picture cache, if any
+    if Path("cache/pfp_character.png").exists():
+        Path("cache/pfp_character.png").unlink()
 
     if character != 'None':
+        folder = 'characters' if not mode == 'instruct' else 'characters/instruction-following'
+        picture = generate_pfp_cache(character)
         for extension in ["yml", "yaml", "json"]:
-            filepath = Path(f'characters/{character}.{extension}')
+            filepath = Path(f'{folder}/{character}.{extension}')
             if filepath.exists():
                 break
         file_contents = open(filepath, 'r', encoding='utf-8').read()
@@ -350,19 +380,21 @@ def load_character(character, name1, name2):
 
         if 'context' in data:
             context = f"{data['context'].strip()}\n\n"
-            greeting_field = 'greeting'
-        else:
+        elif "char_persona" in data:
             context = build_pygmalion_style_context(data)
             greeting_field = 'char_greeting'
 
-        if 'example_dialogue' in data and data['example_dialogue'] != '':
+        if 'example_dialogue' in data:
             context += f"{data['example_dialogue'].strip()}\n"
-        if greeting_field in data and len(data[greeting_field].strip()) > 0:
+        if greeting_field in data:
             greeting = data[greeting_field]  
+        if 'end_of_turn' in data:
+            end_of_turn = data['end_of_turn']  
     else:
         context = shared.settings['context']
         name2 = shared.settings['name2']
         greeting = shared.settings['greeting'] 
+        end_of_turn = shared.settings['end_of_turn']
 
     if Path(f'logs/{shared.character}_persistent.json').exists():
         load_history(open(Path(f'logs/{shared.character}_persistent.json'), 'rb').read(), name1, name2)
@@ -370,13 +402,10 @@ def load_character(character, name1, name2):
         shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
         shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
 
-    if shared.args.cai_chat:
-        return name1, name2, greeting, context, generate_chat_html(shared.history['visible'], name1, name2, shared.character)
-    else:
-        return name1, name2, greeting, context, shared.history['visible']
+    return name1, name2, picture, greeting, context, end_of_turn, chat_html_wrapper(shared.history['visible'], name1, name2, mode, reset_cache=True)
 
 def load_default_history(name1, name2):
-    load_character("None", name1, name2)
+    load_character("None", name1, name2, "chat")
 
 def upload_character(json_file, img, tavern=False):
     json_file = json_file if type(json_file) == str else json_file.decode('utf-8')
@@ -404,7 +433,17 @@ def upload_tavern_character(img, name1, name2):
     _json = {"char_name": _json['name'], "char_persona": _json['description'], "char_greeting": _json["first_mes"], "example_dialogue": _json['mes_example'], "world_scenario": _json['scenario']}
     return upload_character(json.dumps(_json), img, tavern=True)
 
-def upload_your_profile_picture(img):
-    img = Image.open(io.BytesIO(img))
-    img.save(Path('img_me.png'))
-    print('Profile picture saved to "img_me.png"')
+def upload_your_profile_picture(img, name1, name2, mode):
+    cache_folder = Path("cache")
+    if not cache_folder.exists():
+        cache_folder.mkdir()
+
+    if img == None:
+        if Path("cache/pfp_me.png").exists():
+            Path("cache/pfp_me.png").unlink()
+    else:
+        img = make_thumbnail(img)
+        img.save(Path('cache/pfp_me.png'))
+        print('Profile picture saved to "cache/pfp_me.png"')
+
+    return chat_html_wrapper(shared.history['visible'], name1, name2, mode, reset_cache=True)

+ 61 - 12
modules/html_generator.py

@@ -6,10 +6,11 @@ This is a library for formatting text outputs as nice HTML.
 
 import os
 import re
+import time
 from pathlib import Path
 
 import markdown
-from PIL import Image
+from PIL import Image, ImageOps
 
 # This is to store the paths to the thumbnails of the profile pictures
 image_cache = {}
@@ -20,6 +21,8 @@ with open(Path(__file__).resolve().parent / '../css/html_4chan_style.css', 'r')
     _4chan_css = css_f.read()
 with open(Path(__file__).resolve().parent / '../css/html_cai_style.css', 'r') as f:
     cai_css = f.read()
+with open(Path(__file__).resolve().parent / '../css/html_instruct_style.css', 'r') as f:
+    instruct_css = f.read()
 
 def fix_newlines(string):
     string = string.replace('\n', '\n\n')
@@ -95,6 +98,13 @@ def generate_4chan_html(f):
 
     return output
 
+def make_thumbnail(image):
+    image = image.resize((350, round(image.size[1]/image.size[0]*350)), Image.Resampling.LANCZOS)
+    if image.size[1] > 470:
+        image = ImageOps.fit(image, (350, 470), Image.ANTIALIAS)
+
+    return image
+
 def get_image_cache(path):
     cache_folder = Path("cache")
     if not cache_folder.exists():
@@ -102,26 +112,52 @@ def get_image_cache(path):
 
     mtime = os.stat(path).st_mtime
     if (path in image_cache and mtime != image_cache[path][0]) or (path not in image_cache):
-        img = Image.open(path)
-        img.thumbnail((200, 200))
+        img = make_thumbnail(Image.open(path))
         output_file = Path(f'cache/{path.name}_cache.png')
         img.convert('RGB').save(output_file, format='PNG')
         image_cache[path] = [mtime, output_file.as_posix()]
 
     return image_cache[path][1]
 
-def load_html_image(paths):
-    for str_path in paths:
-          path = Path(str_path)
-          if path.exists():
-              return f'<img src="file/{get_image_cache(path)}">'
-    return ''
+def generate_instruct_html(history):
+    output = f'<style>{instruct_css}</style><div class="chat" id="chat">'
+    for i,_row in enumerate(history[::-1]):
+        row = [convert_to_markdown(entry) for entry in _row]
+
+        output += f"""
+              <div class="assistant-message">
+                <div class="text">
+                  <div class="message-body">
+                    {row[1]}
+                  </div>
+                </div>
+              </div>
+            """
+
+        if len(row[0]) == 0: # don't display empty user messages
+            continue
+
+        output += f"""
+              <div class="user-message">
+                <div class="text">
+                  <div class="message-body">
+                    {row[0]}
+                  </div>
+                </div>
+              </div>
+            """
+
+    output += "</div>"
+
+    return output
 
-def generate_chat_html(history, name1, name2, character):
+def generate_cai_chat_html(history, name1, name2, reset_cache=False):
     output = f'<style>{cai_css}</style><div class="chat" id="chat">'
 
-    img_bot = load_html_image([f"characters/{character}.{ext}" for ext in ['png', 'jpg', 'jpeg']] + ["img_bot.png","img_bot.jpg","img_bot.jpeg"])
-    img_me = load_html_image(["img_me.png", "img_me.jpg", "img_me.jpeg"])
+    # The time.time() is to prevent the brower from caching the image
+    suffix = f"?{time.time()}" if reset_cache else f"?{name2}"
+    img_bot = f'<img src="file/cache/pfp_character.png{suffix}">' if Path("cache/pfp_character.png").exists() else ''
+    img_me = f'<img src="file/cache/pfp_me.png{suffix}">' if Path("cache/pfp_me.png").exists() else ''
 
     for i,_row in enumerate(history[::-1]):
         row = [convert_to_markdown(entry) for entry in _row]
@@ -163,3 +199,16 @@ def generate_chat_html(history, name1, name2, character):
 
     output += "</div>"
     return output
+
+def generate_chat_html(history, name1, name2):
+    return generate_cai_chat_html(history, name1, name2)
+
+def chat_html_wrapper(history, name1, name2, mode, reset_cache=False):
+    if mode == "cai-chat":
+        return generate_cai_chat_html(history, name1, name2, reset_cache)
+    elif mode == "chat":
+        return generate_chat_html(history, name1, name2)
+    elif mode == "instruct":
+        return generate_instruct_html(history)
+    else:
+        return ''

+ 65 - 0
modules/llamacpp_model_alternative.py

@@ -0,0 +1,65 @@
+'''
+Based on 
+https://github.com/abetlen/llama-cpp-python
+
+Documentation:
+https://abetlen.github.io/llama-cpp-python/
+'''
+
+import multiprocessing
+
+from llama_cpp import Llama
+
+from modules import shared
+from modules.callbacks import Iteratorize
+
+
+class LlamaCppModel:
+    def __init__(self):
+        self.initialized = False
+
+    @classmethod
+    def from_pretrained(self, path):
+        result = self()
+
+        params = {
+            'model_path': str(path),
+            'n_ctx': 2048,
+            'seed': 0,
+            'n_threads': shared.args.threads or None
+        }
+        self.model = Llama(**params)
+
+        # This is ugly, but the model and the tokenizer are the same object in this library.
+        return result, result 
+
+    def encode(self, string):
+        if type(string) is str:
+            string = string.encode()
+        return self.model.tokenize(string)
+
+    def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=1, callback=None):
+        if type(context) is str:
+            context = context.encode()
+        tokens = self.model.tokenize(context)
+
+        output = b""
+        count = 0
+        for token in self.model.generate(tokens, top_k=top_k, top_p=top_p, temp=temperature, repeat_penalty=repetition_penalty):
+            text = self.model.detokenize([token])
+            output += text
+            if callback:
+                callback(text.decode())
+
+            count += 1
+            if count >= token_count or (token == self.model.token_eos()):
+                break
+
+        return output.decode()
+
+    def generate_with_streaming(self, **kwargs):
+        with Iteratorize(self.generate, kwargs, callback=None) as generator:
+            reply = ''
+            for token in generator:
+                reply += token
+                yield reply

+ 6 - 4
modules/models.py

@@ -10,7 +10,7 @@ import torch
 import transformers
 from accelerate import infer_auto_device_map, init_empty_weights
 from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
-                          BitsAndBytesConfig)
+                          BitsAndBytesConfig, LlamaTokenizer)
 
 import modules.shared as shared
 
@@ -42,7 +42,7 @@ def load_model(model_name):
     t0 = time.time()
 
     shared.is_RWKV = 'rwkv-' in model_name.lower()
-    shared.is_llamacpp = len(list(Path(f'models/{model_name}').glob('ggml*.bin'))) > 0
+    shared.is_llamacpp = len(list(Path(f'{shared.args.model_dir}/{model_name}').glob('ggml*.bin'))) > 0
 
     # Default settings
     if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV, shared.is_llamacpp]):
@@ -103,9 +103,9 @@ def load_model(model_name):
 
     # llamacpp model
     elif shared.is_llamacpp:
-        from modules.llamacpp_model import LlamaCppModel
+        from modules.llamacpp_model_alternative import LlamaCppModel
 
-        model_file = list(Path(f'models/{model_name}').glob('ggml*.bin'))[0]
+        model_file = list(Path(f'{shared.args.model_dir}/{model_name}').glob('ggml*.bin'))[0]
         print(f"llama.cpp weights detected: {model_file}\n")
 
         model, tokenizer = LlamaCppModel.from_pretrained(model_file)
@@ -172,6 +172,8 @@ def load_model(model_name):
     # Loading the tokenizer
     if any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists():
         tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/"))
+    elif type(model) is transformers.LlamaForCausalLM:
+        tokenizer = LlamaTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"), clean_up_tokenization_spaces=True)
     else:
         tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/"))
     tokenizer.truncation_side = 'left'

+ 11 - 4
modules/shared.py

@@ -33,6 +33,7 @@ settings = {
     'name2': 'Assistant',
     'context': 'This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.',
     'greeting': 'Hello there!',
+    'end_of_turn': '',
     'stop_at_newline': False,
     'chat_prompt_size': 2048,
     'chat_prompt_size_min': 0,
@@ -44,6 +45,7 @@ settings = {
     'chat_default_extensions': ["gallery"],
     'presets': {
         'default': 'NovelAI-Sphinx Moth',
+        '.*(alpaca|llama)': "LLaMA-Precise",
         '.*pygmalion': 'NovelAI-Storywriter',
         '.*RWKV': 'Naive',
     },
@@ -73,8 +75,8 @@ parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpForma
 
 # Basic settings
 parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.')
-parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode.')
-parser.add_argument('--cai-chat', action='store_true', help='Launch the web UI in chat mode with a style similar to Character.AI\'s. If the file img_bot.png or img_bot.jpg exists in the same folder as server.py, this image will be used as the bot\'s profile picture. Similarly, img_me.png or img_me.jpg will be used as your profile picture.')
+parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode with a style similar to the Character.AI website.')
+parser.add_argument('--cai-chat', action='store_true', help='DEPRECATED: use --chat instead.')
 parser.add_argument('--model', type=str, help='Name of the model to load by default.')
 parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.')
 parser.add_argument("--model-dir", type=str, default='models/', help="Path to directory with all the models")
@@ -131,12 +133,17 @@ parser.add_argument("--gradio-auth-path", type=str, help='Set the gradio authent
 
 args = parser.parse_args()
 
-# Provisional, this will be deleted later
+# Deprecation warnings for parameters that have been renamed
 deprecated_dict = {'gptq_bits': ['wbits', 0], 'gptq_model_type': ['model_type', None], 'gptq_pre_layer': ['prelayer', 0]}
 for k in deprecated_dict:
     if eval(f"args.{k}") != deprecated_dict[k][1]:
         print(f"Warning: --{k} is deprecated and will be removed. Use --{deprecated_dict[k][0]} instead.")
         exec(f"args.{deprecated_dict[k][0]} = args.{k}")
 
+# Deprecation warnings for parameters that have been removed
+if args.cai_chat:
+    print("Warning: --cai-chat is deprecated. Use --chat instead.")
+    args.chat = True
+
 def is_chat():
-    return any((args.chat, args.cai_chat))
+    return args.chat

+ 27 - 31
modules/text_generation.py

@@ -28,6 +28,10 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
         return input_ids
     else:
         input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens)
+
+        if type(shared.tokenizer) is transformers.LlamaTokenizer and input_ids[0][0] == 29871:
+            input_ids = input_ids[:,1:]
+
         if shared.args.cpu:
             return input_ids
         elif shared.args.flexgen:
@@ -102,10 +106,11 @@ def set_manual_seed(seed):
 def stop_everything_event():
     shared.stop_everything = True
 
-def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=None, stopping_strings=[]):
+def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]):
     clear_torch_cache()
-    set_manual_seed(seed)
+    set_manual_seed(generate_state['seed'])
     shared.stop_everything = False
+    generate_params = {}
     t0 = time.time()
 
     original_question = question
@@ -117,9 +122,12 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
     # These models are not part of Hugging Face, so we handle them
     # separately and terminate the function call earlier
     if any((shared.is_RWKV, shared.is_llamacpp)):
+        for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']:
+            generate_params[k] = generate_state[k]
+        generate_params["token_count"] = generate_state["max_new_tokens"]
         try:
             if shared.args.no_stream:
-                reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty)
+                reply = shared.model.generate(context=question, **generate_params)
                 output = original_question+reply
                 if not shared.is_chat():
                     reply = original_question + apply_extensions(reply, "output")
@@ -130,7 +138,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
 
                 # RWKV has proper streaming, which is very nice.
                 # No need to generate 8 tokens at a time.
-                for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty):
+                for reply in shared.model.generate_with_streaming(context=question, **generate_params):
                     output = original_question+reply
                     if not shared.is_chat():
                         reply = original_question + apply_extensions(reply, "output")
@@ -145,7 +153,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
             print(f"Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})")
             return
 
-    input_ids = encode(question, max_new_tokens)
+    input_ids = encode(question, generate_state['max_new_tokens'])
     original_input_ids = input_ids
     output = input_ids[0]
 
@@ -158,33 +166,21 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
         t = [encode(string, 0, add_special_tokens=False) for string in stopping_strings]
         stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=t, starting_idx=len(input_ids[0])))
 
-    generate_params = {}
+    generate_params["max_new_tokens"] = generate_state['max_new_tokens']
     if not shared.args.flexgen:
-        generate_params.update({
-            "max_new_tokens": max_new_tokens,
-            "eos_token_id": eos_token_ids,
-            "stopping_criteria": stopping_criteria_list,
-            "do_sample": do_sample,
-            "temperature": temperature,
-            "top_p": top_p,
-            "typical_p": typical_p,
-            "repetition_penalty": repetition_penalty,
-            "encoder_repetition_penalty": encoder_repetition_penalty,
-            "top_k": top_k,
-            "min_length": min_length if shared.args.no_stream else 0,
-            "no_repeat_ngram_size": no_repeat_ngram_size,
-            "num_beams": num_beams,
-            "penalty_alpha": penalty_alpha,
-            "length_penalty": length_penalty,
-            "early_stopping": early_stopping,
-        })
+        for k in ["do_sample", "temperature", "top_p", "typical_p", "repetition_penalty", "encoder_repetition_penalty", "top_k", "min_length", "no_repeat_ngram_size", "num_beams", "penalty_alpha", "length_penalty", "early_stopping"]:
+            generate_params[k] = generate_state[k]
+        generate_params["eos_token_id"] = eos_token_ids
+        generate_params["stopping_criteria"] = stopping_criteria_list
+        if shared.args.no_stream:
+            generate_params["min_length"] = 0
     else:
-        generate_params.update({
-            "max_new_tokens": max_new_tokens if shared.args.no_stream else 8,
-            "do_sample": do_sample,
-            "temperature": temperature,
-            "stop": eos_token_ids[-1],
-        })
+        for k in ["do_sample", "temperature"]:
+            generate_params[k] = generate_state[k]
+        generate_params["stop"] = generate_state["eos_token_ids"][-1]
+        if not shared.args.no_stream:
+            generate_params["max_new_tokens"] = 8
+
     if shared.args.no_cache:
         generate_params.update({"use_cache": False})
     if shared.args.deepspeed:
@@ -244,7 +240,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
 
         # Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
         else:
-            for i in range(max_new_tokens//8+1):
+            for i in range(generate_state['max_new_tokens']//8+1):
                 clear_torch_cache()
                 with torch.no_grad():
                     output = shared.model.generate(**generate_params)[0]

+ 78 - 24
modules/training.py

@@ -20,7 +20,7 @@ MAX_STEPS = 0
 CURRENT_GRADIENT_ACCUM = 1
 
 def get_dataset(path: str, ext: str):
-    return ['None'] + sorted(set((k.stem for k in Path(path).glob(f'*.{ext}'))), key=str.lower)
+    return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=str.lower)
 
 def create_train_interface():
     with gr.Tab('Train LoRA', elem_id='lora-train-tab'):
@@ -45,22 +45,26 @@ def create_train_interface():
             with gr.Row():
                 dataset = gr.Dropdown(choices=get_dataset('training/datasets', 'json'), value='None', label='Dataset', info='The dataset file to use for training.')
                 ui.create_refresh_button(dataset, lambda : None, lambda : {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button')
-                eval_dataset = gr.Dropdown(choices=get_dataset('training/datasets', 'json'), value='None', label='Evaluation Dataset', info='The dataset file used to evaluate the model after training.')
+                eval_dataset = gr.Dropdown(choices=get_dataset('training/datasets', 'json'), value='None', label='Evaluation Dataset', info='The (optional) dataset file used to evaluate the model after training.')
                 ui.create_refresh_button(eval_dataset, lambda : None, lambda : {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button')
                 format = gr.Dropdown(choices=get_dataset('training/formats', 'json'), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.')
                 ui.create_refresh_button(format, lambda : None, lambda : {'choices': get_dataset('training/formats', 'json')}, 'refresh-button')
+
         with gr.Tab(label="Raw Text File"):
             with gr.Row():
                 raw_text_file = gr.Dropdown(choices=get_dataset('training/datasets', 'txt'), value='None', label='Text File', info='The raw text file to use for training.')
                 ui.create_refresh_button(raw_text_file, lambda : None, lambda : {'choices': get_dataset('training/datasets', 'txt')}, 'refresh-button')
-                overlap_len = gr.Slider(label='Overlap Length', minimum=0,maximum=512, value=128, step=16, info='Overlap length - ie how many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length above). Setting overlap to exactly half the cutoff length may be ideal.')
+            with gr.Row():
+                overlap_len = gr.Slider(label='Overlap Length', minimum=0, maximum=512, value=128, step=16, info='Overlap length - ie how many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length below). Setting overlap to exactly half the cutoff length may be ideal.')
+                newline_favor_len = gr.Slider(label='Prefer Newline Cut Length', minimum=0, maximum=512, value=128, step=16, info='Length (in characters, not tokens) of the maximum distance to shift an overlap cut by to ensure chunks cut at newlines. If too low, cuts may occur in the middle of lines.')
 
         with gr.Row():
             start_button = gr.Button("Start LoRA Training")
             stop_button = gr.Button("Interrupt")
 
         output = gr.Markdown(value="Ready")
-        start_button.click(do_train, [lora_name, micro_batch_size, batch_size, epochs, learning_rate, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, raw_text_file, overlap_len], [output])
+        start_button.click(do_train, [lora_name, micro_batch_size, batch_size, epochs, learning_rate, lora_rank, lora_alpha, lora_dropout,
+                                      cutoff_len, dataset, eval_dataset, format, raw_text_file, overlap_len, newline_favor_len], [output])
         stop_button.click(do_interrupt, [], [], cancels=[], queue=False)
 
 def do_interrupt():
@@ -91,8 +95,8 @@ def clean_path(base_path: str, path: str):
         return path
     return f'{Path(base_path).absolute()}/{path}'
 
-def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lora_rank: int,
-             lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, raw_text_file: str, overlap_len: int):
+def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lora_rank: int, lora_alpha: int, lora_dropout: float,
+             cutoff_len: int, dataset: str, eval_dataset: str, format: str, raw_text_file: str, overlap_len: int, newline_favor_len: int):
     global WANT_INTERRUPT, CURRENT_STEPS, MAX_STEPS, CURRENT_GRADIENT_ACCUM
     WANT_INTERRUPT = False
     CURRENT_STEPS = 0
@@ -103,6 +107,25 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
     lora_name = f"{shared.args.lora_dir}/{clean_path(None, lora_name)}"
     actual_lr = float(learning_rate)
 
+    model_type = type(shared.model).__name__
+    if model_type != "LlamaForCausalLM":
+        if model_type == "PeftModelForCausalLM":
+            yield "You are trying to train a LoRA while you already have another LoRA loaded. This will work, but may have unexpected effects. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
+            print("Warning: Training LoRA over top of another LoRA. May have unexpected effects.")
+        else:
+            yield "LoRA training has only currently been validated for LLaMA models. Unexpected errors may follow. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
+            print(f"Warning: LoRA training has only currently been validated for LLaMA models. (Found model type: {model_type})")
+        time.sleep(5)
+
+    if shared.args.wbits > 0 or shared.args.gptq_bits > 0:
+        yield "LoRA training does not yet support 4bit. Please use `--load-in-8bit` for now."
+        return
+
+    elif not shared.args.load_in_8bit:
+        yield "It is highly recommended you use `--load-in-8bit` for LoRA training. *(Will continue anyway in 2 seconds, press `Interrupt` to stop.)*"
+        print("Warning: It is highly recommended you use `--load-in-8bit` for LoRA training.")
+        time.sleep(2) # Give it a moment for the message to show in UI before continuing
+
     if cutoff_len <= 0 or micro_batch_size <= 0 or batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0:
         yield "Cannot input zeroes."
         return
@@ -126,15 +149,20 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
             raw_text = file.read()
         tokens = shared.tokenizer.encode(raw_text)
         del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM
+
         tokens = list(split_chunks(tokens, cutoff_len - overlap_len))
         for i in range(1, len(tokens)):
             tokens[i] = tokens[i - 1][-overlap_len:] + tokens[i]
         text_chunks = [shared.tokenizer.decode(x) for x in tokens]
         del tokens
-        data = Dataset.from_list([tokenize(x) for x in text_chunks])
-        train_data = data.shuffle()
-        eval_data = None
+
+        if newline_favor_len > 0:
+            text_chunks = [cut_chunk_for_newline(x, newline_favor_len) for x in text_chunks]
+
+        train_data = Dataset.from_list([tokenize(x) for x in text_chunks])
         del text_chunks
+        train_data = train_data.shuffle()
+        eval_data = None
 
     else:
         if dataset in ['None', '']:
@@ -232,33 +260,37 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
     # TODO: save/load checkpoints to resume from?
     print("Starting training...")
     yield "Starting..."
+    if WANT_INTERRUPT:
+        yield "Interrupted before start."
+        return
 
-    def threadedRun():
+    def threaded_run():
         trainer.train()
 
-    thread = threading.Thread(target=threadedRun)
+    thread = threading.Thread(target=threaded_run)
     thread.start()
-    lastStep = 0
-    startTime = time.perf_counter()
+    last_step = 0
+    start_time = time.perf_counter()
 
     while thread.is_alive():
         time.sleep(0.5)
         if WANT_INTERRUPT:
             yield "Interrupting, please wait... *(Run will stop after the current training step completes.)*"
-        elif CURRENT_STEPS != lastStep:
-            lastStep = CURRENT_STEPS
-            timeElapsed = time.perf_counter() - startTime
-            if timeElapsed <= 0:
-                timerInfo = ""
-                totalTimeEstimate = 999
+
+        elif CURRENT_STEPS != last_step:
+            last_step = CURRENT_STEPS
+            time_elapsed = time.perf_counter() - start_time
+            if time_elapsed <= 0:
+                timer_info = ""
+                total_time_estimate = 999
             else:
-                its = CURRENT_STEPS / timeElapsed
+                its = CURRENT_STEPS / time_elapsed
                 if its > 1:
-                    timerInfo = f"`{its:.2f}` it/s"
+                    timer_info = f"`{its:.2f}` it/s"
                 else:
-                    timerInfo = f"`{1.0/its:.2f}` s/it"
-                totalTimeEstimate = (1.0/its) * (MAX_STEPS)
-            yield f"Running... **{CURRENT_STEPS}** / **{MAX_STEPS}** ... {timerInfo}, `{timeElapsed:.0f}`/`{totalTimeEstimate:.0f}` seconds"
+                    timer_info = f"`{1.0/its:.2f}` s/it"
+                total_time_estimate = (1.0/its) * (MAX_STEPS)
+            yield f"Running... **{CURRENT_STEPS}** / **{MAX_STEPS}** ... {timer_info}, {format_time(time_elapsed)} / {format_time(total_time_estimate)} ... {format_time(total_time_estimate - time_elapsed)} remaining"
 
     print("Training complete, saving...")
     lora_model.save_pretrained(lora_name)
@@ -273,3 +305,25 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
 def split_chunks(arr, step):
     for i in range(0, len(arr), step):
         yield arr[i:i + step]
+
+def cut_chunk_for_newline(chunk: str, max_length: int):
+    if '\n' not in chunk:
+        return chunk
+    first_newline = chunk.index('\n')
+    if first_newline < max_length:
+        chunk = chunk[first_newline + 1:]
+    if '\n' not in chunk:
+        return chunk
+    last_newline = chunk.rindex('\n')
+    if len(chunk) - last_newline < max_length:
+        chunk = chunk[:last_newline]
+    return chunk
+
+def format_time(seconds: float):
+    if seconds < 120:
+        return f"`{seconds:.0f}` seconds"
+    minutes = seconds / 60
+    if minutes < 120:
+        return f"`{minutes:.0f}` minutes"
+    hours = minutes / 60
+    return f"`{hours:.0f}` hours"

+ 6 - 0
presets/LLaMA-Precise.txt

@@ -0,0 +1,6 @@
+do_sample=True
+top_p=0.1
+top_k=40
+temperature=0.7
+repetition_penalty=1.18
+typical_p=1.0

+ 1 - 2
requirements.txt

@@ -3,12 +3,11 @@ bitsandbytes==0.37.2
 datasets
 flexgen==0.1.7
 gradio==3.24.1
-llamacpp==0.1.11
 markdown
 numpy
 peft==0.2.0
 requests
-rwkv==0.7.2
+rwkv==0.7.3
 safetensors==0.3.0
 sentencepiece
 pyyaml

+ 97 - 61
server.py

@@ -1,3 +1,7 @@
+import os
+
+os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
+
 import io
 import json
 import re
@@ -8,10 +12,11 @@ from datetime import datetime
 from pathlib import Path
 
 import gradio as gr
+from PIL import Image
 
 import modules.extensions as extensions_module
-from modules import chat, shared, training, ui
-from modules.html_generator import generate_chat_html
+from modules import chat, shared, training, ui, api
+from modules.html_generator import chat_html_wrapper
 from modules.LoRA import add_lora_to_model
 from modules.models import load_model, load_soft_prompt
 from modules.text_generation import (clear_torch_cache, generate_reply,
@@ -47,6 +52,13 @@ def get_available_prompts():
 
 def get_available_characters():
     paths = (x for x in Path('characters').iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
+    return ['None'] + sorted(set((k.stem for k in paths if k.stem != "instruction-following")), key=str.lower)
+
+def get_available_instruction_templates():
+    path = "characters/instruction-following"
+    paths = []
+    if os.path.exists(path):
+        paths = (x for x in Path(path).iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
     return ['None'] + sorted(set((k.stem for k in paths)), key=str.lower)
 
 def get_available_extensions():
@@ -76,7 +88,7 @@ def load_lora_wrapper(selected_lora):
     add_lora_to_model(selected_lora)
     return selected_lora
 
-def load_preset_values(preset_menu, return_dict=False):
+def load_preset_values(preset_menu, state, return_dict=False):
     generate_params = {
         'do_sample': True,
         'temperature': 1,
@@ -98,13 +110,13 @@ def load_preset_values(preset_menu, return_dict=False):
         i = i.rstrip(',').strip().split('=')
         if len(i) == 2 and i[0].strip() != 'tokens':
             generate_params[i[0].strip()] = eval(i[1].strip())
-
     generate_params['temperature'] = min(1.99, generate_params['temperature'])
 
     if return_dict:
         return generate_params
     else:
-        return generate_params['do_sample'], generate_params['temperature'], generate_params['top_p'], generate_params['typical_p'], generate_params['repetition_penalty'], generate_params['encoder_repetition_penalty'], generate_params['top_k'], generate_params['min_length'], generate_params['no_repeat_ngram_size'], generate_params['num_beams'], generate_params['penalty_alpha'], generate_params['length_penalty'], generate_params['early_stopping']
+        state.update(generate_params)
+        return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
 
 def upload_soft_prompt(file):
     with zipfile.ZipFile(io.BytesIO(file)) as zf:
@@ -118,19 +130,8 @@ def upload_soft_prompt(file):
 
     return name
 
-def create_model_and_preset_menus():
-    with gr.Row():
-        with gr.Column():
-            with gr.Row():
-                shared.gradio['model_menu'] = gr.Dropdown(choices=available_models, value=shared.model_name, label='Model')
-                ui.create_refresh_button(shared.gradio['model_menu'], lambda : None, lambda : {'choices': get_available_models()}, 'refresh-button')
-        with gr.Column():
-            with gr.Row():
-                shared.gradio['preset_menu'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset')
-                ui.create_refresh_button(shared.gradio['preset_menu'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button')
-
 def save_prompt(text):
-    fname = f"{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}.txt"
+    fname = f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}.txt"
     with open(Path(f'prompts/{fname}'), 'w', encoding='utf-8') as f:
         f.write(text)
     return f"Saved to prompts/{fname}"
@@ -144,7 +145,7 @@ def load_prompt(fname):
             if text[-1] == '\n':
                 text = text[:-1]
             return text
-        
+
 def create_prompt_menus():
     with gr.Row():
         with gr.Column():
@@ -160,12 +161,31 @@ def create_prompt_menus():
     shared.gradio['prompt_menu'].change(load_prompt, [shared.gradio['prompt_menu']], [shared.gradio['textbox']], show_progress=False)
     shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False)
 
+def create_model_menus():
+    with gr.Row():
+        with gr.Column():
+            with gr.Row():
+                shared.gradio['model_menu'] = gr.Dropdown(choices=available_models, value=shared.model_name, label='Model')
+                ui.create_refresh_button(shared.gradio['model_menu'], lambda : None, lambda : {'choices': get_available_models()}, 'refresh-button')
+        with gr.Column():
+            with gr.Row():
+                shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA')
+                ui.create_refresh_button(shared.gradio['lora_menu'], lambda : None, lambda : {'choices': get_available_loras()}, 'refresh-button')
+
+    shared.gradio['model_menu'].change(load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_menu'], show_progress=True)
+    shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True)
+
 def create_settings_menus(default_preset):
-    generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', return_dict=True)
+    generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', {}, return_dict=True)
+    for k in ['max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size', 'chat_generation_attempts']:
+        generate_params[k] = shared.settings[k]
+    shared.gradio['generate_state'] = gr.State(generate_params)
 
     with gr.Row():
         with gr.Column():
-            create_model_and_preset_menus()
+            with gr.Row():
+                shared.gradio['preset_menu'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset')
+                ui.create_refresh_button(shared.gradio['preset_menu'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button')
         with gr.Column():
             shared.gradio['seed'] = gr.Number(value=shared.settings['seed'], label='Seed (-1 for random)')
 
@@ -199,9 +219,6 @@ def create_settings_menus(default_preset):
                         shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
                 shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
 
-    with gr.Row():
-        shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA')
-        ui.create_refresh_button(shared.gradio['lora_menu'], lambda : None, lambda : {'choices': get_available_loras()}, 'refresh-button')
 
     with gr.Accordion('Soft prompt', open=False):
         with gr.Row():
@@ -212,17 +229,14 @@ def create_settings_menus(default_preset):
         with gr.Row():
             shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip'])
 
-    shared.gradio['model_menu'].change(load_model_wrapper, [shared.gradio['model_menu']], [shared.gradio['model_menu']], show_progress=True)
-    shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']])
-    shared.gradio['lora_menu'].change(load_lora_wrapper, [shared.gradio['lora_menu']], [shared.gradio['lora_menu']], show_progress=True)
-    shared.gradio['softprompts_menu'].change(load_soft_prompt, [shared.gradio['softprompts_menu']], [shared.gradio['softprompts_menu']], show_progress=True)
-    shared.gradio['upload_softprompt'].upload(upload_soft_prompt, [shared.gradio['upload_softprompt']], [shared.gradio['softprompts_menu']])
+    shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio[k] for k in ['preset_menu', 'generate_state']], [shared.gradio[k] for k in ['generate_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']])
+    shared.gradio['softprompts_menu'].change(load_soft_prompt, shared.gradio['softprompts_menu'], shared.gradio['softprompts_menu'], show_progress=True)
+    shared.gradio['upload_softprompt'].upload(upload_soft_prompt, shared.gradio['upload_softprompt'], shared.gradio['softprompts_menu'])
 
 def set_interface_arguments(interface_mode, extensions, bool_active):
     modes = ["default", "notebook", "chat", "cai_chat"]
     cmd_list = vars(shared.args)
     bool_list = [k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes]
-    #int_list = [k for k in cmd_list if type(k) is int]
 
     shared.args.extensions = extensions
     for k in modes[1:]:
@@ -295,10 +309,7 @@ def create_interface():
         if shared.is_chat():
             shared.gradio['Chat input'] = gr.State()
             with gr.Tab("Text generation", elem_id="main"):
-                if shared.args.cai_chat:
-                    shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], shared.character))
-                else:
-                    shared.gradio['display'] = gr.Chatbot(value=shared.history['visible'], elem_id="gradio-chatbot")
+                shared.gradio['display'] = gr.HTML(value=chat_html_wrapper(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], 'cai-chat'))
                 shared.gradio['textbox'] = gr.Textbox(label='Input')
                 with gr.Row():
                     shared.gradio['Generate'] = gr.Button('Generate')
@@ -315,11 +326,20 @@ def create_interface():
                     shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False)
                     shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
 
+                shared.gradio["Chat mode"] = gr.Radio(choices=["cai-chat", "chat", "instruct"], value="cai-chat", label="Mode")
+                shared.gradio["Instruction templates"] = gr.Dropdown(choices=get_available_instruction_templates(), label="Instruction template", value="None", visible=False)
+
             with gr.Tab("Character", elem_id="chat-settings"):
-                shared.gradio['name1'] = gr.Textbox(value=shared.settings['name1'], lines=1, label='Your name')
-                shared.gradio['name2'] = gr.Textbox(value=shared.settings['name2'], lines=1, label='Character\'s name')
-                shared.gradio['greeting'] = gr.Textbox(value=shared.settings['greeting'], lines=2, label='Greeting')
-                shared.gradio['context'] = gr.Textbox(value=shared.settings['context'], lines=8, label='Context')
+                with gr.Row():
+                    with gr.Column(scale=8):
+                        shared.gradio['name1'] = gr.Textbox(value=shared.settings['name1'], lines=1, label='Your name')
+                        shared.gradio['name2'] = gr.Textbox(value=shared.settings['name2'], lines=1, label='Character\'s name')
+                        shared.gradio['greeting'] = gr.Textbox(value=shared.settings['greeting'], lines=4, label='Greeting')
+                        shared.gradio['context'] = gr.Textbox(value=shared.settings['context'], lines=4, label='Context')
+                        shared.gradio['end_of_turn'] = gr.Textbox(value=shared.settings["end_of_turn"], lines=1, label='End of turn string')
+                    with gr.Column(scale=1):
+                        shared.gradio['character_picture'] = gr.Image(label='Character picture', type="pil")
+                        shared.gradio['your_picture'] = gr.Image(label='Your picture', type="pil", value=Image.open(Path("cache/pfp_me.png")) if Path("cache/pfp_me.png").exists() else None)
                 with gr.Row():
                     shared.gradio['character_menu'] = gr.Dropdown(choices=available_characters, value='None', label='Character', elem_id='character-menu')
                     ui.create_refresh_button(shared.gradio['character_menu'], lambda : None, lambda : {'choices': get_available_characters()}, 'refresh-button')
@@ -347,8 +367,6 @@ def create_interface():
 
                         gr.Markdown("# TavernAI PNG format")
                         shared.gradio['upload_img_tavern'] = gr.File(type='binary', file_types=['image'])
-                    with gr.Tab('Upload your profile picture'):
-                        shared.gradio['upload_img_me'] = gr.File(type='binary', file_types=['image'])
 
             with gr.Tab("Parameters", elem_id="parameters"):
                 with gr.Box():
@@ -359,35 +377,35 @@ def create_interface():
                             shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size'])
                         with gr.Column():
                             shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)')
-                            shared.gradio['check'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character?')
+                            shared.gradio['stop_at_newline'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character?')
 
                 create_settings_menus(default_preset)
 
-            function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper'
-            shared.input_params = [shared.gradio[k] for k in ['Chat input', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'seed', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider', 'chat_generation_attempts']]
+            shared.input_params = [shared.gradio[k] for k in ['Chat input', 'generate_state', 'name1', 'name2', 'context', 'Chat mode', 'end_of_turn']]
 
             def set_chat_input(textbox):
                 return textbox, ""
 
             gen_events.append(shared.gradio['Generate'].click(set_chat_input, shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False))
-            gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
+            gen_events.append(shared.gradio['Generate'].click(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
             gen_events.append(shared.gradio['textbox'].submit(set_chat_input, shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False))
-            gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
+            gen_events.append(shared.gradio['textbox'].submit(chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
             gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
             gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
             shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
 
             shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream)
-            shared.gradio['Replace last reply'].click(chat.replace_last_reply, [shared.gradio['textbox'], shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'], show_progress=shared.args.no_stream)
+            shared.gradio['Replace last reply'].click(chat.replace_last_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'Chat mode']], shared.gradio['display'], show_progress=shared.args.no_stream)
 
             # Clear history with confirmation
             clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']]
             shared.gradio['Clear history'].click(lambda :[gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
             shared.gradio['Clear history-confirm'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
-            shared.gradio['Clear history-confirm'].click(chat.clear_chat_log, [shared.gradio['name1'], shared.gradio['name2'], shared.gradio['greeting']], shared.gradio['display'])
+            shared.gradio['Clear history-confirm'].click(chat.clear_chat_log, [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'Chat mode']], shared.gradio['display'])
             shared.gradio['Clear history-cancel'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
+            shared.gradio['Chat mode'].change(lambda x : gr.update(visible= x=='instruct'), shared.gradio['Chat mode'], shared.gradio['Instruction templates'])
 
-            shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
+            shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
             shared.gradio['download_button'].click(chat.save_history, inputs=[], outputs=[shared.gradio['download']])
             shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']])
 
@@ -399,20 +417,21 @@ def create_interface():
             shared.gradio['textbox'].submit(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False)
             shared.gradio['textbox'].submit(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
 
-            shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2']], [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'context', 'display']])
-            shared.gradio['upload_chat_history'].upload(chat.load_history, [shared.gradio['upload_chat_history'], shared.gradio['name1'], shared.gradio['name2']], [])
+            shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
+            shared.gradio['Instruction templates'].change(lambda character, name1, name2, mode: chat.load_character(character, name1, name2, mode), [shared.gradio[k] for k in ['Instruction templates', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
+            shared.gradio['upload_chat_history'].upload(chat.load_history, [shared.gradio[k] for k in ['upload_chat_history', 'name1', 'name2']], [])
             shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']])
-            shared.gradio['upload_img_me'].upload(chat.upload_your_profile_picture, [shared.gradio['upload_img_me']], [])
+            shared.gradio['your_picture'].change(chat.upload_your_profile_picture, [shared.gradio[k] for k in ['your_picture', 'name1', 'name2', 'Chat mode']], shared.gradio['display'])
 
-            reload_func = chat.redraw_html if shared.args.cai_chat else lambda : shared.history['visible']
-            reload_inputs = [shared.gradio['name1'], shared.gradio['name2']] if shared.args.cai_chat else []
-            shared.gradio['upload_chat_history'].upload(reload_func, reload_inputs, [shared.gradio['display']])
-            shared.gradio['upload_img_me'].upload(reload_func, reload_inputs, [shared.gradio['display']])
-            shared.gradio['Stop'].click(reload_func, reload_inputs, [shared.gradio['display']])
+            reload_inputs = [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']]
+            shared.gradio['upload_chat_history'].upload(chat.redraw_html, reload_inputs, [shared.gradio['display']])
+            shared.gradio['Stop'].click(chat.redraw_html, reload_inputs, [shared.gradio['display']])
+            shared.gradio['Instruction templates'].change(chat.redraw_html, reload_inputs, [shared.gradio['display']])
+            shared.gradio['Chat mode'].change(chat.redraw_html, reload_inputs, [shared.gradio['display']])
 
             shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}")
             shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings['name1'], shared.settings['name2']), None, None)
-            shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True)
+            shared.gradio['interface'].load(chat.redraw_html, reload_inputs, [shared.gradio['display']], show_progress=True)
 
         elif shared.args.notebook:
             with gr.Tab("Text generation", elem_id="main"):
@@ -442,9 +461,9 @@ def create_interface():
             with gr.Tab("Parameters", elem_id="parameters"):
                 create_settings_menus(default_preset)
 
-            shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'seed']]
+            shared.input_params = [shared.gradio[k] for k in ['textbox', 'generate_state']]
             output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
-            gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
+            gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
             gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
             shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
             shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
@@ -475,14 +494,17 @@ def create_interface():
             with gr.Tab("Parameters", elem_id="parameters"):
                 create_settings_menus(default_preset)
 
-            shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'seed']]
+            shared.input_params = [shared.gradio[k] for k in ['textbox', 'generate_state']]
             output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']]
-            gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
+            gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
             gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
             gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream))
             shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
             shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
 
+        with gr.Tab("Model", elem_id="model-tab"):
+            create_model_menus()
+
         with gr.Tab("Training", elem_id="training-tab"):
             training.create_train_interface()
 
@@ -496,7 +518,6 @@ def create_interface():
             cmd_list = vars(shared.args)
             bool_list = [k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes]
             bool_active = [k for k in bool_list if vars(shared.args)[k]]
-            #int_list = [k for k in cmd_list if type(k) is int]
 
             gr.Markdown("*Experimental*")
             shared.gradio['interface_modes_menu'] = gr.Dropdown(choices=modes, value=current_mode, label="Mode")
@@ -510,6 +531,21 @@ def create_interface():
         if shared.args.extensions is not None:
             extensions_module.create_extensions_block()
 
+        def change_dict_value(d, key, value):
+            d[key] = value
+            return d
+
+        for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size_slider', 'chat_generation_attempts']:
+            if k not in shared.gradio:
+                continue
+            if type(shared.gradio[k]) in [gr.Checkbox, gr.Number]:
+                shared.gradio[k].change(lambda state, value, copy=k: change_dict_value(state, copy, value), inputs=[shared.gradio['generate_state'], shared.gradio[k]], outputs=shared.gradio['generate_state'])
+            else:
+                shared.gradio[k].release(lambda state, value, copy=k: change_dict_value(state, copy, value), inputs=[shared.gradio['generate_state'], shared.gradio[k]], outputs=shared.gradio['generate_state'])
+
+        if not shared.is_chat():
+            api.create_apis()
+
     # Authentication
     auth = None
     if shared.args.gradio_auth_path is not None: