server.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388
  1. import gc
  2. import io
  3. import json
  4. import re
  5. import sys
  6. import time
  7. import zipfile
  8. from pathlib import Path
  9. import gradio as gr
  10. import torch
  11. import modules.chat as chat
  12. import modules.extensions as extensions_module
  13. import modules.shared as shared
  14. import modules.ui as ui
  15. from modules.extensions import extension_state, load_extensions, update_extensions_parameters
  16. from modules.html_generator import generate_chat_html
  17. from modules.models import load_model, load_soft_prompt
  18. from modules.text_generation import generate_reply
  19. if (shared.args.chat or shared.args.cai_chat) and not shared.args.no_stream:
  20. print("Warning: chat mode currently becomes somewhat slower with text streaming on.\nConsider starting the web UI with the --no-stream option.\n")
  21. # Loading custom settings
  22. if shared.args.settings is not None and Path(shared.args.settings).exists():
  23. new_settings = json.loads(open(Path(shared.args.settings), 'r').read())
  24. for item in new_settings:
  25. shared.settings[item] = new_settings[item]
  26. def get_available_models():
  27. return sorted([item.name for item in list(Path('models/').glob('*')) if not item.name.endswith(('.txt', '-np'))], key=str.lower)
  28. def get_available_presets():
  29. return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower)
  30. def get_available_characters():
  31. return ["None"] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('characters').glob('*.json'))), key=str.lower)
  32. def get_available_extensions():
  33. return sorted(set(map(lambda x : x.parts[1], Path('extensions').glob('*/script.py'))), key=str.lower)
  34. def get_available_softprompts():
  35. return ["None"] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('softprompts').glob('*.zip'))), key=str.lower)
  36. def load_model_wrapper(selected_model):
  37. if selected_model != shared.model_name:
  38. shared.model_name = selected_model
  39. shared.model = shared.tokenizer = None
  40. if not shared.args.cpu:
  41. gc.collect()
  42. torch.cuda.empty_cache()
  43. shared.model, shared.tokenizer = load_model(shared.model_name)
  44. return selected_model
  45. def load_preset_values(preset_menu, return_dict=False):
  46. generate_params = {
  47. 'do_sample': True,
  48. 'temperature': 1,
  49. 'top_p': 1,
  50. 'typical_p': 1,
  51. 'repetition_penalty': 1,
  52. 'top_k': 50,
  53. 'num_beams': 1,
  54. 'penalty_alpha': 0,
  55. 'min_length': 0,
  56. 'length_penalty': 1,
  57. 'no_repeat_ngram_size': 0,
  58. 'early_stopping': False,
  59. }
  60. with open(Path(f'presets/{preset_menu}.txt'), 'r') as infile:
  61. preset = infile.read()
  62. for i in preset.splitlines():
  63. i = i.rstrip(',').strip().split('=')
  64. if len(i) == 2 and i[0].strip() != 'tokens':
  65. generate_params[i[0].strip()] = eval(i[1].strip())
  66. generate_params['temperature'] = min(1.99, generate_params['temperature'])
  67. if return_dict:
  68. return generate_params
  69. else:
  70. return generate_params['do_sample'], generate_params['temperature'], generate_params['top_p'], generate_params['typical_p'], generate_params['repetition_penalty'], generate_params['top_k'], generate_params['min_length'], generate_params['no_repeat_ngram_size'], generate_params['num_beams'], generate_params['penalty_alpha'], generate_params['length_penalty'], generate_params['early_stopping']
  71. def upload_soft_prompt(file):
  72. with zipfile.ZipFile(io.BytesIO(file)) as zf:
  73. zf.extract('meta.json')
  74. j = json.loads(open('meta.json', 'r').read())
  75. name = j['name']
  76. Path('meta.json').unlink()
  77. with open(Path(f'softprompts/{name}.zip'), 'wb') as f:
  78. f.write(file)
  79. return name
  80. def create_extensions_block():
  81. extensions_ui_elements = []
  82. default_values = []
  83. if not (shared.args.chat or shared.args.cai_chat):
  84. gr.Markdown('## Extensions parameters')
  85. for ext in sorted(extension_state, key=lambda x : extension_state[x][1]):
  86. if extension_state[ext][0] == True:
  87. params = extensions_module.get_params(ext)
  88. for param in params:
  89. _id = f"{ext}-{param}"
  90. default_value = shared.settings[_id] if _id in shared.settings else params[param]
  91. default_values.append(default_value)
  92. if type(params[param]) == str:
  93. extensions_ui_elements.append(gr.Textbox(value=default_value, label=f"{ext}-{param}"))
  94. elif type(params[param]) in [int, float]:
  95. extensions_ui_elements.append(gr.Number(value=default_value, label=f"{ext}-{param}"))
  96. elif type(params[param]) == bool:
  97. extensions_ui_elements.append(gr.Checkbox(value=default_value, label=f"{ext}-{param}"))
  98. update_extensions_parameters(*default_values)
  99. btn_extensions = gr.Button("Apply")
  100. btn_extensions.click(update_extensions_parameters, [*extensions_ui_elements], [])
  101. def create_settings_menus():
  102. generate_params = load_preset_values(shared.settings[f'preset{suffix}'] if not shared.args.flexgen else 'Naive', return_dict=True)
  103. with gr.Row():
  104. with gr.Column():
  105. with gr.Row():
  106. model_menu = gr.Dropdown(choices=available_models, value=shared.model_name, label='Model')
  107. ui.create_refresh_button(model_menu, lambda : None, lambda : {"choices": get_available_models()}, "refresh-button")
  108. with gr.Column():
  109. with gr.Row():
  110. preset_menu = gr.Dropdown(choices=available_presets, value=shared.settings[f'preset{suffix}'] if not shared.args.flexgen else 'Naive', label='Generation parameters preset')
  111. ui.create_refresh_button(preset_menu, lambda : None, lambda : {"choices": get_available_presets()}, "refresh-button")
  112. with gr.Accordion("Custom generation parameters", open=False, elem_id="accordion"):
  113. with gr.Row():
  114. do_sample = gr.Checkbox(value=generate_params['do_sample'], label="do_sample")
  115. temperature = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label="temperature")
  116. with gr.Row():
  117. top_k = gr.Slider(0,200,value=generate_params['top_k'],step=1,label="top_k")
  118. top_p = gr.Slider(0.0,1.0,value=generate_params['top_p'],step=0.01,label="top_p")
  119. with gr.Row():
  120. repetition_penalty = gr.Slider(1.0,4.99,value=generate_params['repetition_penalty'],step=0.01,label="repetition_penalty")
  121. no_repeat_ngram_size = gr.Slider(0, 20, step=1, value=generate_params["no_repeat_ngram_size"], label="no_repeat_ngram_size")
  122. with gr.Row():
  123. typical_p = gr.Slider(0.0,1.0,value=generate_params['typical_p'],step=0.01,label="typical_p")
  124. min_length = gr.Slider(0, 2000, step=1, value=generate_params["min_length"] if shared.args.no_stream else 0, label="min_length", interactive=shared.args.no_stream)
  125. gr.Markdown("Contrastive search:")
  126. penalty_alpha = gr.Slider(0, 5, value=generate_params["penalty_alpha"], label="penalty_alpha")
  127. gr.Markdown("Beam search (uses a lot of VRAM):")
  128. with gr.Row():
  129. num_beams = gr.Slider(1, 20, step=1, value=generate_params["num_beams"], label="num_beams")
  130. length_penalty = gr.Slider(-5, 5, value=generate_params["length_penalty"], label="length_penalty")
  131. early_stopping = gr.Checkbox(value=generate_params["early_stopping"], label="early_stopping")
  132. with gr.Accordion("Soft prompt", open=False, elem_id="accordion"):
  133. with gr.Row():
  134. softprompts_menu = gr.Dropdown(choices=available_softprompts, value="None", label='Soft prompt')
  135. ui.create_refresh_button(softprompts_menu, lambda : None, lambda : {"choices": get_available_softprompts()}, "refresh-button")
  136. gr.Markdown('Upload a soft prompt (.zip format):')
  137. with gr.Row():
  138. upload_softprompt = gr.File(type='binary', file_types=[".zip"])
  139. model_menu.change(load_model_wrapper, [model_menu], [model_menu], show_progress=True)
  140. preset_menu.change(load_preset_values, [preset_menu], [do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping])
  141. softprompts_menu.change(load_soft_prompt, [softprompts_menu], [softprompts_menu], show_progress=True)
  142. upload_softprompt.upload(upload_soft_prompt, [upload_softprompt], [softprompts_menu])
  143. return preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping
  144. available_models = get_available_models()
  145. available_presets = get_available_presets()
  146. available_characters = get_available_characters()
  147. available_softprompts = get_available_softprompts()
  148. extensions_module.available_extensions = get_available_extensions()
  149. if shared.args.extensions is not None:
  150. load_extensions()
  151. # Choosing the default model
  152. if shared.args.model is not None:
  153. shared.model_name = shared.args.model
  154. else:
  155. if len(available_models) == 0:
  156. print("No models are available! Please download at least one.")
  157. sys.exit(0)
  158. elif len(available_models) == 1:
  159. i = 0
  160. else:
  161. print("The following models are available:\n")
  162. for i, model in enumerate(available_models):
  163. print(f"{i+1}. {model}")
  164. print(f"\nWhich one do you want to load? 1-{len(available_models)}\n")
  165. i = int(input())-1
  166. print()
  167. shared.model_name = available_models[i]
  168. shared.model, shared.tokenizer = load_model(shared.model_name)
  169. # UI settings
  170. buttons = {}
  171. gen_events = []
  172. suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else ''
  173. description = f"\n\n# Text generation lab\nGenerate text using Large Language Models.\n"
  174. if shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')):
  175. default_text = shared.settings['prompt_gpt4chan']
  176. elif re.match('(rosey|chip|joi)_.*_instruct.*', shared.model_name.lower()) is not None:
  177. default_text = 'User: \n'
  178. else:
  179. default_text = shared.settings['prompt']
  180. if shared.args.chat or shared.args.cai_chat:
  181. if Path(f'logs/persistent.json').exists():
  182. chat.load_history(open(Path(f'logs/persistent.json'), 'rb').read(), shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}'])
  183. with gr.Blocks(css=ui.css+ui.chat_css, analytics_enabled=False) as interface:
  184. if shared.args.cai_chat:
  185. display = gr.HTML(value=generate_chat_html(chat.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}'], chat.character))
  186. else:
  187. display = gr.Chatbot(value=chat.history['visible'])
  188. textbox = gr.Textbox(label='Input')
  189. with gr.Row():
  190. buttons["Stop"] = gr.Button("Stop")
  191. buttons["Generate"] = gr.Button("Generate")
  192. buttons["Regenerate"] = gr.Button("Regenerate")
  193. with gr.Row():
  194. buttons["Impersonate"] = gr.Button("Impersonate")
  195. buttons["Remove last"] = gr.Button("Remove last")
  196. buttons["Clear history"] = gr.Button("Clear history")
  197. with gr.Row():
  198. buttons["Send last reply to input"] = gr.Button("Send last reply to input")
  199. buttons["Replace last reply"] = gr.Button("Replace last reply")
  200. if shared.args.picture:
  201. with gr.Row():
  202. picture_select = gr.Image(label="Send a picture", type='pil')
  203. with gr.Tab("Chat settings"):
  204. name1 = gr.Textbox(value=shared.settings[f'name1{suffix}'], lines=1, label='Your name')
  205. name2 = gr.Textbox(value=shared.settings[f'name2{suffix}'], lines=1, label='Bot\'s name')
  206. context = gr.Textbox(value=shared.settings[f'context{suffix}'], lines=2, label='Context')
  207. with gr.Row():
  208. character_menu = gr.Dropdown(choices=available_characters, value="None", label='Character')
  209. ui.create_refresh_button(character_menu, lambda : None, lambda : {"choices": get_available_characters()}, "refresh-button")
  210. with gr.Row():
  211. check = gr.Checkbox(value=shared.settings[f'stop_at_newline{suffix}'], label='Stop generating at new line character?')
  212. with gr.Row():
  213. with gr.Tab('Chat history'):
  214. with gr.Row():
  215. with gr.Column():
  216. gr.Markdown('Upload')
  217. upload_chat_history = gr.File(type='binary', file_types=[".json", ".txt"])
  218. with gr.Column():
  219. gr.Markdown('Download')
  220. download = gr.File()
  221. buttons["Download"] = gr.Button(value="Click me")
  222. with gr.Tab('Upload character'):
  223. with gr.Row():
  224. with gr.Column():
  225. gr.Markdown('1. Select the JSON file')
  226. upload_char = gr.File(type='binary', file_types=[".json"])
  227. with gr.Column():
  228. gr.Markdown('2. Select your character\'s profile picture (optional)')
  229. upload_img = gr.File(type='binary', file_types=["image"])
  230. buttons["Upload character"] = gr.Button(value="Submit")
  231. with gr.Tab('Upload your profile picture'):
  232. upload_img_me = gr.File(type='binary', file_types=["image"])
  233. with gr.Tab('Upload TavernAI Character Card'):
  234. upload_img_tavern = gr.File(type='binary', file_types=["image"])
  235. with gr.Tab("Generation settings"):
  236. with gr.Row():
  237. with gr.Column():
  238. max_new_tokens = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
  239. with gr.Column():
  240. chat_prompt_size_slider = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size'])
  241. preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping = create_settings_menus()
  242. if shared.args.extensions is not None:
  243. with gr.Tab("Extensions"):
  244. create_extensions_block()
  245. input_params = [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size_slider]
  246. if shared.args.picture:
  247. input_params.append(picture_select)
  248. function_call = "chat.cai_chatbot_wrapper" if shared.args.cai_chat else "chat.chatbot_wrapper"
  249. gen_events.append(buttons["Generate"].click(eval(function_call), input_params, display, show_progress=shared.args.no_stream, api_name="textgen"))
  250. gen_events.append(textbox.submit(eval(function_call), input_params, display, show_progress=shared.args.no_stream))
  251. if shared.args.picture:
  252. picture_select.upload(eval(function_call), input_params, display, show_progress=shared.args.no_stream)
  253. gen_events.append(buttons["Regenerate"].click(chat.regenerate_wrapper, input_params, display, show_progress=shared.args.no_stream))
  254. gen_events.append(buttons["Impersonate"].click(chat.impersonate_wrapper, input_params, textbox, show_progress=shared.args.no_stream))
  255. buttons["Stop"].click(chat.stop_everything_event, [], [], cancels=gen_events)
  256. buttons["Send last reply to input"].click(chat.send_last_reply_to_input, [], textbox, show_progress=shared.args.no_stream)
  257. buttons["Replace last reply"].click(chat.replace_last_reply, [textbox, name1, name2], display, show_progress=shared.args.no_stream)
  258. buttons["Clear history"].click(chat.clear_chat_log, [character_menu, name1, name2], display)
  259. buttons["Remove last"].click(chat.remove_last_message, [name1, name2], [display, textbox], show_progress=False)
  260. buttons["Download"].click(chat.save_history, inputs=[], outputs=[download])
  261. buttons["Upload character"].click(chat.upload_character, [upload_char, upload_img], [character_menu])
  262. # Clearing stuff and saving the history
  263. for i in ["Generate", "Regenerate", "Replace last reply"]:
  264. buttons[i].click(lambda x: "", textbox, textbox, show_progress=False)
  265. buttons[i].click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
  266. buttons["Clear history"].click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
  267. textbox.submit(lambda x: "", textbox, textbox, show_progress=False)
  268. textbox.submit(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
  269. character_menu.change(chat.load_character, [character_menu, name1, name2], [name2, context, display])
  270. upload_chat_history.upload(chat.load_history, [upload_chat_history, name1, name2], [])
  271. upload_img_tavern.upload(chat.upload_tavern_character, [upload_img_tavern, name1, name2], [character_menu])
  272. upload_img_me.upload(chat.upload_your_profile_picture, [upload_img_me], [])
  273. if shared.args.picture:
  274. picture_select.upload(lambda : None, [], [picture_select], show_progress=False)
  275. if shared.args.cai_chat:
  276. upload_chat_history.upload(chat.redraw_html, [name1, name2], [display])
  277. upload_img_me.upload(chat.redraw_html, [name1, name2], [display])
  278. else:
  279. upload_chat_history.upload(lambda : chat.history['visible'], [], [display])
  280. upload_img_me.upload(lambda : chat.history['visible'], [], [display])
  281. elif shared.args.notebook:
  282. with gr.Blocks(css=ui.css, analytics_enabled=False) as interface:
  283. gr.Markdown(description)
  284. with gr.Tab('Raw'):
  285. textbox = gr.Textbox(value=default_text, lines=23)
  286. with gr.Tab('Markdown'):
  287. markdown = gr.Markdown()
  288. with gr.Tab('HTML'):
  289. html = gr.HTML()
  290. buttons["Generate"] = gr.Button("Generate")
  291. buttons["Stop"] = gr.Button("Stop")
  292. max_new_tokens = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
  293. preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping = create_settings_menus()
  294. if shared.args.extensions is not None:
  295. create_extensions_block()
  296. gen_events.append(buttons["Generate"].click(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping], [textbox, markdown, html], show_progress=shared.args.no_stream, api_name="textgen"))
  297. gen_events.append(textbox.submit(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping], [textbox, markdown, html], show_progress=shared.args.no_stream))
  298. buttons["Stop"].click(None, None, None, cancels=gen_events)
  299. else:
  300. with gr.Blocks(css=ui.css, analytics_enabled=False) as interface:
  301. gr.Markdown(description)
  302. with gr.Row():
  303. with gr.Column():
  304. textbox = gr.Textbox(value=default_text, lines=15, label='Input')
  305. max_new_tokens = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
  306. buttons["Generate"] = gr.Button("Generate")
  307. with gr.Row():
  308. with gr.Column():
  309. buttons["Continue"] = gr.Button("Continue")
  310. with gr.Column():
  311. buttons["Stop"] = gr.Button("Stop")
  312. preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping = create_settings_menus()
  313. if shared.args.extensions is not None:
  314. create_extensions_block()
  315. with gr.Column():
  316. with gr.Tab('Raw'):
  317. output_textbox = gr.Textbox(lines=15, label='Output')
  318. with gr.Tab('Markdown'):
  319. markdown = gr.Markdown()
  320. with gr.Tab('HTML'):
  321. html = gr.HTML()
  322. gen_events.append(buttons["Generate"].click(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping], [output_textbox, markdown, html], show_progress=shared.args.no_stream, api_name="textgen"))
  323. gen_events.append(textbox.submit(generate_reply, [textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping], [output_textbox, markdown, html], show_progress=shared.args.no_stream))
  324. gen_events.append(buttons["Continue"].click(generate_reply, [output_textbox, max_new_tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping], [output_textbox, markdown, html], show_progress=shared.args.no_stream))
  325. buttons["Stop"].click(None, None, None, cancels=gen_events)
  326. interface.queue()
  327. if shared.args.listen:
  328. interface.launch(prevent_thread_lock=True, share=shared.args.share, server_name="0.0.0.0", server_port=shared.args.listen_port)
  329. else:
  330. interface.launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port)
  331. # I think that I will need this later
  332. while True:
  333. time.sleep(0.5)