text_generation.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. import random
  2. import re
  3. import time
  4. import traceback
  5. import numpy as np
  6. import torch
  7. import transformers
  8. import modules.shared as shared
  9. from modules.callbacks import (Iteratorize, Stream,
  10. _SentinelTokenStoppingCriteria)
  11. from modules.extensions import apply_extensions
  12. from modules.html_generator import generate_4chan_html, generate_basic_html
  13. from modules.models import clear_torch_cache, local_rank
  14. def get_max_prompt_length(state):
  15. max_length = state['truncation_length'] - state['max_new_tokens']
  16. if shared.soft_prompt:
  17. max_length -= shared.soft_prompt_tensor.shape[1]
  18. return max_length
  19. def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None):
  20. if any((shared.is_RWKV, shared.is_llamacpp)):
  21. input_ids = shared.tokenizer.encode(str(prompt))
  22. input_ids = np.array(input_ids).reshape(1, len(input_ids))
  23. return input_ids
  24. else:
  25. input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', add_special_tokens=add_special_tokens)
  26. # This is a hack for making replies more creative.
  27. if not add_bos_token and input_ids[0][0] == shared.tokenizer.bos_token_id:
  28. input_ids = input_ids[:, 1:]
  29. # Llama adds this extra token when the first character is '\n', and this
  30. # compromises the stopping criteria, so we just remove it
  31. if type(shared.tokenizer) is transformers.LlamaTokenizer and input_ids[0][0] == 29871:
  32. input_ids = input_ids[:, 1:]
  33. # Handling truncation
  34. if truncation_length is not None:
  35. input_ids = input_ids[:, -truncation_length:]
  36. if any((shared.is_RWKV, shared.is_llamacpp, shared.args.cpu)):
  37. return input_ids
  38. elif shared.args.flexgen:
  39. return input_ids.numpy()
  40. elif shared.args.deepspeed:
  41. return input_ids.to(device=local_rank)
  42. elif torch.has_mps:
  43. device = torch.device('mps')
  44. return input_ids.to(device)
  45. else:
  46. return input_ids.cuda()
  47. def decode(output_ids):
  48. # Open Assistant relies on special tokens like <|endoftext|>
  49. if re.match('.*(oasst|galactica)-*', shared.model_name.lower()):
  50. return shared.tokenizer.decode(output_ids, skip_special_tokens=False)
  51. else:
  52. reply = shared.tokenizer.decode(output_ids, skip_special_tokens=True)
  53. reply = reply.replace(r'<|endoftext|>', '')
  54. return reply
  55. def generate_softprompt_input_tensors(input_ids):
  56. inputs_embeds = shared.model.transformer.wte(input_ids)
  57. inputs_embeds = torch.cat((shared.soft_prompt_tensor, inputs_embeds), dim=1)
  58. filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(shared.model.device)
  59. # filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
  60. return inputs_embeds, filler_input_ids
  61. # Removes empty replies from gpt4chan outputs
  62. def fix_gpt4chan(s):
  63. for i in range(10):
  64. s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s)
  65. s = re.sub("--- [0-9]*\n *\n---", "---", s)
  66. s = re.sub("--- [0-9]*\n\n\n---", "---", s)
  67. return s
  68. # Fix the LaTeX equations in galactica
  69. def fix_galactica(s):
  70. s = s.replace(r'\[', r'$')
  71. s = s.replace(r'\]', r'$')
  72. s = s.replace(r'\(', r'$')
  73. s = s.replace(r'\)', r'$')
  74. s = s.replace(r'$$', r'$')
  75. s = re.sub(r'\n', r'\n\n', s)
  76. s = re.sub(r"\n{3,}", "\n\n", s)
  77. return s
  78. def formatted_outputs(reply, model_name):
  79. if not shared.is_chat():
  80. if 'galactica' in model_name.lower():
  81. reply = fix_galactica(reply)
  82. return reply, reply, generate_basic_html(reply)
  83. elif any((k in shared.model_name.lower() for k in ['gpt4chan', 'gpt-4chan'])):
  84. reply = fix_gpt4chan(reply)
  85. return reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply)
  86. else:
  87. return reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply)
  88. else:
  89. return reply
  90. def set_manual_seed(seed):
  91. seed = int(seed)
  92. if seed == -1:
  93. seed = random.randint(1, 2**31)
  94. torch.manual_seed(seed)
  95. if torch.cuda.is_available():
  96. torch.cuda.manual_seed_all(seed)
  97. return seed
  98. def stop_everything_event():
  99. shared.stop_everything = True
  100. def generate_reply(question, state, eos_token=None, stopping_strings=[]):
  101. clear_torch_cache()
  102. seed = set_manual_seed(state['seed'])
  103. shared.stop_everything = False
  104. generate_params = {}
  105. t0 = time.time()
  106. original_question = question
  107. if not shared.is_chat():
  108. question = apply_extensions(question, 'input')
  109. # These models are not part of Hugging Face, so we handle them
  110. # separately and terminate the function call earlier
  111. if any((shared.is_RWKV, shared.is_llamacpp)):
  112. if shared.args.verbose:
  113. print(f'\n\n{question}\n--------------------\n')
  114. for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']:
  115. generate_params[k] = state[k]
  116. generate_params['token_count'] = state['max_new_tokens']
  117. try:
  118. if shared.args.no_stream:
  119. reply = shared.model.generate(context=question, **generate_params)
  120. output = original_question + reply
  121. if not shared.is_chat():
  122. reply = original_question + apply_extensions(reply, 'output')
  123. yield formatted_outputs(reply, shared.model_name)
  124. else:
  125. if not shared.is_chat():
  126. yield formatted_outputs(question, shared.model_name)
  127. # RWKV has proper streaming, which is very nice.
  128. # No need to generate 8 tokens at a time.
  129. for reply in shared.model.generate_with_streaming(context=question, **generate_params):
  130. output = original_question + reply
  131. if not shared.is_chat():
  132. reply = original_question + apply_extensions(reply, 'output')
  133. yield formatted_outputs(reply, shared.model_name)
  134. except Exception:
  135. traceback.print_exc()
  136. finally:
  137. t1 = time.time()
  138. original_tokens = len(encode(original_question)[0])
  139. new_tokens = len(encode(output)[0]) - original_tokens
  140. print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
  141. return
  142. input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))
  143. original_input_ids = input_ids
  144. output = input_ids[0]
  145. if shared.args.verbose:
  146. print(f'\n\n{decode(input_ids[0])}\n--------------------\n')
  147. cuda = not any((shared.args.cpu, shared.args.deepspeed, shared.args.flexgen))
  148. eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
  149. if eos_token is not None:
  150. eos_token_ids.append(int(encode(eos_token)[0][-1]))
  151. # Handling the stopping strings
  152. stopping_criteria_list = transformers.StoppingCriteriaList()
  153. for st in [stopping_strings, state['custom_stopping_strings']]:
  154. if type(st) is list and len(st) > 0:
  155. sentinel_token_ids = [encode(string, add_special_tokens=False) for string in st]
  156. stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=sentinel_token_ids, starting_idx=len(input_ids[0])))
  157. break
  158. if not shared.args.flexgen:
  159. for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']:
  160. generate_params[k] = state[k]
  161. generate_params['eos_token_id'] = eos_token_ids
  162. generate_params['stopping_criteria'] = stopping_criteria_list
  163. if state['ban_eos_token']:
  164. generate_params['suppress_tokens'] = [shared.tokenizer.eos_token_id]
  165. else:
  166. for k in ['max_new_tokens', 'do_sample', 'temperature']:
  167. generate_params[k] = state[k]
  168. generate_params['stop'] = state['eos_token_ids'][-1]
  169. if not shared.args.no_stream:
  170. generate_params['max_new_tokens'] = 8
  171. if shared.args.no_cache:
  172. generate_params.update({'use_cache': False})
  173. if shared.args.deepspeed:
  174. generate_params.update({'synced_gpus': True})
  175. if shared.soft_prompt:
  176. inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
  177. generate_params.update({'inputs_embeds': inputs_embeds})
  178. generate_params.update({'inputs': filler_input_ids})
  179. else:
  180. generate_params.update({'inputs': input_ids})
  181. try:
  182. # Generate the entire reply at once.
  183. if shared.args.no_stream:
  184. with torch.no_grad():
  185. output = shared.model.generate(**generate_params)[0]
  186. if cuda:
  187. output = output.cuda()
  188. if shared.soft_prompt:
  189. output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
  190. new_tokens = len(output) - len(input_ids[0])
  191. reply = decode(output[-new_tokens:])
  192. if not shared.is_chat():
  193. reply = original_question + apply_extensions(reply, 'output')
  194. yield formatted_outputs(reply, shared.model_name)
  195. # Stream the reply 1 token at a time.
  196. # This is based on the trick of using 'stopping_criteria' to create an iterator.
  197. elif not shared.args.flexgen:
  198. def generate_with_callback(callback=None, **kwargs):
  199. kwargs['stopping_criteria'].append(Stream(callback_func=callback))
  200. clear_torch_cache()
  201. with torch.no_grad():
  202. shared.model.generate(**kwargs)
  203. def generate_with_streaming(**kwargs):
  204. return Iteratorize(generate_with_callback, kwargs, callback=None)
  205. if not shared.is_chat():
  206. yield formatted_outputs(original_question, shared.model_name)
  207. with generate_with_streaming(**generate_params) as generator:
  208. for output in generator:
  209. if shared.soft_prompt:
  210. output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
  211. new_tokens = len(output) - len(input_ids[0])
  212. reply = decode(output[-new_tokens:])
  213. if not shared.is_chat():
  214. reply = original_question + apply_extensions(reply, 'output')
  215. if output[-1] in eos_token_ids:
  216. break
  217. yield formatted_outputs(reply, shared.model_name)
  218. # Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
  219. else:
  220. for i in range(state['max_new_tokens'] // 8 + 1):
  221. clear_torch_cache()
  222. with torch.no_grad():
  223. output = shared.model.generate(**generate_params)[0]
  224. if shared.soft_prompt:
  225. output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
  226. new_tokens = len(output) - len(original_input_ids[0])
  227. reply = decode(output[-new_tokens:])
  228. if not shared.is_chat():
  229. reply = original_question + apply_extensions(reply, 'output')
  230. if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
  231. break
  232. yield formatted_outputs(reply, shared.model_name)
  233. input_ids = np.reshape(output, (1, output.shape[0]))
  234. if shared.soft_prompt:
  235. inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
  236. generate_params.update({'inputs_embeds': inputs_embeds})
  237. generate_params.update({'inputs': filler_input_ids})
  238. else:
  239. generate_params.update({'inputs': input_ids})
  240. yield formatted_outputs(reply, shared.model_name)
  241. except Exception:
  242. traceback.print_exc()
  243. finally:
  244. t1 = time.time()
  245. original_tokens = len(original_input_ids[0])
  246. new_tokens = len(output) - original_tokens
  247. print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
  248. return