oobabooga пре 2 година
родитељ
комит
3d6cb5ed63
1 измењених фајлова са 4 додато и 6 уклоњено
  1. 4 6
      modules/GPTQ_loader.py

+ 4 - 6
modules/GPTQ_loader.py

@@ -65,13 +65,11 @@ def load_quantized(model_name):
     else:
         model_type = shared.args.model_type.lower()
 
-    if shared.args.pre_layer:
-        if model_type == 'llama':
-            load_quant = llama_inference_offload.load_quant
-        else:
-            print("Warning: ignoring --pre_layer because it only works for llama model type.")
-            load_quant = _load_quant
+    if shared.args.pre_layer and model_type == 'llama':
+        load_quant = llama_inference_offload.load_quant
     elif model_type in ('llama', 'opt', 'gptj'):
+        if shared.args.pre_layer:
+            print("Warning: ignoring --pre_layer because it only works for llama model type.")
         load_quant = _load_quant
     else:
         print("Unknown pre-quantized model type specified. Only 'llama', 'opt' and 'gptj' are supported")