server.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. import os
  2. import re
  3. import time
  4. import glob
  5. import torch
  6. import gradio as gr
  7. import transformers
  8. from transformers import AutoTokenizer
  9. from transformers import GPTJForCausalLM, AutoModelForCausalLM, AutoModelForSeq2SeqLM, OPTForCausalLM, T5Tokenizer, T5ForConditionalGeneration, GPTJModel, AutoModel
  10. #model_name = "bloomz-7b1-p3"
  11. #model_name = 'gpt-j-6B-float16'
  12. #model_name = "opt-6.7b"
  13. #model_name = 'opt-13b'
  14. #model_name = "gpt4chan_model_float16"
  15. model_name = 'galactica-6.7b'
  16. #model_name = 'gpt-neox-20b'
  17. #model_name = 'flan-t5'
  18. #model_name = 'OPT-13B-Erebus'
  19. loaded_preset = None
  20. def load_model(model_name):
  21. print(f"Loading {model_name}...")
  22. t0 = time.time()
  23. if os.path.exists(f"torch-dumps/{model_name}.pt"):
  24. print("Loading in .pt format...")
  25. model = torch.load(f"torch-dumps/{model_name}.pt").cuda()
  26. elif model_name.lower().startswith(('gpt-neo', 'opt-')):
  27. model = AutoModelForCausalLM.from_pretrained(f"models/{model_name}", device_map='auto', load_in_8bit=True)
  28. elif model_name in ['gpt-j-6B']:
  29. model = AutoModelForCausalLM.from_pretrained(f"models/{model_name}", low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda()
  30. elif model_name in ['flan-t5', 't5-large']:
  31. model = T5ForConditionalGeneration.from_pretrained(f"models/{model_name}").cuda()
  32. if model_name in ['gpt4chan_model_float16']:
  33. tokenizer = AutoTokenizer.from_pretrained("models/gpt-j-6B/")
  34. elif model_name in ['flan-t5']:
  35. tokenizer = T5Tokenizer.from_pretrained(f"models/{model_name}/")
  36. else:
  37. tokenizer = AutoTokenizer.from_pretrained(f"models/{model_name}/")
  38. print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
  39. return model, tokenizer
  40. # Removes empty replies from gpt4chan outputs
  41. def fix_gpt4chan(s):
  42. for i in range(10):
  43. s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s)
  44. s = re.sub("--- [0-9]*\n *\n---", "---", s)
  45. s = re.sub("--- [0-9]*\n\n\n---", "---", s)
  46. return s
  47. def generate_reply(question, temperature, max_length, inference_settings, selected_model):
  48. global model, tokenizer, model_name, loaded_preset, preset
  49. if selected_model != model_name:
  50. model_name = selected_model
  51. model = None
  52. tokenier = None
  53. torch.cuda.empty_cache()
  54. model, tokenizer = load_model(model_name)
  55. if inference_settings != loaded_preset:
  56. with open(f'presets/{inference_settings}.txt', 'r') as infile:
  57. preset = infile.read()
  58. loaded_preset = inference_settings
  59. torch.cuda.empty_cache()
  60. input_text = question
  61. input_ids = tokenizer.encode(str(input_text), return_tensors='pt').cuda()
  62. output = eval(f"model.generate(input_ids, {preset}).cuda()")
  63. reply = tokenizer.decode(output[0], skip_special_tokens=True)
  64. if model_name.startswith('gpt4chan'):
  65. reply = fix_gpt4chan(reply)
  66. return reply
  67. model, tokenizer = load_model(model_name)
  68. if model_name.startswith('gpt4chan'):
  69. default_text = "-----\n--- 865467536\nInput text\n--- 865467537\n"
  70. else:
  71. default_text = "Common sense questions and answers\n\nQuestion: \nFactual answer:"
  72. interface = gr.Interface(
  73. generate_reply,
  74. inputs=[
  75. gr.Textbox(value=default_text, lines=15),
  76. gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Temperature', value=0.7),
  77. gr.Slider(minimum=1, maximum=2000, step=1, label='max_length', value=200),
  78. gr.Dropdown(choices=list(map(lambda x : x.split('/')[-1].split('.')[0], glob.glob("presets/*.txt"))), value="Default"),
  79. gr.Dropdown(choices=sorted(set(map(lambda x : x.split('/')[-1].replace('.pt', ''), glob.glob("models/*") + glob.glob("torch-dumps/*")))), value=model_name),
  80. ],
  81. outputs=[
  82. gr.Textbox(placeholder="", lines=15),
  83. ],
  84. title="Text generation lab",
  85. description=f"Generate text using Large Language Models.",
  86. )
  87. interface.launch(share=False, server_name="0.0.0.0")