Ver código fonte

Merge branch 'oobabooga:main' into ds

81300 3 anos atrás
pai
commit
248ec4fa21
1 arquivos alterados com 2 adições e 2 exclusões
  1. 2 2
      server.py

+ 2 - 2
server.py

@@ -200,7 +200,7 @@ def load_model(model_name):
     # Custom
     else:
         command = "AutoModelForCausalLM.from_pretrained"
-        settings = []
+        settings = ["low_cpu_mem_usage=True"]
 
         if args.cpu:
             settings.append("low_cpu_mem_usage=True")
@@ -211,7 +211,7 @@ def load_model(model_name):
 
             if args.gpu_memory:
                 settings.append(f"max_memory={{0: '{args.gpu_memory or '99'}GiB', 'cpu': '{args.cpu_memory or '99'}GiB'}}")
-            elif not args.load_in_8bit:
+            elif (args.gpu_memory or args.cpu_memory) and not args.load_in_8bit:
                 total_mem = (torch.cuda.get_device_properties(0).total_memory/(1024*1024))
                 suggestion = round((total_mem-1000)/1000)*1000
                 if total_mem-suggestion < 800: