From d6d5fe6a22526783463498067e2e3eafeb2d82e1 Mon Sep 17 00:00:00 2001 From: King Yiu Suen Date: Mon, 11 Oct 2021 11:47:02 -0500 Subject: [PATCH 01/11] Add support for len(datamodule) --- pytorch_lightning/core/datamodule.py | 25 +++++++++++++ tests/core/test_datamodules.py | 56 +++++++++++++++++++++++++++- 2 files changed, 80 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index f3a5c855fe07a..91d165f3cbe67 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -23,6 +23,8 @@ from pytorch_lightning.core.mixins import HyperparametersMixin from pytorch_lightning.utilities import rank_zero_deprecation from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types +from pytorch_lightning.utilities.data import has_len +from pytorch_lightning.utilities.distributed import rank_zero_warn class LightningDataModule(CheckpointHooks, DataHooks, HyperparametersMixin): @@ -481,3 +483,26 @@ def __getstate__(self) -> dict: for fn in ("prepare_data", "setup", "teardown"): del d[fn] return d + + def __len__(self) -> int: + """Returns the sum of the length of all dataloaders defined in DataModule.""" + def get_num_batches(dataloader): + if isinstance(dataloader, Sequence): + return sum(get_num_batches(dl) for dl in dataloader) + if isinstance(dataloader, Mapping): + return sum(get_num_batches(dl) for dl in dataloader.values()) + if not has_len(dataloader): + rank_zero_warn("`__len__` is not implemented for a `Dataloader`.") + return 0 + return len(dataloader) + + num_batches = 0 + for method_name in ("train_dataloader", "val_dataloader", "test_dataloader", "predict_dataloader"): + dataloader_method = getattr(self, method_name) + try: + dataloader = dataloader_method() + num_batches += get_num_batches(dataloader) + except NotImplementedError: + pass + + return num_batches diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 2f84032593472..6302b1fa98ca1 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -21,13 +21,15 @@ import pytest import torch from omegaconf import OmegaConf +from torch.utils.data import DataLoader from pytorch_lightning import LightningDataModule, Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.utilities import AttributeDict from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden -from tests.helpers import BoringDataModule, BoringModel +from tests.helpers import BoringDataModule, BoringModel, RandomDataset +from tests.helpers.dataloaders import CustomNotImplementedErrorDataloader from tests.helpers.datamodules import ClassifDataModule from tests.helpers.runif import RunIf from tests.helpers.simple_models import ClassificationModel @@ -625,3 +627,55 @@ def test_inconsistent_prepare_data_per_node(tmpdir): trainer.model = model trainer.datamodule = dm trainer.data_connector.prepare_data() + + +DATALOADER = DataLoader(RandomDataset(1, 32)) + + +@pytest.mark.parametrize("method_name", ["train_dataloader", "val_dataloader", "test_dataloader", "predict_dataloader"]) +@pytest.mark.parametrize( + ["dataloader", "expected"], + [ + [DATALOADER, 32], + [[DATALOADER, DATALOADER], 64], + [[[DATALOADER], [DATALOADER, DATALOADER]], 96], + [[{"foo": DATALOADER}, {"foo": DATALOADER, "bar": DATALOADER}], 96], + [{"foo": DATALOADER, "bar": DATALOADER}, 64], + [{"foo": {"foo": DATALOADER}, "bar": {"foo": DATALOADER, "bar": DATALOADER}}, 96], + [{"foo": [DATALOADER], "bar": [DATALOADER, DATALOADER]}, 96], + ], +) +def test_len_different_types(method_name, dataloader, expected): + dm = LightningDataModule() + setattr(dm, method_name, lambda: dataloader) + assert len(dm) == expected + + +@pytest.mark.parametrize("method_name", ["train_dataloader", "val_dataloader", "test_dataloader", "predict_dataloader"]) +def test_len_dataloader_no_len(method_name): + dataloader = CustomNotImplementedErrorDataloader(DATALOADER) + dm = LightningDataModule() + setattr(dm, method_name, lambda: dataloader) + with pytest.warns(UserWarning, match="`__len__` is not implemented for a `Dataloader`."): + assert len(dm) == 0 + + +def test_len_all_dataloader_methods_implemented(): + class BoringDataModule(LightningDataModule): + def __init__(self, dataloader): + self.dataloader = dataloader + + def train_dataloader(self): + return {"foo": self.dataloader, "bar": self.dataloader} + + def val_dataloader(self): + return self.dataloader + + def test_dataloader(self): + return [self.dataloader] + + def predict_dataloader(self): + return [self.dataloader, self.dataloader] + + dm = BoringDataModule(DATALOADER) + assert len(dm) == 192 From 32792be45110e9bf6159033fe3978db0e8a13ca1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 11 Oct 2021 16:51:08 +0000 Subject: [PATCH 02/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/core/datamodule.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 91d165f3cbe67..029ba06d89cf9 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -486,6 +486,7 @@ def __getstate__(self) -> dict: def __len__(self) -> int: """Returns the sum of the length of all dataloaders defined in DataModule.""" + def get_num_batches(dataloader): if isinstance(dataloader, Sequence): return sum(get_num_batches(dl) for dl in dataloader) From 46c10dee8f2827dc9550d5a3689bcc70291cbe9f Mon Sep 17 00:00:00 2001 From: King Yiu Suen Date: Mon, 11 Oct 2021 11:57:15 -0500 Subject: [PATCH 03/11] Update --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index fa206f3784c49..7515e9a40d980 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -177,6 +177,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for `torch.autograd.set_detect_anomaly` through `Trainer` constructor argument `detect_anomaly` ([#9848](https://github.com/PyTorchLightning/pytorch-lightning/pull/9848)) +- Added a `len` method to `LightningDataModule` ([#9895](https://github.com/PyTorchLightning/pytorch-lightning/pull/9895)) ### Changed From b872802beda01842d9bef77255dad58e11152e49 Mon Sep 17 00:00:00 2001 From: King Yiu Suen Date: Mon, 11 Oct 2021 11:57:23 -0500 Subject: [PATCH 04/11] Update --- pytorch_lightning/core/datamodule.py | 2 -- tests/core/test_datamodules.py | 1 + 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 029ba06d89cf9..0917528d116cd 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -485,8 +485,6 @@ def __getstate__(self) -> dict: return d def __len__(self) -> int: - """Returns the sum of the length of all dataloaders defined in DataModule.""" - def get_num_batches(dataloader): if isinstance(dataloader, Sequence): return sum(get_num_batches(dl) for dl in dataloader) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 6302b1fa98ca1..5a1a29b8d7a46 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -663,6 +663,7 @@ def test_len_dataloader_no_len(method_name): def test_len_all_dataloader_methods_implemented(): class BoringDataModule(LightningDataModule): def __init__(self, dataloader): + super().__init__() self.dataloader = dataloader def train_dataloader(self): From 8a2330753e460dc3491efe0b9cfe28e4ef6f85a5 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 12 Oct 2021 09:48:57 +0100 Subject: [PATCH 05/11] resolve test --- pytorch_lightning/core/datamodule.py | 20 +++++++------------- tests/core/test_datamodules.py | 9 ++++++--- 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 0917528d116cd..98130a585d2b6 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -22,9 +22,8 @@ from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks from pytorch_lightning.core.mixins import HyperparametersMixin from pytorch_lightning.utilities import rank_zero_deprecation +from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types -from pytorch_lightning.utilities.data import has_len -from pytorch_lightning.utilities.distributed import rank_zero_warn class LightningDataModule(CheckpointHooks, DataHooks, HyperparametersMixin): @@ -485,22 +484,17 @@ def __getstate__(self) -> dict: return d def __len__(self) -> int: - def get_num_batches(dataloader): - if isinstance(dataloader, Sequence): - return sum(get_num_batches(dl) for dl in dataloader) - if isinstance(dataloader, Mapping): - return sum(get_num_batches(dl) for dl in dataloader.values()) - if not has_len(dataloader): - rank_zero_warn("`__len__` is not implemented for a `Dataloader`.") - return 0 - return len(dataloader) - num_batches = 0 + + def get_num_batches(dataloader: DataLoader) -> None: + nonlocal num_batches + num_batches += len(dataloader) + for method_name in ("train_dataloader", "val_dataloader", "test_dataloader", "predict_dataloader"): dataloader_method = getattr(self, method_name) try: dataloader = dataloader_method() - num_batches += get_num_batches(dataloader) + apply_to_collection(dataloader, DataLoader, get_num_batches) except NotImplementedError: pass diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 5a1a29b8d7a46..44319f088c4db 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -566,13 +566,14 @@ class BoringDataModule1(LightningDataModule): batch_size: int dims: int = 2 - def __post_init__(self): - super().__init__(dims=self.dims) + def train_dataloader(self): + return DataLoader(torch.randn(self.batch_size * 2, 10), batch_size=self.batch_size) # asserts for the different dunder methods added by dataclass, when __init__ is implemented, i.e. # __repr__, __eq__, __lt__, __le__, etc. assert BoringDataModule1(batch_size=64).dims == 2 assert BoringDataModule1(batch_size=32) + assert len(BoringDataModule1(batch_size=32)) == 2 assert hasattr(BoringDataModule1, "__repr__") assert BoringDataModule1(batch_size=32) == BoringDataModule1(batch_size=32) @@ -583,7 +584,9 @@ class BoringDataModule2(LightningDataModule): # asserts for the different dunder methods added by dataclass, when super class is inherently initialized, i.e. # __init__, __repr__, __eq__, __lt__, __le__, etc. - assert BoringDataModule2(batch_size=32) + assert BoringDataModule2(batch_size=32) is not None + assert BoringDataModule2(batch_size=32).batch_size == 32 + assert len(BoringDataModule2(batch_size=32)) == 0 assert hasattr(BoringDataModule2, "__repr__") assert BoringDataModule2(batch_size=32).prepare_data() is None assert BoringDataModule2(batch_size=32) == BoringDataModule2(batch_size=32) From f91f8ae90a74193204c5a7e9c0263caa68da08d1 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 12 Oct 2021 15:08:46 +0100 Subject: [PATCH 06/11] update --- pytorch_lightning/core/datamodule.py | 13 +++++++++++-- tests/core/test_datamodules.py | 3 +-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 98130a585d2b6..89281f6b1f074 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -24,6 +24,7 @@ from pytorch_lightning.utilities import rank_zero_deprecation from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types +from pytorch_lightning.utilities.warnings import _warn class LightningDataModule(CheckpointHooks, DataHooks, HyperparametersMixin): @@ -484,11 +485,15 @@ def __getstate__(self) -> dict: return d def __len__(self) -> int: - num_batches = 0 + num_batches = None def get_num_batches(dataloader: DataLoader) -> None: nonlocal num_batches - num_batches += len(dataloader) + L = len(dataloader) + if num_batches is None: + num_batches = L + else: + num_batches += L for method_name in ("train_dataloader", "val_dataloader", "test_dataloader", "predict_dataloader"): dataloader_method = getattr(self, method_name) @@ -498,4 +503,8 @@ def get_num_batches(dataloader: DataLoader) -> None: except NotImplementedError: pass + if not num_batches: + _warn("You datamodule didn't find any valid dataloader and the `__len__` will be returned as 0.") + return 0 + return num_batches diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 44319f088c4db..7b815b78e9474 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -659,8 +659,7 @@ def test_len_dataloader_no_len(method_name): dataloader = CustomNotImplementedErrorDataloader(DATALOADER) dm = LightningDataModule() setattr(dm, method_name, lambda: dataloader) - with pytest.warns(UserWarning, match="`__len__` is not implemented for a `Dataloader`."): - assert len(dm) == 0 + assert len(dm) == 0 def test_len_all_dataloader_methods_implemented(): From e96ce0f3cf53a713ac56a02b162b11034457928e Mon Sep 17 00:00:00 2001 From: King Yiu Suen Date: Wed, 13 Oct 2021 15:23:54 -0500 Subject: [PATCH 07/11] Add support for CombinedLoader, add docstring, etc --- pytorch_lightning/core/datamodule.py | 29 +++++++++++++++++----------- tests/core/test_datamodules.py | 13 ++++++++++++- tests/helpers/dataloaders.py | 5 ++++- 3 files changed, 34 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 89281f6b1f074..df6275681c3fd 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -24,7 +24,8 @@ from pytorch_lightning.utilities import rank_zero_deprecation from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.argparse import add_argparse_args, from_argparse_args, get_init_arguments_and_types -from pytorch_lightning.utilities.warnings import _warn +from pytorch_lightning.utilities.data import has_len +from pytorch_lightning.utilities.warnings import rank_zero_warn class LightningDataModule(CheckpointHooks, DataHooks, HyperparametersMixin): @@ -485,26 +486,32 @@ def __getstate__(self) -> dict: return d def __len__(self) -> int: - num_batches = None + """Returns the total number of batches in all dataloaders defined in the datamodule.""" + + from pytorch_lightning.trainer.supporters import CombinedLoader + + num_batches = 0 + not_implemented_count = 0 def get_num_batches(dataloader: DataLoader) -> None: nonlocal num_batches - L = len(dataloader) - if num_batches is None: - num_batches = L - else: - num_batches += L + if not has_len(dataloader): + rank_zero_warn( + "The number of batches for a dataloader is counted as 0 because it does not have `__len__` defined." + ) + num_batches += len(dataloader) for method_name in ("train_dataloader", "val_dataloader", "test_dataloader", "predict_dataloader"): dataloader_method = getattr(self, method_name) try: dataloader = dataloader_method() + if isinstance(dataloader, CombinedLoader): + dataloader = dataloader.loaders apply_to_collection(dataloader, DataLoader, get_num_batches) except NotImplementedError: - pass + not_implemented_count += 1 - if not num_batches: - _warn("You datamodule didn't find any valid dataloader and the `__len__` will be returned as 0.") - return 0 + if not_implemented_count == 4: + rank_zero_warn("You datamodule does not have any valid dataloader so `__len__` will be returned as 0.") return num_batches diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 7b815b78e9474..2d136503ee184 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -25,6 +25,7 @@ from pytorch_lightning import LightningDataModule, Trainer from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.trainer.supporters import CombinedLoader from pytorch_lightning.utilities import AttributeDict from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden @@ -646,6 +647,7 @@ def test_inconsistent_prepare_data_per_node(tmpdir): [{"foo": DATALOADER, "bar": DATALOADER}, 64], [{"foo": {"foo": DATALOADER}, "bar": {"foo": DATALOADER, "bar": DATALOADER}}, 96], [{"foo": [DATALOADER], "bar": [DATALOADER, DATALOADER]}, 96], + [CombinedLoader({"foo": DATALOADER, "bar": DATALOADER}), 64], ], ) def test_len_different_types(method_name, dataloader, expected): @@ -659,7 +661,8 @@ def test_len_dataloader_no_len(method_name): dataloader = CustomNotImplementedErrorDataloader(DATALOADER) dm = LightningDataModule() setattr(dm, method_name, lambda: dataloader) - assert len(dm) == 0 + with pytest.warns(UserWarning, match="The number of batches for a dataloader is counted as 0"): + assert len(dm) == 0 def test_len_all_dataloader_methods_implemented(): @@ -681,4 +684,12 @@ def predict_dataloader(self): return [self.dataloader, self.dataloader] dm = BoringDataModule(DATALOADER) + + # 6 dataloaders each producing 32 batches: 6 * 32 = 192 assert len(dm) == 192 + + +def test_len_no_dataloader_methods_implemented(): + dm = LightningDataModule() + with pytest.warns(UserWarning, match="You datamodule does not have any valid dataloader"): + assert len(dm) == 0 diff --git a/tests/helpers/dataloaders.py b/tests/helpers/dataloaders.py index 14dde1c8424b2..f3c60e150372e 100644 --- a/tests/helpers/dataloaders.py +++ b/tests/helpers/dataloaders.py @@ -13,9 +13,12 @@ # limitations under the License. """Custom dataloaders for testing.""" +from torch.utils.data import DataLoader -class CustomInfDataloader: + +class CustomInfDataloader(DataLoader): def __init__(self, dataloader): + super().__init__(dataloader.dataset) self.dataloader = dataloader self.iter = iter(dataloader) self.count = 0 From 07278beac75610a6d67ac26d618d62e8a6f95b1a Mon Sep 17 00:00:00 2001 From: King Yiu Suen Date: Wed, 13 Oct 2021 15:30:21 -0500 Subject: [PATCH 08/11] Fix conditional statement --- pytorch_lightning/core/datamodule.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index df6275681c3fd..46936b352ebf6 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -499,7 +499,8 @@ def get_num_batches(dataloader: DataLoader) -> None: rank_zero_warn( "The number of batches for a dataloader is counted as 0 because it does not have `__len__` defined." ) - num_batches += len(dataloader) + else: + num_batches += len(dataloader) for method_name in ("train_dataloader", "val_dataloader", "test_dataloader", "predict_dataloader"): dataloader_method = getattr(self, method_name) From 74345b760523d726f76bc8e49c6670c06c4b90e9 Mon Sep 17 00:00:00 2001 From: King Yiu Suen Date: Wed, 13 Oct 2021 16:06:59 -0500 Subject: [PATCH 09/11] Fix unit tests --- tests/core/test_datamodules.py | 6 +++++- tests/helpers/dataloaders.py | 5 +---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 2d136503ee184..a71769c307526 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -658,7 +658,11 @@ def test_len_different_types(method_name, dataloader, expected): @pytest.mark.parametrize("method_name", ["train_dataloader", "val_dataloader", "test_dataloader", "predict_dataloader"]) def test_len_dataloader_no_len(method_name): - dataloader = CustomNotImplementedErrorDataloader(DATALOADER) + class CustomNotImplementedErrorDataloader(DataLoader): + def __len__(self): + raise NotImplementedError + + dataloader = CustomNotImplementedErrorDataloader(RandomDataset(1, 32)) dm = LightningDataModule() setattr(dm, method_name, lambda: dataloader) with pytest.warns(UserWarning, match="The number of batches for a dataloader is counted as 0"): diff --git a/tests/helpers/dataloaders.py b/tests/helpers/dataloaders.py index f3c60e150372e..14dde1c8424b2 100644 --- a/tests/helpers/dataloaders.py +++ b/tests/helpers/dataloaders.py @@ -13,12 +13,9 @@ # limitations under the License. """Custom dataloaders for testing.""" -from torch.utils.data import DataLoader - -class CustomInfDataloader(DataLoader): +class CustomInfDataloader: def __init__(self, dataloader): - super().__init__(dataloader.dataset) self.dataloader = dataloader self.iter = iter(dataloader) self.count = 0 From d961df5eae8d5aa1dcbb18fa4b5e05550015c3bd Mon Sep 17 00:00:00 2001 From: King Yiu Suen Date: Wed, 13 Oct 2021 16:31:54 -0500 Subject: [PATCH 10/11] Fix flake8 issue --- tests/core/test_datamodules.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index a71769c307526..5890a9c49fd96 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -30,7 +30,6 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden from tests.helpers import BoringDataModule, BoringModel, RandomDataset -from tests.helpers.dataloaders import CustomNotImplementedErrorDataloader from tests.helpers.datamodules import ClassifDataModule from tests.helpers.runif import RunIf from tests.helpers.simple_models import ClassificationModel From 1f581eb07be99562be9126056969962fb5fa5e71 Mon Sep 17 00:00:00 2001 From: King Yiu Suen Date: Thu, 14 Oct 2021 13:00:41 -0500 Subject: [PATCH 11/11] Add dataloader name in warning and wrap only dataloader_method in try --- pytorch_lightning/core/datamodule.py | 12 +++++++----- tests/core/test_datamodules.py | 2 +- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index 46936b352ebf6..8131bbb896bbe 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -493,11 +493,12 @@ def __len__(self) -> int: num_batches = 0 not_implemented_count = 0 - def get_num_batches(dataloader: DataLoader) -> None: + def get_num_batches(dataloader: DataLoader, name: str) -> None: nonlocal num_batches if not has_len(dataloader): rank_zero_warn( - "The number of batches for a dataloader is counted as 0 because it does not have `__len__` defined." + f"The number of batches for a dataloader in `{name}` is counted as 0 " + "because it does not have `__len__` defined." ) else: num_batches += len(dataloader) @@ -506,11 +507,12 @@ def get_num_batches(dataloader: DataLoader) -> None: dataloader_method = getattr(self, method_name) try: dataloader = dataloader_method() - if isinstance(dataloader, CombinedLoader): - dataloader = dataloader.loaders - apply_to_collection(dataloader, DataLoader, get_num_batches) except NotImplementedError: not_implemented_count += 1 + continue + if isinstance(dataloader, CombinedLoader): + dataloader = dataloader.loaders + apply_to_collection(dataloader, DataLoader, get_num_batches, method_name) if not_implemented_count == 4: rank_zero_warn("You datamodule does not have any valid dataloader so `__len__` will be returned as 0.") diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 5890a9c49fd96..eaebb922a936d 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -664,7 +664,7 @@ def __len__(self): dataloader = CustomNotImplementedErrorDataloader(RandomDataset(1, 32)) dm = LightningDataModule() setattr(dm, method_name, lambda: dataloader) - with pytest.warns(UserWarning, match="The number of batches for a dataloader is counted as 0"): + with pytest.warns(UserWarning, match=f"The number of batches for a dataloader in `{method_name}` is counted as 0"): assert len(dm) == 0