chat.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392
  1. import base64
  2. import copy
  3. import io
  4. import json
  5. import re
  6. from datetime import datetime
  7. from pathlib import Path
  8. from PIL import Image
  9. import modules.extensions as extensions_module
  10. import modules.shared as shared
  11. from modules.extensions import apply_extensions
  12. from modules.html_generator import fix_newlines, generate_chat_html
  13. from modules.text_generation import (encode, generate_reply,
  14. get_max_prompt_length)
  15. def generate_chat_output(history, name1, name2, character):
  16. if shared.args.cai_chat:
  17. return generate_chat_html(history, name1, name2, character)
  18. else:
  19. return history
  20. def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=False):
  21. user_input = fix_newlines(user_input)
  22. rows = [f"{context.strip()}\n"]
  23. if shared.soft_prompt:
  24. chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
  25. max_length = min(get_max_prompt_length(max_new_tokens), chat_prompt_size)
  26. i = len(shared.history['internal'])-1
  27. while i >= 0 and len(encode(''.join(rows), max_new_tokens)[0]) < max_length:
  28. rows.insert(1, f"{name2}: {shared.history['internal'][i][1].strip()}\n")
  29. prev_user_input = shared.history['internal'][i][0]
  30. if len(prev_user_input) > 0 and prev_user_input != '<|BEGIN-VISIBLE-CHAT|>':
  31. rows.insert(1, f"{name1}: {prev_user_input.strip()}\n")
  32. i -= 1
  33. if not impersonate:
  34. if len(user_input) > 0:
  35. rows.append(f"{name1}: {user_input}\n")
  36. rows.append(apply_extensions(f"{name2}:", "bot_prefix"))
  37. limit = 3
  38. else:
  39. rows.append(f"{name1}:")
  40. limit = 2
  41. while len(rows) > limit and len(encode(''.join(rows), max_new_tokens)[0]) >= max_length:
  42. rows.pop(1)
  43. prompt = ''.join(rows)
  44. return prompt
  45. def extract_message_from_reply(reply, name1, name2, check):
  46. next_character_found = False
  47. if check:
  48. lines = reply.split('\n')
  49. reply = lines[0].strip()
  50. if len(lines) > 1:
  51. next_character_found = True
  52. else:
  53. for string in [f"\n{name1}:", f"\n{name2}:"]:
  54. idx = reply.find(string)
  55. if idx != -1:
  56. reply = reply[:idx]
  57. next_character_found = True
  58. # If something like "\nYo" is generated just before "\nYou:"
  59. # is completed, trim it
  60. if not next_character_found:
  61. for string in [f"\n{name1}:", f"\n{name2}:"]:
  62. for j in range(len(string)-1, 0, -1):
  63. if reply[-j:] == string[:j]:
  64. reply = reply[:-j]
  65. break
  66. reply = fix_newlines(reply)
  67. return reply, next_character_found
  68. def stop_everything_event():
  69. shared.stop_everything = True
  70. def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1, regenerate=False):
  71. shared.stop_everything = False
  72. just_started = True
  73. eos_token = '\n' if check else None
  74. name1_original = name1
  75. if 'pygmalion' in shared.model_name.lower():
  76. name1 = "You"
  77. # Check if any extension wants to hijack this function call
  78. visible_text = None
  79. custom_generate_chat_prompt = None
  80. for extension, _ in extensions_module.iterator():
  81. if hasattr(extension, 'input_hijack') and extension.input_hijack['state'] == True:
  82. extension.input_hijack['state'] = False
  83. text, visible_text = extension.input_hijack['value']
  84. if custom_generate_chat_prompt is None and hasattr(extension, 'custom_generate_chat_prompt'):
  85. custom_generate_chat_prompt = extension.custom_generate_chat_prompt
  86. if visible_text is None:
  87. visible_text = text
  88. if shared.args.chat:
  89. visible_text = visible_text.replace('\n', '<br>')
  90. text = apply_extensions(text, "input")
  91. if custom_generate_chat_prompt is None:
  92. prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size)
  93. else:
  94. prompt = custom_generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size)
  95. # Yield *Is typing...*
  96. if not regenerate:
  97. yield shared.history['visible']+[[visible_text, shared.processing_message]]
  98. # Generate
  99. cumulative_reply = ''
  100. for i in range(chat_generation_attempts):
  101. for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]):
  102. reply = cumulative_reply + reply
  103. # Extracting the reply
  104. reply, next_character_found = extract_message_from_reply(reply, name1, name2, check)
  105. visible_reply = re.sub("(<USER>|<user>|{{user}})", name1_original, reply)
  106. visible_reply = apply_extensions(visible_reply, "output")
  107. if shared.args.chat:
  108. visible_reply = visible_reply.replace('\n', '<br>')
  109. # We need this global variable to handle the Stop event,
  110. # otherwise gradio gets confused
  111. if shared.stop_everything:
  112. return shared.history['visible']
  113. if just_started:
  114. just_started = False
  115. shared.history['internal'].append(['', ''])
  116. shared.history['visible'].append(['', ''])
  117. shared.history['internal'][-1] = [text, reply]
  118. shared.history['visible'][-1] = [visible_text, visible_reply]
  119. if not shared.args.no_stream:
  120. yield shared.history['visible']
  121. if next_character_found:
  122. break
  123. cumulative_reply = reply
  124. yield shared.history['visible']
  125. def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
  126. eos_token = '\n' if check else None
  127. if 'pygmalion' in shared.model_name.lower():
  128. name1 = "You"
  129. prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=True)
  130. # Yield *Is typing...*
  131. yield shared.processing_message
  132. cumulative_reply = ''
  133. for i in range(chat_generation_attempts):
  134. for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]):
  135. reply = cumulative_reply + reply
  136. reply, next_character_found = extract_message_from_reply(reply, name1, name2, check)
  137. yield reply
  138. if next_character_found:
  139. break
  140. cumulative_reply = reply
  141. yield reply
  142. def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
  143. for _history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, check, chat_prompt_size, chat_generation_attempts):
  144. yield generate_chat_html(_history, name1, name2, shared.character)
  145. def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
  146. if (shared.character != 'None' and len(shared.history['visible']) == 1) or len(shared.history['internal']) == 0:
  147. yield generate_chat_output(shared.history['visible'], name1, name2, shared.character)
  148. else:
  149. last_visible = shared.history['visible'].pop()
  150. last_internal = shared.history['internal'].pop()
  151. # Yield '*Is typing...*'
  152. yield generate_chat_output(shared.history['visible']+[[last_visible[0], shared.processing_message]], name1, name2, shared.character)
  153. for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, check, chat_prompt_size, chat_generation_attempts, regenerate=True):
  154. if shared.args.cai_chat:
  155. shared.history['visible'][-1] = [last_visible[0], _history[-1][1]]
  156. else:
  157. shared.history['visible'][-1] = (last_visible[0], _history[-1][1])
  158. yield generate_chat_output(shared.history['visible'], name1, name2, shared.character)
  159. def remove_last_message(name1, name2):
  160. if len(shared.history['visible']) > 0 and not shared.history['internal'][-1][0] == '<|BEGIN-VISIBLE-CHAT|>':
  161. last = shared.history['visible'].pop()
  162. shared.history['internal'].pop()
  163. else:
  164. last = ['', '']
  165. if shared.args.cai_chat:
  166. return generate_chat_html(shared.history['visible'], name1, name2, shared.character), last[0]
  167. else:
  168. return shared.history['visible'], last[0]
  169. def send_last_reply_to_input():
  170. if len(shared.history['internal']) > 0:
  171. return shared.history['internal'][-1][1]
  172. else:
  173. return ''
  174. def replace_last_reply(text, name1, name2):
  175. if len(shared.history['visible']) > 0:
  176. if shared.args.cai_chat:
  177. shared.history['visible'][-1][1] = text
  178. else:
  179. shared.history['visible'][-1] = (shared.history['visible'][-1][0], text)
  180. shared.history['internal'][-1][1] = apply_extensions(text, "input")
  181. return generate_chat_output(shared.history['visible'], name1, name2, shared.character)
  182. def clear_html():
  183. return generate_chat_html([], "", "", shared.character)
  184. def clear_chat_log(name1, name2):
  185. if shared.character != 'None':
  186. found = False
  187. for i in range(len(shared.history['internal'])):
  188. if '<|BEGIN-VISIBLE-CHAT|>' in shared.history['internal'][i][0]:
  189. shared.history['visible'] = [['', apply_extensions(shared.history['internal'][i][1], "output")]]
  190. shared.history['internal'] = [shared.history['internal'][i]]
  191. found = True
  192. break
  193. if not found:
  194. shared.history['visible'] = []
  195. shared.history['internal'] = []
  196. else:
  197. shared.history['internal'] = []
  198. shared.history['visible'] = []
  199. return generate_chat_output(shared.history['visible'], name1, name2, shared.character)
  200. def redraw_html(name1, name2):
  201. return generate_chat_html(shared.history['visible'], name1, name2, shared.character)
  202. def tokenize_dialogue(dialogue, name1, name2):
  203. _history = []
  204. dialogue = re.sub('<START>', '', dialogue)
  205. dialogue = re.sub('<start>', '', dialogue)
  206. dialogue = re.sub('(\n|^)[Aa]non:', '\\1You:', dialogue)
  207. dialogue = re.sub('(\n|^)\[CHARACTER\]:', f'\\g<1>{name2}:', dialogue)
  208. idx = [m.start() for m in re.finditer(f"(^|\n)({re.escape(name1)}|{re.escape(name2)}):", dialogue)]
  209. if len(idx) == 0:
  210. return _history
  211. messages = []
  212. for i in range(len(idx)-1):
  213. messages.append(dialogue[idx[i]:idx[i+1]].strip())
  214. messages.append(dialogue[idx[-1]:].strip())
  215. entry = ['', '']
  216. for i in messages:
  217. if i.startswith(f'{name1}:'):
  218. entry[0] = i[len(f'{name1}:'):].strip()
  219. elif i.startswith(f'{name2}:'):
  220. entry[1] = i[len(f'{name2}:'):].strip()
  221. if not (len(entry[0]) == 0 and len(entry[1]) == 0):
  222. _history.append(entry)
  223. entry = ['', '']
  224. print("\033[1;32;1m\nDialogue tokenized to:\033[0;37;0m\n", end='')
  225. for row in _history:
  226. for column in row:
  227. print("\n")
  228. for line in column.strip().split('\n'):
  229. print("| "+line+"\n")
  230. print("|\n")
  231. print("------------------------------")
  232. return _history
  233. def save_history(timestamp=True):
  234. prefix = '' if shared.character == 'None' else f"{shared.character}_"
  235. if timestamp:
  236. fname = f"{prefix}{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
  237. else:
  238. fname = f"{prefix}persistent.json"
  239. if not Path('logs').exists():
  240. Path('logs').mkdir()
  241. with open(Path(f'logs/{fname}'), 'w', encoding='utf-8') as f:
  242. f.write(json.dumps({'data': shared.history['internal'], 'data_visible': shared.history['visible']}, indent=2))
  243. return Path(f'logs/{fname}')
  244. def load_history(file, name1, name2):
  245. file = file.decode('utf-8')
  246. try:
  247. j = json.loads(file)
  248. if 'data' in j:
  249. shared.history['internal'] = j['data']
  250. if 'data_visible' in j:
  251. shared.history['visible'] = j['data_visible']
  252. else:
  253. shared.history['visible'] = copy.deepcopy(shared.history['internal'])
  254. # Compatibility with Pygmalion AI's official web UI
  255. elif 'chat' in j:
  256. shared.history['internal'] = [':'.join(x.split(':')[1:]).strip() for x in j['chat']]
  257. if len(j['chat']) > 0 and j['chat'][0].startswith(f'{name2}:'):
  258. shared.history['internal'] = [['<|BEGIN-VISIBLE-CHAT|>', shared.history['internal'][0]]] + [[shared.history['internal'][i], shared.history['internal'][i+1]] for i in range(1, len(shared.history['internal'])-1, 2)]
  259. shared.history['visible'] = copy.deepcopy(shared.history['internal'])
  260. shared.history['visible'][0][0] = ''
  261. else:
  262. shared.history['internal'] = [[shared.history['internal'][i], shared.history['internal'][i+1]] for i in range(0, len(shared.history['internal'])-1, 2)]
  263. shared.history['visible'] = copy.deepcopy(shared.history['internal'])
  264. except:
  265. shared.history['internal'] = tokenize_dialogue(file, name1, name2)
  266. shared.history['visible'] = copy.deepcopy(shared.history['internal'])
  267. def load_default_history(name1, name2):
  268. if Path('logs/persistent.json').exists():
  269. load_history(open(Path('logs/persistent.json'), 'rb').read(), name1, name2)
  270. else:
  271. shared.history['internal'] = []
  272. shared.history['visible'] = []
  273. def load_character(_character, name1, name2):
  274. context = ""
  275. shared.history['internal'] = []
  276. shared.history['visible'] = []
  277. if _character != 'None':
  278. shared.character = _character
  279. data = json.loads(open(Path(f'characters/{_character}.json'), 'r', encoding='utf-8').read())
  280. name2 = data['char_name']
  281. if 'char_persona' in data and data['char_persona'] != '':
  282. context += f"{data['char_name']}'s Persona: {data['char_persona']}\n"
  283. if 'world_scenario' in data and data['world_scenario'] != '':
  284. context += f"Scenario: {data['world_scenario']}\n"
  285. context = f"{context.strip()}\n<START>\n"
  286. if 'example_dialogue' in data and data['example_dialogue'] != '':
  287. data['example_dialogue'] = data['example_dialogue'].replace('{{user}}', name1).replace('{{char}}', name2)
  288. data['example_dialogue'] = data['example_dialogue'].replace('<USER>', name1).replace('<BOT>', name2)
  289. context += f"{data['example_dialogue'].strip()}\n"
  290. if 'char_greeting' in data and len(data['char_greeting'].strip()) > 0:
  291. shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', data['char_greeting']]]
  292. shared.history['visible'] += [['', apply_extensions(data['char_greeting'], "output")]]
  293. else:
  294. shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', "Hello there!"]]
  295. shared.history['visible'] += [['', "Hello there!"]]
  296. else:
  297. shared.character = None
  298. context = shared.settings['context_pygmalion']
  299. name2 = shared.settings['name2_pygmalion']
  300. if Path(f'logs/{shared.character}_persistent.json').exists():
  301. load_history(open(Path(f'logs/{shared.character}_persistent.json'), 'rb').read(), name1, name2)
  302. if shared.args.cai_chat:
  303. return name2, context, generate_chat_html(shared.history['visible'], name1, name2, shared.character)
  304. else:
  305. return name2, context, shared.history['visible']
  306. def upload_character(json_file, img, tavern=False):
  307. json_file = json_file if type(json_file) == str else json_file.decode('utf-8')
  308. data = json.loads(json_file)
  309. outfile_name = data["char_name"]
  310. i = 1
  311. while Path(f'characters/{outfile_name}.json').exists():
  312. outfile_name = f'{data["char_name"]}_{i:03d}'
  313. i += 1
  314. if tavern:
  315. outfile_name = f'TavernAI-{outfile_name}'
  316. with open(Path(f'characters/{outfile_name}.json'), 'w', encoding='utf-8') as f:
  317. f.write(json_file)
  318. if img is not None:
  319. img = Image.open(io.BytesIO(img))
  320. img.save(Path(f'characters/{outfile_name}.png'))
  321. print(f'New character saved to "characters/{outfile_name}.json".')
  322. return outfile_name
  323. def upload_tavern_character(img, name1, name2):
  324. _img = Image.open(io.BytesIO(img))
  325. _img.getexif()
  326. decoded_string = base64.b64decode(_img.info['chara'])
  327. _json = json.loads(decoded_string)
  328. _json = {"char_name": _json['name'], "char_persona": _json['description'], "char_greeting": _json["first_mes"], "example_dialogue": _json['mes_example'], "world_scenario": _json['scenario']}
  329. return upload_character(json.dumps(_json), img, tavern=True)
  330. def upload_your_profile_picture(img):
  331. img = Image.open(io.BytesIO(img))
  332. img.save(Path('img_me.png'))
  333. print('Profile picture saved to "img_me.png"')