Selaa lähdekoodia

Load default model with --model flag

oobabooga 3 vuotta sitten
vanhempi
commit
f54a13929f
1 muutettua tiedostoa jossa 24 lisäystä ja 11 poistoa
  1. 24 11
      server.py

+ 24 - 11
server.py

@@ -2,23 +2,19 @@ import os
 import re
 import re
 import time
 import time
 import glob
 import glob
+from sys import exit
 import torch
 import torch
+import argparse
 import gradio as gr
 import gradio as gr
 import transformers
 import transformers
 from transformers import AutoTokenizer
 from transformers import AutoTokenizer
 from transformers import GPTJForCausalLM, AutoModelForCausalLM, AutoModelForSeq2SeqLM, OPTForCausalLM, T5Tokenizer, T5ForConditionalGeneration, GPTJModel, AutoModel
 from transformers import GPTJForCausalLM, AutoModelForCausalLM, AutoModelForSeq2SeqLM, OPTForCausalLM, T5Tokenizer, T5ForConditionalGeneration, GPTJModel, AutoModel
 
 
-#model_name = "bloomz-7b1-p3"
-#model_name = 'gpt-j-6B-float16'
-#model_name = "opt-6.7b"
-#model_name = 'opt-13b'
-model_name = "gpt4chan_model_float16"
-#model_name = 'galactica-6.7b'
-#model_name = 'gpt-neox-20b'
-#model_name = 'flan-t5'
-#model_name = 'OPT-13B-Erebus'
-
+parser = argparse.ArgumentParser()
+parser.add_argument('--model', type=str, help='Name of the model to load by default')
+args = parser.parse_args()
 loaded_preset = None
 loaded_preset = None
+available_models = sorted(set(map(lambda x : x.split('/')[-1].replace('.pt', ''), glob.glob("models/*[!\.][!t][!x][!t]")+ glob.glob("torch-dumps/*[!\.][!t][!x][!t]"))))
 
 
 def load_model(model_name):
 def load_model(model_name):
     print(f"Loading {model_name}...")
     print(f"Loading {model_name}...")
@@ -85,7 +81,24 @@ def generate_reply(question, temperature, max_length, inference_settings, select
 
 
     return reply
     return reply
 
 
+# Choosing the default model
+if args.model is not None:
+    model_name = args.model
+else:
+    if len(available_models == 0):
+        print("No models are available! Please download at least one.")
+        exit(0)
+    elif len(available_models) == 1:
+        i = 0
+    else:
+        print("The following models are available:\n")
+        for i,model in enumerate(available_models):
+            print(f"{i+1}. {model}")
+        print(f"\nWhich one do you want to load? 1-{len(available_models)}\n")
+        i = int(input())-1
+    model_name = available_models[i]
 model, tokenizer = load_model(model_name)
 model, tokenizer = load_model(model_name)
+
 if model_name.startswith('gpt4chan'):
 if model_name.startswith('gpt4chan'):
     default_text = "-----\n--- 865467536\nInput text\n--- 865467537\n"
     default_text = "-----\n--- 865467536\nInput text\n--- 865467537\n"
 else:
 else:
@@ -98,7 +111,7 @@ interface = gr.Interface(
         gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Temperature', value=0.7),
         gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Temperature', value=0.7),
         gr.Slider(minimum=1, maximum=2000, step=1, label='max_length', value=200),
         gr.Slider(minimum=1, maximum=2000, step=1, label='max_length', value=200),
         gr.Dropdown(choices=list(map(lambda x : x.split('/')[-1].split('.')[0], glob.glob("presets/*.txt"))), value="Default"),
         gr.Dropdown(choices=list(map(lambda x : x.split('/')[-1].split('.')[0], glob.glob("presets/*.txt"))), value="Default"),
-        gr.Dropdown(choices=sorted(set(map(lambda x : x.split('/')[-1].replace('.pt', ''), glob.glob("models/*") + glob.glob("torch-dumps/*")))), value=model_name),
+        gr.Dropdown(choices=available_models, value=model_name),
     ],
     ],
     outputs=[
     outputs=[
          gr.Textbox(placeholder="", lines=15),
          gr.Textbox(placeholder="", lines=15),