Skip to content

Commit 36770b2

Browse files
awaelchlicarmocca
andauthored
validate manual optimization and supported features before running training (#7788)
Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 0bad218 commit 36770b2

File tree

5 files changed

+34
-6
lines changed

5 files changed

+34
-6
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
119119
- Changed `teardown()` in `Accelerator` to allow `training_type_plugin` to customize `teardown` logic ([#7579](https://github.com/PyTorchLightning/pytorch-lightning/pull/7579))
120120

121121

122+
- `Trainer.fit` now raises an error when using manual optimization with unsupported features such as `gradient_clip_val` or `accumulate_grad_batches` ([#7788](https://github.com/PyTorchLightning/pytorch-lightning/pull/7788))
123+
124+
122125
### Deprecated
123126

124127

pytorch_lightning/trainer/configuration_validator.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def verify_loop_configurations(self, model: 'pl.LightningModule') -> None:
3434
if self.trainer.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING):
3535
self.__verify_train_loop_configuration(model)
3636
self.__verify_eval_loop_configuration(model, 'val')
37+
self.__verify_manual_optimization_support(model)
3738
elif self.trainer.state.fn == TrainerFn.VALIDATING:
3839
self.__verify_eval_loop_configuration(model, 'val')
3940
elif self.trainer.state.fn == TrainerFn.TESTING:
@@ -112,3 +113,19 @@ def __verify_dp_batch_transfer_support(self, model: 'pl.LightningModule') -> Non
112113
for hook in batch_transfer_hooks:
113114
if self.trainer.accelerator_connector.use_dp and is_overridden(hook, model):
114115
raise MisconfigurationException(f'Overriding `{hook}` is not supported in DP mode.')
116+
117+
def __verify_manual_optimization_support(self, model: 'pl.LightningModule') -> None:
118+
if model.automatic_optimization:
119+
return
120+
if self.trainer.gradient_clip_val > 0:
121+
raise MisconfigurationException(
122+
f"Automatic gradient clipping is not supported for manual optimization."
123+
f" Remove `Trainer(gradient_clip_val={self.trainer.gradient_clip_val})`"
124+
f" or switch to automatic optimization."
125+
)
126+
if self.trainer.accumulate_grad_batches != 1:
127+
raise MisconfigurationException(
128+
f"Automatic gradient accumulation is not supported for manual optimization."
129+
f" Remove `Trainer(accumulate_grad_batches={self.trainer.accumulate_grad_batches})`"
130+
f" or switch to automatic optimization."
131+
)

tests/core/test_lightning_optimizer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,6 @@ def configure_optimizers(self):
123123
limit_val_batches=1,
124124
max_epochs=1,
125125
weights_summary=None,
126-
accumulate_grad_batches=999, # does not do anything if manual optimization
127126
)
128127

129128
with patch.multiple(torch.optim.SGD, zero_grad=DEFAULT, step=DEFAULT) as sgd, \

tests/trainer/optimization/test_manual_optimization.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,6 @@ def on_train_epoch_end(self, *_, **__):
424424
limit_val_batches=0,
425425
precision=16,
426426
amp_backend='native',
427-
accumulate_grad_batches=4,
428427
gpus=1,
429428
)
430429
trainer.fit(model)
@@ -631,7 +630,6 @@ def configure_optimizers(self):
631630
limit_val_batches=2,
632631
max_epochs=1,
633632
log_every_n_steps=1,
634-
accumulate_grad_batches=2,
635633
)
636634

637635
trainer.fit(model)
@@ -682,7 +680,6 @@ def configure_optimizers(self):
682680
limit_val_batches=2,
683681
max_epochs=1,
684682
log_every_n_steps=1,
685-
accumulate_grad_batches=2,
686683
)
687684

688685
trainer.fit(model)
@@ -757,7 +754,6 @@ def configure_optimizers(self):
757754
limit_val_batches=2,
758755
max_epochs=1,
759756
log_every_n_steps=1,
760-
accumulate_grad_batches=2,
761757
)
762758

763759
trainer.fit(model)
@@ -867,7 +863,6 @@ def train_manual_optimization(tmpdir, accelerator, model_cls=TesManualOptimizati
867863
limit_val_batches=2,
868864
max_epochs=1,
869865
log_every_n_steps=1,
870-
accumulate_grad_batches=2,
871866
gpus=2,
872867
accelerator=accelerator,
873868
callbacks=[TestManualOptimizationDDPCallack()]

tests/trainer/test_config_validator.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,3 +147,17 @@ def predict_dataloader(self):
147147

148148
with pytest.raises(MisconfigurationException, match="Dataloader not found for `Trainer.predict`"):
149149
trainer.predict(model)
150+
151+
152+
def test_trainer_manual_optimization_config(tmpdir):
153+
""" Test error message when requesting Trainer features unsupported with manual optimization """
154+
model = BoringModel()
155+
model.automatic_optimization = False
156+
157+
trainer = Trainer(gradient_clip_val=1.0)
158+
with pytest.raises(MisconfigurationException, match="Automatic gradient clipping is not supported"):
159+
trainer.fit(model)
160+
161+
trainer = Trainer(accumulate_grad_batches=2)
162+
with pytest.raises(MisconfigurationException, match="Automatic gradient accumulation is not supported"):
163+
trainer.fit(model)

0 commit comments

Comments
 (0)