@@ -178,7 +178,7 @@ def load_model(model_name):
# DeepSpeed ZeRO-3
elif args.deepspeed:
- model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}", no_split_module_classes=["GPTJBlock"]))
+ model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"))
model = deepspeed.initialize(model=model,
config_params=ds_config,
model_parameters=None,