server.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648
  1. import os
  2. os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
  3. import importlib
  4. import io
  5. import json
  6. import os
  7. import re
  8. import sys
  9. import time
  10. import traceback
  11. import zipfile
  12. from datetime import datetime
  13. from pathlib import Path
  14. import gradio as gr
  15. import requests
  16. from huggingface_hub import HfApi
  17. from PIL import Image
  18. import modules.extensions as extensions_module
  19. from modules import api, chat, shared, training, ui
  20. from modules.html_generator import chat_html_wrapper
  21. from modules.LoRA import add_lora_to_model
  22. from modules.models import load_model, load_soft_prompt, unload_model
  23. from modules.text_generation import generate_reply, stop_everything_event
  24. # Loading custom settings
  25. settings_file = None
  26. if shared.args.settings is not None and Path(shared.args.settings).exists():
  27. settings_file = Path(shared.args.settings)
  28. elif Path('settings.json').exists():
  29. settings_file = Path('settings.json')
  30. if settings_file is not None:
  31. print(f"Loading settings from {settings_file}...")
  32. new_settings = json.loads(open(settings_file, 'r').read())
  33. for item in new_settings:
  34. shared.settings[item] = new_settings[item]
  35. def get_available_models():
  36. if shared.args.flexgen:
  37. return sorted([re.sub('-np$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if item.name.endswith('-np')], key=str.lower)
  38. else:
  39. return sorted([re.sub('.pth$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
  40. def get_available_presets():
  41. return sorted(set((k.stem for k in Path('presets').glob('*.txt'))), key=str.lower)
  42. def get_available_prompts():
  43. prompts = []
  44. prompts += sorted(set((k.stem for k in Path('prompts').glob('[0-9]*.txt'))), key=str.lower, reverse=True)
  45. prompts += sorted(set((k.stem for k in Path('prompts').glob('*.txt'))), key=str.lower)
  46. prompts += ['None']
  47. return prompts
  48. def get_available_characters():
  49. paths = (x for x in Path('characters').iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
  50. return ['None'] + sorted(set((k.stem for k in paths if k.stem != "instruction-following")), key=str.lower)
  51. def get_available_instruction_templates():
  52. path = "characters/instruction-following"
  53. paths = []
  54. if os.path.exists(path):
  55. paths = (x for x in Path(path).iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
  56. return ['None'] + sorted(set((k.stem for k in paths)), key=str.lower)
  57. def get_available_extensions():
  58. return sorted(set(map(lambda x: x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower)
  59. def get_available_softprompts():
  60. return ['None'] + sorted(set((k.stem for k in Path('softprompts').glob('*.zip'))), key=str.lower)
  61. def get_available_loras():
  62. return ['None'] + sorted([item.name for item in list(Path(shared.args.lora_dir).glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
  63. def load_model_wrapper(selected_model):
  64. if selected_model != shared.model_name:
  65. shared.model_name = selected_model
  66. unload_model()
  67. if selected_model != '':
  68. shared.model, shared.tokenizer = load_model(shared.model_name)
  69. return selected_model
  70. def load_lora_wrapper(selected_lora):
  71. add_lora_to_model(selected_lora)
  72. return selected_lora
  73. def load_preset_values(preset_menu, state, return_dict=False):
  74. generate_params = {
  75. 'do_sample': True,
  76. 'temperature': 1,
  77. 'top_p': 1,
  78. 'typical_p': 1,
  79. 'repetition_penalty': 1,
  80. 'encoder_repetition_penalty': 1,
  81. 'top_k': 50,
  82. 'num_beams': 1,
  83. 'penalty_alpha': 0,
  84. 'min_length': 0,
  85. 'length_penalty': 1,
  86. 'no_repeat_ngram_size': 0,
  87. 'early_stopping': False,
  88. }
  89. with open(Path(f'presets/{preset_menu}.txt'), 'r') as infile:
  90. preset = infile.read()
  91. for i in preset.splitlines():
  92. i = i.rstrip(',').strip().split('=')
  93. if len(i) == 2 and i[0].strip() != 'tokens':
  94. generate_params[i[0].strip()] = eval(i[1].strip())
  95. generate_params['temperature'] = min(1.99, generate_params['temperature'])
  96. if return_dict:
  97. return generate_params
  98. else:
  99. state.update(generate_params)
  100. return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
  101. def upload_soft_prompt(file):
  102. with zipfile.ZipFile(io.BytesIO(file)) as zf:
  103. zf.extract('meta.json')
  104. j = json.loads(open('meta.json', 'r').read())
  105. name = j['name']
  106. Path('meta.json').unlink()
  107. with open(Path(f'softprompts/{name}.zip'), 'wb') as f:
  108. f.write(file)
  109. return name
  110. def save_prompt(text):
  111. fname = f"{datetime.now().strftime('%Y-%m-%d-%H%M%S')}.txt"
  112. with open(Path(f'prompts/{fname}'), 'w', encoding='utf-8') as f:
  113. f.write(text)
  114. return f"Saved to prompts/{fname}"
  115. def load_prompt(fname):
  116. if fname in ['None', '']:
  117. return ''
  118. else:
  119. with open(Path(f'prompts/{fname}.txt'), 'r', encoding='utf-8') as f:
  120. text = f.read()
  121. if text[-1] == '\n':
  122. text = text[:-1]
  123. return text
  124. def create_prompt_menus():
  125. with gr.Row():
  126. with gr.Column():
  127. with gr.Row():
  128. shared.gradio['prompt_menu'] = gr.Dropdown(choices=get_available_prompts(), value='None', label='Prompt')
  129. ui.create_refresh_button(shared.gradio['prompt_menu'], lambda: None, lambda: {'choices': get_available_prompts()}, 'refresh-button')
  130. with gr.Column():
  131. with gr.Column():
  132. shared.gradio['save_prompt'] = gr.Button('Save prompt')
  133. shared.gradio['status'] = gr.Markdown('Ready')
  134. shared.gradio['prompt_menu'].change(load_prompt, [shared.gradio['prompt_menu']], [shared.gradio['textbox']], show_progress=False)
  135. shared.gradio['save_prompt'].click(save_prompt, [shared.gradio['textbox']], [shared.gradio['status']], show_progress=False)
  136. def download_model_wrapper(repo_id):
  137. try:
  138. downloader = importlib.import_module("download-model")
  139. model = repo_id
  140. branch = "main"
  141. check = False
  142. yield("Cleaning up the model/branch names")
  143. model, branch = downloader.sanitize_model_and_branch_names(model, branch)
  144. yield("Getting the download links from Hugging Face")
  145. links, sha256, is_lora = downloader.get_download_links_from_huggingface(model, branch, text_only=False)
  146. yield("Getting the output folder")
  147. output_folder = downloader.get_output_folder(model, branch, is_lora)
  148. if check:
  149. yield("Checking previously downloaded files")
  150. downloader.check_model_files(model, branch, links, sha256, output_folder)
  151. else:
  152. yield("Downloading files")
  153. downloader.download_model_files(model, branch, links, sha256, output_folder, threads=1)
  154. yield("Done!")
  155. except:
  156. yield traceback.format_exc()
  157. def create_model_menus():
  158. with gr.Row():
  159. with gr.Column():
  160. with gr.Row():
  161. shared.gradio['model_menu'] = gr.Dropdown(choices=available_models, value=shared.model_name, label='Model')
  162. ui.create_refresh_button(shared.gradio['model_menu'], lambda: None, lambda: {'choices': get_available_models()}, 'refresh-button')
  163. with gr.Column():
  164. with gr.Row():
  165. shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA')
  166. ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': get_available_loras()}, 'refresh-button')
  167. with gr.Row():
  168. with gr.Column():
  169. with gr.Row():
  170. with gr.Column():
  171. shared.gradio['custom_model_menu'] = gr.Textbox(label="Download Custom Model",
  172. info="Enter Hugging Face username/model path e.g: facebook/galactica-125m")
  173. with gr.Column():
  174. shared.gradio['download_button'] = gr.Button("Download", show_progress=True)
  175. shared.gradio['download_status'] = gr.Markdown()
  176. with gr.Column():
  177. pass
  178. shared.gradio['model_menu'].change(load_model_wrapper, shared.gradio['model_menu'], shared.gradio['model_menu'], show_progress=True)
  179. shared.gradio['lora_menu'].change(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['lora_menu'], show_progress=True)
  180. shared.gradio['download_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], shared.gradio['download_status'], show_progress=False)
  181. def create_settings_menus(default_preset):
  182. generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', {}, return_dict=True)
  183. for k in ['max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size', 'chat_generation_attempts']:
  184. generate_params[k] = shared.settings[k]
  185. shared.gradio['generate_state'] = gr.State(generate_params)
  186. with gr.Row():
  187. with gr.Column():
  188. with gr.Row():
  189. shared.gradio['preset_menu'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset')
  190. ui.create_refresh_button(shared.gradio['preset_menu'], lambda: None, lambda: {'choices': get_available_presets()}, 'refresh-button')
  191. with gr.Column():
  192. shared.gradio['seed'] = gr.Number(value=shared.settings['seed'], label='Seed (-1 for random)')
  193. with gr.Row():
  194. with gr.Column():
  195. with gr.Box():
  196. gr.Markdown('Custom generation parameters ([reference](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig))')
  197. with gr.Row():
  198. with gr.Column():
  199. shared.gradio['temperature'] = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label='temperature')
  200. shared.gradio['top_p'] = gr.Slider(0.0, 1.0, value=generate_params['top_p'], step=0.01, label='top_p')
  201. shared.gradio['top_k'] = gr.Slider(0, 200, value=generate_params['top_k'], step=1, label='top_k')
  202. shared.gradio['typical_p'] = gr.Slider(0.0, 1.0, value=generate_params['typical_p'], step=0.01, label='typical_p')
  203. with gr.Column():
  204. shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=generate_params['repetition_penalty'], step=0.01, label='repetition_penalty')
  205. shared.gradio['encoder_repetition_penalty'] = gr.Slider(0.8, 1.5, value=generate_params['encoder_repetition_penalty'], step=0.01, label='encoder_repetition_penalty')
  206. shared.gradio['no_repeat_ngram_size'] = gr.Slider(0, 20, step=1, value=generate_params['no_repeat_ngram_size'], label='no_repeat_ngram_size')
  207. shared.gradio['min_length'] = gr.Slider(0, 2000, step=1, value=generate_params['min_length'] if shared.args.no_stream else 0, label='min_length', interactive=shared.args.no_stream)
  208. shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
  209. with gr.Column():
  210. with gr.Box():
  211. gr.Markdown('Contrastive search')
  212. shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha')
  213. with gr.Box():
  214. gr.Markdown('Beam search (uses a lot of VRAM)')
  215. with gr.Row():
  216. with gr.Column():
  217. shared.gradio['num_beams'] = gr.Slider(1, 20, step=1, value=generate_params['num_beams'], label='num_beams')
  218. with gr.Column():
  219. shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
  220. shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
  221. with gr.Accordion('Soft prompt', open=False):
  222. with gr.Row():
  223. shared.gradio['softprompts_menu'] = gr.Dropdown(choices=available_softprompts, value='None', label='Soft prompt')
  224. ui.create_refresh_button(shared.gradio['softprompts_menu'], lambda: None, lambda: {'choices': get_available_softprompts()}, 'refresh-button')
  225. gr.Markdown('Upload a soft prompt (.zip format):')
  226. with gr.Row():
  227. shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip'])
  228. shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio[k] for k in ['preset_menu', 'generate_state']], [shared.gradio[k] for k in ['generate_state', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']])
  229. shared.gradio['softprompts_menu'].change(load_soft_prompt, shared.gradio['softprompts_menu'], shared.gradio['softprompts_menu'], show_progress=True)
  230. shared.gradio['upload_softprompt'].upload(upload_soft_prompt, shared.gradio['upload_softprompt'], shared.gradio['softprompts_menu'])
  231. def set_interface_arguments(interface_mode, extensions, bool_active):
  232. modes = ["default", "notebook", "chat", "cai_chat"]
  233. cmd_list = vars(shared.args)
  234. bool_list = [k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes]
  235. shared.args.extensions = extensions
  236. for k in modes[1:]:
  237. exec(f"shared.args.{k} = False")
  238. if interface_mode != "default":
  239. exec(f"shared.args.{interface_mode} = True")
  240. for k in bool_list:
  241. exec(f"shared.args.{k} = False")
  242. for k in bool_active:
  243. exec(f"shared.args.{k} = True")
  244. shared.need_restart = True
  245. available_models = get_available_models()
  246. available_presets = get_available_presets()
  247. available_characters = get_available_characters()
  248. available_softprompts = get_available_softprompts()
  249. available_loras = get_available_loras()
  250. # Default extensions
  251. extensions_module.available_extensions = get_available_extensions()
  252. if shared.is_chat():
  253. for extension in shared.settings['chat_default_extensions']:
  254. shared.args.extensions = shared.args.extensions or []
  255. if extension not in shared.args.extensions:
  256. shared.args.extensions.append(extension)
  257. else:
  258. for extension in shared.settings['default_extensions']:
  259. shared.args.extensions = shared.args.extensions or []
  260. if extension not in shared.args.extensions:
  261. shared.args.extensions.append(extension)
  262. # Default model
  263. if shared.args.model is not None:
  264. shared.model_name = shared.args.model
  265. else:
  266. if len(available_models) == 0:
  267. print('No models are available! Please download at least one.')
  268. sys.exit(0)
  269. elif len(available_models) == 1:
  270. i = 0
  271. else:
  272. print('The following models are available:\n')
  273. for i, model in enumerate(available_models):
  274. print(f'{i+1}. {model}')
  275. print(f'\nWhich one do you want to load? 1-{len(available_models)}\n')
  276. i = int(input()) - 1
  277. print()
  278. shared.model_name = available_models[i]
  279. shared.model, shared.tokenizer = load_model(shared.model_name)
  280. if shared.args.lora:
  281. add_lora_to_model(shared.args.lora)
  282. # Default UI settings
  283. default_preset = shared.settings['presets'][next((k for k in shared.settings['presets'] if re.match(k.lower(), shared.model_name.lower())), 'default')]
  284. if shared.lora_name != "None":
  285. default_text = load_prompt(shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')])
  286. else:
  287. default_text = load_prompt(shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')])
  288. title = 'Text generation web UI'
  289. def create_interface():
  290. gen_events = []
  291. if shared.args.extensions is not None and len(shared.args.extensions) > 0:
  292. extensions_module.load_extensions()
  293. with gr.Blocks(css=ui.css if not shared.is_chat() else ui.css + ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']:
  294. if shared.is_chat():
  295. shared.gradio['Chat input'] = gr.State()
  296. with gr.Tab("Text generation", elem_id="main"):
  297. shared.gradio['display'] = gr.HTML(value=chat_html_wrapper(shared.history['visible'], shared.settings['name1'], shared.settings['name2'], 'cai-chat'))
  298. shared.gradio['textbox'] = gr.Textbox(label='Input')
  299. with gr.Row():
  300. shared.gradio['Generate'] = gr.Button('Generate', elem_id='Generate')
  301. shared.gradio['Stop'] = gr.Button('Stop', elem_id="stop")
  302. with gr.Row():
  303. shared.gradio['Regenerate'] = gr.Button('Regenerate')
  304. shared.gradio['Continue'] = gr.Button('Continue')
  305. shared.gradio['Impersonate'] = gr.Button('Impersonate')
  306. with gr.Row():
  307. shared.gradio['Copy last reply'] = gr.Button('Copy last reply')
  308. shared.gradio['Replace last reply'] = gr.Button('Replace last reply')
  309. shared.gradio['Remove last'] = gr.Button('Remove last')
  310. shared.gradio['Clear history'] = gr.Button('Clear history')
  311. shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False)
  312. shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
  313. shared.gradio["Chat mode"] = gr.Radio(choices=["cai-chat", "chat", "instruct"], value="cai-chat", label="Mode")
  314. shared.gradio["Instruction templates"] = gr.Dropdown(choices=get_available_instruction_templates(), label="Instruction template", value="None", visible=False)
  315. with gr.Tab("Character", elem_id="chat-settings"):
  316. with gr.Row():
  317. with gr.Column(scale=8):
  318. shared.gradio['name1'] = gr.Textbox(value=shared.settings['name1'], lines=1, label='Your name')
  319. shared.gradio['name2'] = gr.Textbox(value=shared.settings['name2'], lines=1, label='Character\'s name')
  320. shared.gradio['greeting'] = gr.Textbox(value=shared.settings['greeting'], lines=4, label='Greeting')
  321. shared.gradio['context'] = gr.Textbox(value=shared.settings['context'], lines=4, label='Context')
  322. shared.gradio['end_of_turn'] = gr.Textbox(value=shared.settings["end_of_turn"], lines=1, label='End of turn string')
  323. with gr.Column(scale=1):
  324. shared.gradio['character_picture'] = gr.Image(label='Character picture', type="pil")
  325. shared.gradio['your_picture'] = gr.Image(label='Your picture', type="pil", value=Image.open(Path("cache/pfp_me.png")) if Path("cache/pfp_me.png").exists() else None)
  326. with gr.Row():
  327. shared.gradio['character_menu'] = gr.Dropdown(choices=available_characters, value='None', label='Character', elem_id='character-menu')
  328. ui.create_refresh_button(shared.gradio['character_menu'], lambda: None, lambda: {'choices': get_available_characters()}, 'refresh-button')
  329. with gr.Row():
  330. with gr.Tab('Chat history'):
  331. with gr.Row():
  332. with gr.Column():
  333. gr.Markdown('Upload')
  334. shared.gradio['upload_chat_history'] = gr.File(type='binary', file_types=['.json', '.txt'])
  335. with gr.Column():
  336. gr.Markdown('Download')
  337. shared.gradio['download'] = gr.File()
  338. shared.gradio['download_button'] = gr.Button(value='Click me')
  339. with gr.Tab('Upload character'):
  340. gr.Markdown("# JSON format")
  341. with gr.Row():
  342. with gr.Column():
  343. gr.Markdown('1. Select the JSON file')
  344. shared.gradio['upload_json'] = gr.File(type='binary', file_types=['.json'])
  345. with gr.Column():
  346. gr.Markdown('2. Select your character\'s profile picture (optional)')
  347. shared.gradio['upload_img_bot'] = gr.File(type='binary', file_types=['image'])
  348. shared.gradio['Upload character'] = gr.Button(value='Submit')
  349. gr.Markdown("# TavernAI PNG format")
  350. shared.gradio['upload_img_tavern'] = gr.File(type='binary', file_types=['image'])
  351. with gr.Tab("Parameters", elem_id="parameters"):
  352. with gr.Box():
  353. gr.Markdown("Chat parameters")
  354. with gr.Row():
  355. with gr.Column():
  356. shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
  357. shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size'])
  358. with gr.Column():
  359. shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)')
  360. shared.gradio['stop_at_newline'] = gr.Checkbox(value=shared.settings['stop_at_newline'], label='Stop generating at new line character?')
  361. create_settings_menus(default_preset)
  362. shared.input_params = [shared.gradio[k] for k in ['Chat input', 'generate_state', 'name1', 'name2', 'context', 'Chat mode', 'end_of_turn']]
  363. clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']]
  364. reload_inputs = [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']]
  365. gen_events.append(shared.gradio['Generate'].click(
  366. lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
  367. chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
  368. lambda: chat.save_history(timestamp=False), None, None, show_progress=False)
  369. )
  370. gen_events.append(shared.gradio['textbox'].submit(
  371. lambda x: (x, ''), shared.gradio['textbox'], [shared.gradio['Chat input'], shared.gradio['textbox']], show_progress=False).then(
  372. chat.cai_chatbot_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
  373. lambda: chat.save_history(timestamp=False), None, None, show_progress=False)
  374. )
  375. gen_events.append(shared.gradio['Regenerate'].click(
  376. chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
  377. lambda: chat.save_history(timestamp=False), None, None, show_progress=False)
  378. )
  379. gen_events.append(shared.gradio['Continue'].click(
  380. chat.continue_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream).then(
  381. lambda: chat.save_history(timestamp=False), None, None, show_progress=False)
  382. )
  383. shared.gradio['Replace last reply'].click(
  384. chat.replace_last_reply, [shared.gradio[k] for k in ['textbox', 'name1', 'name2', 'Chat mode']], shared.gradio['display'], show_progress=shared.args.no_stream).then(
  385. lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False).then(
  386. lambda: chat.save_history(timestamp=False), None, None, show_progress=False)
  387. shared.gradio['Clear history-confirm'].click(
  388. lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr).then(
  389. chat.clear_chat_log, [shared.gradio[k] for k in ['name1', 'name2', 'greeting', 'Chat mode']], shared.gradio['display']).then(
  390. lambda: chat.save_history(timestamp=False), None, None, show_progress=False)
  391. shared.gradio['Stop'].click(
  392. stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None).then(
  393. chat.redraw_html, reload_inputs, shared.gradio['display'])
  394. shared.gradio['Chat mode'].change(
  395. lambda x: gr.update(visible=x == 'instruct'), shared.gradio['Chat mode'], shared.gradio['Instruction templates']).then(
  396. chat.redraw_html, reload_inputs, shared.gradio['display'])
  397. shared.gradio['Instruction templates'].change(
  398. lambda character, name1, name2, mode: chat.load_character(character, name1, name2, mode), [shared.gradio[k] for k in ['Instruction templates', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']]).then(
  399. chat.redraw_html, reload_inputs, shared.gradio['display'])
  400. shared.gradio['upload_chat_history'].upload(
  401. chat.load_history, [shared.gradio[k] for k in ['upload_chat_history', 'name1', 'name2']], None).then(
  402. chat.redraw_html, reload_inputs, shared.gradio['display'])
  403. gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
  404. shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, None, shared.gradio['textbox'], show_progress=shared.args.no_stream)
  405. shared.gradio['Clear history'].click(lambda: [gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
  406. shared.gradio['Clear history-cancel'].click(lambda: [gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
  407. shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio[k] for k in ['name1', 'name2', 'Chat mode']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
  408. shared.gradio['download_button'].click(chat.save_history, inputs=None, outputs=[shared.gradio['download']])
  409. shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']])
  410. shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2', 'Chat mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
  411. shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']])
  412. shared.gradio['your_picture'].change(chat.upload_your_profile_picture, [shared.gradio[k] for k in ['your_picture', 'name1', 'name2', 'Chat mode']], shared.gradio['display'])
  413. shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}")
  414. shared.gradio['interface'].load(chat.load_default_history, [shared.gradio[k] for k in ['name1', 'name2']], None)
  415. shared.gradio['interface'].load(chat.redraw_html, reload_inputs, shared.gradio['display'], show_progress=True)
  416. elif shared.args.notebook:
  417. with gr.Tab("Text generation", elem_id="main"):
  418. with gr.Row():
  419. with gr.Column(scale=4):
  420. with gr.Tab('Raw'):
  421. shared.gradio['textbox'] = gr.Textbox(value=default_text, elem_id="textbox", lines=27)
  422. with gr.Tab('Markdown'):
  423. shared.gradio['markdown'] = gr.Markdown()
  424. with gr.Tab('HTML'):
  425. shared.gradio['html'] = gr.HTML()
  426. with gr.Row():
  427. with gr.Column():
  428. with gr.Row():
  429. shared.gradio['Generate'] = gr.Button('Generate')
  430. shared.gradio['Stop'] = gr.Button('Stop')
  431. with gr.Column():
  432. pass
  433. with gr.Column(scale=1):
  434. gr.HTML('<div style="padding-bottom: 13px"></div>')
  435. shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
  436. create_prompt_menus()
  437. with gr.Tab("Parameters", elem_id="parameters"):
  438. create_settings_menus(default_preset)
  439. shared.input_params = [shared.gradio[k] for k in ['textbox', 'generate_state']]
  440. output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
  441. gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
  442. gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
  443. shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None)
  444. shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
  445. else:
  446. with gr.Tab("Text generation", elem_id="main"):
  447. with gr.Row():
  448. with gr.Column():
  449. shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=21, label='Input')
  450. shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
  451. shared.gradio['Generate'] = gr.Button('Generate')
  452. with gr.Row():
  453. with gr.Column():
  454. shared.gradio['Continue'] = gr.Button('Continue')
  455. with gr.Column():
  456. shared.gradio['Stop'] = gr.Button('Stop')
  457. create_prompt_menus()
  458. with gr.Column():
  459. with gr.Tab('Raw'):
  460. shared.gradio['output_textbox'] = gr.Textbox(lines=27, label='Output')
  461. with gr.Tab('Markdown'):
  462. shared.gradio['markdown'] = gr.Markdown()
  463. with gr.Tab('HTML'):
  464. shared.gradio['html'] = gr.HTML()
  465. with gr.Tab("Parameters", elem_id="parameters"):
  466. create_settings_menus(default_preset)
  467. shared.input_params = [shared.gradio[k] for k in ['textbox', 'generate_state']]
  468. output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']]
  469. gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
  470. gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
  471. gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream))
  472. shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None)
  473. shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
  474. with gr.Tab("Model", elem_id="model-tab"):
  475. create_model_menus()
  476. with gr.Tab("Training", elem_id="training-tab"):
  477. training.create_train_interface()
  478. with gr.Tab("Interface mode", elem_id="interface-mode"):
  479. modes = ["default", "notebook", "chat", "cai_chat"]
  480. current_mode = "default"
  481. for mode in modes[1:]:
  482. if eval(f"shared.args.{mode}"):
  483. current_mode = mode
  484. break
  485. cmd_list = vars(shared.args)
  486. bool_list = [k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes]
  487. bool_active = [k for k in bool_list if vars(shared.args)[k]]
  488. gr.Markdown("*Experimental*")
  489. shared.gradio['interface_modes_menu'] = gr.Dropdown(choices=modes, value=current_mode, label="Mode")
  490. shared.gradio['extensions_menu'] = gr.CheckboxGroup(choices=get_available_extensions(), value=shared.args.extensions, label="Available extensions")
  491. shared.gradio['bool_menu'] = gr.CheckboxGroup(choices=bool_list, value=bool_active, label="Boolean command-line flags")
  492. shared.gradio['reset_interface'] = gr.Button("Apply and restart the interface")
  493. # Reset interface event
  494. shared.gradio['reset_interface'].click(
  495. set_interface_arguments, [shared.gradio[k] for k in ['interface_modes_menu', 'extensions_menu', 'bool_menu']], None).then(
  496. lambda: None, None, None, _js='() => {document.body.innerHTML=\'<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>\'; setTimeout(function(){location.reload()},2500); return []}')
  497. if shared.args.extensions is not None:
  498. extensions_module.create_extensions_block()
  499. def change_dict_value(d, key, value):
  500. d[key] = value
  501. return d
  502. for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'max_new_tokens', 'seed', 'stop_at_newline', 'chat_prompt_size_slider', 'chat_generation_attempts']:
  503. if k not in shared.gradio:
  504. continue
  505. if type(shared.gradio[k]) in [gr.Checkbox, gr.Number]:
  506. shared.gradio[k].change(lambda state, value, copy=k: change_dict_value(state, copy, value), inputs=[shared.gradio['generate_state'], shared.gradio[k]], outputs=shared.gradio['generate_state'])
  507. else:
  508. shared.gradio[k].release(lambda state, value, copy=k: change_dict_value(state, copy, value), inputs=[shared.gradio['generate_state'], shared.gradio[k]], outputs=shared.gradio['generate_state'])
  509. if not shared.is_chat():
  510. api.create_apis()
  511. # Authentication
  512. auth = None
  513. if shared.args.gradio_auth_path is not None:
  514. gradio_auth_creds = []
  515. with open(shared.args.gradio_auth_path, 'r', encoding="utf8") as file:
  516. for line in file.readlines():
  517. gradio_auth_creds += [x.strip() for x in line.split(',') if x.strip()]
  518. auth = [tuple(cred.split(':')) for cred in gradio_auth_creds]
  519. # Launch the interface
  520. shared.gradio['interface'].queue()
  521. if shared.args.listen:
  522. shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name='0.0.0.0', server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch, auth=auth)
  523. else:
  524. shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch, auth=auth)
  525. create_interface()
  526. while True:
  527. time.sleep(0.5)
  528. if shared.need_restart:
  529. shared.need_restart = False
  530. shared.gradio['interface'].close()
  531. create_interface()