Skip to content

Commit f345840

Browse files
mauvilsapre-commit-ci[bot]carmocca
authored
Fix support for torch Module type hints in LightningCLI (#7807)
* Fixed support for torch Module type hints in LightningCLI * - Fix issue with serializing values when type hint is Any. - Run unit test only on newer torchvision versions in which the base class is Module. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Minor change * Update CHANGELOG.md Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 36770b2 commit f345840

File tree

3 files changed

+58
-1
lines changed

3 files changed

+58
-1
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
185185
- Fixed a bug where `precision=64` with `accelerator='ddp_spawn'` would throw a pickle error ([#6924](https://github.com/PyTorchLightning/pytorch-lightning/pull/6924))
186186

187187

188+
- Fixed support for `torch.nn.Module` type hints in `LightningCLI` ([#7807](https://github.com/PyTorchLightning/pytorch-lightning/pull/7807))
189+
190+
188191
## [1.3.2] - 2021-05-18
189192

190193
### Changed

requirements/extra.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@ torchtext>=0.5
77
# onnx>=1.7.0
88
onnxruntime>=1.3.0
99
hydra-core>=1.0
10-
jsonargparse[signatures]>=3.12.0
10+
jsonargparse[signatures]>=3.13.1

tests/utilities/test_cli.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,26 @@
2020
from argparse import Namespace
2121
from contextlib import redirect_stdout
2222
from io import StringIO
23+
from typing import List, Optional
2324
from unittest import mock
2425

2526
import pytest
27+
import torch
2628
import yaml
29+
from packaging import version
2730

2831
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
2932
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
3033
from pytorch_lightning.plugins.environments import SLURMEnvironment
3134
from pytorch_lightning.utilities import _TPU_AVAILABLE
3235
from pytorch_lightning.utilities.cli import LightningArgumentParser, LightningCLI, SaveConfigCallback
36+
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE
3337
from tests.helpers import BoringDataModule, BoringModel
3438

39+
torchvision_version = version.parse('0')
40+
if _TORCHVISION_AVAILABLE:
41+
torchvision_version = version.parse(__import__('torchvision').__version__)
42+
3543

3644
@mock.patch('argparse.ArgumentParser.parse_args')
3745
def test_default_args(mock_argparse, tmpdir):
@@ -443,3 +451,49 @@ def __init__(
443451
assert cli.model.submodule2 == cli.config_init['model']['submodule2']
444452
assert isinstance(cli.config_init['model']['submodule1'], BoringModel)
445453
assert isinstance(cli.config_init['model']['submodule2'], BoringModel)
454+
455+
456+
@pytest.mark.skipif(torchvision_version < version.parse('0.8.0'), reason='torchvision>=0.8.0 is required')
457+
def test_lightning_cli_torch_modules(tmpdir):
458+
459+
class TestModule(BoringModel):
460+
461+
def __init__(
462+
self,
463+
activation: torch.nn.Module = None,
464+
transform: Optional[List[torch.nn.Module]] = None,
465+
):
466+
super().__init__()
467+
self.activation = activation
468+
self.transform = transform
469+
470+
config = """model:
471+
activation:
472+
class_path: torch.nn.LeakyReLU
473+
init_args:
474+
negative_slope: 0.2
475+
transform:
476+
- class_path: torchvision.transforms.Resize
477+
init_args:
478+
size: 64
479+
- class_path: torchvision.transforms.CenterCrop
480+
init_args:
481+
size: 64
482+
"""
483+
config_path = tmpdir / 'config.yaml'
484+
with open(config_path, 'w') as f:
485+
f.write(config)
486+
487+
cli_args = [
488+
f'--trainer.default_root_dir={tmpdir}',
489+
'--trainer.max_epochs=1',
490+
f'--config={str(config_path)}',
491+
]
492+
493+
with mock.patch('sys.argv', ['any.py'] + cli_args):
494+
cli = LightningCLI(TestModule)
495+
496+
assert isinstance(cli.model.activation, torch.nn.LeakyReLU)
497+
assert cli.model.activation.negative_slope == 0.2
498+
assert len(cli.model.transform) == 2
499+
assert all(isinstance(v, torch.nn.Module) for v in cli.model.transform)

0 commit comments

Comments
 (0)