Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed a bug where `precision=64` with `accelerator='ddp_spawn'` would throw a pickle error ([#6924](https://github.com/PyTorchLightning/pytorch-lightning/pull/6924))


- Fixed support for `torch.nn.Module` type hints in `LightningCLI` ([#7807](https://github.com/PyTorchLightning/pytorch-lightning/pull/7807))


## [1.3.2] - 2021-05-18

### Changed
Expand Down
2 changes: 1 addition & 1 deletion requirements/extra.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ torchtext>=0.5
# onnx>=1.7.0
onnxruntime>=1.3.0
hydra-core>=1.0
jsonargparse[signatures]>=3.12.0
jsonargparse[signatures]>=3.13.1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this the fix?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the fix is there.

54 changes: 54 additions & 0 deletions tests/utilities/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,26 @@
from argparse import Namespace
from contextlib import redirect_stdout
from io import StringIO
from typing import List, Optional
from unittest import mock

import pytest
import torch
import yaml
from packaging import version

from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.plugins.environments import SLURMEnvironment
from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities.cli import LightningArgumentParser, LightningCLI, SaveConfigCallback
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE
from tests.helpers import BoringDataModule, BoringModel

torchvision_version = version.parse('0')
if _TORCHVISION_AVAILABLE:
torchvision_version = version.parse(__import__('torchvision').__version__)


@mock.patch('argparse.ArgumentParser.parse_args')
def test_default_args(mock_argparse, tmpdir):
Expand Down Expand Up @@ -443,3 +451,49 @@ def __init__(
assert cli.model.submodule2 == cli.config_init['model']['submodule2']
assert isinstance(cli.config_init['model']['submodule1'], BoringModel)
assert isinstance(cli.config_init['model']['submodule2'], BoringModel)


@pytest.mark.skipif(torchvision_version < version.parse('0.8.0'), reason='torchvision>=0.8.0 is required')
def test_lightning_cli_torch_modules(tmpdir):

class TestModule(BoringModel):

def __init__(
self,
activation: torch.nn.Module = None,
transform: Optional[List[torch.nn.Module]] = None,
):
super().__init__()
self.activation = activation
self.transform = transform

config = """model:
activation:
class_path: torch.nn.LeakyReLU
init_args:
negative_slope: 0.2
transform:
- class_path: torchvision.transforms.Resize
init_args:
size: 64
- class_path: torchvision.transforms.CenterCrop
init_args:
size: 64
"""
config_path = tmpdir / 'config.yaml'
with open(config_path, 'w') as f:
f.write(config)

cli_args = [
f'--trainer.default_root_dir={tmpdir}',
'--trainer.max_epochs=1',
f'--config={str(config_path)}',
]

with mock.patch('sys.argv', ['any.py'] + cli_args):
cli = LightningCLI(TestModule)

assert isinstance(cli.model.activation, torch.nn.LeakyReLU)
assert cli.model.activation.negative_slope == 0.2
assert len(cli.model.transform) == 2
assert all(isinstance(v, torch.nn.Module) for v in cli.model.transform)