|
13 | 13 | # limitations under the License.
|
14 | 14 | import torch
|
15 | 15 | import torch.nn.functional as F
|
| 16 | +from torch.utils.data import DataLoader |
16 | 17 |
|
17 | 18 | import pytorch_lightning as pl
|
18 | 19 | import tests.helpers.pipelines as tpipes
|
19 | 20 | import tests.helpers.utils as tutils
|
20 | 21 | from pytorch_lightning.callbacks import EarlyStopping
|
21 | 22 | from pytorch_lightning.core import memory
|
22 |
| -from tests.helpers import BoringModel |
| 23 | +from tests.helpers import BoringModel, RandomDataset |
23 | 24 | from tests.helpers.datamodules import ClassifDataModule
|
24 | 25 | from tests.helpers.runif import RunIf
|
25 | 26 | from tests.helpers.simple_models import ClassificationModel
|
@@ -125,19 +126,58 @@ def test_dp_test(tmpdir):
|
125 | 126 | assert torch.all(torch.eq(old_weights, new_weights))
|
126 | 127 |
|
127 | 128 |
|
| 129 | +class ReductionTestModel(BoringModel): |
| 130 | + |
| 131 | + def train_dataloader(self): |
| 132 | + return DataLoader(RandomDataset(32, 64), batch_size=2) |
| 133 | + |
| 134 | + def val_dataloader(self): |
| 135 | + return DataLoader(RandomDataset(32, 64), batch_size=2) |
| 136 | + |
| 137 | + def test_dataloader(self): |
| 138 | + return DataLoader(RandomDataset(32, 64), batch_size=2) |
| 139 | + |
| 140 | + def add_outputs(self, output, device): |
| 141 | + output.update({ |
| 142 | + "reduce_int": torch.tensor(device.index, dtype=torch.int, device=device), |
| 143 | + "reduce_float": torch.tensor(device.index, dtype=torch.float, device=device), |
| 144 | + }) |
| 145 | + |
| 146 | + def training_step(self, batch, batch_idx): |
| 147 | + output = super().training_step(batch, batch_idx) |
| 148 | + self.add_outputs(output, batch.device) |
| 149 | + return output |
| 150 | + |
| 151 | + def validation_step(self, batch, batch_idx): |
| 152 | + output = super().validation_step(batch, batch_idx) |
| 153 | + self.add_outputs(output, batch.device) |
| 154 | + return output |
| 155 | + |
| 156 | + def test_step(self, batch, batch_idx): |
| 157 | + output = super().test_step(batch, batch_idx) |
| 158 | + self.add_outputs(output, batch.device) |
| 159 | + return output |
| 160 | + |
| 161 | + def training_epoch_end(self, outputs): |
| 162 | + assert outputs[0]["loss"].shape == torch.Size([]) |
| 163 | + assert outputs[0]["reduce_int"].item() == 0 # mean([0, 1]) = 0 |
| 164 | + assert outputs[0]["reduce_float"].item() == 0.5 # mean([0., 1.]) = 0.5 |
| 165 | + |
| 166 | + |
128 | 167 | @RunIf(min_gpus=2)
|
129 | 168 | def test_dp_training_step_dict(tmpdir):
|
130 |
| - """ |
131 |
| - This test verify dp properly reduce dictionaries |
132 |
| - """ |
133 |
| - |
134 |
| - model = BoringModel() |
| 169 | + """ This test verifies that dp properly reduces dictionaries """ |
| 170 | + model = ReductionTestModel() |
135 | 171 | model.training_step_end = None
|
| 172 | + model.validation_step_end = None |
| 173 | + model.test_step_end = None |
| 174 | + |
136 | 175 | trainer = pl.Trainer(
|
137 | 176 | default_root_dir=tmpdir,
|
138 | 177 | max_epochs=1,
|
139 |
| - limit_train_batches=2, |
140 |
| - limit_val_batches=0, |
| 178 | + limit_train_batches=1, |
| 179 | + limit_val_batches=1, |
| 180 | + limit_test_batches=1, |
141 | 181 | gpus=2,
|
142 | 182 | accelerator='dp',
|
143 | 183 | )
|
|
0 commit comments