text_generation.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. import gc
  2. import re
  3. import time
  4. import numpy as np
  5. import torch
  6. import transformers
  7. from tqdm import tqdm
  8. import modules.shared as shared
  9. from modules.extensions import apply_extensions
  10. from modules.html_generator import generate_4chan_html, generate_basic_html
  11. from modules.models import local_rank
  12. from modules.stopping_criteria import _SentinelTokenStoppingCriteria
  13. def get_max_prompt_length(tokens):
  14. max_length = 2048-tokens
  15. if shared.soft_prompt:
  16. max_length -= shared.soft_prompt_tensor.shape[1]
  17. return max_length
  18. def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
  19. # These models do not have explicit tokenizers for now, so
  20. # we return an estimate for the number of tokens
  21. if shared.is_RWKV or shared.is_LLaMA:
  22. return np.zeros((1, len(prompt)//4))
  23. input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens)
  24. if shared.args.cpu:
  25. return input_ids
  26. elif shared.args.flexgen:
  27. return input_ids.numpy()
  28. elif shared.args.deepspeed:
  29. return input_ids.to(device=local_rank)
  30. else:
  31. return input_ids.cuda()
  32. def decode(output_ids):
  33. reply = shared.tokenizer.decode(output_ids, skip_special_tokens=True)
  34. reply = reply.replace(r'<|endoftext|>', '')
  35. return reply
  36. def generate_softprompt_input_tensors(input_ids):
  37. inputs_embeds = shared.model.transformer.wte(input_ids)
  38. inputs_embeds = torch.cat((shared.soft_prompt_tensor, inputs_embeds), dim=1)
  39. filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(shared.model.device)
  40. #filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
  41. return inputs_embeds, filler_input_ids
  42. # Removes empty replies from gpt4chan outputs
  43. def fix_gpt4chan(s):
  44. for i in range(10):
  45. s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s)
  46. s = re.sub("--- [0-9]*\n *\n---", "---", s)
  47. s = re.sub("--- [0-9]*\n\n\n---", "---", s)
  48. return s
  49. # Fix the LaTeX equations in galactica
  50. def fix_galactica(s):
  51. s = s.replace(r'\[', r'$')
  52. s = s.replace(r'\]', r'$')
  53. s = s.replace(r'\(', r'$')
  54. s = s.replace(r'\)', r'$')
  55. s = s.replace(r'$$', r'$')
  56. s = re.sub(r'\n', r'\n\n', s)
  57. s = re.sub(r"\n{3,}", "\n\n", s)
  58. return s
  59. def formatted_outputs(reply, model_name):
  60. if not (shared.args.chat or shared.args.cai_chat):
  61. if model_name.lower().startswith('galactica'):
  62. reply = fix_galactica(reply)
  63. return reply, reply, generate_basic_html(reply)
  64. elif model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')):
  65. reply = fix_gpt4chan(reply)
  66. return reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply)
  67. else:
  68. return reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply)
  69. else:
  70. return reply
  71. def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None):
  72. gc.collect()
  73. if not shared.args.cpu:
  74. torch.cuda.empty_cache()
  75. t0 = time.time()
  76. # These models are not part of Hugging Face, so we handle them
  77. # separately and terminate the function call earlier
  78. if shared.is_RWKV or shared.is_LLaMA:
  79. if shared.args.no_stream:
  80. reply = shared.model.generate(question, token_count=max_new_tokens, temperature=temperature, top_p=top_p)
  81. t1 = time.time()
  82. print(f"Output generated in {(t1-t0):.2f} seconds.")
  83. yield formatted_outputs(reply, shared.model_name)
  84. else:
  85. for i in tqdm(range(max_new_tokens//8+1)):
  86. reply = shared.model.generate(question, token_count=8, temperature=temperature, top_p=top_p)
  87. yield formatted_outputs(reply, shared.model_name)
  88. question = reply
  89. return
  90. original_question = question
  91. if not (shared.args.chat or shared.args.cai_chat):
  92. question = apply_extensions(question, "input")
  93. if shared.args.verbose:
  94. print(f"\n\n{question}\n--------------------\n")
  95. input_ids = encode(question, max_new_tokens)
  96. cuda = "" if (shared.args.cpu or shared.args.deepspeed or shared.args.flexgen) else ".cuda()"
  97. n = shared.tokenizer.eos_token_id if eos_token is None else encode(eos_token)[0][-1]
  98. if stopping_string is not None:
  99. # The stopping_criteria code below was copied from
  100. # https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py
  101. t = encode(stopping_string, 0, add_special_tokens=False)
  102. stopping_criteria_list = transformers.StoppingCriteriaList([
  103. _SentinelTokenStoppingCriteria(
  104. sentinel_token_ids=t,
  105. starting_idx=len(input_ids[0])
  106. )
  107. ])
  108. else:
  109. stopping_criteria_list = None
  110. if not shared.args.flexgen:
  111. generate_params = [
  112. f"eos_token_id={n}",
  113. f"stopping_criteria=stopping_criteria_list",
  114. f"do_sample={do_sample}",
  115. f"temperature={temperature}",
  116. f"top_p={top_p}",
  117. f"typical_p={typical_p}",
  118. f"repetition_penalty={repetition_penalty}",
  119. f"top_k={top_k}",
  120. f"min_length={min_length if shared.args.no_stream else 0}",
  121. f"no_repeat_ngram_size={no_repeat_ngram_size}",
  122. f"num_beams={num_beams}",
  123. f"penalty_alpha={penalty_alpha}",
  124. f"length_penalty={length_penalty}",
  125. f"early_stopping={early_stopping}",
  126. ]
  127. else:
  128. generate_params = [
  129. f"do_sample={do_sample}",
  130. f"temperature={temperature}",
  131. f"stop={n}",
  132. ]
  133. if shared.args.deepspeed:
  134. generate_params.append("synced_gpus=True")
  135. if shared.args.no_stream:
  136. generate_params.append("max_new_tokens=max_new_tokens")
  137. else:
  138. generate_params.append("max_new_tokens=8")
  139. if shared.soft_prompt:
  140. inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
  141. generate_params.insert(0, "inputs_embeds=inputs_embeds")
  142. generate_params.insert(0, "filler_input_ids")
  143. else:
  144. generate_params.insert(0, "input_ids")
  145. # Generate the entire reply at once
  146. if shared.args.no_stream:
  147. with torch.no_grad():
  148. output = eval(f"shared.model.generate({', '.join(generate_params)}){cuda}")[0]
  149. if shared.soft_prompt:
  150. output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
  151. reply = decode(output)
  152. if not (shared.args.chat or shared.args.cai_chat):
  153. reply = original_question + apply_extensions(reply[len(question):], "output")
  154. t1 = time.time()
  155. print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0)/8:.2f} it/s, {len(output)-len(input_ids[0])} tokens)")
  156. yield formatted_outputs(reply, shared.model_name)
  157. # Generate the reply 8 tokens at a time
  158. else:
  159. yield formatted_outputs(original_question, shared.model_name)
  160. shared.still_streaming = True
  161. for i in tqdm(range(max_new_tokens//8+1)):
  162. with torch.no_grad():
  163. output = eval(f"shared.model.generate({', '.join(generate_params)}){cuda}")[0]
  164. if shared.soft_prompt:
  165. output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
  166. reply = decode(output)
  167. if not (shared.args.chat or shared.args.cai_chat):
  168. reply = original_question + apply_extensions(reply[len(question):], "output")
  169. if not shared.args.flexgen:
  170. if output[-1] == n:
  171. break
  172. input_ids = torch.reshape(output, (1, output.shape[0]))
  173. else:
  174. if np.count_nonzero(input_ids[0] == n) < np.count_nonzero(output == n):
  175. break
  176. input_ids = np.reshape(output, (1, output.shape[0]))
  177. #Mid-stream yield, ran if no breaks
  178. yield formatted_outputs(reply, shared.model_name)
  179. if shared.soft_prompt:
  180. inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
  181. #Stream finished from max tokens or break. Do final yield.
  182. shared.still_streaming = False
  183. yield formatted_outputs(reply, shared.model_name)