|
5 | 5 | from pytorch_lightning import Trainer |
6 | 6 | from pytorch_lightning.utilities.data import ( |
7 | 7 | _replace_dataloader_init_method, |
| 8 | + _update_dataloader, |
8 | 9 | extract_batch_size, |
9 | 10 | get_len, |
10 | 11 | has_iterable_dataset, |
@@ -115,6 +116,38 @@ def test_has_len_all_rank(): |
115 | 116 | assert has_len_all_ranks(DataLoader(RandomDataset(1, 1)), trainer.training_type_plugin, model) |
116 | 117 |
|
117 | 118 |
|
| 119 | +def test_update_dataloader_typerror_custom_exception(): |
| 120 | + class BadImpl(DataLoader): |
| 121 | + def __init__(self, foo, *args, **kwargs): |
| 122 | + self.foo = foo |
| 123 | + # positional conflict with `dataset` |
| 124 | + super().__init__(foo, *args, **kwargs) |
| 125 | + |
| 126 | + dataloader = BadImpl([1, 2, 3]) |
| 127 | + with pytest.raises(MisconfigurationException, match="`DataLoader` implementation has an error.*`dataset`"): |
| 128 | + _update_dataloader(dataloader, dataloader.sampler) |
| 129 | + |
| 130 | + class BadImpl2(DataLoader): |
| 131 | + def __init__(self, randomize, *args, **kwargs): |
| 132 | + self.randomize = randomize |
| 133 | + # keyword conflict with `shuffle` |
| 134 | + super().__init__(*args, shuffle=randomize, **kwargs) |
| 135 | + |
| 136 | + dataloader = BadImpl2(False, []) |
| 137 | + with pytest.raises(MisconfigurationException, match="`DataLoader` implementation has an error.*`shuffle`"): |
| 138 | + _update_dataloader(dataloader, dataloader.sampler) |
| 139 | + |
| 140 | + class GoodImpl(DataLoader): |
| 141 | + def __init__(self, randomize, *args, **kwargs): |
| 142 | + # fixed implementation, kwargs are filtered |
| 143 | + self.randomize = randomize or kwargs.pop("shuffle", False) |
| 144 | + super().__init__(*args, shuffle=randomize, **kwargs) |
| 145 | + |
| 146 | + dataloader = GoodImpl(False, []) |
| 147 | + new_dataloader = _update_dataloader(dataloader, dataloader.sampler) |
| 148 | + assert isinstance(new_dataloader, GoodImpl) |
| 149 | + |
| 150 | + |
118 | 151 | def test_replace_dataloader_init_method(): |
119 | 152 | """Test that context manager intercepts arguments passed to custom subclasses of torch.utils.DataLoader and |
120 | 153 | sets them as attributes.""" |
|
0 commit comments