chat.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373
  1. import base64
  2. import copy
  3. import io
  4. import json
  5. import re
  6. from datetime import datetime
  7. from io import BytesIO
  8. from pathlib import Path
  9. from PIL import Image
  10. import modules.shared as shared
  11. from modules.extensions import apply_extensions
  12. from modules.html_generator import generate_chat_html
  13. from modules.text_generation import encode, generate_reply, get_max_prompt_length
  14. if shared.args.picture and (shared.args.cai_chat or shared.args.chat):
  15. import modules.bot_picture as bot_picture
  16. # This gets the new line characters right.
  17. def clean_chat_message(text):
  18. text = text.replace('\n', '\n\n')
  19. text = re.sub(r"\n{3,}", "\n\n", text)
  20. text = text.strip()
  21. return text
  22. def generate_chat_prompt(user_input, tokens, name1, name2, context, chat_prompt_size, impersonate=False):
  23. user_input = clean_chat_message(user_input)
  24. rows = [f"{context.strip()}\n"]
  25. if shared.soft_prompt:
  26. chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
  27. max_length = min(get_max_prompt_length(tokens), chat_prompt_size)
  28. i = len(shared.history['internal'])-1
  29. while i >= 0 and len(encode(''.join(rows), tokens)[0]) < max_length:
  30. rows.insert(1, f"{name2}: {shared.history['internal'][i][1].strip()}\n")
  31. if not (shared.history['internal'][i][0] == '<|BEGIN-VISIBLE-CHAT|>'):
  32. rows.insert(1, f"{name1}: {shared.history['internal'][i][0].strip()}\n")
  33. i -= 1
  34. if not impersonate:
  35. rows.append(f"{name1}: {user_input}\n")
  36. rows.append(apply_extensions(f"{name2}:", "bot_prefix"))
  37. limit = 3
  38. else:
  39. rows.append(f"{name1}:")
  40. limit = 2
  41. while len(rows) > limit and len(encode(''.join(rows), tokens)[0]) >= max_length:
  42. rows.pop(1)
  43. prompt = ''.join(rows)
  44. return prompt
  45. def extract_message_from_reply(question, reply, current, other, check, extensions=False):
  46. next_character_found = False
  47. substring_found = False
  48. previous_idx = [m.start() for m in re.finditer(f"(^|\n){re.escape(current)}:", question)]
  49. idx = [m.start() for m in re.finditer(f"(^|\n){re.escape(current)}:", reply)]
  50. idx = idx[len(previous_idx)-1]
  51. if extensions:
  52. reply = reply[idx + 1 + len(apply_extensions(f"{current}:", "bot_prefix")):]
  53. else:
  54. reply = reply[idx + 1 + len(f"{current}:"):]
  55. if check:
  56. reply = reply.split('\n')[0].strip()
  57. else:
  58. idx = reply.find(f"\n{other}:")
  59. if idx != -1:
  60. reply = reply[:idx]
  61. next_character_found = True
  62. reply = clean_chat_message(reply)
  63. # Detect if something like "\nYo" is generated just before
  64. # "\nYou:" is completed
  65. tmp = f"\n{other}:"
  66. for j in range(1, len(tmp)):
  67. if reply[-j:] == tmp[:j]:
  68. substring_found = True
  69. return reply, next_character_found, substring_found
  70. def generate_chat_picture(picture, name1, name2):
  71. text = f'*{name1} sends {name2} a picture that contains the following: "{bot_picture.caption_image(picture)}"*'
  72. buffer = BytesIO()
  73. picture.save(buffer, format="JPEG")
  74. img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
  75. visible_text = f'<img src="data:image/jpeg;base64,{img_str}">'
  76. return text, visible_text
  77. def stop_everything_event():
  78. shared.stop_everything = True
  79. def chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None):
  80. shared.stop_everything = False
  81. just_started = True
  82. eos_token = '\n' if check else None
  83. if 'pygmalion' in shared.model_name.lower():
  84. name1 = "You"
  85. # Create the prompt
  86. if shared.args.picture and picture is not None:
  87. text, visible_text = generate_chat_picture(picture, name1, name2)
  88. else:
  89. visible_text = text
  90. if shared.args.chat:
  91. visible_text = visible_text.replace('\n', '<br>')
  92. text = apply_extensions(text, "input")
  93. prompt = generate_chat_prompt(text, tokens, name1, name2, context, chat_prompt_size)
  94. # Generate
  95. for reply in generate_reply(prompt, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name1}:"):
  96. # Extracting the reply
  97. reply, next_character_found, substring_found = extract_message_from_reply(prompt, reply, name2, name1, check, extensions=True)
  98. visible_reply = apply_extensions(reply, "output")
  99. if shared.args.chat:
  100. visible_reply = visible_reply.replace('\n', '<br>')
  101. # We need this global variable to handle the Stop event,
  102. # otherwise gradio gets confused
  103. if shared.stop_everything:
  104. return shared.history['visible']
  105. if just_started:
  106. just_started = False
  107. shared.history['internal'].append(['', ''])
  108. shared.history['visible'].append(['', ''])
  109. shared.history['internal'][-1] = [text, reply]
  110. shared.history['visible'][-1] = [visible_text, visible_reply]
  111. if not substring_found:
  112. yield shared.history['visible']
  113. if next_character_found:
  114. break
  115. yield shared.history['visible']
  116. def impersonate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None):
  117. if 'pygmalion' in shared.model_name.lower():
  118. name1 = "You"
  119. prompt = generate_chat_prompt(text, tokens, name1, name2, context, chat_prompt_size, impersonate=True)
  120. eos_token = '\n' if check else None
  121. for reply in generate_reply(prompt, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"):
  122. reply, next_character_found, substring_found = extract_message_from_reply(prompt, reply, name1, name2, check, extensions=False)
  123. if not substring_found:
  124. yield reply
  125. if next_character_found:
  126. break
  127. yield reply
  128. def cai_chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None):
  129. for _history in chatbot_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture):
  130. yield generate_chat_html(_history, name1, name2, shared.character)
  131. def regenerate_wrapper(text, tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture=None):
  132. if shared.character != 'None' and len(shared.history['visible']) == 1:
  133. if shared.args.cai_chat:
  134. yield generate_chat_html(shared.history['visible'], name1, name2, shared.character)
  135. else:
  136. yield shared.history['visible']
  137. else:
  138. last_visible = shared.history['visible'].pop()
  139. last_internal = shared.history['internal'].pop()
  140. for _history in chatbot_wrapper(last_internal[0], tokens, do_sample, max_new_tokens, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, picture):
  141. if shared.args.cai_chat:
  142. shared.history['visible'][-1] = [last_visible[0], _history[-1][1]]
  143. yield generate_chat_html(shared.history['visible'], name1, name2, shared.character)
  144. else:
  145. shared.history['visible'][-1] = (last_visible[0], _history[-1][1])
  146. yield shared.history['visible']
  147. def remove_last_message(name1, name2):
  148. if not shared.history['internal'][-1][0] == '<|BEGIN-VISIBLE-CHAT|>':
  149. last = shared.history['visible'].pop()
  150. shared.history['internal'].pop()
  151. else:
  152. last = ['', '']
  153. if shared.args.cai_chat:
  154. return generate_chat_html(shared.history['visible'], name1, name2, shared.character), last[0]
  155. else:
  156. return shared.history['visible'], last[0]
  157. def send_last_reply_to_input():
  158. if len(shared.history['internal']) > 0:
  159. return shared.history['internal'][-1][1]
  160. else:
  161. return ''
  162. def replace_last_reply(text, name1, name2):
  163. if len(shared.history['visible']) > 0:
  164. if shared.args.cai_chat:
  165. shared.history['visible'][-1][1] = text
  166. else:
  167. shared.history['visible'][-1] = (shared.history['visible'][-1][0], text)
  168. shared.history['internal'][-1][1] = apply_extensions(text, "input")
  169. if shared.args.cai_chat:
  170. return generate_chat_html(shared.history['visible'], name1, name2, shared.character)
  171. else:
  172. return shared.history['visible']
  173. def clear_html():
  174. return generate_chat_html([], "", "", shared.character)
  175. def clear_chat_log(name1, name2):
  176. if shared.character != 'None':
  177. for i in range(len(shared.history['internal'])):
  178. if '<|BEGIN-VISIBLE-CHAT|>' in shared.history['internal'][i][0]:
  179. shared.history['visible'] = [['', apply_extensions(shared.history['internal'][i][1], "output")]]
  180. shared.history['internal'] = shared.history['internal'][:i+1]
  181. break
  182. else:
  183. shared.history['internal'] = []
  184. shared.history['visible'] = []
  185. if shared.args.cai_chat:
  186. return generate_chat_html(shared.history['visible'], name1, name2, shared.character)
  187. else:
  188. return shared.history['visible']
  189. def redraw_html(name1, name2):
  190. return generate_chat_html(shared.history['visible'], name1, name2, shared.character)
  191. def tokenize_dialogue(dialogue, name1, name2):
  192. _history = []
  193. dialogue = re.sub('<START>', '', dialogue)
  194. dialogue = re.sub('<start>', '', dialogue)
  195. dialogue = re.sub('(\n|^)[Aa]non:', '\\1You:', dialogue)
  196. dialogue = re.sub('(\n|^)\[CHARACTER\]:', f'\\g<1>{name2}:', dialogue)
  197. idx = [m.start() for m in re.finditer(f"(^|\n)({re.escape(name1)}|{re.escape(name2)}):", dialogue)]
  198. if len(idx) == 0:
  199. return _history
  200. messages = []
  201. for i in range(len(idx)-1):
  202. messages.append(dialogue[idx[i]:idx[i+1]].strip())
  203. messages.append(dialogue[idx[-1]:].strip())
  204. entry = ['', '']
  205. for i in messages:
  206. if i.startswith(f'{name1}:'):
  207. entry[0] = i[len(f'{name1}:'):].strip()
  208. elif i.startswith(f'{name2}:'):
  209. entry[1] = i[len(f'{name2}:'):].strip()
  210. if not (len(entry[0]) == 0 and len(entry[1]) == 0):
  211. _history.append(entry)
  212. entry = ['', '']
  213. print(f"\033[1;32;1m\nDialogue tokenized to:\033[0;37;0m\n", end='')
  214. for row in _history:
  215. for column in row:
  216. print("\n")
  217. for line in column.strip().split('\n'):
  218. print("| "+line+"\n")
  219. print("|\n")
  220. print("------------------------------")
  221. return _history
  222. def save_history(timestamp=True):
  223. prefix = '' if shared.character == 'None' else f"{shared.character}_"
  224. if timestamp:
  225. fname = f"{prefix}{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
  226. else:
  227. fname = f"{prefix}persistent.json"
  228. if not Path('logs').exists():
  229. Path('logs').mkdir()
  230. with open(Path(f'logs/{fname}'), 'w') as f:
  231. f.write(json.dumps({'data': shared.history['internal'], 'data_visible': shared.history['visible']}, indent=2))
  232. return Path(f'logs/{fname}')
  233. def load_history(file, name1, name2):
  234. file = file.decode('utf-8')
  235. try:
  236. j = json.loads(file)
  237. if 'data' in j:
  238. shared.history['internal'] = j['data']
  239. if 'data_visible' in j:
  240. shared.history['visible'] = j['data_visible']
  241. else:
  242. shared.history['visible'] = copy.deepcopy(shared.history['internal'])
  243. # Compatibility with Pygmalion AI's official web UI
  244. elif 'chat' in j:
  245. shared.history['internal'] = [':'.join(x.split(':')[1:]).strip() for x in j['chat']]
  246. if len(j['chat']) > 0 and j['chat'][0].startswith(f'{name2}:'):
  247. shared.history['internal'] = [['<|BEGIN-VISIBLE-CHAT|>', shared.history['internal'][0]]] + [[shared.history['internal'][i], shared.history['internal'][i+1]] for i in range(1, len(shared.history['internal'])-1, 2)]
  248. shared.history['visible'] = copy.deepcopy(shared.history['internal'])
  249. shared.history['visible'][0][0] = ''
  250. else:
  251. shared.history['internal'] = [[shared.history['internal'][i], shared.history['internal'][i+1]] for i in range(0, len(shared.history['internal'])-1, 2)]
  252. shared.history['visible'] = copy.deepcopy(shared.history['internal'])
  253. except:
  254. shared.history['internal'] = tokenize_dialogue(file, name1, name2)
  255. shared.history['visible'] = copy.deepcopy(shared.history['internal'])
  256. def load_default_history(name1, name2):
  257. if Path(f'logs/persistent.json').exists():
  258. load_history(open(Path(f'logs/persistent.json'), 'rb').read(), name1, name2)
  259. else:
  260. shared.history['internal'] = []
  261. shared.history['visible'] = []
  262. def load_character(_character, name1, name2):
  263. context = ""
  264. shared.history['internal'] = []
  265. shared.history['visible'] = []
  266. if _character != 'None':
  267. shared.character = _character
  268. data = json.loads(open(Path(f'characters/{_character}.json'), 'r').read())
  269. name2 = data['char_name']
  270. if 'char_persona' in data and data['char_persona'] != '':
  271. context += f"{data['char_name']}'s Persona: {data['char_persona']}\n"
  272. if 'world_scenario' in data and data['world_scenario'] != '':
  273. context += f"Scenario: {data['world_scenario']}\n"
  274. context = f"{context.strip()}\n<START>\n"
  275. if 'example_dialogue' in data and data['example_dialogue'] != '':
  276. shared.history['internal'] = tokenize_dialogue(data['example_dialogue'], name1, name2)
  277. if 'char_greeting' in data and len(data['char_greeting'].strip()) > 0:
  278. shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', data['char_greeting']]]
  279. shared.history['visible'] += [['', apply_extensions(data['char_greeting'], "output")]]
  280. else:
  281. shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', "Hello there!"]]
  282. shared.history['visible'] += [['', "Hello there!"]]
  283. else:
  284. shared.character = None
  285. context = shared.settings['context_pygmalion']
  286. name2 = shared.settings['name2_pygmalion']
  287. if Path(f'logs/{shared.character}_persistent.json').exists():
  288. load_history(open(Path(f'logs/{shared.character}_persistent.json'), 'rb').read(), name1, name2)
  289. if shared.args.cai_chat:
  290. return name2, context, generate_chat_html(shared.history['visible'], name1, name2, shared.character)
  291. else:
  292. return name2, context, shared.history['visible']
  293. def upload_character(json_file, img, tavern=False):
  294. json_file = json_file if type(json_file) == str else json_file.decode('utf-8')
  295. data = json.loads(json_file)
  296. outfile_name = data["char_name"]
  297. i = 1
  298. while Path(f'characters/{outfile_name}.json').exists():
  299. outfile_name = f'{data["char_name"]}_{i:03d}'
  300. i += 1
  301. if tavern:
  302. outfile_name = f'TavernAI-{outfile_name}'
  303. with open(Path(f'characters/{outfile_name}.json'), 'w') as f:
  304. f.write(json_file)
  305. if img is not None:
  306. img = Image.open(io.BytesIO(img))
  307. img.save(Path(f'characters/{outfile_name}.png'))
  308. print(f'New character saved to "characters/{outfile_name}.json".')
  309. return outfile_name
  310. def upload_tavern_character(img, name1, name2):
  311. _img = Image.open(io.BytesIO(img))
  312. _img.getexif()
  313. decoded_string = base64.b64decode(_img.info['chara'])
  314. _json = json.loads(decoded_string)
  315. _json = {"char_name": _json['name'], "char_persona": _json['description'], "char_greeting": _json["first_mes"], "example_dialogue": _json['mes_example'], "world_scenario": _json['scenario']}
  316. _json['example_dialogue'] = _json['example_dialogue'].replace('{{user}}', name1).replace('{{char}}', _json['char_name'])
  317. return upload_character(json.dumps(_json), img, tavern=True)
  318. def upload_your_profile_picture(img):
  319. img = Image.open(io.BytesIO(img))
  320. img.save(Path(f'img_me.png'))
  321. print(f'Profile picture saved to "img_me.png"')