|
@@ -76,7 +76,7 @@ def load_model(model_name):
|
|
|
num_bits=4, group_size=64,
|
|
num_bits=4, group_size=64,
|
|
|
group_dim=2, symmetric=False))
|
|
group_dim=2, symmetric=False))
|
|
|
|
|
|
|
|
- model = OptLM(f"facebook/{shared.model_name}", env, shared.model_name, policy)
|
|
|
|
|
|
|
+ model = OptLM(f"facebook/{shared.model_name}", env, shared.args.model_dir, policy)
|
|
|
|
|
|
|
|
# DeepSpeed ZeRO-3
|
|
# DeepSpeed ZeRO-3
|
|
|
elif shared.args.deepspeed:
|
|
elif shared.args.deepspeed:
|