Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed torch distributed not available in setup hook for DDP ([#6506](https://github.com/PyTorchLightning/pytorch-lightning/pull/6506))


- Fixed an issue with `IterableDataset` when `__len__` is not defined ([#6828](https://github.com/PyTorchLightning/pytorch-lightning/pull/6828))


## [1.2.6] - 2021-03-30

### Changed
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,7 +866,7 @@ def validate(
self.validating = True

# If you supply a datamodule you can't supply val_dataloaders
if val_dataloaders and datamodule:
if val_dataloaders is not None and datamodule:
raise MisconfigurationException(
'You cannot pass both `trainer.validate(val_dataloaders=..., datamodule=...)`'
)
Expand Down Expand Up @@ -928,7 +928,7 @@ def test(
self.testing = True

# If you supply a datamodule you can't supply test_dataloaders
if test_dataloaders and datamodule:
if test_dataloaders is not None and datamodule:
raise MisconfigurationException('You cannot pass both `trainer.test(test_dataloaders=..., datamodule=...)`')

model_provided = model is not None
Expand Down Expand Up @@ -1024,7 +1024,7 @@ def predict(
self.state = TrainerState.PREDICTING
self.predicting = True

if dataloaders and datamodule:
if dataloaders is not None and datamodule:
raise MisconfigurationException(
'You cannot pass dataloaders to trainer.predict if you supply a datamodule.'
)
Expand Down
26 changes: 20 additions & 6 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,28 +636,42 @@ def test_warning_with_few_workers_multi_loader(_, tmpdir, ckpt_path, stage):

def test_warning_with_iterable_dataset_and_len(tmpdir):
""" Tests that a warning message is shown when an IterableDataset defines `__len__`. """
model = EvalModelTemplate()
model = BoringModel()
original_dataset = model.train_dataloader().dataset

class IterableWithLen(IterableDataset):
class IterableWithoutLen(IterableDataset):

def __iter__(self):
return iter(original_dataset)

class IterableWithLen(IterableWithoutLen):

def __len__(self):
return len(original_dataset)

# with __len__ defined
dataloader = DataLoader(IterableWithLen(), batch_size=16)
assert has_len(dataloader)
assert has_iterable_dataset(dataloader)
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=3,
)
trainer = Trainer(default_root_dir=tmpdir, max_steps=3)
with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'):
trainer.validate(model, val_dataloaders=[dataloader])
with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'):
trainer.fit(model, train_dataloader=dataloader, val_dataloaders=[dataloader])
with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'):
trainer.test(model, test_dataloaders=[dataloader])
with pytest.warns(UserWarning, match='Your `IterableDataset` has `__len__` defined.'):
trainer.predict(model, dataloaders=[dataloader])

# without __len__ defined
dataloader = DataLoader(IterableWithoutLen(), batch_size=16)
assert not has_len(dataloader)
assert has_iterable_dataset(dataloader)
trainer = Trainer(default_root_dir=tmpdir, max_steps=3)
trainer.validate(model, val_dataloaders=dataloader)
trainer.fit(model, train_dataloader=dataloader, val_dataloaders=[dataloader])
trainer.test(model, test_dataloaders=dataloader)
trainer.predict(model, dataloaders=dataloader)


@RunIf(min_gpus=2)
Expand Down