2121import pytest
2222import torch
2323from omegaconf import OmegaConf
24- from torch .utils .data import DataLoader
2524
2625from pytorch_lightning import LightningDataModule , Trainer
2726from pytorch_lightning .callbacks import ModelCheckpoint
28- from pytorch_lightning .trainer .supporters import CombinedLoader
2927from pytorch_lightning .utilities import AttributeDict
3028from pytorch_lightning .utilities .exceptions import MisconfigurationException
3129from pytorch_lightning .utilities .model_helpers import is_overridden
32- from tests .helpers import BoringDataModule , BoringModel , RandomDataset
30+ from tests .helpers import BoringDataModule , BoringModel
3331from tests .helpers .datamodules import ClassifDataModule
3432from tests .helpers .runif import RunIf
3533from tests .helpers .simple_models import ClassificationModel
@@ -566,14 +564,13 @@ class BoringDataModule1(LightningDataModule):
566564 batch_size : int
567565 dims : int = 2
568566
569- def train_dataloader (self ):
570- return DataLoader ( torch . randn ( self . batch_size * 2 , 10 ), batch_size = self .batch_size )
567+ def __post_init__ (self ):
568+ super (). __init__ ( dims = self .dims )
571569
572570 # asserts for the different dunder methods added by dataclass, when __init__ is implemented, i.e.
573571 # __repr__, __eq__, __lt__, __le__, etc.
574572 assert BoringDataModule1 (batch_size = 64 ).dims == 2
575573 assert BoringDataModule1 (batch_size = 32 )
576- assert len (BoringDataModule1 (batch_size = 32 )) == 2
577574 assert hasattr (BoringDataModule1 , "__repr__" )
578575 assert BoringDataModule1 (batch_size = 32 ) == BoringDataModule1 (batch_size = 32 )
579576
@@ -584,9 +581,7 @@ class BoringDataModule2(LightningDataModule):
584581
585582 # asserts for the different dunder methods added by dataclass, when super class is inherently initialized, i.e.
586583 # __init__, __repr__, __eq__, __lt__, __le__, etc.
587- assert BoringDataModule2 (batch_size = 32 ) is not None
588- assert BoringDataModule2 (batch_size = 32 ).batch_size == 32
589- assert len (BoringDataModule2 (batch_size = 32 )) == 0
584+ assert BoringDataModule2 (batch_size = 32 )
590585 assert hasattr (BoringDataModule2 , "__repr__" )
591586 assert BoringDataModule2 (batch_size = 32 ).prepare_data () is None
592587 assert BoringDataModule2 (batch_size = 32 ) == BoringDataModule2 (batch_size = 32 )
@@ -630,76 +625,3 @@ def test_inconsistent_prepare_data_per_node(tmpdir):
630625 trainer .model = model
631626 trainer .datamodule = dm
632627 trainer ._data_connector .prepare_data ()
633-
634-
635- DATALOADER = DataLoader (RandomDataset (1 , 32 ))
636-
637-
638- @pytest .mark .parametrize ("method_name" , ["train_dataloader" , "val_dataloader" , "test_dataloader" , "predict_dataloader" ])
639- @pytest .mark .parametrize (
640- ["dataloader" , "expected" ],
641- [
642- [DATALOADER , 32 ],
643- [[DATALOADER , DATALOADER ], 64 ],
644- [[[DATALOADER ], [DATALOADER , DATALOADER ]], 96 ],
645- [[{"foo" : DATALOADER }, {"foo" : DATALOADER , "bar" : DATALOADER }], 96 ],
646- [{"foo" : DATALOADER , "bar" : DATALOADER }, 64 ],
647- [{"foo" : {"foo" : DATALOADER }, "bar" : {"foo" : DATALOADER , "bar" : DATALOADER }}, 96 ],
648- [{"foo" : [DATALOADER ], "bar" : [DATALOADER , DATALOADER ]}, 96 ],
649- [CombinedLoader ({"foo" : DATALOADER , "bar" : DATALOADER }), 64 ],
650- ],
651- )
652- def test_len_different_types (method_name , dataloader , expected ):
653- dm = LightningDataModule ()
654- setattr (dm , method_name , lambda : dataloader )
655- assert len (dm ) == expected
656-
657-
658- @pytest .mark .parametrize ("method_name" , ["train_dataloader" , "val_dataloader" , "test_dataloader" , "predict_dataloader" ])
659- def test_len_dataloader_no_len (method_name ):
660- class CustomNotImplementedErrorDataloader (DataLoader ):
661- def __len__ (self ):
662- raise NotImplementedError
663-
664- dataloader = CustomNotImplementedErrorDataloader (RandomDataset (1 , 32 ))
665- dm = LightningDataModule ()
666- setattr (dm , method_name , lambda : dataloader )
667- with pytest .warns (UserWarning , match = f"The number of batches for a dataloader in `{ method_name } ` is counted as 0" ):
668- assert len (dm ) == 0
669-
670-
671- def test_len_all_dataloader_methods_implemented ():
672- class BoringDataModule (LightningDataModule ):
673- def __init__ (self , dataloader ):
674- super ().__init__ ()
675- self .dataloader = dataloader
676-
677- def train_dataloader (self ):
678- return {"foo" : self .dataloader , "bar" : self .dataloader }
679-
680- def val_dataloader (self ):
681- return self .dataloader
682-
683- def test_dataloader (self ):
684- return [self .dataloader ]
685-
686- def predict_dataloader (self ):
687- return [self .dataloader , self .dataloader ]
688-
689- dm = BoringDataModule (DATALOADER )
690-
691- # 6 dataloaders each producing 32 batches: 6 * 32 = 192
692- assert len (dm ) == 192
693-
694-
695- def test_len_no_dataloader_methods_implemented ():
696- dm = LightningDataModule ()
697- with pytest .warns (UserWarning , match = "You datamodule does not have any valid dataloader" ):
698- assert len (dm ) == 0
699-
700- dm .train_dataloader = None
701- dm .val_dataloader = None
702- dm .test_dataloader = None
703- dm .predict_dataloader = None
704- with pytest .warns (UserWarning , match = "You datamodule does not have any valid dataloader" ):
705- assert len (dm ) == 0
0 commit comments