server.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554
  1. import gc
  2. import io
  3. import json
  4. import os
  5. import re
  6. import sys
  7. import time
  8. import zipfile
  9. from pathlib import Path
  10. import gradio as gr
  11. import numpy as np
  12. import torch
  13. import transformers
  14. from PIL import Image
  15. from transformers import AutoConfig
  16. from transformers import AutoModelForCausalLM
  17. from transformers import AutoTokenizer
  18. import modules.chat as chat
  19. import modules.extensions as extensions_module
  20. import modules.shared as shared
  21. from modules.extensions import extension_state
  22. from modules.extensions import load_extensions
  23. from modules.extensions import update_extensions_parameters
  24. from modules.html_generator import *
  25. from modules.prompt import generate_reply
  26. from modules.ui import *
  27. transformers.logging.set_verbosity_error()
  28. if (shared.args.chat or shared.args.cai_chat) and not shared.args.no_stream:
  29. print("Warning: chat mode currently becomes somewhat slower with text streaming on.\nConsider starting the web UI with the --no-stream option.\n")
  30. settings = {
  31. 'max_new_tokens': 200,
  32. 'max_new_tokens_min': 1,
  33. 'max_new_tokens_max': 2000,
  34. 'preset': 'NovelAI-Sphinx Moth',
  35. 'name1': 'Person 1',
  36. 'name2': 'Person 2',
  37. 'context': 'This is a conversation between two people.',
  38. 'prompt': 'Common sense questions and answers\n\nQuestion: \nFactual answer:',
  39. 'prompt_gpt4chan': '-----\n--- 865467536\nInput text\n--- 865467537\n',
  40. 'stop_at_newline': True,
  41. 'chat_prompt_size': 2048,
  42. 'chat_prompt_size_min': 0,
  43. 'chat_prompt_size_max': 2048,
  44. 'preset_pygmalion': 'Pygmalion',
  45. 'name1_pygmalion': 'You',
  46. 'name2_pygmalion': 'Kawaii',
  47. 'context_pygmalion': "Kawaii's persona: Kawaii is a cheerful person who loves to make others smile. She is an optimist who loves to spread happiness and positivity wherever she goes.\n<START>",
  48. 'stop_at_newline_pygmalion': False,
  49. }
  50. if shared.args.settings is not None and Path(shared.args.settings).exists():
  51. new_settings = json.loads(open(Path(shared.args.settings), 'r').read())
  52. for item in new_settings:
  53. settings[item] = new_settings[item]
  54. if shared.args.flexgen:
  55. from flexgen.flex_opt import (Policy, OptLM, TorchDevice, TorchDisk, TorchMixedDevice, CompressionConfig, Env, Task, get_opt_config)
  56. if shared.args.deepspeed:
  57. import deepspeed
  58. from transformers.deepspeed import HfDeepSpeedConfig, is_deepspeed_zero3_enabled
  59. from modules.deepspeed_parameters import generate_ds_config
  60. # Distributed setup
  61. local_rank = shared.args.local_rank if shared.args.local_rank is not None else int(os.getenv("LOCAL_RANK", "0"))
  62. world_size = int(os.getenv("WORLD_SIZE", "1"))
  63. torch.cuda.set_device(local_rank)
  64. deepspeed.init_distributed()
  65. ds_config = generate_ds_config(shared.args.bf16, 1 * world_size, shared.args.nvme_offload_dir)
  66. dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration
  67. if shared.args.picture and (shared.args.cai_chat or shared.args.chat):
  68. import modules.bot_picture as bot_picture
  69. def load_model(model_name):
  70. print(f"Loading {model_name}...")
  71. t0 = time.time()
  72. # Default settings
  73. if not (shared.args.cpu or shared.args.load_in_8bit or shared.args.auto_devices or shared.args.disk or shared.args.gpu_memory is not None or shared.args.cpu_memory is not None or shared.args.deepspeed or shared.args.flexgen):
  74. if any(size in shared.model_name.lower() for size in ('13b', '20b', '30b')):
  75. model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), device_map='auto', load_in_8bit=True)
  76. else:
  77. model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16).cuda()
  78. # FlexGen
  79. elif shared.args.flexgen:
  80. gpu = TorchDevice("cuda:0")
  81. cpu = TorchDevice("cpu")
  82. disk = TorchDisk(shared.args.disk_cache_dir)
  83. env = Env(gpu=gpu, cpu=cpu, disk=disk, mixed=TorchMixedDevice([gpu, cpu, disk]))
  84. # Offloading policy
  85. policy = Policy(1, 1,
  86. shared.args.percent[0], shared.args.percent[1],
  87. shared.args.percent[2], shared.args.percent[3],
  88. shared.args.percent[4], shared.args.percent[5],
  89. overlap=True, sep_layer=True, pin_weight=True,
  90. cpu_cache_compute=False, attn_sparsity=1.0,
  91. compress_weight=shared.args.compress_weight,
  92. comp_weight_config=CompressionConfig(
  93. num_bits=4, group_size=64,
  94. group_dim=0, symmetric=False),
  95. compress_cache=False,
  96. comp_cache_config=CompressionConfig(
  97. num_bits=4, group_size=64,
  98. group_dim=2, symmetric=False))
  99. opt_config = get_opt_config(f"facebook/{shared.model_name}")
  100. model = OptLM(opt_config, env, "models", policy)
  101. model.init_all_weights()
  102. # DeepSpeed ZeRO-3
  103. elif shared.args.deepspeed:
  104. model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16)
  105. model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0]
  106. model.module.eval() # Inference
  107. print(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}")
  108. # Custom
  109. else:
  110. command = "AutoModelForCausalLM.from_pretrained"
  111. params = ["low_cpu_mem_usage=True"]
  112. if not shared.args.cpu and not torch.cuda.is_available():
  113. print("Warning: no GPU has been detected.\nFalling back to CPU mode.\n")
  114. shared.args.cpu = True
  115. if shared.args.cpu:
  116. params.append("low_cpu_mem_usage=True")
  117. params.append("torch_dtype=torch.float32")
  118. else:
  119. params.append("device_map='auto'")
  120. params.append("load_in_8bit=True" if shared.args.load_in_8bit else "torch_dtype=torch.bfloat16" if shared.args.bf16 else "torch_dtype=torch.float16")
  121. if shared.args.gpu_memory:
  122. params.append(f"max_memory={{0: '{shared.args.gpu_memory or '99'}GiB', 'cpu': '{shared.args.cpu_memory or '99'}GiB'}}")
  123. elif not shared.args.load_in_8bit:
  124. total_mem = (torch.cuda.get_device_properties(0).total_memory/(1024*1024))
  125. suggestion = round((total_mem-1000)/1000)*1000
  126. if total_mem-suggestion < 800:
  127. suggestion -= 1000
  128. suggestion = int(round(suggestion/1000))
  129. print(f"\033[1;32;1mAuto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors.\nYou can manually set other values.\033[0;37;0m")
  130. params.append(f"max_memory={{0: '{suggestion}GiB', 'cpu': '{shared.args.cpu_memory or '99'}GiB'}}")
  131. if shared.args.disk:
  132. params.append(f"offload_folder='{shared.args.disk_cache_dir}'")
  133. command = f"{command}(Path(f'models/{shared.model_name}'), {', '.join(set(params))})"
  134. model = eval(command)
  135. # Loading the tokenizer
  136. if shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')) and Path(f"models/gpt-j-6B/").exists():
  137. tokenizer = AutoTokenizer.from_pretrained(Path("models/gpt-j-6B/"))
  138. else:
  139. tokenizer = AutoTokenizer.from_pretrained(Path(f"models/{shared.model_name}/"))
  140. tokenizer.truncation_side = 'left'
  141. print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
  142. return model, tokenizer
  143. def load_soft_prompt(name):
  144. if name == 'None':
  145. shared.soft_prompt = False
  146. shared.soft_prompt_tensor = None
  147. else:
  148. with zipfile.ZipFile(Path(f'softprompts/{name}.zip')) as zf:
  149. zf.extract('tensor.npy')
  150. zf.extract('meta.json')
  151. j = json.loads(open('meta.json', 'r').read())
  152. print(f"\nLoading the softprompt \"{name}\".")
  153. for field in j:
  154. if field != 'name':
  155. if type(j[field]) is list:
  156. print(f"{field}: {', '.join(j[field])}")
  157. else:
  158. print(f"{field}: {j[field]}")
  159. print()
  160. tensor = np.load('tensor.npy')
  161. Path('tensor.npy').unlink()
  162. Path('meta.json').unlink()
  163. tensor = torch.Tensor(tensor).to(device=shared.model.device, dtype=shared.model.dtype)
  164. tensor = torch.reshape(tensor, (1, tensor.shape[0], tensor.shape[1]))
  165. shared.soft_prompt = True
  166. shared.soft_prompt_tensor = tensor
  167. return name
  168. def upload_soft_prompt(file):
  169. with zipfile.ZipFile(io.BytesIO(file)) as zf:
  170. zf.extract('meta.json')
  171. j = json.loads(open('meta.json', 'r').read())
  172. name = j['name']
  173. Path('meta.json').unlink()
  174. with open(Path(f'softprompts/{name}.zip'), 'wb') as f:
  175. f.write(file)
  176. return name
  177. def load_model_wrapper(selected_model):
  178. if selected_model != shared.model_name:
  179. shared.model_name = selected_model
  180. model = shared.tokenizer = None
  181. if not shared.args.cpu:
  182. gc.collect()
  183. torch.cuda.empty_cache()
  184. shared.model, shared.tokenizer = load_model(shared.model_name)
  185. return selected_model
  186. def load_preset_values(preset_menu, return_dict=False):
  187. generate_params = {
  188. 'do_sample': True,
  189. 'temperature': 1,
  190. 'top_p': 1,
  191. 'typical_p': 1,
  192. 'repetition_penalty': 1,
  193. 'top_k': 50,
  194. 'num_beams': 1,
  195. 'penalty_alpha': 0,
  196. 'min_length': 0,
  197. 'length_penalty': 1,
  198. 'no_repeat_ngram_size': 0,
  199. 'early_stopping': False,
  200. }
  201. with open(Path(f'presets/{preset_menu}.txt'), 'r') as infile:
  202. preset = infile.read()
  203. for i in preset.splitlines():
  204. i = i.rstrip(',').strip().split('=')
  205. if len(i) == 2 and i[0].strip() != 'tokens':
  206. generate_params[i[0].strip()] = eval(i[1].strip())
  207. generate_params['temperature'] = min(1.99, generate_params['temperature'])
  208. if return_dict:
  209. return generate_params
  210. else:
  211. return generate_params['do_sample'], generate_params['temperature'], generate_params['top_p'], generate_params['typical_p'], generate_params['repetition_penalty'], generate_params['top_k'], generate_params['min_length'], generate_params['no_repeat_ngram_size'], generate_params['num_beams'], generate_params['penalty_alpha'], generate_params['length_penalty'], generate_params['early_stopping']
  212. def get_available_models():
  213. return sorted([item.name for item in list(Path('models/').glob('*')) if not item.name.endswith(('.txt', '-np'))], key=str.lower)
  214. def get_available_presets():
  215. return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower)
  216. def get_available_characters():
  217. return ["None"] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('characters').glob('*.json'))), key=str.lower)
  218. def get_available_extensions():
  219. return sorted(set(map(lambda x : x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower)
  220. def get_available_softprompts():
  221. return ["None"] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('softprompts').glob('*.zip'))), key=str.lower)
  222. def create_extensions_block():
  223. extensions_ui_elements = []
  224. default_values = []
  225. if not (shared.args.chat or shared.args.cai_chat):
  226. gr.Markdown('## Extensions parameters')
  227. for ext in sorted(extension_state, key=lambda x : extension_state[x][1]):
  228. if extension_state[ext][0] == True:
  229. params = extensions_module.get_params(ext)
  230. for param in params:
  231. _id = f"{ext}-{param}"
  232. default_value = settings[_id] if _id in settings else params[param]
  233. default_values.append(default_value)
  234. if type(params[param]) == str:
  235. extensions_ui_elements.append(gr.Textbox(value=default_value, label=f"{ext}-{param}"))
  236. elif type(params[param]) in [int, float]:
  237. extensions_ui_elements.append(gr.Number(value=default_value, label=f"{ext}-{param}"))
  238. elif type(params[param]) == bool:
  239. extensions_ui_elements.append(gr.Checkbox(value=default_value, label=f"{ext}-{param}"))
  240. update_extensions_parameters(*default_values)
  241. btn_extensions = gr.Button("Apply")
  242. btn_extensions.click(update_extensions_parameters, [*extensions_ui_elements], [])
  243. def create_settings_menus():
  244. generate_params = load_preset_values(settings[f'preset{suffix}'] if not shared.args.flexgen else 'Naive', return_dict=True)
  245. with gr.Row():
  246. with gr.Column():
  247. with gr.Row():
  248. model_menu = gr.Dropdown(choices=available_models, value=shared.model_name, label='Model')
  249. create_refresh_button(model_menu, lambda : None, lambda : {"choices": get_available_models()}, "refresh-button")
  250. with gr.Column():
  251. with gr.Row():
  252. preset_menu = gr.Dropdown(choices=available_presets, value=settings[f'preset{suffix}'] if not shared.args.flexgen else 'Naive', label='Generation parameters preset')
  253. create_refresh_button(preset_menu, lambda : None, lambda : {"choices": get_available_presets()}, "refresh-button")
  254. with gr.Accordion("Custom generation parameters", open=False, elem_id="accordion"):
  255. with gr.Row():
  256. do_sample = gr.Checkbox(value=generate_params['do_sample'], label="do_sample")
  257. temperature = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label="temperature")
  258. with gr.Row():
  259. top_k = gr.Slider(0,200,value=generate_params['top_k'],step=1,label="top_k")
  260. top_p = gr.Slider(0.0,1.0,value=generate_params['top_p'],step=0.01,label="top_p")
  261. with gr.Row():
  262. repetition_penalty = gr.Slider(1.0,4.99,value=generate_params['repetition_penalty'],step=0.01,label="repetition_penalty")
  263. no_repeat_ngram_size = gr.Slider(0, 20, step=1, value=generate_params["no_repeat_ngram_size"], label="no_repeat_ngram_size")
  264. with gr.Row():
  265. typical_p = gr.Slider(0.0,1.0,value=generate_params['typical_p'],step=0.01,label="typical_p")
  266. min_length = gr.Slider(0, 2000, step=1, value=generate_params["min_length"] if shared.args.no_stream else 0, label="min_length", interactive=shared.args.no_stream)
  267. gr.Markdown("Contrastive search:")
  268. penalty_alpha = gr.Slider(0, 5, value=generate_params["penalty_alpha"], label="penalty_alpha")
  269. gr.Markdown("Beam search (uses a lot of VRAM):")
  270. with gr.Row():
  271. num_beams = gr.Slider(1, 20, step=1, value=generate_params["num_beams"], label="num_beams")
  272. length_penalty = gr.Slider(-5, 5, value=generate_params["length_penalty"], label="length_penalty")
  273. early_stopping = gr.Checkbox(value=generate_params["early_stopping"], label="early_stopping")
  274. with gr.Accordion("Soft prompt", open=False, elem_id="accordion"):
  275. with gr.Row():
  276. softprompts_menu = gr.Dropdown(choices=available_softprompts, value="None", label='Soft prompt')
  277. create_refresh_button(softprompts_menu, lambda : None, lambda : {"choices": get_available_softprompts()}, "refresh-button")
  278. gr.Markdown('Upload a soft prompt (.zip format):')
  279. with gr.Row():
  280. upload_softprompt = gr.File(type='binary', file_types=[".zip"])
  281. model_menu.change(load_model_wrapper, [model_menu], [model_menu], show_progress=True)
  282. preset_menu.change(load_preset_values, [preset_menu], [do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping])
  283. softprompts_menu.change(load_soft_prompt, [softprompts_menu], [softprompts_menu], show_progress=True)
  284. upload_softprompt.upload(upload_soft_prompt, [upload_softprompt], [softprompts_menu])
  285. return preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping
  286. # Global variables
  287. available_models = get_available_models()
  288. available_presets = get_available_presets()
  289. available_characters = get_available_characters()
  290. extensions_module.available_extensions = get_available_extensions()
  291. available_softprompts = get_available_softprompts()
  292. if shared.args.extensions is not None:
  293. load_extensions()
  294. # Choosing the default model
  295. if shared.args.model is not None:
  296. shared.model_name = shared.args.model
  297. else:
  298. if len(available_models) == 0:
  299. print("No models are available! Please download at least one.")
  300. sys.exit(0)
  301. elif len(available_models) == 1:
  302. i = 0
  303. else:
  304. print("The following models are available:\n")
  305. for i,model in enumerate(available_models):
  306. print(f"{i+1}. {model}")
  307. print(f"\nWhich one do you want to load? 1-{len(available_models)}\n")
  308. i = int(input())-1
  309. print()
  310. shared.model_name = available_models[i]
  311. shared.model, shared.tokenizer = load_model(shared.model_name)
  312. loaded_preset = None
  313. # UI settings
  314. if shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')):
  315. default_text = settings['prompt_gpt4chan']
  316. elif re.match('(rosey|chip|joi)_.*_instruct.*', shared.model_name.lower()) is not None:
  317. default_text = 'User: \n'
  318. else:
  319. default_text = settings['prompt']
  320. description = f"\n\n# Text generation lab\nGenerate text using Large Language Models.\n"
  321. suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else ''
  322. buttons = {}
  323. gen_events = []
  324. if shared.args.chat or shared.args.cai_chat:
  325. if Path(f'logs/persistent.json').exists():
  326. chat.load_history(open(Path(f'logs/persistent.json'), 'rb').read(), settings[f'name1{suffix}'], settings[f'name2{suffix}'])
  327. with gr.Blocks(css=css+chat_css, analytics_enabled=False) as interface:
  328. if shared.args.cai_chat:
  329. display = gr.HTML(value=generate_chat_html(chat.history['visible'], settings[f'name1{suffix}'], settings[f'name2{suffix}'], chat.character))
  330. else:
  331. display = gr.Chatbot(value=chat.history['visible'])
  332. textbox = gr.Textbox(label='Input')
  333. with gr.Row():
  334. buttons["Stop"] = gr.Button("Stop")
  335. buttons["Generate"] = gr.Button("Generate")
  336. buttons["Regenerate"] = gr.Button("Regenerate")
  337. with gr.Row():
  338. buttons["Impersonate"] = gr.Button("Impersonate")
  339. buttons["Remove last"] = gr.Button("Remove last")
  340. buttons["Clear history"] = gr.Button("Clear history")
  341. with gr.Row():
  342. buttons["Send last reply to input"] = gr.Button("Send last reply to input")
  343. buttons["Replace last reply"] = gr.Button("Replace last reply")
  344. if shared.args.picture:
  345. with gr.Row():
  346. picture_select = gr.Image(label="Send a picture", type='pil')
  347. with gr.Tab("Chat settings"):
  348. name1 = gr.Textbox(value=settings[f'name1{suffix}'], lines=1, label='Your name')
  349. name2 = gr.Textbox(value=settings[f'name2{suffix}'], lines=1, label='Bot\'s name')
  350. context = gr.Textbox(value=settings[f'context{suffix}'], lines=2, label='Context')
  351. with gr.Row():
  352. character_menu = gr.Dropdown(choices=available_characters, value="None", label='Character')
  353. create_refresh_button(character_menu, lambda : None, lambda : {"choices": get_available_characters()}, "refresh-button")
  354. with gr.Row():
  355. check = gr.Checkbox(value=settings[f'stop_at_newline{suffix}'], label='Stop generating at new line character?')
  356. with gr.Row():
  357. with gr.Tab('Chat history'):
  358. with gr.Row():
  359. with gr.Column():
  360. gr.Markdown('Upload')
  361. upload_chat_history = gr.File(type='binary', file_types=[".json", ".txt"])
  362. with gr.Column():
  363. gr.Markdown('Download')
  364. download = gr.File()
  365. buttons["Download"] = gr.Button(value="Click me")
  366. with gr.Tab('Upload character'):
  367. with gr.Row():
  368. with gr.Column():
  369. gr.Markdown('1. Select the JSON file')
  370. upload_char = gr.File(type='binary', file_types=[".json"])
  371. with gr.Column():
  372. gr.Markdown('2. Select your character\'s profile picture (optional)')
  373. upload_img = gr.File(type='binary', file_types=["image"])
  374. buttons["Upload character"] = gr.Button(value="Submit")
  375. with gr.Tab('Upload your profile picture'):
  376. upload_img_me = gr.File(type='binary', file_types=["image"])
  377. with gr.Tab('Upload TavernAI Character Card'):
  378. upload_img_tavern = gr.File(type='binary', file_types=["image"])
  379. with gr.Tab("Generation settings"):
  380. with gr.Row():
  381. with gr.Column():
  382. max_new_tokens = gr.Slider(minimum=settings['max_new_tokens_min'], maximum=settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=settings['max_new_tokens'])
  383. with gr.Column():
  384. chat_prompt_size_slider = gr.Slider(minimum=settings['chat_prompt_size_min'], maximum=settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=settings['chat_prompt_size'])
  385. preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping = create_settings_menus()
  386. if shared.args.extensions is not None:
  387. with gr.Tab("Extensions"):
  388. create_extensions_block()
  389. input_params = [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size_slider]
  390. if shared.args.picture:
  391. input_params.append(picture_select)
  392. function_call = "chat.cai_chatbot_wrapper" if shared.args.cai_chat else "chat.chatbot_wrapper"
  393. gen_events.append(buttons["Generate"].click(eval(function_call), input_params, display, show_progress=shared.args.no_stream, api_name="textgen"))
  394. gen_events.append(textbox.submit(eval(function_call), input_params, display, show_progress=shared.args.no_stream))
  395. if shared.args.picture:
  396. picture_select.upload(eval(function_call), input_params, display, show_progress=shared.args.no_stream)
  397. gen_events.append(buttons["Regenerate"].click(chat.regenerate_wrapper, input_params, display, show_progress=shared.args.no_stream))
  398. gen_events.append(buttons["Impersonate"].click(chat.impersonate_wrapper, input_params, textbox, show_progress=shared.args.no_stream))
  399. buttons["Stop"].click(chat.stop_everything_event, [], [], cancels=gen_events)
  400. buttons["Send last reply to input"].click(chat.send_last_reply_to_input, [], textbox, show_progress=shared.args.no_stream)
  401. buttons["Replace last reply"].click(chat.replace_last_reply, [textbox, name1, name2], display, show_progress=shared.args.no_stream)
  402. buttons["Clear history"].click(chat.clear_chat_log, [character_menu, name1, name2], display)
  403. buttons["Remove last"].click(chat.remove_last_message, [name1, name2], [display, textbox], show_progress=False)
  404. buttons["Download"].click(chat.save_history, inputs=[], outputs=[download])
  405. buttons["Upload character"].click(chat.upload_character, [upload_char, upload_img], [character_menu])
  406. # Clearing stuff and saving the history
  407. for i in ["Generate", "Regenerate", "Replace last reply"]:
  408. buttons[i].click(lambda x: "", textbox, textbox, show_progress=False)
  409. buttons[i].click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
  410. buttons["Clear history"].click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
  411. textbox.submit(lambda x: "", textbox, textbox, show_progress=False)
  412. textbox.submit(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
  413. character_menu.change(chat.load_character, [character_menu, name1, name2], [name2, context, display])
  414. upload_chat_history.upload(chat.load_history, [upload_chat_history, name1, name2], [])
  415. upload_img_tavern.upload(chat.upload_tavern_character, [upload_img_tavern, name1, name2], [character_menu])
  416. upload_img_me.upload(chat.upload_your_profile_picture, [upload_img_me], [])
  417. if shared.args.picture:
  418. picture_select.upload(lambda : None, [], [picture_select], show_progress=False)
  419. if shared.args.cai_chat:
  420. upload_chat_history.upload(chat.redraw_html, [name1, name2], [display])
  421. upload_img_me.upload(chat.redraw_html, [name1, name2], [display])
  422. else:
  423. upload_chat_history.upload(lambda : chat.history['visible'], [], [display])
  424. upload_img_me.upload(lambda : chat.history['visible'], [], [display])
  425. elif shared.args.notebook:
  426. with gr.Blocks(css=css, analytics_enabled=False) as interface:
  427. gr.Markdown(description)
  428. with gr.Tab('Raw'):
  429. textbox = gr.Textbox(value=default_text, lines=23)
  430. with gr.Tab('Markdown'):
  431. markdown = gr.Markdown()
  432. with gr.Tab('HTML'):
  433. html = gr.HTML()
  434. buttons["Generate"] = gr.Button("Generate")
  435. buttons["Stop"] = gr.Button("Stop")
  436. max_new_tokens = gr.Slider(minimum=settings['max_new_tokens_min'], maximum=settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=settings['max_new_tokens'])
  437. preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping = create_settings_menus()
  438. if shared.args.extensions is not None:
  439. create_extensions_block()
  440. gen_events.append(buttons["Generate"].click(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping], [textbox, markdown, html], show_progress=shared.args.no_stream, api_name="textgen"))
  441. gen_events.append(textbox.submit(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping], [textbox, markdown, html], show_progress=shared.args.no_stream))
  442. buttons["Stop"].click(None, None, None, cancels=gen_events)
  443. else:
  444. with gr.Blocks(css=css, analytics_enabled=False) as interface:
  445. gr.Markdown(description)
  446. with gr.Row():
  447. with gr.Column():
  448. textbox = gr.Textbox(value=default_text, lines=15, label='Input')
  449. max_new_tokens = gr.Slider(minimum=settings['max_new_tokens_min'], maximum=settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=settings['max_new_tokens'])
  450. buttons["Generate"] = gr.Button("Generate")
  451. with gr.Row():
  452. with gr.Column():
  453. buttons["Continue"] = gr.Button("Continue")
  454. with gr.Column():
  455. buttons["Stop"] = gr.Button("Stop")
  456. preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping = create_settings_menus()
  457. if shared.args.extensions is not None:
  458. create_extensions_block()
  459. with gr.Column():
  460. with gr.Tab('Raw'):
  461. output_textbox = gr.Textbox(lines=15, label='Output')
  462. with gr.Tab('Markdown'):
  463. markdown = gr.Markdown()
  464. with gr.Tab('HTML'):
  465. html = gr.HTML()
  466. gen_events.append(buttons["Generate"].click(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping], [output_textbox, markdown, html], show_progress=shared.args.no_stream, api_name="textgen"))
  467. gen_events.append(textbox.submit(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping], [output_textbox, markdown, html], show_progress=shared.args.no_stream))
  468. gen_events.append(buttons["Continue"].click(generate_reply, [output_textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping], [output_textbox, markdown, html], show_progress=shared.args.no_stream))
  469. buttons["Stop"].click(None, None, None, cancels=gen_events)
  470. interface.queue()
  471. if shared.args.listen:
  472. interface.launch(prevent_thread_lock=True, share=shared.args.share, server_name="0.0.0.0", server_port=shared.args.listen_port)
  473. else:
  474. interface.launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port)
  475. # I think that I will need this later
  476. while True:
  477. time.sleep(0.5)