Skip to content

Commit 7efb764

Browse files
committed
bc
1 parent 0266104 commit 7efb764

File tree

4 files changed

+52
-10
lines changed

4 files changed

+52
-10
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,7 @@ class LitAutoEncoder(pl.LightningModule):
319319
self.automatic_optimization = False
320320

321321
def training_step(self, batch, batch_idx):
322+
# access your optimizers with use_pl_optimizer=False. Default is True
322323
opt_a, opt_b = self.optimizers(use_pl_optimizer=True)
323324

324325
loss_a = ...

pytorch_lightning/trainer/training_loop.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -838,14 +838,19 @@ def build_train_args(self, batch, batch_idx, opt_idx, hiddens):
838838
# enable not needing to add opt_idx to training_step
839839
args = [batch, batch_idx]
840840

841-
if len(self.trainer.optimizers) > 1 and self.automatic_optimization:
841+
if len(self.trainer.optimizers) > 1:
842842
if self.trainer.has_arg("training_step", "optimizer_idx"):
843+
if not self.automatic_optimization:
844+
self.warning_cache.warn(
845+
"`training_step` hook signature has changed in v1.3."
846+
" `optimizer_idx` argument has been removed in case of manual optimization. Support for"
847+
" the old signature will be removed in v1.5", DeprecationWarning
848+
)
843849
args.append(opt_idx)
844-
else:
845-
num_opts = len(self.trainer.optimizers)
850+
elif not self.trainer.has_arg("training_step", "optimizer_idx") and self.automatic_optimization:
846851
raise ValueError(
847-
f"Your LightningModule defines {num_opts} optimizers but"
848-
f' training_step is missing the "optimizer_idx" argument.'
852+
f"Your LightningModule defines {len(self.trainer.optimizers)} optimizers but"
853+
' `training_step` is missing the `optimizer_idx` argument.'
849854
)
850855

851856
# pass hiddens if using tbptt

tests/deprecated_api/test_remove_1-5.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from unittest import mock
1717

1818
import pytest
19+
from torch import optim
1920

2021
from pytorch_lightning import Callback, Trainer
2122
from pytorch_lightning.loggers import WandbLogger
@@ -74,3 +75,25 @@ def test_v1_5_0_running_sanity_check():
7475
trainer = Trainer()
7576
with pytest.deprecated_call(match='has been renamed to `Trainer.sanity_checking`'):
7677
assert not trainer.running_sanity_check
78+
79+
80+
def test_old_training_step_signature_with_opt_idx_manual_opt(tmpdir):
81+
82+
class OldSignatureModel(BoringModel):
83+
84+
def __init__(self):
85+
super().__init__()
86+
self.automatic_optimization = False
87+
88+
def training_step(self, batch, batch_idx, optimizer_idx):
89+
assert optimizer_idx is not None
90+
return super().training_step(batch, batch_idx)
91+
92+
def configure_optimizers(self):
93+
return [optim.SGD(self.parameters(), lr=1e-2), optim.SGD(self.parameters(), lr=1e-2)]
94+
95+
model = OldSignatureModel()
96+
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=2)
97+
98+
with pytest.deprecated_call(match="`training_step` .* `optimizer_idx` .* manual .* will be removed in v1.5"):
99+
trainer.fit(model)

tests/trainer/optimization/test_multiple_optimizers.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"""
1515
Tests to ensure that the behaviours related to multiple optimizers works
1616
"""
17+
import pytest
1718
import torch
1819

1920
import pytorch_lightning as pl
@@ -90,11 +91,6 @@ def training_epoch_end(self, outputs) -> None:
9091
# outputs should be an array with an entry per optimizer
9192
assert len(outputs) == 2
9293

93-
def configure_optimizers(self):
94-
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
95-
optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1)
96-
return optimizer, optimizer_2
97-
9894
model = TestModel()
9995
model.val_dataloader = None
10096

@@ -154,3 +150,20 @@ def training_epoch_end(self, outputs) -> None:
154150
trainer.fit(model)
155151

156152
assert model.training_step_called
153+
154+
155+
def test_multiple_optimizers_no_opt_idx_argument(tmpdir):
156+
"""
157+
Test that an error is raised if no optimizer_idx is present when
158+
multiple optimizeres are passed in case of automatic_optimization
159+
"""
160+
161+
class TestModel(MultiOptModel):
162+
163+
def training_step(self, batch, batch_idx):
164+
return super().training_step(batch, batch_idx)
165+
166+
trainer = pl.Trainer(default_root_dir=tmpdir, fast_dev_run=2)
167+
168+
with pytest.raises(ValueError, match='`training_step` is missing the `optimizer_idx`'):
169+
trainer.fit(TestModel())

0 commit comments

Comments
 (0)