Skip to content

Commit 53815e6

Browse files
Fix overlapping samples in DDP when no global seed is set (#17713)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 41cfa33 commit 53815e6

File tree

3 files changed

+43
-3
lines changed

3 files changed

+43
-3
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
138138
- Fixed `WandbLogger` ignoring the `WANDB_PROJECT` environment variable ([#16222](https://github.com/Lightning-AI/lightning/pull/16222))
139139

140140

141+
- Fixed an edge case causing overlapping samples in DDP when no global seed is set ([#17713](https://github.com/Lightning-AI/lightning/pull/17713))
142+
143+
141144
## [2.0.1.post0] - 2023-04-11
142145

143146
### Fixed

src/lightning/pytorch/trainer/connectors/data_connector.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from dataclasses import dataclass, field
1717
from typing import Any, Iterable, Optional, Tuple, Union
1818

19-
from torch.utils.data import BatchSampler, DataLoader, Sampler, SequentialSampler
19+
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, Sampler, SequentialSampler
2020
from torch.utils.data.distributed import DistributedSampler
2121

2222
import lightning.pytorch as pl
@@ -251,8 +251,11 @@ def _get_distributed_sampler(
251251
"""This function is used to created the distributed sampler injected within the user DataLoader."""
252252
kwargs["shuffle"] = shuffle and not overfit_batches
253253
kwargs.setdefault("seed", int(os.getenv("PL_GLOBAL_SEED", 0)))
254-
cls = UnrepeatedDistributedSamplerWrapper if mode == RunningStage.PREDICTING else DistributedSamplerWrapper
255-
return cls(dataloader.sampler, **kwargs)
254+
if mode == RunningStage.PREDICTING:
255+
return UnrepeatedDistributedSamplerWrapper(dataloader.sampler, **kwargs)
256+
if isinstance(dataloader.sampler, (RandomSampler, SequentialSampler)):
257+
return DistributedSampler(dataloader.dataset, **kwargs)
258+
return DistributedSamplerWrapper(dataloader.sampler, **kwargs)
256259

257260

258261
def _resolve_overfit_batches(combined_loader: CombinedLoader, mode: RunningStage) -> None:

tests/tests_pytorch/trainer/test_dataloaders.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -804,6 +804,40 @@ def test_dataloader_distributed_sampler(tmpdir):
804804
trainer.test(model)
805805

806806

807+
class TestModelUniqueDDPSampling(BoringModel):
808+
def __init__(self):
809+
super().__init__()
810+
self.seen_samples = []
811+
812+
def training_step(self, batch):
813+
self.seen_samples.extend(batch.tolist())
814+
815+
def on_train_end(self):
816+
seen_samples = self.all_gather(self.seen_samples)
817+
# The samples should be unique across all processes
818+
assert set(torch.cat(seen_samples).view(-1).tolist()) == set(range(32))
819+
820+
821+
@RunIf(standalone=True)
822+
def test_distributed_sampler_without_global_seed(tmpdir):
823+
"""Test that the samples are non-overlapping in DDP when shuffling is enabled and no global seed is set."""
824+
# This test must run without a global seed set (e.g. through `seed_everything`), to ensure that each process
825+
# starts with a different initial state.
826+
assert "PL_GLOBAL_SEED" not in os.environ
827+
train_dataloader = DataLoader(range(32), shuffle=True, batch_size=4)
828+
trainer = Trainer(
829+
default_root_dir=tmpdir,
830+
num_sanity_val_steps=False,
831+
logger=False,
832+
enable_progress_bar=False,
833+
accelerator="cpu",
834+
devices=2,
835+
strategy="ddp",
836+
max_epochs=1,
837+
)
838+
trainer.fit(TestModelUniqueDDPSampling(), train_dataloader)
839+
840+
807841
class ModelWithDataLoaderDistributedSampler(BoringModel):
808842
def train_dataloader(self):
809843
dataloader = super().train_dataloader()

0 commit comments

Comments
 (0)