server.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. import time
  2. import re
  3. import torch
  4. import gradio as gr
  5. import transformers
  6. from transformers import AutoTokenizer
  7. from transformers import GPTJForCausalLM, AutoModelForCausalLM, AutoModelForSeq2SeqLM, OPTForCausalLM, T5Tokenizer, T5ForConditionalGeneration, GPTJModel, AutoModel
  8. #model_name = "bloomz-7b1-p3"
  9. #model_name = 'gpt-j-6B-float16'
  10. #model_name = "opt-6.7b"
  11. #model_name = 'opt-13b'
  12. #model_name = "gpt4chan_model_float16"
  13. model_name = 'galactica-6.7b'
  14. #model_name = 'gpt-neox-20b'
  15. #model_name = 'flan-t5'
  16. #model_name = 'OPT-13B-Erebus'
  17. def load_model(model_name):
  18. print(f"Loading {model_name}")
  19. t0 = time.time()
  20. if model_name in ['gpt-neox-20b', 'opt-13b', 'OPT-13B-Erebus']:
  21. model = AutoModelForCausalLM.from_pretrained(f"models/{model_name}", device_map='auto', load_in_8bit=True)
  22. elif model_name in ['gpt-j-6B']:
  23. model = AutoModelForCausalLM.from_pretrained(f"models/{model_name}", low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda()
  24. elif model_name in ['flan-t5']:
  25. model = T5ForConditionalGeneration.from_pretrained(f"models/{model_name}").cuda()
  26. else:
  27. model = torch.load(f"torch-dumps/{model_name}.pt").cuda()
  28. if model_name in ['gpt4chan_model_float16']:
  29. tokenizer = AutoTokenizer.from_pretrained("models/gpt-j-6B/")
  30. elif model_name in ['flan-t5']:
  31. tokenizer = T5Tokenizer.from_pretrained(f"models/{model_name}/")
  32. else:
  33. tokenizer = AutoTokenizer.from_pretrained(f"models/{model_name}/")
  34. print(f"Loaded the model in {time.time()-t0} seconds.")
  35. return model, tokenizer
  36. def fix_gpt4chan(s):
  37. for i in range(10):
  38. s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s)
  39. s = re.sub("--- [0-9]*\n *\n---", "---", s)
  40. s = re.sub("--- [0-9]*\n\n\n---", "---", s)
  41. return s
  42. def fn(question, temperature, max_length, inference_settings, selected_model):
  43. global model, tokenizer, model_name
  44. if selected_model != model_name:
  45. model_name = selected_model
  46. model = None
  47. tokenier = None
  48. torch.cuda.empty_cache()
  49. model, tokenizer = load_model(model_name)
  50. torch.cuda.empty_cache()
  51. input_text = question
  52. input_ids = tokenizer.encode(str(input_text), return_tensors='pt').cuda()
  53. if inference_settings == 'Default':
  54. output = model.generate(
  55. input_ids,
  56. do_sample=True,
  57. max_new_tokens=max_length,
  58. #max_length=max_length+len(input_ids[0]),
  59. top_p=1,
  60. typical_p=0.3,
  61. temperature=temperature,
  62. ).cuda()
  63. elif inference_settings == 'Verbose':
  64. output = model.generate(
  65. input_ids,
  66. num_beams=10,
  67. min_length=max_length,
  68. max_new_tokens=max_length,
  69. length_penalty =1.4,
  70. no_repeat_ngram_size=2,
  71. early_stopping=True,
  72. temperature=0.7,
  73. top_k=150,
  74. top_p=0.92,
  75. repetition_penalty=4.5,
  76. ).cuda()
  77. reply = tokenizer.decode(output[0], skip_special_tokens=True)
  78. if model_name.startswith('gpt4chan'):
  79. reply = fix_gpt4chan(reply)
  80. return reply
  81. model, tokenizer = load_model(model_name)
  82. if model_name.startswith('gpt4chan'):
  83. default_text = "-----\n--- 865467536\nInput text\n--- 865467537\n"
  84. else:
  85. default_text = "Common sense questions and answers\n\nQuestion: \nFactual answer:"
  86. interface = gr.Interface(
  87. fn,
  88. inputs=[
  89. gr.Textbox(value=default_text, lines=15),
  90. gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Temperature', value=0.7),
  91. gr.Slider(minimum=1, maximum=2000, step=1, label='max_length', value=200),
  92. gr.Dropdown(choices=["Default", "Verbose"], value="Default"),
  93. gr.Dropdown(choices=["gpt4chan_model_float16", "galactica-6.7b", "opt-6.7b", "opt-13b", "gpt-neox-20b", "gpt-j-6B-float16", "flan-t5", "bloomz-7b1-p3", "OPT-13B-Erebus"], value=model_name),
  94. ],
  95. outputs=[
  96. gr.Textbox(placeholder="", lines=15),
  97. ],
  98. title="Text generation lab",
  99. description=f"Generate text using Large Language Models. Currently working with {model_name}",
  100. )
  101. interface.launch(share=False, server_name="0.0.0.0")