server.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import os
  2. import re
  3. import time
  4. import glob
  5. import torch
  6. import gradio as gr
  7. import transformers
  8. from transformers import AutoTokenizer
  9. from transformers import GPTJForCausalLM, AutoModelForCausalLM, AutoModelForSeq2SeqLM, OPTForCausalLM, T5Tokenizer, T5ForConditionalGeneration, GPTJModel, AutoModel
  10. #model_name = "bloomz-7b1-p3"
  11. #model_name = 'gpt-j-6B-float16'
  12. #model_name = "opt-6.7b"
  13. #model_name = 'opt-13b'
  14. model_name = "gpt4chan_model_float16"
  15. #model_name = 'galactica-6.7b'
  16. #model_name = 'gpt-neox-20b'
  17. #model_name = 'flan-t5'
  18. #model_name = 'OPT-13B-Erebus'
  19. loaded_preset = None
  20. def load_model(model_name):
  21. print(f"Loading {model_name}...")
  22. t0 = time.time()
  23. # Loading the model
  24. if os.path.exists(f"torch-dumps/{model_name}.pt"):
  25. print("Loading in .pt format...")
  26. model = torch.load(f"torch-dumps/{model_name}.pt").cuda()
  27. elif model_name.lower().startswith(('gpt-neo', 'opt-', 'galactica')):
  28. if any(size in model_name for size in ('13b', '20b', '30b')):
  29. model = AutoModelForCausalLM.from_pretrained(f"models/{model_name}", device_map='auto', load_in_8bit=True)
  30. else:
  31. model = AutoModelForCausalLM.from_pretrained(f"models/{model_name}", low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda()
  32. elif model_name in ['gpt-j-6B']:
  33. model = AutoModelForCausalLM.from_pretrained(f"models/{model_name}", low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda()
  34. elif model_name in ['flan-t5', 't5-large']:
  35. model = T5ForConditionalGeneration.from_pretrained(f"models/{model_name}").cuda()
  36. else:
  37. model = AutoModelForCausalLM.from_pretrained(f"models/{model_name}", low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda()
  38. # Loading the tokenizer
  39. if model_name.startswith('gpt4chan'):
  40. tokenizer = AutoTokenizer.from_pretrained("models/gpt-j-6B/")
  41. elif model_name in ['flan-t5']:
  42. tokenizer = T5Tokenizer.from_pretrained(f"models/{model_name}/")
  43. else:
  44. tokenizer = AutoTokenizer.from_pretrained(f"models/{model_name}/")
  45. print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
  46. return model, tokenizer
  47. # Removes empty replies from gpt4chan outputs
  48. def fix_gpt4chan(s):
  49. for i in range(10):
  50. s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s)
  51. s = re.sub("--- [0-9]*\n *\n---", "---", s)
  52. s = re.sub("--- [0-9]*\n\n\n---", "---", s)
  53. return s
  54. def generate_reply(question, temperature, max_length, inference_settings, selected_model):
  55. global model, tokenizer, model_name, loaded_preset, preset
  56. if selected_model != model_name:
  57. model_name = selected_model
  58. model = None
  59. tokenier = None
  60. torch.cuda.empty_cache()
  61. model, tokenizer = load_model(model_name)
  62. if inference_settings != loaded_preset:
  63. with open(f'presets/{inference_settings}.txt', 'r') as infile:
  64. preset = infile.read()
  65. loaded_preset = inference_settings
  66. torch.cuda.empty_cache()
  67. input_text = question
  68. input_ids = tokenizer.encode(str(input_text), return_tensors='pt').cuda()
  69. output = eval(f"model.generate(input_ids, {preset}).cuda()")
  70. reply = tokenizer.decode(output[0], skip_special_tokens=True)
  71. if model_name.startswith('gpt4chan'):
  72. reply = fix_gpt4chan(reply)
  73. return reply
  74. model, tokenizer = load_model(model_name)
  75. if model_name.startswith('gpt4chan'):
  76. default_text = "-----\n--- 865467536\nInput text\n--- 865467537\n"
  77. else:
  78. default_text = "Common sense questions and answers\n\nQuestion: \nFactual answer:"
  79. interface = gr.Interface(
  80. generate_reply,
  81. inputs=[
  82. gr.Textbox(value=default_text, lines=15),
  83. gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Temperature', value=0.7),
  84. gr.Slider(minimum=1, maximum=2000, step=1, label='max_length', value=200),
  85. gr.Dropdown(choices=list(map(lambda x : x.split('/')[-1].split('.')[0], glob.glob("presets/*.txt"))), value="Default"),
  86. gr.Dropdown(choices=sorted(set(map(lambda x : x.split('/')[-1].replace('.pt', ''), glob.glob("models/*") + glob.glob("torch-dumps/*")))), value=model_name),
  87. ],
  88. outputs=[
  89. gr.Textbox(placeholder="", lines=15),
  90. ],
  91. title="Text generation lab",
  92. description=f"Generate text using Large Language Models.",
  93. )
  94. interface.launch(share=False, server_name="0.0.0.0")