|
22 | 22 | import torch.distributed as torch_distrib |
23 | 23 | import torch.nn.functional as F |
24 | 24 |
|
| 25 | +from lightning.fabric.utilities.exceptions import MisconfigurationException |
25 | 26 | from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0 |
26 | 27 | from lightning.pytorch import seed_everything, Trainer |
27 | 28 | from lightning.pytorch.demos.boring_classes import BoringModel, ManualOptimBoringModel |
@@ -886,3 +887,37 @@ def configure_optimizers(self): |
886 | 887 |
|
887 | 888 | assert set(trainer.logged_metrics) == {"loss_d", "loss_g"} |
888 | 889 | assert set(trainer.progress_bar_metrics) == {"loss_d", "loss_g"} |
| 890 | + |
| 891 | + |
| 892 | +@pytest.mark.parametrize("automatic_optimization", [True, False]) |
| 893 | +def test_manual_optimization_with_non_pytorch_scheduler(automatic_optimization): |
| 894 | + """In manual optimization, the user can provide a custom scheduler that doesn't follow PyTorch's interface.""" |
| 895 | + |
| 896 | + class IncompatibleScheduler: |
| 897 | + def __init__(self, optimizer): |
| 898 | + self.optimizer = optimizer |
| 899 | + |
| 900 | + def state_dict(self): |
| 901 | + return {} |
| 902 | + |
| 903 | + def load_state_dict(self, _): |
| 904 | + pass |
| 905 | + |
| 906 | + class Model(BoringModel): |
| 907 | + def __init__(self): |
| 908 | + super().__init__() |
| 909 | + self.automatic_optimization = automatic_optimization |
| 910 | + |
| 911 | + def configure_optimizers(self): |
| 912 | + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) |
| 913 | + scheduler = IncompatibleScheduler(optimizer) |
| 914 | + return [optimizer], [scheduler] |
| 915 | + |
| 916 | + model = Model() |
| 917 | + trainer = Trainer(accelerator="cpu", max_epochs=0) |
| 918 | + if automatic_optimization: |
| 919 | + with pytest.raises(MisconfigurationException, match="doesn't follow PyTorch's LRScheduler"): |
| 920 | + trainer.fit(model) |
| 921 | + else: |
| 922 | + # No error for manual optimization |
| 923 | + trainer.fit(model) |
0 commit comments