text_generation.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. import gc
  2. import re
  3. import time
  4. import numpy as np
  5. import torch
  6. import transformers
  7. import modules.shared as shared
  8. from modules.callbacks import (Iteratorize, Stream,
  9. _SentinelTokenStoppingCriteria)
  10. from modules.extensions import apply_extensions
  11. from modules.html_generator import generate_4chan_html, generate_basic_html
  12. from modules.models import local_rank
  13. def get_max_prompt_length(tokens):
  14. max_length = 2048-tokens
  15. if shared.soft_prompt:
  16. max_length -= shared.soft_prompt_tensor.shape[1]
  17. return max_length
  18. def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
  19. if shared.is_RWKV:
  20. input_ids = shared.tokenizer.encode(str(prompt))
  21. input_ids = np.array(input_ids).reshape(1, len(input_ids))
  22. return input_ids
  23. else:
  24. input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens)
  25. if shared.args.cpu:
  26. return input_ids
  27. elif shared.args.flexgen:
  28. return input_ids.numpy()
  29. elif shared.args.deepspeed:
  30. return input_ids.to(device=local_rank)
  31. else:
  32. return input_ids.cuda()
  33. def decode(output_ids):
  34. reply = shared.tokenizer.decode(output_ids, skip_special_tokens=True)
  35. reply = reply.replace(r'<|endoftext|>', '')
  36. return reply
  37. def generate_softprompt_input_tensors(input_ids):
  38. inputs_embeds = shared.model.transformer.wte(input_ids)
  39. inputs_embeds = torch.cat((shared.soft_prompt_tensor, inputs_embeds), dim=1)
  40. filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(shared.model.device)
  41. #filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
  42. return inputs_embeds, filler_input_ids
  43. # Removes empty replies from gpt4chan outputs
  44. def fix_gpt4chan(s):
  45. for i in range(10):
  46. s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s)
  47. s = re.sub("--- [0-9]*\n *\n---", "---", s)
  48. s = re.sub("--- [0-9]*\n\n\n---", "---", s)
  49. return s
  50. # Fix the LaTeX equations in galactica
  51. def fix_galactica(s):
  52. s = s.replace(r'\[', r'$')
  53. s = s.replace(r'\]', r'$')
  54. s = s.replace(r'\(', r'$')
  55. s = s.replace(r'\)', r'$')
  56. s = s.replace(r'$$', r'$')
  57. s = re.sub(r'\n', r'\n\n', s)
  58. s = re.sub(r"\n{3,}", "\n\n", s)
  59. return s
  60. def formatted_outputs(reply, model_name):
  61. if not (shared.args.chat or shared.args.cai_chat):
  62. if model_name.lower().startswith('galactica'):
  63. reply = fix_galactica(reply)
  64. return reply, reply, generate_basic_html(reply)
  65. elif model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')):
  66. reply = fix_gpt4chan(reply)
  67. return reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply)
  68. else:
  69. return reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply)
  70. else:
  71. return reply
  72. def clear_torch_cache():
  73. gc.collect()
  74. if not shared.args.cpu:
  75. torch.cuda.empty_cache()
  76. def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None):
  77. clear_torch_cache()
  78. t0 = time.time()
  79. # These models are not part of Hugging Face, so we handle them
  80. # separately and terminate the function call earlier
  81. if shared.is_RWKV:
  82. if shared.args.no_stream:
  83. reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k)
  84. yield formatted_outputs(reply, shared.model_name)
  85. else:
  86. yield formatted_outputs(question, shared.model_name)
  87. # RWKV has proper streaming, which is very nice.
  88. # No need to generate 8 tokens at a time.
  89. for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k):
  90. yield formatted_outputs(reply, shared.model_name)
  91. t1 = time.time()
  92. output = encode(reply)[0]
  93. input_ids = encode(question)
  94. print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(input_ids[0])} tokens)")
  95. return
  96. original_question = question
  97. if not (shared.args.chat or shared.args.cai_chat):
  98. question = apply_extensions(question, "input")
  99. if shared.args.verbose:
  100. print(f"\n\n{question}\n--------------------\n")
  101. input_ids = encode(question, max_new_tokens)
  102. original_input_ids = input_ids
  103. cuda = "" if (shared.args.cpu or shared.args.deepspeed or shared.args.flexgen) else ".cuda()"
  104. n = shared.tokenizer.eos_token_id if eos_token is None else int(encode(eos_token)[0][-1])
  105. if stopping_string is not None:
  106. # The stopping_criteria code below was copied from
  107. # https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py
  108. t = encode(stopping_string, 0, add_special_tokens=False)
  109. stopping_criteria_list = transformers.StoppingCriteriaList([
  110. _SentinelTokenStoppingCriteria(
  111. sentinel_token_ids=t,
  112. starting_idx=len(input_ids[0])
  113. )
  114. ])
  115. else:
  116. stopping_criteria_list = []
  117. if not shared.args.flexgen:
  118. generate_params = [
  119. f"max_new_tokens=max_new_tokens",
  120. f"eos_token_id={n}",
  121. f"stopping_criteria=stopping_criteria_list",
  122. f"do_sample={do_sample}",
  123. f"temperature={temperature}",
  124. f"top_p={top_p}",
  125. f"typical_p={typical_p}",
  126. f"repetition_penalty={repetition_penalty}",
  127. f"top_k={top_k}",
  128. f"min_length={min_length if shared.args.no_stream else 0}",
  129. f"no_repeat_ngram_size={no_repeat_ngram_size}",
  130. f"num_beams={num_beams}",
  131. f"penalty_alpha={penalty_alpha}",
  132. f"length_penalty={length_penalty}",
  133. f"early_stopping={early_stopping}",
  134. ]
  135. else:
  136. generate_params = [
  137. f"max_new_tokens={max_new_tokens if shared.args.no_stream else 8}",
  138. f"do_sample={do_sample}",
  139. f"temperature={temperature}",
  140. f"stop={n}",
  141. ]
  142. if shared.args.deepspeed:
  143. generate_params.append("synced_gpus=True")
  144. if shared.soft_prompt:
  145. inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
  146. generate_params.insert(0, "inputs_embeds=inputs_embeds")
  147. generate_params.insert(0, "inputs=filler_input_ids")
  148. else:
  149. generate_params.insert(0, "inputs=input_ids")
  150. # Generate the entire reply at once.
  151. if shared.args.no_stream:
  152. with torch.no_grad():
  153. output = eval(f"shared.model.generate({', '.join(generate_params)}){cuda}")[0]
  154. if shared.soft_prompt:
  155. output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
  156. reply = decode(output)
  157. if not (shared.args.chat or shared.args.cai_chat):
  158. reply = original_question + apply_extensions(reply[len(question):], "output")
  159. yield formatted_outputs(reply, shared.model_name)
  160. # Stream the reply 1 token at a time.
  161. # This is based on the trick of using 'stopping_criteria' to create an iterator.
  162. elif not shared.args.flexgen:
  163. def generate_with_callback(callback=None, **kwargs):
  164. if 'stopping_criteria' not in kwargs:
  165. kwargs['stopping_criteria'] = []
  166. kwargs['stopping_criteria'].append(Stream(callback_func=callback))
  167. shared.model.generate(**kwargs)[0]
  168. def generate_with_streaming(**kwargs):
  169. return Iteratorize(generate_with_callback, kwargs, callback=None)
  170. yield formatted_outputs(original_question, shared.model_name)
  171. for output in eval(f"generate_with_streaming({', '.join(generate_params)})"):
  172. if shared.soft_prompt:
  173. output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
  174. reply = decode(output)
  175. if not (shared.args.chat or shared.args.cai_chat):
  176. reply = original_question + apply_extensions(reply[len(question):], "output")
  177. yield formatted_outputs(reply, shared.model_name)
  178. if not shared.args.flexgen:
  179. if output[-1] == n:
  180. break
  181. else:
  182. if np.count_nonzero(input_ids[0] == n) < np.count_nonzero(output == n):
  183. break
  184. # Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
  185. else:
  186. for i in range(max_new_tokens//8+1):
  187. clear_torch_cache()
  188. with torch.no_grad():
  189. output = eval(f"shared.model.generate({', '.join(generate_params)})")[0]
  190. if shared.soft_prompt:
  191. output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
  192. reply = decode(output)
  193. if not (shared.args.chat or shared.args.cai_chat):
  194. reply = original_question + apply_extensions(reply[len(question):], "output")
  195. yield formatted_outputs(reply, shared.model_name)
  196. if not shared.args.flexgen:
  197. if output[-1] == n:
  198. break
  199. input_ids = torch.reshape(output, (1, output.shape[0]))
  200. else:
  201. if np.count_nonzero(input_ids[0] == n) < np.count_nonzero(output == n):
  202. break
  203. input_ids = np.reshape(output, (1, output.shape[0]))
  204. if shared.soft_prompt:
  205. inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
  206. t1 = time.time()
  207. print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(original_input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(original_input_ids[0])} tokens)")
  208. return