Quellcode durchsuchen

Add progress bar for model loading

oobabooga vor 3 Jahren
Ursprung
Commit
7ef7bba6e6
1 geänderte Dateien mit 3 neuen und 1 gelöschten Zeilen
  1. 3 1
      server.py

+ 3 - 1
server.py

@@ -168,6 +168,8 @@ def load_model_wrapper(selected_model):
             torch.cuda.empty_cache()
         model, tokenizer = load_model(model_name)
 
+    return selected_model
+
 def load_preset_values(preset_menu, return_dict=False):
     generate_params = {
         'do_sample': True,
@@ -408,7 +410,7 @@ def create_settings_menus():
                 min_length = gr.Slider(0, 2000, step=1, value=generate_params["min_length"] if args.no_stream else 0, label="min_length", interactive=args.no_stream)
                 early_stopping = gr.Checkbox(value=generate_params["early_stopping"], label="early_stopping")
 
-    model_menu.change(load_model_wrapper, [model_menu], [])
+    model_menu.change(load_model_wrapper, [model_menu], [model_menu], show_progress=True)
     preset_menu.change(load_preset_values, [preset_menu], [do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping])
     return preset_menu, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping