chat.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418
  1. import base64
  2. import copy
  3. import io
  4. import json
  5. import re
  6. from datetime import datetime
  7. from pathlib import Path
  8. import yaml
  9. from PIL import Image
  10. import modules.extensions as extensions_module
  11. import modules.shared as shared
  12. from modules.extensions import apply_extensions
  13. from modules.html_generator import fix_newlines, generate_chat_html
  14. from modules.text_generation import (encode, generate_reply,
  15. get_max_prompt_length)
  16. def generate_chat_output(history, name1, name2, character):
  17. if shared.args.cai_chat:
  18. return generate_chat_html(history, name1, name2, character)
  19. else:
  20. return history
  21. def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=False, also_return_rows=False):
  22. user_input = fix_newlines(user_input)
  23. rows = [f"{context.strip()}\n"]
  24. if shared.soft_prompt:
  25. chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
  26. max_length = min(get_max_prompt_length(max_new_tokens), chat_prompt_size)
  27. i = len(shared.history['internal'])-1
  28. while i >= 0 and len(encode(''.join(rows), max_new_tokens)[0]) < max_length:
  29. rows.insert(1, f"{name2}: {shared.history['internal'][i][1].strip()}\n")
  30. prev_user_input = shared.history['internal'][i][0]
  31. if len(prev_user_input) > 0 and prev_user_input != '<|BEGIN-VISIBLE-CHAT|>':
  32. rows.insert(1, f"{name1}: {prev_user_input.strip()}\n")
  33. i -= 1
  34. if not impersonate:
  35. if len(user_input) > 0:
  36. rows.append(f"{name1}: {user_input}\n")
  37. rows.append(apply_extensions(f"{name2}:", "bot_prefix"))
  38. limit = 3
  39. else:
  40. rows.append(f"{name1}:")
  41. limit = 2
  42. while len(rows) > limit and len(encode(''.join(rows), max_new_tokens)[0]) >= max_length:
  43. rows.pop(1)
  44. prompt = ''.join(rows)
  45. if also_return_rows:
  46. return prompt, rows
  47. else:
  48. return prompt
  49. def extract_message_from_reply(reply, name1, name2, stop_at_newline):
  50. next_character_found = False
  51. if stop_at_newline:
  52. lines = reply.split('\n')
  53. reply = lines[0].strip()
  54. if len(lines) > 1:
  55. next_character_found = True
  56. else:
  57. for string in [f"\n{name1}:", f"\n{name2}:"]:
  58. idx = reply.find(string)
  59. if idx != -1:
  60. reply = reply[:idx]
  61. next_character_found = True
  62. # If something like "\nYo" is generated just before "\nYou:"
  63. # is completed, trim it
  64. if not next_character_found:
  65. for string in [f"\n{name1}:", f"\n{name2}:"]:
  66. for j in range(len(string)-1, 0, -1):
  67. if reply[-j:] == string[:j]:
  68. reply = reply[:-j]
  69. break
  70. reply = fix_newlines(reply)
  71. return reply, next_character_found
  72. def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1, regenerate=False):
  73. just_started = True
  74. eos_token = '\n' if stop_at_newline else None
  75. name1_original = name1
  76. if 'pygmalion' in shared.model_name.lower():
  77. name1 = "You"
  78. # Check if any extension wants to hijack this function call
  79. visible_text = None
  80. custom_generate_chat_prompt = None
  81. for extension, _ in extensions_module.iterator():
  82. if hasattr(extension, 'input_hijack') and extension.input_hijack['state'] == True:
  83. extension.input_hijack['state'] = False
  84. text, visible_text = extension.input_hijack['value']
  85. if custom_generate_chat_prompt is None and hasattr(extension, 'custom_generate_chat_prompt'):
  86. custom_generate_chat_prompt = extension.custom_generate_chat_prompt
  87. if visible_text is None:
  88. visible_text = text
  89. if shared.args.chat:
  90. visible_text = visible_text.replace('\n', '<br>')
  91. text = apply_extensions(text, "input")
  92. if custom_generate_chat_prompt is None:
  93. prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size)
  94. else:
  95. prompt = custom_generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size)
  96. # Yield *Is typing...*
  97. if not regenerate:
  98. yield shared.history['visible']+[[visible_text, shared.processing_message]]
  99. # Generate
  100. cumulative_reply = ''
  101. for i in range(chat_generation_attempts):
  102. reply = None
  103. for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]):
  104. reply = cumulative_reply + reply
  105. # Extracting the reply
  106. reply, next_character_found = extract_message_from_reply(reply, name1, name2, stop_at_newline)
  107. visible_reply = re.sub("(<USER>|<user>|{{user}})", name1_original, reply)
  108. visible_reply = apply_extensions(visible_reply, "output")
  109. if shared.args.chat:
  110. visible_reply = visible_reply.replace('\n', '<br>')
  111. # We need this global variable to handle the Stop event,
  112. # otherwise gradio gets confused
  113. if shared.stop_everything:
  114. return shared.history['visible']
  115. if just_started:
  116. just_started = False
  117. shared.history['internal'].append(['', ''])
  118. shared.history['visible'].append(['', ''])
  119. shared.history['internal'][-1] = [text, reply]
  120. shared.history['visible'][-1] = [visible_text, visible_reply]
  121. if not shared.args.no_stream:
  122. yield shared.history['visible']
  123. if next_character_found:
  124. break
  125. if reply is not None:
  126. cumulative_reply = reply
  127. yield shared.history['visible']
  128. def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1):
  129. eos_token = '\n' if stop_at_newline else None
  130. if 'pygmalion' in shared.model_name.lower():
  131. name1 = "You"
  132. prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=True)
  133. # Yield *Is typing...*
  134. yield shared.processing_message
  135. cumulative_reply = ''
  136. for i in range(chat_generation_attempts):
  137. reply = None
  138. for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, eos_token=eos_token, stopping_strings=[f"\n{name1}:", f"\n{name2}:"]):
  139. reply = cumulative_reply + reply
  140. reply, next_character_found = extract_message_from_reply(reply, name1, name2, stop_at_newline)
  141. yield reply
  142. if next_character_found:
  143. break
  144. if reply is not None:
  145. cumulative_reply = reply
  146. yield reply
  147. def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1):
  148. for history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts):
  149. yield generate_chat_html(history, name1, name2, shared.character)
  150. def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts=1):
  151. if (shared.character != 'None' and len(shared.history['visible']) == 1) or len(shared.history['internal']) == 0:
  152. yield generate_chat_output(shared.history['visible'], name1, name2, shared.character)
  153. else:
  154. last_visible = shared.history['visible'].pop()
  155. last_internal = shared.history['internal'].pop()
  156. # Yield '*Is typing...*'
  157. yield generate_chat_output(shared.history['visible']+[[last_visible[0], shared.processing_message]], name1, name2, shared.character)
  158. for history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, seed, name1, name2, context, stop_at_newline, chat_prompt_size, chat_generation_attempts, regenerate=True):
  159. if shared.args.cai_chat:
  160. shared.history['visible'][-1] = [last_visible[0], history[-1][1]]
  161. else:
  162. shared.history['visible'][-1] = (last_visible[0], history[-1][1])
  163. yield generate_chat_output(shared.history['visible'], name1, name2, shared.character)
  164. def remove_last_message(name1, name2):
  165. if len(shared.history['visible']) > 0 and not shared.history['internal'][-1][0] == '<|BEGIN-VISIBLE-CHAT|>':
  166. last = shared.history['visible'].pop()
  167. shared.history['internal'].pop()
  168. else:
  169. last = ['', '']
  170. if shared.args.cai_chat:
  171. return generate_chat_html(shared.history['visible'], name1, name2, shared.character), last[0]
  172. else:
  173. return shared.history['visible'], last[0]
  174. def send_last_reply_to_input():
  175. if len(shared.history['internal']) > 0:
  176. return shared.history['internal'][-1][1]
  177. else:
  178. return ''
  179. def replace_last_reply(text, name1, name2):
  180. if len(shared.history['visible']) > 0:
  181. if shared.args.cai_chat:
  182. shared.history['visible'][-1][1] = text
  183. else:
  184. shared.history['visible'][-1] = (shared.history['visible'][-1][0], text)
  185. shared.history['internal'][-1][1] = apply_extensions(text, "input")
  186. return generate_chat_output(shared.history['visible'], name1, name2, shared.character)
  187. def clear_html():
  188. return generate_chat_html([], "", "", shared.character)
  189. def clear_chat_log(name1, name2):
  190. if shared.character != 'None':
  191. found = False
  192. for i in range(len(shared.history['internal'])):
  193. if '<|BEGIN-VISIBLE-CHAT|>' in shared.history['internal'][i][0]:
  194. shared.history['visible'] = [['', apply_extensions(shared.history['internal'][i][1], "output")]]
  195. shared.history['internal'] = [shared.history['internal'][i]]
  196. found = True
  197. break
  198. if not found:
  199. shared.history['visible'] = []
  200. shared.history['internal'] = []
  201. else:
  202. shared.history['internal'] = []
  203. shared.history['visible'] = []
  204. return generate_chat_output(shared.history['visible'], name1, name2, shared.character)
  205. def redraw_html(name1, name2):
  206. return generate_chat_html(shared.history['visible'], name1, name2, shared.character)
  207. def tokenize_dialogue(dialogue, name1, name2):
  208. history = []
  209. dialogue = re.sub('<START>', '', dialogue)
  210. dialogue = re.sub('<start>', '', dialogue)
  211. dialogue = re.sub('(\n|^)[Aa]non:', '\\1You:', dialogue)
  212. dialogue = re.sub('(\n|^)\[CHARACTER\]:', f'\\g<1>{name2}:', dialogue)
  213. idx = [m.start() for m in re.finditer(f"(^|\n)({re.escape(name1)}|{re.escape(name2)}):", dialogue)]
  214. if len(idx) == 0:
  215. return history
  216. messages = []
  217. for i in range(len(idx)-1):
  218. messages.append(dialogue[idx[i]:idx[i+1]].strip())
  219. messages.append(dialogue[idx[-1]:].strip())
  220. entry = ['', '']
  221. for i in messages:
  222. if i.startswith(f'{name1}:'):
  223. entry[0] = i[len(f'{name1}:'):].strip()
  224. elif i.startswith(f'{name2}:'):
  225. entry[1] = i[len(f'{name2}:'):].strip()
  226. if not (len(entry[0]) == 0 and len(entry[1]) == 0):
  227. history.append(entry)
  228. entry = ['', '']
  229. print("\033[1;32;1m\nDialogue tokenized to:\033[0;37;0m\n", end='')
  230. for row in history:
  231. for column in row:
  232. print("\n")
  233. for line in column.strip().split('\n'):
  234. print("| "+line+"\n")
  235. print("|\n")
  236. print("------------------------------")
  237. return history
  238. def save_history(timestamp=True):
  239. prefix = '' if shared.character == 'None' else f"{shared.character}_"
  240. if timestamp:
  241. fname = f"{prefix}{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
  242. else:
  243. fname = f"{prefix}persistent.json"
  244. if not Path('logs').exists():
  245. Path('logs').mkdir()
  246. with open(Path(f'logs/{fname}'), 'w', encoding='utf-8') as f:
  247. f.write(json.dumps({'data': shared.history['internal'], 'data_visible': shared.history['visible']}, indent=2))
  248. return Path(f'logs/{fname}')
  249. def load_history(file, name1, name2):
  250. file = file.decode('utf-8')
  251. try:
  252. j = json.loads(file)
  253. if 'data' in j:
  254. shared.history['internal'] = j['data']
  255. if 'data_visible' in j:
  256. shared.history['visible'] = j['data_visible']
  257. else:
  258. shared.history['visible'] = copy.deepcopy(shared.history['internal'])
  259. # Compatibility with Pygmalion AI's official web UI
  260. elif 'chat' in j:
  261. shared.history['internal'] = [':'.join(x.split(':')[1:]).strip() for x in j['chat']]
  262. if len(j['chat']) > 0 and j['chat'][0].startswith(f'{name2}:'):
  263. shared.history['internal'] = [['<|BEGIN-VISIBLE-CHAT|>', shared.history['internal'][0]]] + [[shared.history['internal'][i], shared.history['internal'][i+1]] for i in range(1, len(shared.history['internal'])-1, 2)]
  264. shared.history['visible'] = copy.deepcopy(shared.history['internal'])
  265. shared.history['visible'][0][0] = ''
  266. else:
  267. shared.history['internal'] = [[shared.history['internal'][i], shared.history['internal'][i+1]] for i in range(0, len(shared.history['internal'])-1, 2)]
  268. shared.history['visible'] = copy.deepcopy(shared.history['internal'])
  269. except:
  270. shared.history['internal'] = tokenize_dialogue(file, name1, name2)
  271. shared.history['visible'] = copy.deepcopy(shared.history['internal'])
  272. def load_default_history(name1, name2):
  273. shared.character = 'None'
  274. if Path('logs/persistent.json').exists():
  275. load_history(open(Path('logs/persistent.json'), 'rb').read(), name1, name2)
  276. else:
  277. shared.history['internal'] = []
  278. shared.history['visible'] = []
  279. def replace_character_names(text, name1, name2):
  280. text = text.replace('{{user}}', name1).replace('{{char}}', name2)
  281. return text.replace('<USER>', name1).replace('<BOT>', name2)
  282. def build_pygmalion_style_context(data):
  283. context = ""
  284. if 'char_persona' in data and data['char_persona'] != '':
  285. context += f"{data['char_name']}'s Persona: {data['char_persona']}\n"
  286. if 'world_scenario' in data and data['world_scenario'] != '':
  287. context += f"Scenario: {data['world_scenario']}\n"
  288. context = f"{context.strip()}\n<START>\n"
  289. return context
  290. def load_character(_character, name1, name2):
  291. shared.history['internal'] = []
  292. shared.history['visible'] = []
  293. if _character != 'None':
  294. shared.character = _character
  295. for extension in ["yml", "yaml", "json"]:
  296. filepath = Path(f'characters/{_character}.{extension}')
  297. if filepath.exists():
  298. break
  299. file_contents = open(filepath, 'r', encoding='utf-8').read()
  300. data = json.loads(file_contents) if extension == "json" else yaml.safe_load(file_contents)
  301. name2 = data['name'] if 'name' in data else data['char_name']
  302. for field in ['context', 'greeting', 'example_dialogue', 'char_persona', 'char_greeting', 'world_scenario']:
  303. if field in data:
  304. data[field] = replace_character_names(data[field], name1, name2)
  305. if 'context' in data:
  306. context = f"{data['context'].strip()}\n\n"
  307. greeting_field = 'greeting'
  308. else:
  309. context = build_pygmalion_style_context(data)
  310. greeting_field = 'char_greeting'
  311. if 'example_dialogue' in data and data['example_dialogue'] != '':
  312. context += f"{data['example_dialogue'].strip()}\n"
  313. if greeting_field in data and len(data[greeting_field].strip()) > 0:
  314. shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', data[greeting_field]]]
  315. shared.history['visible'] += [['', apply_extensions(data[greeting_field], "output")]]
  316. else:
  317. shared.character = 'None'
  318. context = shared.settings['context']
  319. name2 = shared.settings['name2']
  320. if Path(f'logs/{shared.character}_persistent.json').exists():
  321. load_history(open(Path(f'logs/{shared.character}_persistent.json'), 'rb').read(), name1, name2)
  322. if shared.args.cai_chat:
  323. return name2, context, generate_chat_html(shared.history['visible'], name1, name2, shared.character)
  324. else:
  325. return name2, context, shared.history['visible']
  326. def upload_character(json_file, img, tavern=False):
  327. json_file = json_file if type(json_file) == str else json_file.decode('utf-8')
  328. data = json.loads(json_file)
  329. outfile_name = data["char_name"]
  330. i = 1
  331. while Path(f'characters/{outfile_name}.json').exists():
  332. outfile_name = f'{data["char_name"]}_{i:03d}'
  333. i += 1
  334. if tavern:
  335. outfile_name = f'TavernAI-{outfile_name}'
  336. with open(Path(f'characters/{outfile_name}.json'), 'w', encoding='utf-8') as f:
  337. f.write(json_file)
  338. if img is not None:
  339. img = Image.open(io.BytesIO(img))
  340. img.save(Path(f'characters/{outfile_name}.png'))
  341. print(f'New character saved to "characters/{outfile_name}.json".')
  342. return outfile_name
  343. def upload_tavern_character(img, name1, name2):
  344. _img = Image.open(io.BytesIO(img))
  345. _img.getexif()
  346. decoded_string = base64.b64decode(_img.info['chara'])
  347. _json = json.loads(decoded_string)
  348. _json = {"char_name": _json['name'], "char_persona": _json['description'], "char_greeting": _json["first_mes"], "example_dialogue": _json['mes_example'], "world_scenario": _json['scenario']}
  349. return upload_character(json.dumps(_json), img, tavern=True)
  350. def upload_your_profile_picture(img):
  351. img = Image.open(io.BytesIO(img))
  352. img.save(Path('img_me.png'))
  353. print('Profile picture saved to "img_me.png"')