|
20 | 20 | from argparse import Namespace |
21 | 21 | from contextlib import redirect_stdout |
22 | 22 | from io import StringIO |
| 23 | +from typing import List, Optional |
23 | 24 | from unittest import mock |
24 | 25 |
|
25 | 26 | import pytest |
| 27 | +import torch |
26 | 28 | import yaml |
| 29 | +from packaging import version |
27 | 30 |
|
28 | 31 | from pytorch_lightning import LightningDataModule, LightningModule, Trainer |
29 | 32 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint |
30 | 33 | from pytorch_lightning.plugins.environments import SLURMEnvironment |
31 | 34 | from pytorch_lightning.utilities import _TPU_AVAILABLE |
32 | 35 | from pytorch_lightning.utilities.cli import LightningArgumentParser, LightningCLI, SaveConfigCallback |
| 36 | +from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE |
33 | 37 | from tests.helpers import BoringDataModule, BoringModel |
34 | 38 |
|
| 39 | +torchvision_version = version.parse('0') |
| 40 | +if _TORCHVISION_AVAILABLE: |
| 41 | + torchvision_version = version.parse(__import__('torchvision').__version__) |
| 42 | + |
35 | 43 |
|
36 | 44 | @mock.patch('argparse.ArgumentParser.parse_args') |
37 | 45 | def test_default_args(mock_argparse, tmpdir): |
@@ -443,3 +451,49 @@ def __init__( |
443 | 451 | assert cli.model.submodule2 == cli.config_init['model']['submodule2'] |
444 | 452 | assert isinstance(cli.config_init['model']['submodule1'], BoringModel) |
445 | 453 | assert isinstance(cli.config_init['model']['submodule2'], BoringModel) |
| 454 | + |
| 455 | + |
| 456 | +@pytest.mark.skipif(torchvision_version < version.parse('0.8.0'), reason='torchvision>=0.8.0 is required') |
| 457 | +def test_lightning_cli_torch_modules(tmpdir): |
| 458 | + |
| 459 | + class TestModule(BoringModel): |
| 460 | + |
| 461 | + def __init__( |
| 462 | + self, |
| 463 | + activation: torch.nn.Module = None, |
| 464 | + transform: Optional[List[torch.nn.Module]] = None, |
| 465 | + ): |
| 466 | + super().__init__() |
| 467 | + self.activation = activation |
| 468 | + self.transform = transform |
| 469 | + |
| 470 | + config = """model: |
| 471 | + activation: |
| 472 | + class_path: torch.nn.LeakyReLU |
| 473 | + init_args: |
| 474 | + negative_slope: 0.2 |
| 475 | + transform: |
| 476 | + - class_path: torchvision.transforms.Resize |
| 477 | + init_args: |
| 478 | + size: 64 |
| 479 | + - class_path: torchvision.transforms.CenterCrop |
| 480 | + init_args: |
| 481 | + size: 64 |
| 482 | + """ |
| 483 | + config_path = tmpdir / 'config.yaml' |
| 484 | + with open(config_path, 'w') as f: |
| 485 | + f.write(config) |
| 486 | + |
| 487 | + cli_args = [ |
| 488 | + f'--trainer.default_root_dir={tmpdir}', |
| 489 | + '--trainer.max_epochs=1', |
| 490 | + f'--config={str(config_path)}', |
| 491 | + ] |
| 492 | + |
| 493 | + with mock.patch('sys.argv', ['any.py'] + cli_args): |
| 494 | + cli = LightningCLI(TestModule) |
| 495 | + |
| 496 | + assert isinstance(cli.model.activation, torch.nn.LeakyReLU) |
| 497 | + assert cli.model.activation.negative_slope == 0.2 |
| 498 | + assert len(cli.model.transform) == 2 |
| 499 | + assert all(isinstance(v, torch.nn.Module) for v in cli.model.transform) |
0 commit comments