diff --git a/CHANGELOG.md b/CHANGELOG.md index fa206f3784c49..7515e9a40d980 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -177,6 +177,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for `torch.autograd.set_detect_anomaly` through `Trainer` constructor argument `detect_anomaly` ([#9848](https://github.com/PyTorchLightning/pytorch-lightning/pull/9848)) +- Added a `len` method to `LightningDataModule` ([#9895](https://github.com/PyTorchLightning/pytorch-lightning/pull/9895)) ### Changed diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index f3a5c855fe07a..8131bbb896bbe 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -22,7 +22,10 @@ from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks from pytorch_lightning.core.mixins import HyperparametersMixin from pytorch_lightning.utilities import rank_zero_deprecation +from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types +from pytorch_lightning.utilities.data import has_len +from pytorch_lightning.utilities.warnings import rank_zero_warn class LightningDataModule(CheckpointHooks, DataHooks, HyperparametersMixin): @@ -481,3 +484,37 @@ def __getstate__(self) -> dict: for fn in ("prepare_data", "setup", "teardown"): del d[fn] return d + + def __len__(self) -> int: + """Returns the total number of batches in all dataloaders defined in the datamodule.""" + + from pytorch_lightning.trainer.supporters import CombinedLoader + + num_batches = 0 + not_implemented_count = 0 + + def get_num_batches(dataloader: DataLoader, name: str) -> None: + nonlocal num_batches + if not has_len(dataloader): + rank_zero_warn( + f"The number of batches for a dataloader in `{name}` is counted as 0 " + "because it does not have `__len__` defined." + ) + else: + num_batches += len(dataloader) + + for method_name in ("train_dataloader", "val_dataloader", "test_dataloader", "predict_dataloader"): + dataloader_method = getattr(self, method_name) + try: + dataloader = dataloader_method() + except NotImplementedError: + not_implemented_count += 1 + continue + if isinstance(dataloader, CombinedLoader): + dataloader = dataloader.loaders + apply_to_collection(dataloader, DataLoader, get_num_batches, method_name) + + if not_implemented_count == 4: + rank_zero_warn("You datamodule does not have any valid dataloader so `__len__` will be returned as 0.") + + return num_batches diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 2f84032593472..eaebb922a936d 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -21,13 +21,15 @@ import pytest import torch from omegaconf import OmegaConf +from torch.utils.data import DataLoader from pytorch_lightning import LightningDataModule, Trainer from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.trainer.supporters import CombinedLoader from pytorch_lightning.utilities import AttributeDict from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden -from tests.helpers import BoringDataModule, BoringModel +from tests.helpers import BoringDataModule, BoringModel, RandomDataset from tests.helpers.datamodules import ClassifDataModule from tests.helpers.runif import RunIf from tests.helpers.simple_models import ClassificationModel @@ -564,13 +566,14 @@ class BoringDataModule1(LightningDataModule): batch_size: int dims: int = 2 - def __post_init__(self): - super().__init__(dims=self.dims) + def train_dataloader(self): + return DataLoader(torch.randn(self.batch_size * 2, 10), batch_size=self.batch_size) # asserts for the different dunder methods added by dataclass, when __init__ is implemented, i.e. # __repr__, __eq__, __lt__, __le__, etc. assert BoringDataModule1(batch_size=64).dims == 2 assert BoringDataModule1(batch_size=32) + assert len(BoringDataModule1(batch_size=32)) == 2 assert hasattr(BoringDataModule1, "__repr__") assert BoringDataModule1(batch_size=32) == BoringDataModule1(batch_size=32) @@ -581,7 +584,9 @@ class BoringDataModule2(LightningDataModule): # asserts for the different dunder methods added by dataclass, when super class is inherently initialized, i.e. # __init__, __repr__, __eq__, __lt__, __le__, etc. - assert BoringDataModule2(batch_size=32) + assert BoringDataModule2(batch_size=32) is not None + assert BoringDataModule2(batch_size=32).batch_size == 32 + assert len(BoringDataModule2(batch_size=32)) == 0 assert hasattr(BoringDataModule2, "__repr__") assert BoringDataModule2(batch_size=32).prepare_data() is None assert BoringDataModule2(batch_size=32) == BoringDataModule2(batch_size=32) @@ -625,3 +630,69 @@ def test_inconsistent_prepare_data_per_node(tmpdir): trainer.model = model trainer.datamodule = dm trainer.data_connector.prepare_data() + + +DATALOADER = DataLoader(RandomDataset(1, 32)) + + +@pytest.mark.parametrize("method_name", ["train_dataloader", "val_dataloader", "test_dataloader", "predict_dataloader"]) +@pytest.mark.parametrize( + ["dataloader", "expected"], + [ + [DATALOADER, 32], + [[DATALOADER, DATALOADER], 64], + [[[DATALOADER], [DATALOADER, DATALOADER]], 96], + [[{"foo": DATALOADER}, {"foo": DATALOADER, "bar": DATALOADER}], 96], + [{"foo": DATALOADER, "bar": DATALOADER}, 64], + [{"foo": {"foo": DATALOADER}, "bar": {"foo": DATALOADER, "bar": DATALOADER}}, 96], + [{"foo": [DATALOADER], "bar": [DATALOADER, DATALOADER]}, 96], + [CombinedLoader({"foo": DATALOADER, "bar": DATALOADER}), 64], + ], +) +def test_len_different_types(method_name, dataloader, expected): + dm = LightningDataModule() + setattr(dm, method_name, lambda: dataloader) + assert len(dm) == expected + + +@pytest.mark.parametrize("method_name", ["train_dataloader", "val_dataloader", "test_dataloader", "predict_dataloader"]) +def test_len_dataloader_no_len(method_name): + class CustomNotImplementedErrorDataloader(DataLoader): + def __len__(self): + raise NotImplementedError + + dataloader = CustomNotImplementedErrorDataloader(RandomDataset(1, 32)) + dm = LightningDataModule() + setattr(dm, method_name, lambda: dataloader) + with pytest.warns(UserWarning, match=f"The number of batches for a dataloader in `{method_name}` is counted as 0"): + assert len(dm) == 0 + + +def test_len_all_dataloader_methods_implemented(): + class BoringDataModule(LightningDataModule): + def __init__(self, dataloader): + super().__init__() + self.dataloader = dataloader + + def train_dataloader(self): + return {"foo": self.dataloader, "bar": self.dataloader} + + def val_dataloader(self): + return self.dataloader + + def test_dataloader(self): + return [self.dataloader] + + def predict_dataloader(self): + return [self.dataloader, self.dataloader] + + dm = BoringDataModule(DATALOADER) + + # 6 dataloaders each producing 32 batches: 6 * 32 = 192 + assert len(dm) == 192 + + +def test_len_no_dataloader_methods_implemented(): + dm = LightningDataModule() + with pytest.warns(UserWarning, match="You datamodule does not have any valid dataloader"): + assert len(dm) == 0