server.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511
  1. import io
  2. import json
  3. import re
  4. import sys
  5. import time
  6. import zipfile
  7. from datetime import datetime
  8. from pathlib import Path
  9. import gradio as gr
  10. from modules import chat, shared, ui, training
  11. import modules.extensions as extensions_module
  12. from modules.html_generator import generate_chat_html
  13. from modules.LoRA import add_lora_to_model
  14. from modules.models import load_model, load_soft_prompt
  15. from modules.text_generation import clear_torch_cache, generate_reply
  16. # Loading custom settings
  17. settings_file = None
  18. if shared.args.settings is not None and Path(shared.args.settings).exists():
  19. settings_file = Path(shared.args.settings)
  20. elif Path('settings.json').exists():
  21. settings_file = Path('settings.json')
  22. if settings_file is not None:
  23. print(f"Loading settings from {settings_file}...")
  24. new_settings = json.loads(open(settings_file, 'r').read())
  25. for item in new_settings:
  26. shared.settings[item] = new_settings[item]
  27. def get_available_models():
  28. if shared.args.flexgen:
  29. return sorted([re.sub('-np$', '', item.name) for item in list(Path('models/').glob('*')) if item.name.endswith('-np')], key=str.lower)
  30. else:
  31. return sorted([re.sub('.pth$', '', item.name) for item in list(Path('models/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
  32. def get_available_presets():
  33. return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower)
  34. def get_available_prompts():
  35. prompts = []
  36. prompts += sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('prompts').glob('[0-9]*.txt'))), key=str.lower, reverse=True)
  37. prompts += sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('prompts').glob('*.txt'))), key=str.lower)
  38. prompts += ['None']
  39. return prompts
  40. def get_available_characters():
  41. return ['None'] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('characters').glob('*.json'))), key=str.lower)
  42. def get_available_extensions():
  43. return sorted(set(map(lambda x : x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower)
  44. def get_available_softprompts():
  45. return ['None'] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('softprompts').glob('*.zip'))), key=str.lower)
  46. def get_available_loras():
  47. return ['None'] + sorted([item.name for item in list(Path('loras/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
  48. def unload_model():
  49. shared.model = shared.tokenizer = None
  50. clear_torch_cache()
  51. def load_model_wrapper(selected_model):
  52. if selected_model != shared.model_name:
  53. shared.model_name = selected_model
  54. unload_model()
  55. if selected_model != '':
  56. shared.model, shared.tokenizer = load_model(shared.model_name)
  57. return selected_model
  58. def load_lora_wrapper(selected_lora):
  59. add_lora_to_model(selected_lora)
  60. default_text = shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')]
  61. return selected_lora, default_text
  62. def load_preset_values(preset_menu, return_dict=False):
  63. generate_params = {
  64. 'do_sample': True,
  65. 'temperature': 1,
  66. 'top_p': 1,
  67. 'typical_p': 1,
  68. 'repetition_penalty': 1,
  69. 'encoder_repetition_penalty': 1,
  70. 'top_k': 50,
  71. 'num_beams': 1,
  72. 'penalty_alpha': 0,
  73. 'min_length': 0,
  74. 'length_penalty': 1,
  75. 'no_repeat_ngram_size': 0,
  76. 'early_stopping': False,
  77. }
  78. with open(Path(f'presets/{preset_menu}.txt'), 'r') as infile:
  79. preset = infile.read()
  80. for i in preset.splitlines():
  81. i = i.rstrip(',').strip().split('=')
  82. if len(i) == 2 and i[0].strip() != 'tokens':
  83. generate_params[i[0].strip()] = eval(i[1].strip())
  84. generate_params['temperature'] = min(1.99, generate_params['temperature'])
  85. if return_dict:
  86. return generate_params
  87. else:
  88. return generate_params['do_sample'], generate_params['temperature'], generate_params['top_p'], generate_params['typical_p'], generate_params['repetition_penalty'], generate_params['encoder_repetition_penalty'], generate_params['top_k'], generate_params['min_length'], generate_params['no_repeat_ngram_size'], generate_params['num_beams'], generate_params['penalty_alpha'], generate_params['length_penalty'], generate_params['early_stopping']
  89. def upload_soft_prompt(file):
  90. with zipfile.ZipFile(io.BytesIO(file)) as zf:
  91. zf.extract('meta.json')
  92. j = json.loads(open('meta.json', 'r').read())
  93. name = j['name']
  94. Path('meta.json').unlink()
  95. with open(Path(f'softprompts/{name}.zip'), 'wb') as f:
  96. f.write(file)
  97. return name
  98. def create_model_and_preset_menus():
  99. with gr.Row():
  100. with gr.Column():
  101. with gr.Row():
  102. shared.gradio['model_menu'] = gr.Dropdown(choices=available_models, value=shared.model_name, label='Model')
  103. ui.create_refresh_button(shared.gradio['model_menu'], lambda : None, lambda : {'choices': get_available_models()}, 'refresh-button')
  104. with gr.Column():
  105. with gr.Row():
  106. shared.gradio['preset_menu'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset')
  107. ui.create_refresh_button(shared.gradio['preset_menu'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button')
  108. def save_prompt(text):
  109. fname = f"{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}.txt"
  110. with open(Path(f'prompts/{fname}'), 'w', encoding='utf-8') as f:
  111. f.write(text)
  112. return f"Saved prompt to prompts/{fname}"
  113. def load_prompt(fname):
  114. if fname in ['None', '']:
  115. return ''
  116. else:
  117. with open(Path(f'prompts/{fname}.txt'), 'r', encoding='utf-8') as f:
  118. return f.read()
  119. def create_prompt_menus():
  120. with gr.Row():
  121. with gr.Column():
  122. with gr.Row():
  123. shared.gradio['prompt_menu'] = gr.Dropdown(choices=get_available_prompts(), value='None', label='Prompt')
  124. ui.create_refresh_button(shared.gradio['prompt_menu'], lambda : None, lambda : {'choices': get_available_prompts()}, 'refresh-button')
  125. with gr.Column():
  126. with gr.Column():
  127. shared.gradio['save_prompt'] = gr.Button('Save prompt')
  128. shared.gradio['status'] = gr.Markdown('Ready')
  129. shared.gradio['prompt_menu'].change(load_prompt, [shared.gradio['prompt_menu']], [shared.gradio['textbox']], show_progress=True)
  130. shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False)
  131. def create_settings_menus(default_preset):
  132. generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', return_dict=True)
  133. with gr.Row():
  134. with gr.Column():
  135. create_model_and_preset_menus()
  136. with gr.Column():
  137. shared.gradio['seed'] = gr.Number(value=-1, label='Seed (-1 for random)')
  138. with gr.Row():
  139. with gr.Column():
  140. with gr.Box():
  141. gr.Markdown('Custom generation parameters ([reference](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig))')
  142. with gr.Row():
  143. with gr.Column():
  144. shared.gradio['temperature'] = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label='temperature')
  145. shared.gradio['top_p'] = gr.Slider(0.0,1.0,value=generate_params['top_p'],step=0.01,label='top_p')
  146. shared.gradio['top_k'] = gr.Slider(0,200,value=generate_params['top_k'],step=1,label='top_k')
  147. shared.gradio['typical_p'] = gr.Slider(0.0,1.0,value=generate_params['typical_p'],step=0.01,label='typical_p')
  148. with gr.Column():
  149. shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=generate_params['repetition_penalty'],step=0.01,label='repetition_penalty')
  150. shared.gradio['encoder_repetition_penalty'] = gr.Slider(0.8, 1.5, value=generate_params['encoder_repetition_penalty'],step=0.01,label='encoder_repetition_penalty')
  151. shared.gradio['no_repeat_ngram_size'] = gr.Slider(0, 20, step=1, value=generate_params['no_repeat_ngram_size'], label='no_repeat_ngram_size')
  152. shared.gradio['min_length'] = gr.Slider(0, 2000, step=1, value=generate_params['min_length'] if shared.args.no_stream else 0, label='min_length', interactive=shared.args.no_stream)
  153. shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
  154. with gr.Column():
  155. with gr.Box():
  156. gr.Markdown('Contrastive search')
  157. shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha')
  158. with gr.Box():
  159. gr.Markdown('Beam search (uses a lot of VRAM)')
  160. with gr.Row():
  161. with gr.Column():
  162. shared.gradio['num_beams'] = gr.Slider(1, 20, step=1, value=generate_params['num_beams'], label='num_beams')
  163. with gr.Column():
  164. shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
  165. shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
  166. with gr.Row():
  167. shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA')
  168. ui.create_refresh_button(shared.gradio['lora_menu'], lambda : None, lambda : {'choices': get_available_loras()}, 'refresh-button')
  169. with gr.Accordion('Soft prompt', open=False):
  170. with gr.Row():
  171. shared.gradio['softprompts_menu'] = gr.Dropdown(choices=available_softprompts, value='None', label='Soft prompt')
  172. ui.create_refresh_button(shared.gradio['softprompts_menu'], lambda : None, lambda : {'choices': get_available_softprompts()}, 'refresh-button')
  173. gr.Markdown('Upload a soft prompt (.zip format):')
  174. with gr.Row():
  175. shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip'])
  176. shared.gradio['model_menu'].change(load_model_wrapper, [shared.gradio['model_menu']], [shared.gradio['model_menu']], show_progress=True)
  177. shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']])
  178. shared.gradio['lora_menu'].change(load_lora_wrapper, [shared.gradio['lora_menu']], [shared.gradio['lora_menu'], shared.gradio['textbox']], show_progress=True)
  179. shared.gradio['softprompts_menu'].change(load_soft_prompt, [shared.gradio['softprompts_menu']], [shared.gradio['softprompts_menu']], show_progress=True)
  180. shared.gradio['upload_softprompt'].upload(upload_soft_prompt, [shared.gradio['upload_softprompt']], [shared.gradio['softprompts_menu']])
  181. def set_interface_arguments(interface_mode, extensions, cmd_active):
  182. modes = ["default", "notebook", "chat", "cai_chat"]
  183. cmd_list = vars(shared.args)
  184. cmd_list = [k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes]
  185. shared.args.extensions = extensions
  186. for k in modes[1:]:
  187. exec(f"shared.args.{k} = False")
  188. if interface_mode != "default":
  189. exec(f"shared.args.{interface_mode} = True")
  190. for k in cmd_list:
  191. exec(f"shared.args.{k} = False")
  192. for k in cmd_active:
  193. exec(f"shared.args.{k} = True")
  194. shared.need_restart = True
  195. available_models = get_available_models()
  196. available_presets = get_available_presets()
  197. available_characters = get_available_characters()
  198. available_softprompts = get_available_softprompts()
  199. available_loras = get_available_loras()
  200. # Default extensions
  201. extensions_module.available_extensions = get_available_extensions()
  202. if shared.args.chat or shared.args.cai_chat:
  203. for extension in shared.settings['chat_default_extensions']:
  204. shared.args.extensions = shared.args.extensions or []
  205. if extension not in shared.args.extensions:
  206. shared.args.extensions.append(extension)
  207. else:
  208. for extension in shared.settings['default_extensions']:
  209. shared.args.extensions = shared.args.extensions or []
  210. if extension not in shared.args.extensions:
  211. shared.args.extensions.append(extension)
  212. # Default model
  213. if shared.args.model is not None:
  214. shared.model_name = shared.args.model
  215. else:
  216. if len(available_models) == 0:
  217. print('No models are available! Please download at least one.')
  218. sys.exit(0)
  219. elif len(available_models) == 1:
  220. i = 0
  221. else:
  222. print('The following models are available:\n')
  223. for i, model in enumerate(available_models):
  224. print(f'{i+1}. {model}')
  225. print(f'\nWhich one do you want to load? 1-{len(available_models)}\n')
  226. i = int(input())-1
  227. print()
  228. shared.model_name = available_models[i]
  229. shared.model, shared.tokenizer = load_model(shared.model_name)
  230. if shared.args.lora:
  231. add_lora_to_model(shared.args.lora)
  232. # Default UI settings
  233. default_preset = shared.settings['presets'][next((k for k in shared.settings['presets'] if re.match(k.lower(), shared.model_name.lower())), 'default')]
  234. if shared.lora_name != "None":
  235. default_text = shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')]
  236. else:
  237. default_text = shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')]
  238. title ='Text generation web UI'
  239. description = '\n\n# Text generation lab\nGenerate text using Large Language Models.\n'
  240. suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else ''
  241. def create_interface():
  242. gen_events = []
  243. if shared.args.extensions is not None and len(shared.args.extensions) > 0:
  244. extensions_module.load_extensions()
  245. with gr.Blocks(css=ui.css if not any((shared.args.chat, shared.args.cai_chat)) else ui.css+ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']:
  246. if shared.args.chat or shared.args.cai_chat:
  247. with gr.Tab("Text generation", elem_id="main"):
  248. if shared.args.cai_chat:
  249. shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}'], shared.character))
  250. else:
  251. shared.gradio['display'] = gr.Chatbot(value=shared.history['visible']).style(color_map=("#326efd", "#212528"))
  252. shared.gradio['textbox'] = gr.Textbox(label='Input')
  253. with gr.Row():
  254. shared.gradio['Generate'] = gr.Button('Generate')
  255. shared.gradio['Stop'] = gr.Button('Stop', elem_id="stop")
  256. with gr.Row():
  257. shared.gradio['Impersonate'] = gr.Button('Impersonate')
  258. shared.gradio['Regenerate'] = gr.Button('Regenerate')
  259. with gr.Row():
  260. shared.gradio['Copy last reply'] = gr.Button('Copy last reply')
  261. shared.gradio['Replace last reply'] = gr.Button('Replace last reply')
  262. shared.gradio['Remove last'] = gr.Button('Remove last')
  263. shared.gradio['Clear history'] = gr.Button('Clear history')
  264. shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False)
  265. shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
  266. with gr.Tab("Character", elem_id="chat-settings"):
  267. shared.gradio['name1'] = gr.Textbox(value=shared.settings[f'name1{suffix}'], lines=1, label='Your name')
  268. shared.gradio['name2'] = gr.Textbox(value=shared.settings[f'name2{suffix}'], lines=1, label='Bot\'s name')
  269. shared.gradio['context'] = gr.Textbox(value=shared.settings[f'context{suffix}'], lines=5, label='Context')
  270. with gr.Row():
  271. shared.gradio['character_menu'] = gr.Dropdown(choices=available_characters, value='None', label='Character', elem_id='character-menu')
  272. ui.create_refresh_button(shared.gradio['character_menu'], lambda : None, lambda : {'choices': get_available_characters()}, 'refresh-button')
  273. with gr.Row():
  274. with gr.Tab('Chat history'):
  275. with gr.Row():
  276. with gr.Column():
  277. gr.Markdown('Upload')
  278. shared.gradio['upload_chat_history'] = gr.File(type='binary', file_types=['.json', '.txt'])
  279. with gr.Column():
  280. gr.Markdown('Download')
  281. shared.gradio['download'] = gr.File()
  282. shared.gradio['download_button'] = gr.Button(value='Click me')
  283. with gr.Tab('Upload character'):
  284. with gr.Row():
  285. with gr.Column():
  286. gr.Markdown('1. Select the JSON file')
  287. shared.gradio['upload_json'] = gr.File(type='binary', file_types=['.json'])
  288. with gr.Column():
  289. gr.Markdown('2. Select your character\'s profile picture (optional)')
  290. shared.gradio['upload_img_bot'] = gr.File(type='binary', file_types=['image'])
  291. shared.gradio['Upload character'] = gr.Button(value='Submit')
  292. with gr.Tab('Upload your profile picture'):
  293. shared.gradio['upload_img_me'] = gr.File(type='binary', file_types=['image'])
  294. with gr.Tab('Upload TavernAI Character Card'):
  295. shared.gradio['upload_img_tavern'] = gr.File(type='binary', file_types=['image'])
  296. with gr.Tab("Parameters", elem_id="parameters"):
  297. with gr.Box():
  298. gr.Markdown("Chat parameters")
  299. with gr.Row():
  300. with gr.Column():
  301. shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
  302. shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size'])
  303. with gr.Column():
  304. shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)')
  305. shared.gradio['check'] = gr.Checkbox(value=shared.settings[f'stop_at_newline{suffix}'], label='Stop generating at new line character?')
  306. create_settings_menus(default_preset)
  307. function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper'
  308. shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'seed', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider', 'chat_generation_attempts']]
  309. gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
  310. gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
  311. gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
  312. gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
  313. shared.gradio['Stop'].click(chat.stop_everything_event, [], [], cancels=gen_events, queue=False)
  314. shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream)
  315. shared.gradio['Replace last reply'].click(chat.replace_last_reply, [shared.gradio['textbox'], shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'], show_progress=shared.args.no_stream)
  316. # Clear history with confirmation
  317. clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']]
  318. shared.gradio['Clear history'].click(lambda :[gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
  319. shared.gradio['Clear history-confirm'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
  320. shared.gradio['Clear history-confirm'].click(chat.clear_chat_log, [shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'])
  321. shared.gradio['Clear history-cancel'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
  322. shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
  323. shared.gradio['download_button'].click(chat.save_history, inputs=[], outputs=[shared.gradio['download']])
  324. shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']])
  325. # Clearing stuff and saving the history
  326. for i in ['Generate', 'Regenerate', 'Replace last reply']:
  327. shared.gradio[i].click(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False)
  328. shared.gradio[i].click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
  329. shared.gradio['Clear history-confirm'].click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
  330. shared.gradio['textbox'].submit(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False)
  331. shared.gradio['textbox'].submit(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
  332. shared.gradio['character_menu'].change(chat.load_character, [shared.gradio['character_menu'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['name2'], shared.gradio['context'], shared.gradio['display']])
  333. shared.gradio['upload_chat_history'].upload(chat.load_history, [shared.gradio['upload_chat_history'], shared.gradio['name1'], shared.gradio['name2']], [])
  334. shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']])
  335. shared.gradio['upload_img_me'].upload(chat.upload_your_profile_picture, [shared.gradio['upload_img_me']], [])
  336. reload_func = chat.redraw_html if shared.args.cai_chat else lambda : shared.history['visible']
  337. reload_inputs = [shared.gradio['name1'], shared.gradio['name2']] if shared.args.cai_chat else []
  338. shared.gradio['upload_chat_history'].upload(reload_func, reload_inputs, [shared.gradio['display']])
  339. shared.gradio['upload_img_me'].upload(reload_func, reload_inputs, [shared.gradio['display']])
  340. shared.gradio['Stop'].click(reload_func, reload_inputs, [shared.gradio['display']])
  341. shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}")
  342. shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}']), None, None)
  343. shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True)
  344. elif shared.args.notebook:
  345. with gr.Tab("Text generation", elem_id="main"):
  346. with gr.Row():
  347. with gr.Column(scale=4):
  348. with gr.Tab('Raw'):
  349. shared.gradio['textbox'] = gr.Textbox(value=default_text, elem_id="textbox", lines=25)
  350. with gr.Tab('Markdown'):
  351. shared.gradio['markdown'] = gr.Markdown()
  352. with gr.Tab('HTML'):
  353. shared.gradio['html'] = gr.HTML()
  354. with gr.Row():
  355. shared.gradio['Generate'] = gr.Button('Generate')
  356. shared.gradio['Stop'] = gr.Button('Stop')
  357. with gr.Column(scale=1):
  358. gr.Markdown("\n")
  359. shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
  360. create_prompt_menus()
  361. with gr.Tab("Parameters", elem_id="parameters"):
  362. create_settings_menus(default_preset)
  363. shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'seed']]
  364. output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
  365. gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
  366. gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
  367. shared.gradio['Stop'].click(None, None, None, cancels=gen_events)
  368. shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
  369. else:
  370. with gr.Tab("Text generation", elem_id="main"):
  371. with gr.Row():
  372. with gr.Column():
  373. shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=15, label='Input')
  374. shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
  375. shared.gradio['Generate'] = gr.Button('Generate')
  376. with gr.Row():
  377. with gr.Column():
  378. shared.gradio['Continue'] = gr.Button('Continue')
  379. with gr.Column():
  380. shared.gradio['Stop'] = gr.Button('Stop')
  381. create_prompt_menus()
  382. with gr.Column():
  383. with gr.Tab('Raw'):
  384. shared.gradio['output_textbox'] = gr.Textbox(lines=25, label='Output')
  385. with gr.Tab('Markdown'):
  386. shared.gradio['markdown'] = gr.Markdown()
  387. with gr.Tab('HTML'):
  388. shared.gradio['html'] = gr.HTML()
  389. with gr.Tab("Parameters", elem_id="parameters"):
  390. create_settings_menus(default_preset)
  391. shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'seed']]
  392. output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']]
  393. gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
  394. gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
  395. gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream))
  396. shared.gradio['Stop'].click(None, None, None, cancels=gen_events)
  397. shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
  398. with gr.Tab("Interface mode", elem_id="interface-mode"):
  399. modes = ["default", "notebook", "chat", "cai_chat"]
  400. current_mode = "default"
  401. for mode in modes[1:]:
  402. if eval(f"shared.args.{mode}"):
  403. current_mode = mode
  404. break
  405. cmd_list = vars(shared.args)
  406. cmd_list = [k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes]
  407. active_cmd_list = [k for k in cmd_list if vars(shared.args)[k]]
  408. gr.Markdown("*Experimental*")
  409. shared.gradio['interface_modes_menu'] = gr.Dropdown(choices=modes, value=current_mode, label="Mode")
  410. shared.gradio['extensions_menu'] = gr.CheckboxGroup(choices=get_available_extensions(), value=shared.args.extensions, label="Available extensions")
  411. shared.gradio['cmd_arguments_menu'] = gr.CheckboxGroup(choices=cmd_list, value=active_cmd_list, label="Boolean command-line flags")
  412. shared.gradio['reset_interface'] = gr.Button("Apply and restart the interface", type="primary")
  413. shared.gradio['reset_interface'].click(set_interface_arguments, [shared.gradio[k] for k in ['interface_modes_menu', 'extensions_menu', 'cmd_arguments_menu']], None)
  414. shared.gradio['reset_interface'].click(lambda : None, None, None, _js='() => {document.body.innerHTML=\'<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>\'; setTimeout(function(){location.reload()},2500)}')
  415. with gr.Tab("Training", elem_id="training-tab"):
  416. training.create_train_interface()
  417. if shared.args.extensions is not None:
  418. extensions_module.create_extensions_block()
  419. # Launch the interface
  420. shared.gradio['interface'].queue()
  421. if shared.args.listen:
  422. shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name='0.0.0.0', server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch)
  423. else:
  424. shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch)
  425. create_interface()
  426. while True:
  427. time.sleep(0.5)
  428. if shared.need_restart:
  429. shared.need_restart = False
  430. shared.gradio['interface'].close()
  431. create_interface()