Skip to content

Commit 38a5fe7

Browse files
rohitgr7carmoccatchaton
authored
Remove optimizer_idx arg in manual optimization (#6093)
Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: chaton <[email protected]>
1 parent 966184a commit 38a5fe7

File tree

10 files changed

+89
-38
lines changed

10 files changed

+89
-38
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
7575
- Removed deprecated `LightningModule` `hparams` setter ([#6207](https://github.com/PyTorchLightning/pytorch-lightning/pull/6207))
7676

7777

78+
- Removed `optimizer_idx` argument from `training_step` in manual optimization ([#6093](https://github.com/PyTorchLightning/pytorch-lightning/pull/6093))
79+
80+
7881
### Fixed
7982

8083
- Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011))

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -318,9 +318,9 @@ class LitAutoEncoder(pl.LightningModule):
318318
super().__init__()
319319
self.automatic_optimization = False
320320

321-
def training_step(self, batch, batch_idx, optimizer_idx):
321+
def training_step(self, batch, batch_idx):
322322
# access your optimizers with use_pl_optimizer=False. Default is True
323-
(opt_a, opt_b) = self.optimizers(use_pl_optimizer=True)
323+
opt_a, opt_b = self.optimizers(use_pl_optimizer=True)
324324

325325
loss_a = ...
326326
self.manual_backward(loss_a, opt_a)

docs/source/common/lightning_module.rst

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -952,14 +952,12 @@ When set to ``False``, Lightning does not automate the optimization process. Thi
952952
953953
This is recommended only if using 2+ optimizers AND if you know how to perform the optimization procedure properly. Note that automatic optimization can still be used with multiple optimizers by relying on the ``optimizer_idx`` parameter. Manual optimization is most useful for research topics like reinforcement learning, sparse coding, and GAN research.
954954

955-
In the multi-optimizer case, ignore the ``optimizer_idx`` argument and use the optimizers directly
956-
957955
.. code-block:: python
958956
959957
def __init__(self):
960958
self.automatic_optimization = False
961959
962-
def training_step(self, batch, batch_idx, optimizer_idx):
960+
def training_step(self, batch, batch_idx):
963961
# access your optimizers with use_pl_optimizer=False. Default is True
964962
opt_a, opt_b = self.optimizers(use_pl_optimizer=True)
965963

docs/source/common/optimizers.rst

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,10 @@ to manually manage the optimization process. To do so, do the following:
5151

5252
.. code-block:: python
5353
54-
def training_step(batch, batch_idx):
54+
def __init__(self):
55+
self.automatic_optimization = False
56+
57+
def training_step(self, batch, batch_idx):
5558
opt = self.optimizers()
5659
5760
loss = self.compute_loss(batch)
@@ -69,7 +72,10 @@ Here is the same example as above using a ``closure``.
6972

7073
.. testcode:: python
7174

72-
def training_step(batch, batch_idx):
75+
def __init__(self):
76+
self.automatic_optimization = False
77+
78+
def training_step(self, batch, batch_idx):
7379
opt = self.optimizers()
7480

7581
def forward_and_backward():
@@ -126,7 +132,6 @@ Here is the same example as above using a ``closure``.
126132
# Optimize Discriminator #
127133
###########################
128134
d_opt.zero_grad()
129-
130135
d_x = self.D(X)
131136
errD_real = self.criterion(d_x, real_label)
132137

@@ -179,6 +184,9 @@ Here is an example for advanced use-case.
179184

180185
...
181186

187+
def __init__(self):
188+
self.automatic_optimization = False
189+
182190
def training_step(self, batch, batch_idx):
183191
# Implementation follows https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
184192
g_opt, d_opt = self.optimizers()

docs/source/starter/new-project.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ Turn off automatic optimization and you control the train loop!
265265
def __init__(self):
266266
self.automatic_optimization = False
267267
268-
def training_step(self, batch, batch_idx, optimizer_idx):
268+
def training_step(self, batch, batch_idx):
269269
# access your optimizers with use_pl_optimizer=False. Default is True
270270
opt_a, opt_b = self.optimizers(use_pl_optimizer=True)
271271

pytorch_lightning/trainer/training_loop.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ def training_step(self, split_batch, batch_idx, opt_idx, hiddens):
310310
closure_loss = None
311311
untouched_loss = None
312312

313-
if self.trainer.train_loop.automatic_optimization:
313+
if self.automatic_optimization:
314314
# accumulate loss
315315
# (if accumulate_grad_batches = 1 no effect)
316316
if is_result_obj:
@@ -840,12 +840,17 @@ def build_train_args(self, batch, batch_idx, opt_idx, hiddens):
840840

841841
if len(self.trainer.optimizers) > 1:
842842
if self.trainer.has_arg("training_step", "optimizer_idx"):
843+
if not self.automatic_optimization:
844+
self.warning_cache.warn(
845+
"`training_step` hook signature has changed in v1.3."
846+
" `optimizer_idx` argument has been removed in case of manual optimization. Support for"
847+
" the old signature will be removed in v1.5", DeprecationWarning
848+
)
843849
args.append(opt_idx)
844-
else:
845-
num_opts = len(self.trainer.optimizers)
850+
elif not self.trainer.has_arg("training_step", "optimizer_idx") and self.automatic_optimization:
846851
raise ValueError(
847-
f"Your LightningModule defines {num_opts} optimizers but "
848-
f'training_step is missing the "optimizer_idx" argument.'
852+
f"Your LightningModule defines {len(self.trainer.optimizers)} optimizers but"
853+
' `training_step` is missing the `optimizer_idx` argument.'
849854
)
850855

851856
# pass hiddens if using tbptt

tests/core/test_lightning_optimizer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,9 @@ def __init__(self):
8989
super().__init__()
9090
self.automatic_optimization = False
9191

92-
def training_step(self, batch, batch_idx, optimizer_idx=None):
92+
def training_step(self, batch, batch_idx):
9393
opt_1, opt_2 = self.optimizers()
94+
9495
assert isinstance(opt_1, LightningOptimizer)
9596
assert isinstance(opt_2, LightningOptimizer)
9697

tests/deprecated_api/test_remove_1-5.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from unittest import mock
1717

1818
import pytest
19+
from torch import optim
1920

2021
from pytorch_lightning import Callback, Trainer
2122
from pytorch_lightning.loggers import WandbLogger
@@ -74,3 +75,25 @@ def test_v1_5_0_running_sanity_check():
7475
trainer = Trainer()
7576
with pytest.deprecated_call(match='has been renamed to `Trainer.sanity_checking`'):
7677
assert not trainer.running_sanity_check
78+
79+
80+
def test_old_training_step_signature_with_opt_idx_manual_opt(tmpdir):
81+
82+
class OldSignatureModel(BoringModel):
83+
84+
def __init__(self):
85+
super().__init__()
86+
self.automatic_optimization = False
87+
88+
def training_step(self, batch, batch_idx, optimizer_idx):
89+
assert optimizer_idx is not None
90+
return super().training_step(batch, batch_idx)
91+
92+
def configure_optimizers(self):
93+
return [optim.SGD(self.parameters(), lr=1e-2), optim.SGD(self.parameters(), lr=1e-2)]
94+
95+
model = OldSignatureModel()
96+
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=2)
97+
98+
with pytest.deprecated_call(match="`training_step` .* `optimizer_idx` .* manual .* will be removed in v1.5"):
99+
trainer.fit(model)

tests/trainer/optimization/test_manual_optimization.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@ def __init__(self):
4040
super().__init__()
4141
self.automatic_optimization = False
4242

43-
def training_step(self, batch, batch_idx, optimizer_idx):
43+
def training_step(self, batch, batch_idx):
4444
# manual
45-
(opt_a, opt_b) = self.optimizers()
45+
opt_a, opt_b = self.optimizers()
4646
loss_1 = self.step(batch[0])
4747

4848
# make sure there are no grads
@@ -107,9 +107,9 @@ def __init__(self):
107107
super().__init__()
108108
self.automatic_optimization = False
109109

110-
def training_step(self, batch, batch_idx, optimizer_idx):
110+
def training_step(self, batch, batch_idx):
111111
# manual
112-
(opt_a, opt_b) = self.optimizers()
112+
opt_a, opt_b = self.optimizers()
113113
loss_1 = self.step(batch[0])
114114

115115
# make sure there are no grads
@@ -176,9 +176,9 @@ def __init__(self):
176176
super().__init__()
177177
self.automatic_optimization = False
178178

179-
def training_step(self, batch, batch_idx, optimizer_idx):
179+
def training_step(self, batch, batch_idx):
180180
# manual
181-
(opt_a, opt_b) = self.optimizers()
181+
opt_a, opt_b = self.optimizers()
182182
loss_1 = self.step(batch[0])
183183

184184
# make sure there are no grads
@@ -251,9 +251,9 @@ def __init__(self):
251251
super().__init__()
252252
self.automatic_optimization = False
253253

254-
def training_step(self, batch, batch_idx, optimizer_idx):
254+
def training_step(self, batch, batch_idx):
255255
# manual
256-
(opt_a, opt_b) = self.optimizers()
256+
opt_a, opt_b = self.optimizers()
257257
loss_1 = self.step(batch[0])
258258

259259
# make sure there are no grads
@@ -321,9 +321,9 @@ def __init__(self):
321321
super().__init__()
322322
self.automatic_optimization = False
323323

324-
def training_step(self, batch, batch_idx, optimizer_idx):
324+
def training_step(self, batch, batch_idx):
325325
# manual
326-
(opt_a, opt_b) = self.optimizers()
326+
opt_a, opt_b = self.optimizers()
327327
x = batch[0]
328328

329329
loss_1 = self(x)
@@ -610,9 +610,9 @@ def on_after_backward(self):
610610
if not (torch.isinf(norm) or torch.isnan(norm)):
611611
assert norm.item() < 100, norm.item()
612612

613-
def training_step(self, batch, batch_idx, optimizer_idx):
613+
def training_step(self, batch, batch_idx):
614614
# manual
615-
(opt_a, opt_b) = self.optimizers()
615+
opt_a, opt_b = self.optimizers()
616616
x = batch[0]
617617

618618
loss_1 = self(x)
@@ -886,7 +886,7 @@ def __init__(self):
886886
super().__init__()
887887
self.automatic_optimization = False
888888

889-
def training_step(self, batch, batch_idx, optimizer_idx):
889+
def training_step(self, batch, batch_idx):
890890

891891
# emulate gans training
892892
opt_gen, opt_dis = self.optimizers()
@@ -981,7 +981,7 @@ def manual_sync_grad(self) -> bool:
981981
torch_distrib.all_reduce(self.layer.weight.grad.data, async_op=False)
982982
return True
983983

984-
def training_step(self, batch, batch_idx, optimizer_idx):
984+
def training_step(self, batch, batch_idx):
985985

986986
# emulate gans training
987987
opt_gen, opt_dis = self.optimizers()
@@ -1088,9 +1088,9 @@ def test_step_with_optimizer_closure_with_different_frequencies_ddp_spawn(tmpdir
10881088
train_manual_optimization(tmpdir, "ddp_spawn")
10891089

10901090

1091-
class TesManualOptimizationDDPModelToggleModel(TesManualOptimizationDDPModel):
1091+
class TestManualOptimizationDDPModelToggleModel(TesManualOptimizationDDPModel):
10921092

1093-
def training_step(self, batch, batch_idx, optimizer_idx):
1093+
def training_step(self, batch, batch_idx):
10941094

10951095
# emulate gans training
10961096
opt_gen, opt_dis = self.optimizers()
@@ -1147,4 +1147,4 @@ def dis_closure():
11471147

11481148
@RunIf(min_gpus=2, special=True)
11491149
def test_step_with_optimizer_closure_with_different_frequencies_ddp_with_toggle_model(tmpdir):
1150-
train_manual_optimization(tmpdir, "ddp", model_cls=TesManualOptimizationDDPModelToggleModel)
1150+
train_manual_optimization(tmpdir, "ddp", model_cls=TestManualOptimizationDDPModelToggleModel)

tests/trainer/optimization/test_multiple_optimizers.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"""
1515
Tests to ensure that the behaviours related to multiple optimizers works
1616
"""
17+
import pytest
1718
import torch
1819

1920
import pytorch_lightning as pl
@@ -90,11 +91,6 @@ def training_epoch_end(self, outputs) -> None:
9091
# outputs should be an array with an entry per optimizer
9192
assert len(outputs) == 2
9293

93-
def configure_optimizers(self):
94-
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
95-
optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1)
96-
return optimizer, optimizer_2
97-
9894
model = TestModel()
9995
model.val_dataloader = None
10096

@@ -119,7 +115,7 @@ def __init__(self):
119115
super().__init__()
120116
self.automatic_optimization = False
121117

122-
def training_step(self, batch, batch_idx, optimizer_idx):
118+
def training_step(self, batch, batch_idx):
123119
self.training_step_called = True
124120

125121
# manual optimization
@@ -154,3 +150,20 @@ def training_epoch_end(self, outputs) -> None:
154150
trainer.fit(model)
155151

156152
assert model.training_step_called
153+
154+
155+
def test_multiple_optimizers_no_opt_idx_argument(tmpdir):
156+
"""
157+
Test that an error is raised if no optimizer_idx is present when
158+
multiple optimizeres are passed in case of automatic_optimization
159+
"""
160+
161+
class TestModel(MultiOptModel):
162+
163+
def training_step(self, batch, batch_idx):
164+
return super().training_step(batch, batch_idx)
165+
166+
trainer = pl.Trainer(default_root_dir=tmpdir, fast_dev_run=2)
167+
168+
with pytest.raises(ValueError, match='`training_step` is missing the `optimizer_idx`'):
169+
trainer.fit(TestModel())

0 commit comments

Comments
 (0)