Просмотр исходного кода

Make FlexGen work with the newest API

oobabooga 2 лет назад
Родитель
Сommit
8e3e8a070f
1 измененных файлов с 5 добавлено и 10 удалено
  1. 5 10
      modules/models.py

+ 5 - 10
modules/models.py

@@ -16,9 +16,8 @@ transformers.logging.set_verbosity_error()
 local_rank = None
 local_rank = None
 
 
 if shared.args.flexgen:
 if shared.args.flexgen:
-    from flexgen.flex_opt import (CompressionConfig, Env, OptLM, Policy,
-                                  TorchDevice, TorchDisk, TorchMixedDevice,
-                                  get_opt_config)
+    from flexgen.flex_opt import (CompressionConfig, ExecutionEnv, OptLM,
+                                  Policy, str2bool)
 
 
 if shared.args.deepspeed:
 if shared.args.deepspeed:
     import deepspeed
     import deepspeed
@@ -48,10 +47,8 @@ def load_model(model_name):
 
 
     # FlexGen
     # FlexGen
     elif shared.args.flexgen:
     elif shared.args.flexgen:
-        gpu = TorchDevice("cuda:0")
-        cpu = TorchDevice("cpu")
-        disk = TorchDisk(shared.args.disk_cache_dir)
-        env = Env(gpu=gpu, cpu=cpu, disk=disk, mixed=TorchMixedDevice([gpu, cpu, disk]))
+        # Initialize environment
+        env = ExecutionEnv.create(shared.args.disk_cache_dir)
 
 
         # Offloading policy
         # Offloading policy
         policy = Policy(1, 1,
         policy = Policy(1, 1,
@@ -69,9 +66,7 @@ def load_model(model_name):
                             num_bits=4, group_size=64,
                             num_bits=4, group_size=64,
                             group_dim=2, symmetric=False))
                             group_dim=2, symmetric=False))
 
 
-        opt_config = get_opt_config(f"facebook/{shared.model_name}")
-        model = OptLM(opt_config, env, "models", policy)
-        model.init_all_weights()
+        model = OptLM(f"facebook/{shared.model_name}", env, "models", policy)
 
 
     # DeepSpeed ZeRO-3
     # DeepSpeed ZeRO-3
     elif shared.args.deepspeed:
     elif shared.args.deepspeed: