Explorar el Código

Load default model with --model flag

oobabooga hace 3 años
padre
commit
f54a13929f
Se han modificado 1 ficheros con 24 adiciones y 11 borrados
  1. 24 11
      server.py

+ 24 - 11
server.py

@@ -2,23 +2,19 @@ import os
 import re
 import re
 import time
 import time
 import glob
 import glob
+from sys import exit
 import torch
 import torch
+import argparse
 import gradio as gr
 import gradio as gr
 import transformers
 import transformers
 from transformers import AutoTokenizer
 from transformers import AutoTokenizer
 from transformers import GPTJForCausalLM, AutoModelForCausalLM, AutoModelForSeq2SeqLM, OPTForCausalLM, T5Tokenizer, T5ForConditionalGeneration, GPTJModel, AutoModel
 from transformers import GPTJForCausalLM, AutoModelForCausalLM, AutoModelForSeq2SeqLM, OPTForCausalLM, T5Tokenizer, T5ForConditionalGeneration, GPTJModel, AutoModel
 
 
-#model_name = "bloomz-7b1-p3"
-#model_name = 'gpt-j-6B-float16'
-#model_name = "opt-6.7b"
-#model_name = 'opt-13b'
-model_name = "gpt4chan_model_float16"
-#model_name = 'galactica-6.7b'
-#model_name = 'gpt-neox-20b'
-#model_name = 'flan-t5'
-#model_name = 'OPT-13B-Erebus'
-
+parser = argparse.ArgumentParser()
+parser.add_argument('--model', type=str, help='Name of the model to load by default')
+args = parser.parse_args()
 loaded_preset = None
 loaded_preset = None
+available_models = sorted(set(map(lambda x : x.split('/')[-1].replace('.pt', ''), glob.glob("models/*[!\.][!t][!x][!t]")+ glob.glob("torch-dumps/*[!\.][!t][!x][!t]"))))
 
 
 def load_model(model_name):
 def load_model(model_name):
     print(f"Loading {model_name}...")
     print(f"Loading {model_name}...")
@@ -85,7 +81,24 @@ def generate_reply(question, temperature, max_length, inference_settings, select
 
 
     return reply
     return reply
 
 
+# Choosing the default model
+if args.model is not None:
+    model_name = args.model
+else:
+    if len(available_models == 0):
+        print("No models are available! Please download at least one.")
+        exit(0)
+    elif len(available_models) == 1:
+        i = 0
+    else:
+        print("The following models are available:\n")
+        for i,model in enumerate(available_models):
+            print(f"{i+1}. {model}")
+        print(f"\nWhich one do you want to load? 1-{len(available_models)}\n")
+        i = int(input())-1
+    model_name = available_models[i]
 model, tokenizer = load_model(model_name)
 model, tokenizer = load_model(model_name)
+
 if model_name.startswith('gpt4chan'):
 if model_name.startswith('gpt4chan'):
     default_text = "-----\n--- 865467536\nInput text\n--- 865467537\n"
     default_text = "-----\n--- 865467536\nInput text\n--- 865467537\n"
 else:
 else:
@@ -98,7 +111,7 @@ interface = gr.Interface(
         gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Temperature', value=0.7),
         gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Temperature', value=0.7),
         gr.Slider(minimum=1, maximum=2000, step=1, label='max_length', value=200),
         gr.Slider(minimum=1, maximum=2000, step=1, label='max_length', value=200),
         gr.Dropdown(choices=list(map(lambda x : x.split('/')[-1].split('.')[0], glob.glob("presets/*.txt"))), value="Default"),
         gr.Dropdown(choices=list(map(lambda x : x.split('/')[-1].split('.')[0], glob.glob("presets/*.txt"))), value="Default"),
-        gr.Dropdown(choices=sorted(set(map(lambda x : x.split('/')[-1].replace('.pt', ''), glob.glob("models/*") + glob.glob("torch-dumps/*")))), value=model_name),
+        gr.Dropdown(choices=available_models, value=model_name),
     ],
     ],
     outputs=[
     outputs=[
          gr.Textbox(placeholder="", lines=15),
          gr.Textbox(placeholder="", lines=15),