|
|
@@ -16,9 +16,8 @@ transformers.logging.set_verbosity_error()
|
|
|
local_rank = None
|
|
|
|
|
|
if shared.args.flexgen:
|
|
|
- from flexgen.flex_opt import (CompressionConfig, Env, OptLM, Policy,
|
|
|
- TorchDevice, TorchDisk, TorchMixedDevice,
|
|
|
- get_opt_config)
|
|
|
+ from flexgen.flex_opt import (CompressionConfig, ExecutionEnv, OptLM,
|
|
|
+ Policy, str2bool)
|
|
|
|
|
|
if shared.args.deepspeed:
|
|
|
import deepspeed
|
|
|
@@ -48,10 +47,8 @@ def load_model(model_name):
|
|
|
|
|
|
# FlexGen
|
|
|
elif shared.args.flexgen:
|
|
|
- gpu = TorchDevice("cuda:0")
|
|
|
- cpu = TorchDevice("cpu")
|
|
|
- disk = TorchDisk(shared.args.disk_cache_dir)
|
|
|
- env = Env(gpu=gpu, cpu=cpu, disk=disk, mixed=TorchMixedDevice([gpu, cpu, disk]))
|
|
|
+ # Initialize environment
|
|
|
+ env = ExecutionEnv.create(shared.args.disk_cache_dir)
|
|
|
|
|
|
# Offloading policy
|
|
|
policy = Policy(1, 1,
|
|
|
@@ -69,9 +66,7 @@ def load_model(model_name):
|
|
|
num_bits=4, group_size=64,
|
|
|
group_dim=2, symmetric=False))
|
|
|
|
|
|
- opt_config = get_opt_config(f"facebook/{shared.model_name}")
|
|
|
- model = OptLM(opt_config, env, "models", policy)
|
|
|
- model.init_all_weights()
|
|
|
+ model = OptLM(f"facebook/{shared.model_name}", env, "models", policy)
|
|
|
|
|
|
# DeepSpeed ZeRO-3
|
|
|
elif shared.args.deepspeed:
|