Bladeren bron

Refactor several function calls and the API

oobabooga 2 jaren geleden
bovenliggende
commit
3f3e42e26c
8 gewijzigde bestanden met toevoegingen van 147 en 118 verwijderingen
  1. 2 16
      api-example-stream.py
  2. 5 16
      api-example.py
  3. 19 16
      extensions/api/script.py
  4. 2 3
      extensions/send_pictures/script.py
  5. 38 0
      modules/api.py
  6. 24 19
      modules/chat.py
  7. 23 31
      modules/text_generation.py
  8. 34 17
      server.py

+ 2 - 16
api-example-stream.py

@@ -36,6 +36,7 @@ async def run(context):
         'early_stopping': False,
         'seed': -1,
     }
+    payload = json.dumps([context, params])
     session = random_hash()
 
     async with websockets.connect(f"ws://{server}:7860/queue/join") as websocket:
@@ -54,22 +55,7 @@ async def run(context):
                         "session_hash": session,
                         "fn_index": 12,
                         "data": [
-                            context,
-                            params['max_new_tokens'],
-                            params['do_sample'],
-                            params['temperature'],
-                            params['top_p'],
-                            params['typical_p'],
-                            params['repetition_penalty'],
-                            params['encoder_repetition_penalty'],
-                            params['top_k'],
-                            params['min_length'],
-                            params['no_repeat_ngram_size'],
-                            params['num_beams'],
-                            params['penalty_alpha'],
-                            params['length_penalty'],
-                            params['early_stopping'],
-                            params['seed'],
+                            payload
                         ]
                     }))
                 case "process_starts":

+ 5 - 16
api-example.py

@@ -10,6 +10,8 @@ Optionally, you can also add the --share flag to generate a public gradio URL,
 allowing you to use the API remotely.
 
 '''
+import json
+
 import requests
 
 # Server address
@@ -38,24 +40,11 @@ params = {
 # Input prompt
 prompt = "What I would like to say is the following: "
 
+payload = json.dumps([prompt, params])
+
 response = requests.post(f"http://{server}:7860/run/textgen", json={
     "data": [
-        prompt,
-        params['max_new_tokens'],
-        params['do_sample'],
-        params['temperature'],
-        params['top_p'],
-        params['typical_p'],
-        params['repetition_penalty'],
-        params['encoder_repetition_penalty'],
-        params['top_k'],
-        params['min_length'],
-        params['no_repeat_ngram_size'],
-        params['num_beams'],
-        params['penalty_alpha'],
-        params['length_penalty'],
-        params['early_stopping'],
-        params['seed'],
+        payload
     ]
 }).json()
 

+ 19 - 16
extensions/api/script.py

@@ -40,24 +40,27 @@ class Handler(BaseHTTPRequestHandler):
                 prompt_lines.pop(0)
 
             prompt = '\n'.join(prompt_lines)
+            generate_params =  {
+                'max_new_tokens': int(body.get('max_length', 200)), 
+                'do_sample': bool(body.get('do_sample', True)),
+                'temperature': float(body.get('temperature', 0.5)), 
+                'top_p': float(body.get('top_p', 1)), 
+                'typical_p': float(body.get('typical', 1)), 
+                'repetition_penalty': float(body.get('rep_pen', 1.1)), 
+                'encoder_repetition_penalty': 1,
+                'top_k': int(body.get('top_k', 0)), 
+                'min_length': int(body.get('min_length', 0)),
+                'no_repeat_ngram_size': int(body.get('no_repeat_ngram_size',0)),
+                'num_beams': int(body.get('num_beams',1)),
+                'penalty_alpha': float(body.get('penalty_alpha', 0)),
+                'length_penalty': float(body.get('length_penalty', 1)),
+                'early_stopping': bool(body.get('early_stopping', False)),
+                'seed': int(body.get('seed', -1)),
+            }
 
             generator = generate_reply(
-                question = prompt, 
-                max_new_tokens = int(body.get('max_length', 200)), 
-                do_sample=bool(body.get('do_sample', True)),
-                temperature=float(body.get('temperature', 0.5)), 
-                top_p=float(body.get('top_p', 1)), 
-                typical_p=float(body.get('typical', 1)), 
-                repetition_penalty=float(body.get('rep_pen', 1.1)), 
-                encoder_repetition_penalty=1, 
-                top_k=int(body.get('top_k', 0)), 
-                min_length=int(body.get('min_length', 0)),
-                no_repeat_ngram_size=int(body.get('no_repeat_ngram_size',0)),
-                num_beams=int(body.get('num_beams',1)),
-                penalty_alpha=float(body.get('penalty_alpha', 0)),
-                length_penalty=float(body.get('length_penalty', 1)),
-                early_stopping=bool(body.get('early_stopping', False)),
-                seed=int(body.get('seed', -1)),
+                prompt, 
+                generate_params,
                 stopping_strings=body.get('stopping_strings', []),
             )
 

+ 2 - 3
extensions/send_pictures/script.py

@@ -2,12 +2,11 @@ import base64
 from io import BytesIO
 
 import gradio as gr
-import modules.chat as chat
-import modules.shared as shared
 import torch
-from PIL import Image
 from transformers import BlipForConditionalGeneration, BlipProcessor
 
+from modules import chat, shared
+
 # If 'state' is True, will hijack the next chat generation with
 # custom input text given by 'value' in the format [text, visible_text]
 input_hijack = {

+ 38 - 0
modules/api.py

@@ -0,0 +1,38 @@
+import json
+
+import gradio as gr
+
+from modules import shared
+from modules.text_generation import generate_reply
+
+
+def generate_reply_wrapper(string):
+    generate_params = {
+        'do_sample': True,
+        'temperature': 1,
+        'top_p': 1,
+        'typical_p': 1,
+        'repetition_penalty': 1,
+        'encoder_repetition_penalty': 1,
+        'top_k': 50,
+        'num_beams': 1,
+        'penalty_alpha': 0,
+        'min_length': 0,
+        'length_penalty': 1,
+        'no_repeat_ngram_size': 0,
+        'early_stopping': False,
+    }
+    params = json.loads(string)
+    for k in params[1]:
+        generate_params[k] = params[1][k]
+    for i in generate_reply(params[0], generate_params):
+        yield i
+
+def create_apis():
+    t1 = gr.Textbox(visible=False)
+    t2 = gr.Textbox(visible=False)
+    dummy = gr.Button(visible=False)
+
+    input_params = [t1]
+    output_params = [t2] + [shared.gradio[k] for k in ['markdown', 'html']]
+    dummy.click(generate_reply_wrapper, input_params, output_params, api_name='textgen')

+ 24 - 19
modules/chat.py

@@ -18,7 +18,12 @@ from modules.text_generation import (encode, generate_reply,
                                      get_max_prompt_length)
 
 
-def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, is_instruct, end_of_turn="", impersonate=False, also_return_rows=False):
+def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, **kwargs):
+    is_instruct = kwargs['is_instruct'] if 'is_instruct' in kwargs else False
+    end_of_turn = kwargs['end_of_turn'] if 'end_of_turn' in kwargs else ''
+    impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
+    also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
+
     user_input = fix_newlines(user_input)
     rows = [f"{context.strip()}\n"]
 
@@ -91,9 +96,9 @@ def extract_message_from_reply(reply, name1, name2, stop_at_newline):
     reply = fix_newlines(reply)
     return reply, next_character_found
 
-def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1, regenerate=False, mode="cai-chat", end_of_turn=""):
+def chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False):
     just_started = True
-    eos_token = '\n' if stop_at_newline else None
+    eos_token = '\n' if generate_state['stop_at_newline'] else None
     name1_original = name1
     if 'pygmalion' in shared.model_name.lower():
         name1 = "You"
@@ -112,11 +117,11 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
         visible_text = text
     text = apply_extensions(text, "input")
 
-    is_instruct = mode == 'instruct'
+    kwargs = {'end_of_turn': end_of_turn, 'is_instruct': mode == 'instruct'}
     if custom_generate_chat_prompt is None:
-        prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, is_instruct, end_of_turn=end_of_turn)
+        prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
     else:
-        prompt = custom_generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, is_instruct, end_of_turn=end_of_turn)
+        prompt = custom_generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], **kwargs)
 
     # Yield *Is typing...*
     if not regenerate:
@@ -124,13 +129,13 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
 
     # Generate
     cumulative_reply = ''
-    for i in range(chat_generation_attempts):
+    for i in range(generate_state['chat_generation_attempts']):
         reply = None
-        for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]):
+        for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]):
             reply = cumulative_reply + reply
 
             # Extracting the reply
-            reply, next_character_found = extract_message_from_reply(reply, name1, name2, stop_at_newline)
+            reply, next_character_found = extract_message_from_reply(reply, name1, name2, generate_state['stop_at_newline'])
             visible_reply = re.sub("(<USER>|<user>|{{user}})", name1_original, reply)
             visible_reply = apply_extensions(visible_reply, "output")
 
@@ -155,23 +160,23 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
 
     yield shared.history['visible']
 
-def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1, mode="cai-chat", end_of_turn=""):
-    eos_token = '\n' if stop_at_newline else None
+def impersonate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
+    eos_token = '\n' if generate_state['stop_at_newline'] else None
 
     if 'pygmalion' in shared.model_name.lower():
         name1 = "You"
 
-    prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=True, end_of_turn=end_of_turn)
+    prompt = generate_chat_prompt(text, generate_state['max_new_tokens'], name1, name2, context, generate_state['chat_prompt_size'], impersonate=True, end_of_turn=end_of_turn)
 
     # Yield *Is typing...*
     yield shared.processing_message
 
     cumulative_reply = ''
-    for i in range(chat_generation_attempts):
+    for i in range(generate_state['chat_generation_attempts']):
         reply = None
-        for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]):
+        for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", generate_state, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]):
             reply = cumulative_reply + reply
-            reply, next_character_found = extract_message_from_reply(reply, name1, name2, stop_at_newline)
+            reply, next_character_found = extract_message_from_reply(reply, name1, name2, generate_state['stop_at_newline'])
             yield reply
             if next_character_found:
                 break
@@ -181,11 +186,11 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ
 
     yield reply
 
-def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1, mode="cai-chat", end_of_turn=""):
-    for history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts, regenerate=False, mode=mode, end_of_turn=end_of_turn):
+def cai_chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
+    for history in chatbot_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn, regenerate=False):
         yield chat_html_wrapper(history, name1, name2, mode)
 
-def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1, mode="cai-chat", end_of_turn=""):
+def regenerate_wrapper(text, generate_state, name1, name2, context, mode, end_of_turn):
     if (shared.character != 'None' and len(shared.history['visible']) == 1) or len(shared.history['internal']) == 0:
         yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
     else:
@@ -193,7 +198,7 @@ def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typi
         last_internal = shared.history['internal'].pop()
         # Yield '*Is typing...*'
         yield chat_html_wrapper(shared.history['visible']+[[last_visible[0], shared.processing_message]], name1, name2, mode)
-        for history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts, regenerate=True, mode=mode, end_of_turn=end_of_turn):
+        for history in chatbot_wrapper(last_internal[0], generate_state, name1, name2, context, mode, end_of_turn, regenerate=True):
             shared.history['visible'][-1] = [last_visible[0], history[-1][1]]
             yield chat_html_wrapper(shared.history['visible'], name1, name2, mode)
 

+ 23 - 31
modules/text_generation.py

@@ -102,10 +102,11 @@ def set_manual_seed(seed):
 def stop_everything_event():
     shared.stop_everything = True
 
-def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=None, stopping_strings=[]):
+def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]):
     clear_torch_cache()
-    set_manual_seed(seed)
+    set_manual_seed(generate_state['seed'])
     shared.stop_everything = False
+    generate_params = {}
     t0 = time.time()
 
     original_question = question
@@ -117,9 +118,12 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
     # These models are not part of Hugging Face, so we handle them
     # separately and terminate the function call earlier
     if any((shared.is_RWKV, shared.is_llamacpp)):
+        for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']:
+            generate_params[k] = generate_state[k]
+        generate_params["token_count"] = generate_state["max_new_tokens"]
         try:
             if shared.args.no_stream:
-                reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty)
+                reply = shared.model.generate(context=question, **generate_params)
                 output = original_question+reply
                 if not shared.is_chat():
                     reply = original_question + apply_extensions(reply, "output")
@@ -130,7 +134,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
 
                 # RWKV has proper streaming, which is very nice.
                 # No need to generate 8 tokens at a time.
-                for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty):
+                for reply in shared.model.generate_with_streaming(context=question, **generate_params):
                     output = original_question+reply
                     if not shared.is_chat():
                         reply = original_question + apply_extensions(reply, "output")
@@ -145,7 +149,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
             print(f"Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens})")
             return
 
-    input_ids = encode(question, max_new_tokens)
+    input_ids = encode(question, generate_state['max_new_tokens'])
     original_input_ids = input_ids
     output = input_ids[0]
 
@@ -158,33 +162,21 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
         t = [encode(string, 0, add_special_tokens=False) for string in stopping_strings]
         stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=t, starting_idx=len(input_ids[0])))
 
-    generate_params = {}
+    generate_params["max_new_tokens"] = generate_state['max_new_tokens']
     if not shared.args.flexgen:
-        generate_params.update({
-            "max_new_tokens": max_new_tokens,
-            "eos_token_id": eos_token_ids,
-            "stopping_criteria": stopping_criteria_list,
-            "do_sample": do_sample,
-            "temperature": temperature,
-            "top_p": top_p,
-            "typical_p": typical_p,
-            "repetition_penalty": repetition_penalty,
-            "encoder_repetition_penalty": encoder_repetition_penalty,
-            "top_k": top_k,
-            "min_length": min_length if shared.args.no_stream else 0,
-            "no_repeat_ngram_size": no_repeat_ngram_size,
-            "num_beams": num_beams,
-            "penalty_alpha": penalty_alpha,
-            "length_penalty": length_penalty,
-            "early_stopping": early_stopping,
-        })
+        for k in ["do_sample", "temperature", "top_p", "typical_p", "repetition_penalty", "encoder_repetition_penalty", "top_k", "min_length", "no_repeat_ngram_size", "num_beams", "penalty_alpha", "length_penalty", "early_stopping"]:
+            generate_params[k] = generate_state[k]
+        generate_params["eos_token_id"] = eos_token_ids
+        generate_params["stopping_criteria"] = stopping_criteria_list
+        if shared.args.no_stream:
+            generate_params["min_length"] = 0
     else:
-        generate_params.update({
-            "max_new_tokens": max_new_tokens if shared.args.no_stream else 8,
-            "do_sample": do_sample,
-            "temperature": temperature,
-            "stop": eos_token_ids[-1],
-        })
+        for k in ["do_sample", "temperature"]:
+            generate_params[k] = generate_state[k]
+        generate_params["stop"] = generate_state["eos_token_ids"][-1]
+        if not shared.args.no_stream:
+            generate_params["max_new_tokens"] = 8
+
     if shared.args.no_cache:
         generate_params.update({"use_cache": False})
     if shared.args.deepspeed:
@@ -244,7 +236,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
 
         # Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
         else:
-            for i in range(max_new_tokens//8+1):
+            for i in range(generate_state['max_new_tokens']//8+1):
                 clear_torch_cache()
                 with torch.no_grad():
                     output = shared.model.generate(**generate_params)[0]

+ 34 - 17
server.py

@@ -15,7 +15,7 @@ import gradio as gr
 from PIL import Image
 
 import modules.extensions as extensions_module
-from modules import chat, shared, training, ui
+from modules import chat, shared, training, ui, api
 from modules.html_generator import chat_html_wrapper
 from modules.LoRA import add_lora_to_model
 from modules.models import load_model, load_soft_prompt
@@ -85,7 +85,7 @@ def load_lora_wrapper(selected_lora):
     add_lora_to_model(selected_lora)
     return selected_lora
 
-def load_preset_values(preset_menu, return_dict=False):
+def load_preset_values(preset_menu, state, return_dict=False):
     generate_params = {
         'do_sample': True,
         'temperature': 1,
@@ -107,13 +107,13 @@ def load_preset_values(preset_menu, return_dict=False):
         i = i.rstrip(',').strip().split('=')
         if len(i) == 2 and i[0].strip() != 'tokens':
             generate_params[i[0].strip()] = eval(i[1].strip())
-
     generate_params['temperature'] = min(1.99, generate_params['temperature'])
 
     if return_dict:
         return generate_params
     else:
-        return generate_params['do_sample'], generate_params['temperature'], generate_params['top_p'], generate_params['typical_p'], generate_params['repetition_penalty'], generate_params['encoder_repetition_penalty'], generate_params['top_k'], generate_params['min_length'], generate_params['no_repeat_ngram_size'], generate_params['num_beams'], generate_params['penalty_alpha'], generate_params['length_penalty'], generate_params['early_stopping']
+        state.update(generate_params)
+        return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
 
 def upload_soft_prompt(file):
     with zipfile.ZipFile(io.BytesIO(file)) as zf:
@@ -170,7 +170,10 @@ def create_prompt_menus():
     shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False)
 
 def create_settings_menus(default_preset):
-    generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', return_dict=True)
+    generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', {}, return_dict=True)
+    for k in ['max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size', 'chat_generation_attempts']:
+        generate_params[k] = shared.settings[k]
+    shared.gradio['generate_state'] = gr.State(generate_params)
 
     with gr.Row():
         with gr.Column():
@@ -221,17 +224,16 @@ def create_settings_menus(default_preset):
         with gr.Row():
             shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip'])
 
-    shared.gradio['model_menu'].change(load_model_wrapper, [shared.gradio['model_menu']], [shared.gradio['model_menu']], show_progress=True)
-    shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']])
-    shared.gradio['lora_menu'].change(load_lora_wrapper, [shared.gradio['lora_menu']], [shared.gradio['lora_menu']], show_progress=True)
-    shared.gradio['softprompts_menu'].change(load_soft_prompt, [shared.gradio['softprompts_menu']], [shared.gradio['softprompts_menu']], show_progress=True)
-    shared.gradio['upload_softprompt'].upload(upload_soft_prompt, [shared.gradio['upload_softprompt']], [shared.gradio['softprompts_menu']])
+    shared.gradio['model_menu'].change(load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_menu'], show_progress=True)
+    shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio[k] for k in ['preset_menu', 'generate_state']], [shared.gradio[k] for k in ['generate_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']])
+    shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True)
+    shared.gradio['softprompts_menu'].change(load_soft_prompt, shared.gradio['softprompts_menu'], shared.gradio['softprompts_menu'], show_progress=True)
+    shared.gradio['upload_softprompt'].upload(upload_soft_prompt, shared.gradio['upload_softprompt'], shared.gradio['softprompts_menu'])
 
 def set_interface_arguments(interface_mode, extensions, bool_active):
     modes = ["default", "notebook", "chat", "cai_chat"]
     cmd_list = vars(shared.args)
     bool_list = [k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes]
-    #int_list = [k for k in cmd_list if type(k) is int]
 
     shared.args.extensions = extensions
     for k in modes[1:]:
@@ -372,11 +374,11 @@ def create_interface():
                             shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size'])
                         with gr.Column():
                             shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)')
-                            shared.gradio['check'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character?')
+                            shared.gradio['stop_at_newline'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character?')
 
                 create_settings_menus(default_preset)
 
-            shared.input_params = [shared.gradio[k] for k in ['Chat input', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'seed', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider', 'chat_generation_attempts', 'Chat mode', 'end_of_turn']]
+            shared.input_params = [shared.gradio[k] for k in ['Chat input', 'generate_state', 'name1', 'name2', 'context', 'Chat mode', 'end_of_turn']]
 
             def set_chat_input(textbox):
                 return textbox, ""
@@ -456,9 +458,9 @@ def create_interface():
             with gr.Tab("Parameters", elem_id="parameters"):
                 create_settings_menus(default_preset)
 
-            shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'seed']]
+            shared.input_params = [shared.gradio[k] for k in ['textbox', 'generate_state']]
             output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
-            gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
+            gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
             gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
             shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
             shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
@@ -489,9 +491,9 @@ def create_interface():
             with gr.Tab("Parameters", elem_id="parameters"):
                 create_settings_menus(default_preset)
 
-            shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'seed']]
+            shared.input_params = [shared.gradio[k] for k in ['textbox', 'generate_state']]
             output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']]
-            gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
+            gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
             gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
             gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream))
             shared.gradio['Stop'].click(stop_everything_event, [], [], queue=False, cancels=gen_events if shared.args.no_stream else None)
@@ -524,6 +526,21 @@ def create_interface():
         if shared.args.extensions is not None:
             extensions_module.create_extensions_block()
 
+        def change_dict_value(d, key, value):
+            d[key] = value
+            return d
+
+        for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size_slider', 'chat_generation_attempts']:
+            if k not in shared.gradio:
+                continue
+            if type(shared.gradio[k]) in [gr.Checkbox, gr.Number]:
+                shared.gradio[k].change(lambda state, value, copy=k: change_dict_value(state, copy, value), inputs=[shared.gradio['generate_state'], shared.gradio[k]], outputs=shared.gradio['generate_state'])
+            else:
+                shared.gradio[k].release(lambda state, value, copy=k: change_dict_value(state, copy, value), inputs=[shared.gradio['generate_state'], shared.gradio[k]], outputs=shared.gradio['generate_state'])
+
+        if not shared.is_chat():
+            api.create_apis()
+
     # Authentication
     auth = None
     if shared.args.gradio_auth_path is not None: