|
|
@@ -77,7 +77,7 @@ def load_model(model_name):
|
|
|
t0 = time.time()
|
|
|
|
|
|
# Default settings
|
|
|
- if not (args.cpu or args.load_in_8bit or args.auto_devices or args.disk or args.gpu_memory is not None):
|
|
|
+ if not (args.cpu or args.load_in_8bit or args.auto_devices or args.disk or args.gpu_memory is not None or args.cpu_memory is not None):
|
|
|
if Path(f"torch-dumps/{model_name}.pt").exists():
|
|
|
print("Loading in .pt format...")
|
|
|
model = torch.load(Path(f"torch-dumps/{model_name}.pt"))
|