Skip to content

Commit 1e77c5f

Browse files
committed
add test
1 parent 321ce93 commit 1e77c5f

File tree

2 files changed

+27
-10
lines changed

2 files changed

+27
-10
lines changed

pytorch_lightning/trainer/trainer.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -939,12 +939,6 @@ def test(
939939
# Attach dataloaders (if given)
940940
self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders)
941941

942-
if not model_provided and ckpt_path == 'best' and self.fast_dev_run:
943-
raise MisconfigurationException(
944-
'You cannot execute testing when the model is not provided and `fast_dev_run=True`. '
945-
'Provide model with `trainer.test(model=...)` or `trainer.test(ckpt_path=...)`'
946-
)
947-
948942
if not model_provided:
949943
self.tested_ckpt_path = self.__load_ckpt_weights(model, ckpt_path=ckpt_path)
950944

@@ -962,10 +956,17 @@ def __load_ckpt_weights(
962956
ckpt_path: Optional[str] = None,
963957
) -> Optional[str]:
964958
# if user requests the best checkpoint but we don't have it, error
965-
if ckpt_path == 'best' and not self.checkpoint_callback.best_model_path:
966-
raise MisconfigurationException(
967-
'ckpt_path is "best", but `ModelCheckpoint` is not configured to save the best model.'
968-
)
959+
if ckpt_path == 'best':
960+
if not self.checkpoint_callback.best_model_path and self.fast_dev_run:
961+
raise MisconfigurationException(
962+
'You cannot execute `trainer.test()` or trainer.validate()`'
963+
' when `fast_dev_run=True`.'
964+
)
965+
966+
if not self.checkpoint_callback.best_model_path:
967+
raise MisconfigurationException(
968+
'ckpt_path is "best", but `ModelCheckpoint` is not configured to save the best model.'
969+
)
969970

970971
# load best weights
971972
if ckpt_path is not None:

tests/trainer/test_trainer.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1777,3 +1777,19 @@ def on_fit_start(self, trainer, pl_module: LightningModule) -> None:
17771777

17781778
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=[TestCallback()])
17791779
trainer.fit(model, datamodule=dm)
1780+
1781+
1782+
@pytest.mark.parametrize("fast_dev_run", [True, False])
1783+
def test_exception_when_testing_or_validating_with_fast_dev_run(tmpdir, fast_dev_run):
1784+
model = BoringModel()
1785+
1786+
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, fast_dev_run=fast_dev_run)
1787+
1788+
trainer.fit(model)
1789+
1790+
if fast_dev_run:
1791+
with pytest.raises(MisconfigurationException, match=".*when `fast_dev_run=True`*"):
1792+
trainer.validate()
1793+
1794+
with pytest.raises(MisconfigurationException, match=".*when `fast_dev_run=True`*"):
1795+
trainer.test()

0 commit comments

Comments
 (0)