소스 검색

Minor rewrite

oobabooga 2 년 전
부모
커밋
3d6cb5ed63
1개의 변경된 파일4개의 추가작업 그리고 6개의 파일을 삭제
  1. 4 6
      modules/GPTQ_loader.py

+ 4 - 6
modules/GPTQ_loader.py

@@ -65,13 +65,11 @@ def load_quantized(model_name):
     else:
         model_type = shared.args.model_type.lower()
 
-    if shared.args.pre_layer:
-        if model_type == 'llama':
-            load_quant = llama_inference_offload.load_quant
-        else:
-            print("Warning: ignoring --pre_layer because it only works for llama model type.")
-            load_quant = _load_quant
+    if shared.args.pre_layer and model_type == 'llama':
+        load_quant = llama_inference_offload.load_quant
     elif model_type in ('llama', 'opt', 'gptj'):
+        if shared.args.pre_layer:
+            print("Warning: ignoring --pre_layer because it only works for llama model type.")
         load_quant = _load_quant
     else:
         print("Unknown pre-quantized model type specified. Only 'llama', 'opt' and 'gptj' are supported")