Skip to content

Commit a272e41

Browse files
pacman100sgugger
andcommitted
fix bugs with trainer (#24134)
* fix the deepspeed test failures * apex fix * FSDP save ckpt fix * Update src/transformers/trainer.py Co-authored-by: Sylvain Gugger <[email protected]> --------- Co-authored-by: Sylvain Gugger <[email protected]>
1 parent 50ed793 commit a272e41

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

src/transformers/trainer.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1749,7 +1749,16 @@ def _inner_training_loop(
17491749

17501750
# prepare using `accelerator` prepare
17511751
if use_accelerator_prepare:
1752-
model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
1752+
if hasattr(self.lr_scheduler, "step"):
1753+
if self.use_apex:
1754+
model = self.accelerator.prepare(self.model)
1755+
else:
1756+
model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
1757+
else:
1758+
# to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config.
1759+
model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
1760+
self.model, self.optimizer, self.lr_scheduler
1761+
)
17531762

17541763
if self.is_fsdp_enabled:
17551764
self.model = model
@@ -2841,6 +2850,7 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa
28412850
or self.is_fsdp_enabled
28422851
):
28432852
if self.is_fsdp_enabled:
2853+
os.makedirs(output_dir, exist_ok=True)
28442854
self.accelerator.state.fsdp_plugin.save_model(self.accelerator, self.model, output_dir)
28452855
else:
28462856
state_dict = self.model.state_dict()

0 commit comments

Comments
 (0)