Parcourir la source

Don't show .pt models in the list

oobabooga il y a 2 ans
Parent
commit
9849aac0f1
2 fichiers modifiés avec 4 ajouts et 1 suppressions
  1. 3 0
      modules/models.py
  2. 1 1
      server.py

+ 3 - 0
modules/models.py

@@ -105,6 +105,9 @@ def load_model(model_name):
         if not Path(f"models/{pt_model}").exists():
         if not Path(f"models/{pt_model}").exists():
             print(f"Could not find models/{pt_model}, exiting...")
             print(f"Could not find models/{pt_model}, exiting...")
             exit()
             exit()
+        elif pt_model == '':
+            print(f"Could not find the .pt model for {model_name}, exiting...")
+            exit()
 
 
         model = load_quant(path_to_model, Path(f"models/{pt_model}"), 4)
         model = load_quant(path_to_model, Path(f"models/{pt_model}"), 4)
         model = model.to(torch.device('cuda:0'))
         model = model.to(torch.device('cuda:0'))

+ 1 - 1
server.py

@@ -37,7 +37,7 @@ def get_available_models():
     if shared.args.flexgen:
     if shared.args.flexgen:
         return sorted([re.sub('-np$', '', item.name) for item in list(Path('models/').glob('*')) if item.name.endswith('-np')], key=str.lower)
         return sorted([re.sub('-np$', '', item.name) for item in list(Path('models/').glob('*')) if item.name.endswith('-np')], key=str.lower)
     else:
     else:
-        return sorted([item.name for item in list(Path('models/').glob('*')) if not item.name.endswith(('.txt', '-np'))], key=str.lower)
+        return sorted([item.name for item in list(Path('models/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt'))], key=str.lower)
 
 
 def get_available_presets():
 def get_available_presets():
     return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower)
     return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower)