|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
| 14 | +import re |
| 15 | + |
14 | 16 | import pytest |
15 | 17 | import torch |
16 | 18 |
|
17 | 19 | from pytorch_lightning import seed_everything, Trainer |
| 20 | +from pytorch_lightning.utilities.exceptions import MisconfigurationException |
18 | 21 | from tests.helpers import BoringModel |
19 | 22 |
|
20 | 23 |
|
@@ -222,3 +225,56 @@ def on_train_batch_start(self, batch, batch_idx, dataloader_idx): |
222 | 225 | else: |
223 | 226 | assert trainer.batch_idx == batch_idx_ |
224 | 227 | assert trainer.global_step == batch_idx_ * max_epochs |
| 228 | + |
| 229 | + |
| 230 | +def test_should_stop_mid_epoch(tmpdir): |
| 231 | + """Test that training correctly stops mid epoch and that validation is still called at the right time""" |
| 232 | + |
| 233 | + class TestModel(BoringModel): |
| 234 | + |
| 235 | + def __init__(self): |
| 236 | + super().__init__() |
| 237 | + self.validation_called_at = None |
| 238 | + |
| 239 | + def training_step(self, batch, batch_idx): |
| 240 | + if batch_idx == 4: |
| 241 | + self.trainer.should_stop = True |
| 242 | + return super().training_step(batch, batch_idx) |
| 243 | + |
| 244 | + def validation_step(self, *args): |
| 245 | + self.validation_called_at = (self.trainer.current_epoch, self.trainer.global_step) |
| 246 | + return super().validation_step(*args) |
| 247 | + |
| 248 | + model = TestModel() |
| 249 | + trainer = Trainer( |
| 250 | + default_root_dir=tmpdir, |
| 251 | + max_epochs=1, |
| 252 | + limit_train_batches=10, |
| 253 | + limit_val_batches=1, |
| 254 | + ) |
| 255 | + trainer.fit(model) |
| 256 | + |
| 257 | + assert trainer.current_epoch == 0 |
| 258 | + assert trainer.global_step == 5 |
| 259 | + assert model.validation_called_at == (0, 4) |
| 260 | + |
| 261 | + |
| 262 | +@pytest.mark.parametrize(['output'], [(5., ), ({'a': 5}, )]) |
| 263 | +def test_warning_invalid_trainstep_output(tmpdir, output): |
| 264 | + |
| 265 | + class TestModel(BoringModel): |
| 266 | + |
| 267 | + def training_step(self, batch, batch_idx): |
| 268 | + return output |
| 269 | + |
| 270 | + model = TestModel() |
| 271 | + |
| 272 | + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1) |
| 273 | + with pytest.raises( |
| 274 | + MisconfigurationException, |
| 275 | + match=re.escape( |
| 276 | + "In automatic optimization, `training_step` must either return a Tensor, " |
| 277 | + "a dict with key 'loss' or None (where the step will be skipped)." |
| 278 | + ) |
| 279 | + ): |
| 280 | + trainer.fit(model) |
0 commit comments