server.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. import re
  2. import time
  3. import glob
  4. import torch
  5. import gradio as gr
  6. import transformers
  7. from transformers import AutoTokenizer
  8. from transformers import GPTJForCausalLM, AutoModelForCausalLM, AutoModelForSeq2SeqLM, OPTForCausalLM, T5Tokenizer, T5ForConditionalGeneration, GPTJModel, AutoModel
  9. #model_name = "bloomz-7b1-p3"
  10. #model_name = 'gpt-j-6B-float16'
  11. #model_name = "opt-6.7b"
  12. #model_name = 'opt-13b'
  13. #model_name = "gpt4chan_model_float16"
  14. model_name = 'galactica-6.7b'
  15. #model_name = 'gpt-neox-20b'
  16. #model_name = 'flan-t5'
  17. #model_name = 'OPT-13B-Erebus'
  18. settings_name = "Default"
  19. def load_model(model_name):
  20. print(f"Loading {model_name}")
  21. t0 = time.time()
  22. if model_name in ['gpt-neox-20b', 'opt-13b', 'OPT-13B-Erebus']:
  23. model = AutoModelForCausalLM.from_pretrained(f"models/{model_name}", device_map='auto', load_in_8bit=True)
  24. elif model_name in ['gpt-j-6B']:
  25. model = AutoModelForCausalLM.from_pretrained(f"models/{model_name}", low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda()
  26. elif model_name in ['flan-t5']:
  27. model = T5ForConditionalGeneration.from_pretrained(f"models/{model_name}").cuda()
  28. else:
  29. model = torch.load(f"torch-dumps/{model_name}.pt").cuda()
  30. if model_name in ['gpt4chan_model_float16']:
  31. tokenizer = AutoTokenizer.from_pretrained("models/gpt-j-6B/")
  32. elif model_name in ['flan-t5']:
  33. tokenizer = T5Tokenizer.from_pretrained(f"models/{model_name}/")
  34. else:
  35. tokenizer = AutoTokenizer.from_pretrained(f"models/{model_name}/")
  36. print(f"Loaded the model in {time.time()-t0} seconds.")
  37. return model, tokenizer
  38. def fix_gpt4chan(s):
  39. for i in range(10):
  40. s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s)
  41. s = re.sub("--- [0-9]*\n *\n---", "---", s)
  42. s = re.sub("--- [0-9]*\n\n\n---", "---", s)
  43. return s
  44. def fn(question, temperature, max_length, inference_settings, selected_model):
  45. global model, tokenizer, model_name, settings_name
  46. if selected_model != model_name:
  47. model_name = selected_model
  48. model = None
  49. tokenier = None
  50. torch.cuda.empty_cache()
  51. model, tokenizer = load_model(model_name)
  52. if inference_settings != settings_name:
  53. with open(f'presets/{inference_settings}.txt', 'r') as infile:
  54. preset = infile.read()
  55. settings_name = inference_settings
  56. torch.cuda.empty_cache()
  57. input_text = question
  58. input_ids = tokenizer.encode(str(input_text), return_tensors='pt').cuda()
  59. output = eval(f"model.generate(input_ids, {preset}).cuda()")
  60. reply = tokenizer.decode(output[0], skip_special_tokens=True)
  61. if model_name.startswith('gpt4chan'):
  62. reply = fix_gpt4chan(reply)
  63. return reply
  64. model, tokenizer = load_model(model_name)
  65. if model_name.startswith('gpt4chan'):
  66. default_text = "-----\n--- 865467536\nInput text\n--- 865467537\n"
  67. else:
  68. default_text = "Common sense questions and answers\n\nQuestion: \nFactual answer:"
  69. interface = gr.Interface(
  70. fn,
  71. inputs=[
  72. gr.Textbox(value=default_text, lines=15),
  73. gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Temperature', value=0.7),
  74. gr.Slider(minimum=1, maximum=2000, step=1, label='max_length', value=200),
  75. gr.Dropdown(choices=list(map(lambda x : x.split('/')[-1].split('.')[0], glob.glob("presets/*.txt"))), value="Default"),
  76. gr.Dropdown(choices=["gpt4chan_model_float16", "galactica-6.7b", "opt-6.7b", "opt-13b", "gpt-neox-20b", "gpt-j-6B-float16", "flan-t5", "bloomz-7b1-p3", "OPT-13B-Erebus"], value=model_name),
  77. ],
  78. outputs=[
  79. gr.Textbox(placeholder="", lines=15),
  80. ],
  81. title="Text generation lab",
  82. description=f"Generate text using Large Language Models. Currently working with {model_name}",
  83. )
  84. interface.launch(share=False, server_name="0.0.0.0")