server.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. import os
  2. import re
  3. import time
  4. import glob
  5. import torch
  6. import gradio as gr
  7. import transformers
  8. from transformers import AutoTokenizer
  9. from transformers import GPTJForCausalLM, AutoModelForCausalLM, AutoModelForSeq2SeqLM, OPTForCausalLM, T5Tokenizer, T5ForConditionalGeneration, GPTJModel, AutoModel
  10. #model_name = "bloomz-7b1-p3"
  11. #model_name = 'gpt-j-6B-float16'
  12. #model_name = "opt-6.7b"
  13. #model_name = 'opt-13b'
  14. #model_name = "gpt4chan_model_float16"
  15. model_name = 'galactica-6.7b'
  16. #model_name = 'gpt-neox-20b'
  17. #model_name = 'flan-t5'
  18. #model_name = 'OPT-13B-Erebus'
  19. loaded_preset = None
  20. def load_model(model_name):
  21. print(f"Loading {model_name}...")
  22. t0 = time.time()
  23. if os.path.exists(f"torch-dumps/{model_name}.pt"):
  24. print("Loading in .pt format...")
  25. model = torch.load(f"torch-dumps/{model_name}.pt").cuda()
  26. elif model_name in ['gpt-neox-20b', 'opt-13b', 'OPT-13B-Erebus']:
  27. model = AutoModelForCausalLM.from_pretrained(f"models/{model_name}", device_map='auto', load_in_8bit=True)
  28. elif model_name in ['gpt-j-6B']:
  29. model = AutoModelForCausalLM.from_pretrained(f"models/{model_name}", low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda()
  30. elif model_name in ['flan-t5', 't5-large']:
  31. model = T5ForConditionalGeneration.from_pretrained(f"models/{model_name}").cuda()
  32. if model_name in ['gpt4chan_model_float16']:
  33. tokenizer = AutoTokenizer.from_pretrained("models/gpt-j-6B/")
  34. elif model_name in ['flan-t5']:
  35. tokenizer = T5Tokenizer.from_pretrained(f"models/{model_name}/")
  36. else:
  37. tokenizer = AutoTokenizer.from_pretrained(f"models/{model_name}/")
  38. print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
  39. return model, tokenizer
  40. def fix_gpt4chan(s):
  41. for i in range(10):
  42. s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s)
  43. s = re.sub("--- [0-9]*\n *\n---", "---", s)
  44. s = re.sub("--- [0-9]*\n\n\n---", "---", s)
  45. return s
  46. def fn(question, temperature, max_length, inference_settings, selected_model):
  47. global model, tokenizer, model_name, loaded_preset, preset
  48. if selected_model != model_name:
  49. model_name = selected_model
  50. model = None
  51. tokenier = None
  52. torch.cuda.empty_cache()
  53. model, tokenizer = load_model(model_name)
  54. if inference_settings != loaded_preset:
  55. with open(f'presets/{inference_settings}.txt', 'r') as infile:
  56. preset = infile.read()
  57. loaded_preset = inference_settings
  58. torch.cuda.empty_cache()
  59. input_text = question
  60. input_ids = tokenizer.encode(str(input_text), return_tensors='pt').cuda()
  61. output = eval(f"model.generate(input_ids, {preset}).cuda()")
  62. reply = tokenizer.decode(output[0], skip_special_tokens=True)
  63. if model_name.startswith('gpt4chan'):
  64. reply = fix_gpt4chan(reply)
  65. return reply
  66. model, tokenizer = load_model(model_name)
  67. if model_name.startswith('gpt4chan'):
  68. default_text = "-----\n--- 865467536\nInput text\n--- 865467537\n"
  69. else:
  70. default_text = "Common sense questions and answers\n\nQuestion: \nFactual answer:"
  71. interface = gr.Interface(
  72. fn,
  73. inputs=[
  74. gr.Textbox(value=default_text, lines=15),
  75. gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Temperature', value=0.7),
  76. gr.Slider(minimum=1, maximum=2000, step=1, label='max_length', value=200),
  77. gr.Dropdown(choices=list(map(lambda x : x.split('/')[-1].split('.')[0], glob.glob("presets/*.txt"))), value="Default"),
  78. gr.Dropdown(choices=sorted(set(map(lambda x : x.split('/')[-1].replace('.pt', ''), glob.glob("models/*") + glob.glob("torch-dumps/*")))), value=model_name),
  79. ],
  80. outputs=[
  81. gr.Textbox(placeholder="", lines=15),
  82. ],
  83. title="Text generation lab",
  84. description=f"Generate text using Large Language Models. Currently working with {model_name}",
  85. )
  86. interface.launch(share=False, server_name="0.0.0.0")