text_generation.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. import gc
  2. import re
  3. import time
  4. import numpy as np
  5. import torch
  6. import transformers
  7. from tqdm import tqdm
  8. from rwkv.utils import PIPELINE, PIPELINE_ARGS
  9. import modules.shared as shared
  10. from modules.extensions import apply_extensions
  11. from modules.html_generator import generate_4chan_html, generate_basic_html
  12. from modules.models import local_rank
  13. from modules.stopping_criteria import _SentinelTokenStoppingCriteria
  14. def get_max_prompt_length(tokens):
  15. max_length = 2048-tokens
  16. if shared.soft_prompt:
  17. max_length -= shared.soft_prompt_tensor.shape[1]
  18. return max_length
  19. def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
  20. if shared.is_RWKV:
  21. return prompt
  22. input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens)
  23. if shared.args.cpu:
  24. return input_ids
  25. elif shared.args.flexgen:
  26. return input_ids.numpy()
  27. elif shared.args.deepspeed:
  28. return input_ids.to(device=local_rank)
  29. else:
  30. return input_ids.cuda()
  31. def decode(output_ids):
  32. reply = shared.tokenizer.decode(output_ids, skip_special_tokens=True)
  33. reply = reply.replace(r'<|endoftext|>', '')
  34. return reply
  35. def generate_softprompt_input_tensors(input_ids):
  36. inputs_embeds = shared.model.transformer.wte(input_ids)
  37. inputs_embeds = torch.cat((shared.soft_prompt_tensor, inputs_embeds), dim=1)
  38. filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(shared.model.device)
  39. #filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
  40. return inputs_embeds, filler_input_ids
  41. # Removes empty replies from gpt4chan outputs
  42. def fix_gpt4chan(s):
  43. for i in range(10):
  44. s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s)
  45. s = re.sub("--- [0-9]*\n *\n---", "---", s)
  46. s = re.sub("--- [0-9]*\n\n\n---", "---", s)
  47. return s
  48. # Fix the LaTeX equations in galactica
  49. def fix_galactica(s):
  50. s = s.replace(r'\[', r'$')
  51. s = s.replace(r'\]', r'$')
  52. s = s.replace(r'\(', r'$')
  53. s = s.replace(r'\)', r'$')
  54. s = s.replace(r'$$', r'$')
  55. s = re.sub(r'\n', r'\n\n', s)
  56. s = re.sub(r"\n{3,}", "\n\n", s)
  57. return s
  58. def formatted_outputs(reply, model_name):
  59. if not (shared.args.chat or shared.args.cai_chat):
  60. if shared.model_name.lower().startswith('galactica'):
  61. reply = fix_galactica(reply)
  62. return reply, reply, generate_basic_html(reply)
  63. elif shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')):
  64. reply = fix_gpt4chan(reply)
  65. return reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply)
  66. else:
  67. return reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply)
  68. else:
  69. return reply
  70. def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None):
  71. gc.collect()
  72. if not shared.args.cpu:
  73. torch.cuda.empty_cache()
  74. if shared.is_RWKV:
  75. args = PIPELINE_ARGS(temperature = temperature, top_p = top_p,
  76. alpha_frequency = 0.25, # Frequency Penalty (as in GPT-3)
  77. alpha_presence = 0.25, # Presence Penalty (as in GPT-3)
  78. token_ban = [0], # ban the generation of some tokens
  79. token_stop = []) # stop generation whenever you see any token here
  80. if shared.args.no_stream:
  81. reply = question + shared.model.generate(question, token_count=max_new_tokens, args=args, callback=None)
  82. yield formatted_outputs(reply, None)
  83. return formatted_outputs(reply, None)
  84. else:
  85. for i in range(max_new_tokens//8):
  86. reply = question + shared.model.generate(question, token_count=8, args=args, callback=None)
  87. yield formatted_outputs(reply, None)
  88. question = reply
  89. return formatted_outputs(reply, None)
  90. original_question = question
  91. if not (shared.args.chat or shared.args.cai_chat):
  92. question = apply_extensions(question, "input")
  93. if shared.args.verbose:
  94. print(f"\n\n{question}\n--------------------\n")
  95. input_ids = encode(question, max_new_tokens)
  96. cuda = "" if (shared.args.cpu or shared.args.deepspeed or shared.args.flexgen) else ".cuda()"
  97. n = shared.tokenizer.eos_token_id if eos_token is None else encode(eos_token)[0][-1]
  98. if stopping_string is not None:
  99. # The stopping_criteria code below was copied from
  100. # https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py
  101. t = encode(stopping_string, 0, add_special_tokens=False)
  102. stopping_criteria_list = transformers.StoppingCriteriaList([
  103. _SentinelTokenStoppingCriteria(
  104. sentinel_token_ids=t,
  105. starting_idx=len(input_ids[0])
  106. )
  107. ])
  108. else:
  109. stopping_criteria_list = None
  110. if not shared.args.flexgen:
  111. generate_params = [
  112. f"eos_token_id={n}",
  113. f"stopping_criteria=stopping_criteria_list",
  114. f"do_sample={do_sample}",
  115. f"temperature={temperature}",
  116. f"top_p={top_p}",
  117. f"typical_p={typical_p}",
  118. f"repetition_penalty={repetition_penalty}",
  119. f"top_k={top_k}",
  120. f"min_length={min_length if shared.args.no_stream else 0}",
  121. f"no_repeat_ngram_size={no_repeat_ngram_size}",
  122. f"num_beams={num_beams}",
  123. f"penalty_alpha={penalty_alpha}",
  124. f"length_penalty={length_penalty}",
  125. f"early_stopping={early_stopping}",
  126. ]
  127. else:
  128. generate_params = [
  129. f"do_sample={do_sample}",
  130. f"temperature={temperature}",
  131. f"stop={n}",
  132. ]
  133. if shared.args.deepspeed:
  134. generate_params.append("synced_gpus=True")
  135. if shared.args.no_stream:
  136. generate_params.append("max_new_tokens=max_new_tokens")
  137. else:
  138. generate_params.append("max_new_tokens=8")
  139. if shared.soft_prompt:
  140. inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
  141. generate_params.insert(0, "inputs_embeds=inputs_embeds")
  142. generate_params.insert(0, "filler_input_ids")
  143. else:
  144. generate_params.insert(0, "input_ids")
  145. # Generate the entire reply at once
  146. if shared.args.no_stream:
  147. t0 = time.time()
  148. with torch.no_grad():
  149. output = eval(f"shared.model.generate({', '.join(generate_params)}){cuda}")[0]
  150. if shared.soft_prompt:
  151. output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
  152. reply = decode(output)
  153. if not (shared.args.chat or shared.args.cai_chat):
  154. reply = original_question + apply_extensions(reply[len(question):], "output")
  155. yield formatted_outputs(reply, shared.model_name)
  156. t1 = time.time()
  157. print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0)/8:.2f} it/s, {len(output)-len(input_ids[0])} tokens)")
  158. # Generate the reply 8 tokens at a time
  159. else:
  160. yield formatted_outputs(original_question, shared.model_name)
  161. for i in tqdm(range(max_new_tokens//8+1)):
  162. with torch.no_grad():
  163. output = eval(f"shared.model.generate({', '.join(generate_params)}){cuda}")[0]
  164. if shared.soft_prompt:
  165. output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
  166. reply = decode(output)
  167. if not (shared.args.chat or shared.args.cai_chat):
  168. reply = original_question + apply_extensions(reply[len(question):], "output")
  169. yield formatted_outputs(reply, shared.model_name)
  170. if not shared.args.flexgen:
  171. if output[-1] == n:
  172. break
  173. input_ids = torch.reshape(output, (1, output.shape[0]))
  174. else:
  175. if np.count_nonzero(input_ids[0] == n) < np.count_nonzero(output == n):
  176. break
  177. input_ids = np.reshape(output, (1, output.shape[0]))
  178. if shared.soft_prompt:
  179. inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)