@@ -383,12 +383,14 @@ def prepare_data(self):
383
383
model.test_dataloader()
384
384
"""
385
385
386
- def train_dataloader (self ) -> DataLoader :
386
+ def train_dataloader (self ) -> Any :
387
387
"""
388
- Implement a PyTorch DataLoader for training.
388
+ Implement one or more PyTorch DataLoaders for training.
389
389
390
390
Return:
391
- Single PyTorch :class:`~torch.utils.data.DataLoader`.
391
+ Either a single PyTorch :class:`~torch.utils.data.DataLoader` or a collection of these
392
+ (list, dict, nested lists and dicts). In the case of multiple dataloaders, please see
393
+ this :ref:`page <multiple-training-dataloaders>`
392
394
393
395
The dataloader you return will not be called every epoch unless you set
394
396
:paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_epoch` to ``True``.
@@ -414,6 +416,7 @@ def train_dataloader(self) -> DataLoader:
414
416
415
417
Example::
416
418
419
+ # single dataloader
417
420
def train_dataloader(self):
418
421
transform = transforms.Compose([transforms.ToTensor(),
419
422
transforms.Normalize((0.5,), (1.0,))])
@@ -426,6 +429,32 @@ def train_dataloader(self):
426
429
)
427
430
return loader
428
431
432
+ # multiple dataloaders, return as list
433
+ def train_dataloader(self):
434
+ mnist = MNIST(...)
435
+ cifar = CIFAR(...)
436
+ mnist_loader = torch.utils.data.DataLoader(
437
+ dataset=mnist, batch_size=self.batch_size, shuffle=True
438
+ )
439
+ cifar_loader = torch.utils.data.DataLoader(
440
+ dataset=cifar, batch_size=self.batch_size, shuffle=True
441
+ )
442
+ # each batch will be a list of tensors: [batch_mnist, batch_cifar]
443
+ return [mnist_loader, cifar_loader]
444
+
445
+ # multiple dataloader, return as dict
446
+ def train_dataloader(self):
447
+ mnist = MNIST(...)
448
+ cifar = CIFAR(...)
449
+ mnist_loader = torch.utils.data.DataLoader(
450
+ dataset=mnist, batch_size=self.batch_size, shuffle=True
451
+ )
452
+ cifar_loader = torch.utils.data.DataLoader(
453
+ dataset=cifar, batch_size=self.batch_size, shuffle=True
454
+ )
455
+ # each batch will be a dict of tensors: {'mnist': batch_mnist, 'cifar': batch_cifar}
456
+ return {'mnist': mnist_loader, 'cifar': cifar_loader}
457
+
429
458
"""
430
459
rank_zero_warn ("`train_dataloader` must be implemented to be used with the Lightning Trainer" )
431
460
0 commit comments