2121import pytest
2222import torch
2323from omegaconf import OmegaConf
24+ from torch .utils .data import DataLoader
2425
2526from pytorch_lightning import LightningDataModule , Trainer
2627from pytorch_lightning .callbacks import ModelCheckpoint
28+ from pytorch_lightning .trainer .supporters import CombinedLoader
2729from pytorch_lightning .utilities import AttributeDict
2830from pytorch_lightning .utilities .exceptions import MisconfigurationException
2931from pytorch_lightning .utilities .model_helpers import is_overridden
30- from tests .helpers import BoringDataModule , BoringModel
32+ from tests .helpers import BoringDataModule , BoringModel , RandomDataset
3133from tests .helpers .datamodules import ClassifDataModule
3234from tests .helpers .runif import RunIf
3335from tests .helpers .simple_models import ClassificationModel
@@ -564,13 +566,14 @@ class BoringDataModule1(LightningDataModule):
564566 batch_size : int
565567 dims : int = 2
566568
567- def __post_init__ (self ):
568- super (). __init__ ( dims = self .dims )
569+ def train_dataloader (self ):
570+ return DataLoader ( torch . randn ( self . batch_size * 2 , 10 ), batch_size = self .batch_size )
569571
570572 # asserts for the different dunder methods added by dataclass, when __init__ is implemented, i.e.
571573 # __repr__, __eq__, __lt__, __le__, etc.
572574 assert BoringDataModule1 (batch_size = 64 ).dims == 2
573575 assert BoringDataModule1 (batch_size = 32 )
576+ assert len (BoringDataModule1 (batch_size = 32 )) == 2
574577 assert hasattr (BoringDataModule1 , "__repr__" )
575578 assert BoringDataModule1 (batch_size = 32 ) == BoringDataModule1 (batch_size = 32 )
576579
@@ -581,7 +584,9 @@ class BoringDataModule2(LightningDataModule):
581584
582585 # asserts for the different dunder methods added by dataclass, when super class is inherently initialized, i.e.
583586 # __init__, __repr__, __eq__, __lt__, __le__, etc.
584- assert BoringDataModule2 (batch_size = 32 )
587+ assert BoringDataModule2 (batch_size = 32 ) is not None
588+ assert BoringDataModule2 (batch_size = 32 ).batch_size == 32
589+ assert len (BoringDataModule2 (batch_size = 32 )) == 0
585590 assert hasattr (BoringDataModule2 , "__repr__" )
586591 assert BoringDataModule2 (batch_size = 32 ).prepare_data () is None
587592 assert BoringDataModule2 (batch_size = 32 ) == BoringDataModule2 (batch_size = 32 )
@@ -625,3 +630,69 @@ def test_inconsistent_prepare_data_per_node(tmpdir):
625630 trainer .model = model
626631 trainer .datamodule = dm
627632 trainer .data_connector .prepare_data ()
633+
634+
635+ DATALOADER = DataLoader (RandomDataset (1 , 32 ))
636+
637+
638+ @pytest .mark .parametrize ("method_name" , ["train_dataloader" , "val_dataloader" , "test_dataloader" , "predict_dataloader" ])
639+ @pytest .mark .parametrize (
640+ ["dataloader" , "expected" ],
641+ [
642+ [DATALOADER , 32 ],
643+ [[DATALOADER , DATALOADER ], 64 ],
644+ [[[DATALOADER ], [DATALOADER , DATALOADER ]], 96 ],
645+ [[{"foo" : DATALOADER }, {"foo" : DATALOADER , "bar" : DATALOADER }], 96 ],
646+ [{"foo" : DATALOADER , "bar" : DATALOADER }, 64 ],
647+ [{"foo" : {"foo" : DATALOADER }, "bar" : {"foo" : DATALOADER , "bar" : DATALOADER }}, 96 ],
648+ [{"foo" : [DATALOADER ], "bar" : [DATALOADER , DATALOADER ]}, 96 ],
649+ [CombinedLoader ({"foo" : DATALOADER , "bar" : DATALOADER }), 64 ],
650+ ],
651+ )
652+ def test_len_different_types (method_name , dataloader , expected ):
653+ dm = LightningDataModule ()
654+ setattr (dm , method_name , lambda : dataloader )
655+ assert len (dm ) == expected
656+
657+
658+ @pytest .mark .parametrize ("method_name" , ["train_dataloader" , "val_dataloader" , "test_dataloader" , "predict_dataloader" ])
659+ def test_len_dataloader_no_len (method_name ):
660+ class CustomNotImplementedErrorDataloader (DataLoader ):
661+ def __len__ (self ):
662+ raise NotImplementedError
663+
664+ dataloader = CustomNotImplementedErrorDataloader (RandomDataset (1 , 32 ))
665+ dm = LightningDataModule ()
666+ setattr (dm , method_name , lambda : dataloader )
667+ with pytest .warns (UserWarning , match = f"The number of batches for a dataloader in `{ method_name } ` is counted as 0" ):
668+ assert len (dm ) == 0
669+
670+
671+ def test_len_all_dataloader_methods_implemented ():
672+ class BoringDataModule (LightningDataModule ):
673+ def __init__ (self , dataloader ):
674+ super ().__init__ ()
675+ self .dataloader = dataloader
676+
677+ def train_dataloader (self ):
678+ return {"foo" : self .dataloader , "bar" : self .dataloader }
679+
680+ def val_dataloader (self ):
681+ return self .dataloader
682+
683+ def test_dataloader (self ):
684+ return [self .dataloader ]
685+
686+ def predict_dataloader (self ):
687+ return [self .dataloader , self .dataloader ]
688+
689+ dm = BoringDataModule (DATALOADER )
690+
691+ # 6 dataloaders each producing 32 batches: 6 * 32 = 192
692+ assert len (dm ) == 192
693+
694+
695+ def test_len_no_dataloader_methods_implemented ():
696+ dm = LightningDataModule ()
697+ with pytest .warns (UserWarning , match = "You datamodule does not have any valid dataloader" ):
698+ assert len (dm ) == 0
0 commit comments