Skip to content

Commit f8b2d5b

Browse files
authored
Improve error message on TypeError during DataLoader reconstruction (#10719)
1 parent 0066ff0 commit f8b2d5b

File tree

3 files changed

+53
-2
lines changed

3 files changed

+53
-2
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2525
- Added support for re-instantiation of custom (subclasses of) `DataLoaders` returned in the `*_dataloader()` methods, i.e., automatic replacement of samplers now works with custom types of `DataLoader` ([#10680](https://github.com/PyTorchLightning/pytorch-lightning/issues/10639))
2626

2727

28-
-
28+
- Show a better error message when a custom `DataLoader` implementation is not well implemented and we need to reconstruct it ([#10719](https://github.com/PyTorchLightning/pytorch-lightning/issues/10719))
2929

3030

3131
-

pytorch_lightning/utilities/data.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,25 @@ def get_len(dataloader: DataLoader) -> Union[int, float]:
180180
def _update_dataloader(dataloader: DataLoader, sampler: Sampler, mode: Optional[RunningStage] = None) -> DataLoader:
181181
dl_kwargs = _get_dataloader_init_kwargs(dataloader, sampler, mode=mode)
182182
dl_cls = type(dataloader)
183-
dataloader = dl_cls(**dl_kwargs)
183+
try:
184+
dataloader = dl_cls(**dl_kwargs)
185+
except TypeError as e:
186+
# improve exception message due to an incorrect implementation of the `DataLoader` where multiple subclass
187+
# `__init__` arguments map to one `DataLoader.__init__` argument
188+
import re
189+
190+
match = re.match(r".*__init__\(\) got multiple values .* '(\w+)'", str(e))
191+
if not match:
192+
# an unexpected `TypeError`, continue failure
193+
raise
194+
argument = match.groups()[0]
195+
message = (
196+
f"The {dl_cls.__name__} `DataLoader` implementation has an error where more than one `__init__` argument"
197+
f" can be passed to its parent's `{argument}=...` `__init__` argument. This is likely caused by allowing"
198+
f" passing both a custom argument that will map to the `{argument}` argument as well as `**kwargs`."
199+
f" `kwargs` should be filtered to make sure they don't contain the `{argument}` key."
200+
)
201+
raise MisconfigurationException(message) from e
184202
return dataloader
185203

186204

tests/utilities/test_data.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pytorch_lightning import Trainer
66
from pytorch_lightning.utilities.data import (
77
_replace_dataloader_init_method,
8+
_update_dataloader,
89
extract_batch_size,
910
get_len,
1011
has_iterable_dataset,
@@ -115,6 +116,38 @@ def test_has_len_all_rank():
115116
assert has_len_all_ranks(DataLoader(RandomDataset(1, 1)), trainer.training_type_plugin, model)
116117

117118

119+
def test_update_dataloader_typerror_custom_exception():
120+
class BadImpl(DataLoader):
121+
def __init__(self, foo, *args, **kwargs):
122+
self.foo = foo
123+
# positional conflict with `dataset`
124+
super().__init__(foo, *args, **kwargs)
125+
126+
dataloader = BadImpl([1, 2, 3])
127+
with pytest.raises(MisconfigurationException, match="`DataLoader` implementation has an error.*`dataset`"):
128+
_update_dataloader(dataloader, dataloader.sampler)
129+
130+
class BadImpl2(DataLoader):
131+
def __init__(self, randomize, *args, **kwargs):
132+
self.randomize = randomize
133+
# keyword conflict with `shuffle`
134+
super().__init__(*args, shuffle=randomize, **kwargs)
135+
136+
dataloader = BadImpl2(False, [])
137+
with pytest.raises(MisconfigurationException, match="`DataLoader` implementation has an error.*`shuffle`"):
138+
_update_dataloader(dataloader, dataloader.sampler)
139+
140+
class GoodImpl(DataLoader):
141+
def __init__(self, randomize, *args, **kwargs):
142+
# fixed implementation, kwargs are filtered
143+
self.randomize = randomize or kwargs.pop("shuffle", False)
144+
super().__init__(*args, shuffle=randomize, **kwargs)
145+
146+
dataloader = GoodImpl(False, [])
147+
new_dataloader = _update_dataloader(dataloader, dataloader.sampler)
148+
assert isinstance(new_dataloader, GoodImpl)
149+
150+
118151
def test_replace_dataloader_init_method():
119152
"""Test that context manager intercepts arguments passed to custom subclasses of torch.utils.DataLoader and
120153
sets them as attributes."""

0 commit comments

Comments
 (0)