Skip to content

Commit b964418

Browse files
committed
add test
1 parent bec6d43 commit b964418

File tree

2 files changed

+28
-10
lines changed

2 files changed

+28
-10
lines changed

pytorch_lightning/trainer/trainer.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -948,12 +948,6 @@ def test(
948948
# Attach dataloaders (if given)
949949
self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders)
950950

951-
if not model_provided and ckpt_path == 'best' and self.fast_dev_run:
952-
raise MisconfigurationException(
953-
'You cannot execute testing when the model is not provided and `fast_dev_run=True`. '
954-
'Provide model with `trainer.test(model=...)` or `trainer.test(ckpt_path=...)`'
955-
)
956-
957951
if not model_provided:
958952
self.tested_ckpt_path = self.__load_ckpt_weights(model, ckpt_path=ckpt_path)
959953

@@ -971,10 +965,17 @@ def __load_ckpt_weights(
971965
ckpt_path: Optional[str] = None,
972966
) -> Optional[str]:
973967
# if user requests the best checkpoint but we don't have it, error
974-
if ckpt_path == 'best' and not self.checkpoint_callback.best_model_path:
975-
raise MisconfigurationException(
976-
'ckpt_path is "best", but `ModelCheckpoint` is not configured to save the best model.'
977-
)
968+
if ckpt_path == 'best':
969+
if not self.checkpoint_callback.best_model_path and self.fast_dev_run:
970+
raise MisconfigurationException(
971+
'You cannot execute `trainer.test()` or trainer.validate()`'
972+
' when `fast_dev_run=True`.'
973+
)
974+
975+
if not self.checkpoint_callback.best_model_path:
976+
raise MisconfigurationException(
977+
'ckpt_path is "best", but `ModelCheckpoint` is not configured to save the best model.'
978+
)
978979

979980
# load best weights
980981
if ckpt_path is not None:

tests/trainer/test_trainer.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1451,6 +1451,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None):
14511451

14521452

14531453
def test_trainer_predict_grad(tmpdir):
1454+
14541455
class CustomBoringModel(BoringModel):
14551456

14561457
def predict_step(self, batch, batch_idx, dataloader_idx=None):
@@ -1776,3 +1777,19 @@ def on_fit_start(self, trainer, pl_module: LightningModule) -> None:
17761777

17771778
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=[TestCallback()])
17781779
trainer.fit(model, datamodule=dm)
1780+
1781+
1782+
@pytest.mark.parametrize("fast_dev_run", [True, False])
1783+
def test_exception_when_testing_or_validating_with_fast_dev_run(tmpdir, fast_dev_run):
1784+
model = BoringModel()
1785+
1786+
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, fast_dev_run=fast_dev_run)
1787+
1788+
trainer.fit(model)
1789+
1790+
if fast_dev_run:
1791+
with pytest.raises(MisconfigurationException, match=".*when `fast_dev_run=True`*"):
1792+
trainer.validate()
1793+
1794+
with pytest.raises(MisconfigurationException, match=".*when `fast_dev_run=True`*"):
1795+
trainer.test()

0 commit comments

Comments
 (0)