Skip to content

Commit 4f90455

Browse files
SkafteNickiBordarohitgr7awaelchli
authored
Update docs on arg train_dataloader in fit (#6076)
* add to docs * update docs * Apply suggestions from code review * Update pytorch_lightning/core/hooks.py Co-authored-by: Rohit Gupta <[email protected]> * nested loaders * Apply suggestions from code review Co-authored-by: Adrian Wälchli <[email protected]> * shorten text length * Update pytorch_lightning/core/hooks.py Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]>
1 parent 5d7388d commit 4f90455

File tree

3 files changed

+60
-7
lines changed

3 files changed

+60
-7
lines changed

docs/source/advanced/multiple_loaders.rst

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ Lightning supports multiple dataloaders in a few ways.
1616

1717
----------
1818

19+
.. _multiple-training-dataloaders:
20+
1921
Multiple training dataloaders
2022
-----------------------------
2123
For training, the usual way to use multiple dataloaders is to create a ``DataLoader`` class
@@ -86,6 +88,27 @@ For more details please have a look at :attr:`~pytorch_lightning.trainer.trainer
8688

8789
return loaders
8890

91+
Furthermore, Lightning also supports that nested lists and dicts (or a combination) can
92+
be returned
93+
94+
.. testcode::
95+
96+
class LitModel(LightningModule):
97+
98+
def train_dataloader(self):
99+
100+
loader_a = torch.utils.data.DataLoader(range(8), batch_size=4)
101+
loader_b = torch.utils.data.DataLoader(range(16), batch_size=4)
102+
loader_c = torch.utils.data.DataLoader(range(32), batch_size=4)
103+
loader_c = torch.utils.data.DataLoader(range(64), batch_size=4)
104+
105+
# pass loaders as a nested dict. This will create batches like this:
106+
# {'loader_a_b': {'a': batch from loader a, 'b': batch from loader b},
107+
# 'loader_c_d': {'c': batch from loader c, 'd': batch from loader d}}
108+
loaders = {'loaders_a_b': {'a': loader_a, 'b': loader_b},
109+
'loaders_c_d': {'c': loader_c, 'd': loader_d}}
110+
return loaders
111+
89112
----------
90113

91114
Test/Val dataloaders

pytorch_lightning/core/hooks.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -383,12 +383,14 @@ def prepare_data(self):
383383
model.test_dataloader()
384384
"""
385385

386-
def train_dataloader(self) -> DataLoader:
386+
def train_dataloader(self) -> Any:
387387
"""
388-
Implement a PyTorch DataLoader for training.
388+
Implement one or more PyTorch DataLoaders for training.
389389
390390
Return:
391-
Single PyTorch :class:`~torch.utils.data.DataLoader`.
391+
Either a single PyTorch :class:`~torch.utils.data.DataLoader` or a collection of these
392+
(list, dict, nested lists and dicts). In the case of multiple dataloaders, please see
393+
this :ref:`page <multiple-training-dataloaders>`
392394
393395
The dataloader you return will not be called every epoch unless you set
394396
:paramref:`~pytorch_lightning.trainer.Trainer.reload_dataloaders_every_epoch` to ``True``.
@@ -414,6 +416,7 @@ def train_dataloader(self) -> DataLoader:
414416
415417
Example::
416418
419+
# single dataloader
417420
def train_dataloader(self):
418421
transform = transforms.Compose([transforms.ToTensor(),
419422
transforms.Normalize((0.5,), (1.0,))])
@@ -426,6 +429,32 @@ def train_dataloader(self):
426429
)
427430
return loader
428431
432+
# multiple dataloaders, return as list
433+
def train_dataloader(self):
434+
mnist = MNIST(...)
435+
cifar = CIFAR(...)
436+
mnist_loader = torch.utils.data.DataLoader(
437+
dataset=mnist, batch_size=self.batch_size, shuffle=True
438+
)
439+
cifar_loader = torch.utils.data.DataLoader(
440+
dataset=cifar, batch_size=self.batch_size, shuffle=True
441+
)
442+
# each batch will be a list of tensors: [batch_mnist, batch_cifar]
443+
return [mnist_loader, cifar_loader]
444+
445+
# multiple dataloader, return as dict
446+
def train_dataloader(self):
447+
mnist = MNIST(...)
448+
cifar = CIFAR(...)
449+
mnist_loader = torch.utils.data.DataLoader(
450+
dataset=mnist, batch_size=self.batch_size, shuffle=True
451+
)
452+
cifar_loader = torch.utils.data.DataLoader(
453+
dataset=cifar, batch_size=self.batch_size, shuffle=True
454+
)
455+
# each batch will be a dict of tensors: {'mnist': batch_mnist, 'cifar': batch_cifar}
456+
return {'mnist': mnist_loader, 'cifar': cifar_loader}
457+
429458
"""
430459
rank_zero_warn("`train_dataloader` must be implemented to be used with the Lightning Trainer")
431460

pytorch_lightning/trainer/trainer.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import warnings
1717
from itertools import count
1818
from pathlib import Path
19-
from typing import Dict, Iterable, List, Optional, Union
19+
from typing import Any, Dict, Iterable, List, Optional, Union
2020

2121
import torch
2222
from torch.utils.data import DataLoader
@@ -399,7 +399,7 @@ def setup_trainer(self, model: LightningModule):
399399
def fit(
400400
self,
401401
model: LightningModule,
402-
train_dataloader: Optional[DataLoader] = None,
402+
train_dataloader: Any = None,
403403
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
404404
datamodule: Optional[LightningDataModule] = None,
405405
):
@@ -411,8 +411,9 @@ def fit(
411411
412412
model: Model to fit.
413413
414-
train_dataloader: A Pytorch DataLoader with training samples. If the model has
415-
a predefined train_dataloader method this will be skipped.
414+
train_dataloader: Either a single PyTorch DataLoader or a collection of these
415+
(list, dict, nested lists and dicts). In the case of multiple dataloaders, please
416+
see this :ref:`page <multiple-training-dataloaders>`
416417
417418
val_dataloaders: Either a single Pytorch Dataloader or a list of them, specifying validation samples.
418419
If the model has a predefined val_dataloaders method this will be skipped

0 commit comments

Comments
 (0)