14
14
import multiprocessing
15
15
import os
16
16
from dataclasses import dataclass , field
17
- from typing import Any , Collection , List , Optional , Tuple , Union
17
+ from typing import Any , Iterable , List , Optional , Tuple , Union
18
18
from weakref import proxy
19
19
20
20
from torch .utils .data import BatchSampler , DataLoader , Sampler , SequentialSampler
@@ -55,7 +55,7 @@ def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_
55
55
self ._test_dataloader_source = _DataLoaderSource (None , "" )
56
56
self ._predict_dataloader_source = _DataLoaderSource (None , "" )
57
57
58
- self ._datahook_selector = _DataHookSelector ( None , None )
58
+ self ._datahook_selector : Optional [ _DataHookSelector ] = None
59
59
60
60
@property
61
61
def _should_reload_train_dl (self ) -> bool :
@@ -230,7 +230,7 @@ def _worker_check(self, dataloader: DataLoader, name: str) -> None:
230
230
category = PossibleUserWarning ,
231
231
)
232
232
233
- def _requires_distributed_sampler (self , dataloader ) -> bool :
233
+ def _requires_distributed_sampler (self , dataloader : DataLoader ) -> bool :
234
234
return (
235
235
self .trainer ._accelerator_connector .replace_sampler_ddp
236
236
and self .trainer ._accelerator_connector .is_distributed
@@ -292,14 +292,18 @@ def _prepare_dataloader(
292
292
293
293
return dataloader
294
294
295
- def _resolve_sampler (self , dataloader : DataLoader , shuffle : bool , mode : Optional [RunningStage ] = None ) -> Sampler :
295
+ def _resolve_sampler (
296
+ self , dataloader : DataLoader , shuffle : bool , mode : Optional [RunningStage ] = None
297
+ ) -> Union [Sampler , Iterable ]:
296
298
if self ._requires_distributed_sampler (dataloader ):
299
+ distributed_sampler_kwargs = self .trainer .distributed_sampler_kwargs
300
+ assert distributed_sampler_kwargs is not None
297
301
sampler = self ._get_distributed_sampler (
298
302
dataloader ,
299
303
shuffle ,
300
304
mode = mode ,
301
305
overfit_batches = self .trainer .overfit_batches ,
302
- ** self . trainer . distributed_sampler_kwargs ,
306
+ ** distributed_sampler_kwargs ,
303
307
)
304
308
305
309
# update docs too once this is resolved
@@ -357,7 +361,7 @@ def _reset_eval_dataloader(
357
361
dataloaders = self ._resolve_overfit_batches (dataloaders , mode )
358
362
359
363
if not isinstance (dataloaders , list ):
360
- dataloaders = [dataloaders ]
364
+ dataloaders = [dataloaders ] # type: ignore[assignment]
361
365
362
366
if any (dl is None for dl in dataloaders ):
363
367
rank_zero_warn ("One of given dataloaders is None and it will be skipped." )
@@ -426,7 +430,7 @@ def _reset_eval_dataloader(
426
430
427
431
return loader_num_batches , dataloaders
428
432
429
- def _request_dataloader (self , stage : RunningStage ) -> Union [ DataLoader , List [ DataLoader ]] :
433
+ def _request_dataloader (self , stage : RunningStage ) -> TRAIN_DATALOADERS :
430
434
"""Requests a dataloader from the given model by calling dataloader hooks corresponding to the given stage.
431
435
432
436
Returns:
@@ -447,10 +451,12 @@ def _request_dataloader(self, stage: RunningStage) -> Union[DataLoader, List[Dat
447
451
return dataloader
448
452
449
453
@staticmethod
450
- def _resolve_overfit_batches (dataloaders : Collection [DataLoader ], mode : RunningStage ) -> Collection [DataLoader ]:
454
+ def _resolve_overfit_batches (
455
+ dataloaders : Union [TRAIN_DATALOADERS , EVAL_DATALOADERS ], mode : RunningStage
456
+ ) -> Union [TRAIN_DATALOADERS , EVAL_DATALOADERS ]:
451
457
all_have_sequential_sampler = True
452
458
453
- def resolve_has_no_sequential_sampler (dataloader : DataLoader ):
459
+ def resolve_has_no_sequential_sampler (dataloader : DataLoader ) -> None :
454
460
nonlocal all_have_sequential_sampler
455
461
all_have_sequential_sampler = all_have_sequential_sampler & isinstance (
456
462
dataloader .sampler , SequentialSampler
@@ -460,19 +466,23 @@ def resolve_has_no_sequential_sampler(dataloader: DataLoader):
460
466
461
467
if not all_have_sequential_sampler :
462
468
rank_zero_warn (
463
- "You requested to overfit but enabled training dataloader shuffling."
469
+ f "You requested to overfit but enabled { mode . dataloader_prefix } dataloader shuffling."
464
470
f" We are turning off the { mode .dataloader_prefix } dataloader shuffling for you."
465
471
)
466
472
467
473
def replace_sampler (dataloader : DataLoader ) -> DataLoader :
468
- return _update_dataloader (dataloader , sampler = SequentialSampler (dataloader .dataset ), mode = mode )
474
+ return _update_dataloader (
475
+ dataloader ,
476
+ sampler = SequentialSampler (dataloader .dataset ), # type: ignore[arg-type]
477
+ mode = mode ,
478
+ )
469
479
470
480
dataloaders = apply_to_collection (dataloaders , DataLoader , replace_sampler )
471
481
472
482
return dataloaders
473
483
474
484
@staticmethod
475
- def _check_eval_shuffling (dataloader , mode ) :
485
+ def _check_eval_shuffling (dataloader : DataLoader , mode : RunningStage ) -> None :
476
486
# limit this warning only for samplers assigned automatically when shuffle is set
477
487
if _is_dataloader_shuffled (dataloader ):
478
488
rank_zero_warn (
@@ -506,18 +516,14 @@ def dataloader(self) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]:
506
516
507
517
If the source is a module, the method with the corresponding :attr:`name` gets called.
508
518
"""
509
- from pytorch_lightning import LightningDataModule , LightningModule # prevent cyclic import
510
-
511
- if not self .name :
512
- return self .instance
513
-
514
- if isinstance (self .instance , LightningModule ):
519
+ if isinstance (self .instance , pl .LightningModule ):
515
520
return self .instance .trainer ._call_lightning_module_hook (self .name , pl_module = self .instance )
516
521
517
- if isinstance (self .instance , LightningDataModule ):
522
+ if isinstance (self .instance , pl . LightningDataModule ):
518
523
method = getattr (self .instance , self .name )
519
524
return method ()
520
525
526
+ assert self .instance is not None
521
527
return self .instance
522
528
523
529
def is_defined (self ) -> bool :
@@ -532,9 +538,7 @@ def is_module(self) -> bool:
532
538
533
539
It does not check whether ``*_dataloader`` methods are actually overridden.
534
540
"""
535
- from pytorch_lightning import LightningDataModule , LightningModule # prevent cyclic import
536
-
537
- return isinstance (self .instance , (LightningModule , LightningDataModule ))
541
+ return isinstance (self .instance , (pl .LightningModule , pl .LightningDataModule ))
538
542
539
543
540
544
@dataclass
@@ -553,7 +557,7 @@ class _DataHookSelector:
553
557
554
558
model : "pl.LightningModule"
555
559
datamodule : Optional ["pl.LightningDataModule" ]
556
- _valid_hooks : Tuple [str ] = field (
560
+ _valid_hooks : Tuple [str , ... ] = field (
557
561
default = ("on_before_batch_transfer" , "transfer_batch_to_device" , "on_after_batch_transfer" )
558
562
)
559
563
0 commit comments