Skip to content

Commit a0d0c4b

Browse files
committed
Refactor Trainer in advance of implementing Trainer.validate
* Replace the `Trainer.testing` attribute with `Trainer.evaluating`, which is currently set to `'test'` if the top-level function called by the user was `Trainer.test(…)` and `None` otherwise. In the next PR, it will be set to `'validation’` when the user calls `validate(…)`. * Update the other components to use the new attribute instead of `Trainer.testing` * Disable the `EarlyStopping` and `ModelCheckpoint` callbacks when `evaluating`. This has no effect when evaluating on the test set, since they were already disabled, but it will be necessary for the validation set * Rename a few other attributes of `Trainer` to clarify that they will be used by both `test(…)` and `validate(…)`
1 parent add387c commit a0d0c4b

File tree

10 files changed

+80
-56
lines changed

10 files changed

+80
-56
lines changed

pytorch_lightning/accelerators/accelerator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def broadcast(self, obj, src=0):
6060
return obj
6161

6262
def train_or_test(self):
63-
if self.trainer.testing:
63+
if self.trainer.evaluating:
6464
results = self.trainer.run_test()
6565
else:
6666
results = self.trainer.train()
@@ -160,7 +160,7 @@ def early_stopping_should_stop(self, pl_module):
160160
return self.trainer.should_stop
161161

162162
def setup_optimizers(self, model):
163-
if self.trainer.testing is True:
163+
if self.trainer.evaluating:
164164
return
165165

166166
optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(model)

pytorch_lightning/callbacks/early_stopping.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,13 +134,13 @@ def on_load_checkpoint(self, checkpointed_state):
134134
self.patience = checkpointed_state['patience']
135135

136136
def on_validation_end(self, trainer, pl_module):
137-
if trainer.running_sanity_check:
137+
if trainer.running_sanity_check or trainer.evaluating:
138138
return
139139

140140
self._run_early_stopping_check(trainer, pl_module)
141141

142142
def on_validation_epoch_end(self, trainer, pl_module):
143-
if trainer.running_sanity_check:
143+
if trainer.running_sanity_check or trainer.evaluating:
144144
return
145145

146146
if self._validate_condition_metric(trainer.logger_connector.callback_metrics):

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ def save_checkpoint(self, trainer, pl_module):
220220
or self.period < 1 # no models are saved
221221
or (epoch + 1) % self.period # skip epoch
222222
or trainer.running_sanity_check # don't save anything during sanity check
223+
or trainer.evaluating # don't save anything during evaluation: might delete the checkpoint being evaluated
223224
or self.last_global_step_saved == global_step # already saved at the last step
224225
):
225226
return

pytorch_lightning/trainer/configuration_validator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,12 @@ def verify_loop_configurations(self, model: LightningModule):
3131
model: The model to check the configuration.
3232
3333
"""
34-
if not self.trainer.testing:
34+
if not self.trainer.evaluating:
3535
self.__verify_train_loop_configuration(model)
3636
self.__verify_eval_loop_configuration(model, 'validation')
3737
else:
38-
# check test loop configuration
39-
self.__verify_eval_loop_configuration(model, 'test')
38+
# check evaluation loop configurations
39+
self.__verify_eval_loop_configuration(model, self.trainer.evaluating)
4040

4141
def __verify_train_loop_configuration(self, model):
4242
# -----------------------------------

pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def prepare_eval_loop_results(self):
265265
for dl_idx in range(self.trainer.evaluation_loop.num_dataloaders):
266266
self.add_to_eval_loop_results(dl_idx, has_been_initialized)
267267

268-
def get_evaluate_epoch_results(self, test_mode):
268+
def get_evaluate_epoch_results(self):
269269
if not self.trainer.running_sanity_check:
270270
# log all the metrics as a single dict
271271
metrics_to_log = self.cached_results.get_epoch_log_metrics()
@@ -274,11 +274,11 @@ def get_evaluate_epoch_results(self, test_mode):
274274

275275
self.prepare_eval_loop_results()
276276

277-
# log results of test
278-
if test_mode and self.trainer.is_global_zero and self.trainer.verbose_test:
277+
# log results of evaluation
278+
if self.trainer.evaluating and self.trainer.is_global_zero and self.trainer.verbose_evaluate:
279279
print('-' * 80)
280280
for result_idx, results in enumerate(self.eval_loop_results):
281-
print(f'DATALOADER:{result_idx} TEST RESULTS')
281+
print(f'DATALOADER:{result_idx} {self.trainer.evaluating.upper()} RESULTS')
282282
pprint(results)
283283
print('-' * 80)
284284

pytorch_lightning/trainer/connectors/model_connector.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,11 @@ def copy_trainer_model_properties(self, model):
3636
m.use_ddp2 = self.trainer.use_ddp2
3737
m.use_ddp = self.trainer.use_ddp
3838
m.use_amp = self.trainer.amp_backend is not None
39-
m.testing = self.trainer.testing
39+
# Currently, the only users of m.testing appear to be DP and DDP,
40+
# which use it to determine whether the model is currently inside
41+
# the validation or test loop. For this reason it must check if
42+
# trainer.evaluating is equal to "test" specifically.
43+
m.testing = self.trainer.evaluating == 'test'
4044
m.use_single_gpu = self.trainer.use_single_gpu
4145
m.use_tpu = self.trainer.use_tpu
4246
m.tpu_local_core_rank = self.trainer.tpu_local_core_rank

pytorch_lightning/trainer/evaluation_loop.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
import torch
1515

16+
import pytorch_lightning as pl
1617
from pytorch_lightning.core.step_result import EvalResult, Result
1718
from pytorch_lightning.trainer.supporters import PredictionCollection
1819
from pytorch_lightning.utilities.distributed import rank_zero_warn
@@ -22,7 +23,7 @@
2223

2324

2425
class EvaluationLoop(object):
25-
def __init__(self, trainer):
26+
def __init__(self, trainer: 'pl.Trainer'):
2627
self.trainer = trainer
2728
self.testing = False
2829
self.outputs = []
@@ -39,13 +40,15 @@ def on_trainer_init(self):
3940
self.trainer.test_dataloaders = None
4041
self.trainer.val_dataloaders = None
4142
self.trainer.running_sanity_check = False
42-
self.trainer.testing = False
4343

44-
# when .test() is called, it sets this
45-
self.trainer.tested_ckpt_path = None
44+
# .validate() sets this to 'validation' and .test() sets this to 'test'
45+
self.trainer.evaluating = None
4646

47-
# when true, prints test results
48-
self.trainer.verbose_test = True
47+
# .validate() and .test() set this when they load a checkpoint
48+
self.trainer.evaluated_ckpt_path = None
49+
50+
# when true, print evaluation results in .validate() and .test()
51+
self.trainer.verbose_evaluate = True
4952

5053
def get_evaluation_dataloaders(self, max_batches):
5154
# select dataloaders
@@ -216,7 +219,7 @@ def evaluation_epoch_end(self):
216219

217220
def log_epoch_metrics_on_evaluation_end(self):
218221
# get the final loop results
219-
eval_loop_results = self.trainer.logger_connector.get_evaluate_epoch_results(self.testing)
222+
eval_loop_results = self.trainer.logger_connector.get_evaluate_epoch_results()
220223
return eval_loop_results
221224

222225
def __run_eval_epoch_end(self, num_dataloaders, using_eval_result):

pytorch_lightning/trainer/trainer.py

Lines changed: 49 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -441,10 +441,6 @@ def fit(
441441
# hook
442442
self.data_connector.prepare_data(model)
443443

444-
# bookkeeping
445-
# we reuse fit in .test() but change its behavior using this flag
446-
self.testing = os.environ.get('PL_TESTING_MODE', self.testing)
447-
448444
# ----------------------------
449445
# SET UP TRAINING
450446
# ----------------------------
@@ -720,33 +716,31 @@ def test(
720716
datamodule: Optional[LightningDataModule] = None,
721717
):
722718
r"""
723-
724-
Separates from fit to make sure you never run on your test set until you want to.
719+
Perform one evaluation epoch over the test set. It's separated from
720+
fit to make sure you never run on your test set until you want to.
725721
726722
Args:
727723
ckpt_path: Either ``best`` or path to the checkpoint you wish to test.
728-
If ``None``, use the weights from the last epoch to test. Default to ``best``.
729-
724+
If ``None``, use the current weights of the model. Default to ``best``.
730725
datamodule: A instance of :class:`LightningDataModule`.
731-
732-
model: The model to test.
733-
734-
test_dataloaders: Either a single
735-
Pytorch Dataloader or a list of them, specifying validation samples.
736-
737-
verbose: If True, prints the test results
726+
model: The model to evaluate.
727+
test_dataloaders: Either a single PyTorch DataLoader or a list of them,
728+
specifying test samples.
729+
verbose: If True, prints the test results.
738730
739731
Returns:
740-
The final test result dictionary. If no test_epoch_end is defined returns a list of dictionaries
732+
The dictionary with final test results returned by test_epoch_end.
733+
If test_epoch_end is not defined, the output is a list of the dictionaries
734+
returned by test_step.
741735
"""
742736
# --------------------
743737
# SETUP HOOK
744738
# --------------------
745-
self.verbose_test = verbose
739+
self.verbose_evaluate = verbose
746740

747741
self.logger_connector.set_stage("test")
748742

749-
# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
743+
# If you supply a datamodule you can't supply test_dataloaders
750744
if test_dataloaders and datamodule:
751745
raise MisconfigurationException(
752746
'You cannot pass test_dataloaders to trainer.test if you supply a datamodule'
@@ -756,15 +750,15 @@ def test(
756750
self.data_connector.attach_datamodule(model or self.get_model(), datamodule, 'test')
757751

758752
if model is not None:
759-
results = self.__test_given_model(model, test_dataloaders)
753+
results = self.__evaluate_given_model(model, test_dataloaders, 'test')
760754
else:
761-
results = self.__test_using_best_weights(ckpt_path, test_dataloaders)
755+
results = self.__evaluate_using_best_weights(ckpt_path, test_dataloaders, 'test')
762756

763757
self.teardown('test')
764758

765759
return results
766760

767-
def __test_using_best_weights(self, ckpt_path, test_dataloaders):
761+
def __evaluate_using_best_weights(self, ckpt_path, dataloaders, stage: str):
768762
model = self.get_model()
769763

770764
# if user requests the best checkpoint but we don't have it, error
@@ -796,40 +790,56 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders):
796790
self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders)
797791

798792
# run tests
799-
self.tested_ckpt_path = ckpt_path
800-
self.testing = True
801-
os.environ['PL_TESTING_MODE'] = '1'
793+
self.evaluating = stage
794+
self.evaluated_ckpt_path = ckpt_path
802795
self.model = model
803796
results = self.fit(model)
804-
self.testing = False
805-
del os.environ['PL_TESTING_MODE']
797+
self.evaluating = None
806798

807799
# teardown
808800
if self.is_function_implemented('teardown'):
809801
model_ref = self.get_model()
810-
model_ref.teardown('test')
802+
model_ref.teardown(stage)
811803

812804
return results
813805

814-
def __test_given_model(self, model, test_dataloaders):
806+
def __evaluate_given_model(self, model, dataloaders, stage: str):
815807

816808
# attach data
817809
if test_dataloaders is not None:
818810
self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders)
819811

820812
# run test
821813
# sets up testing so we short circuit to eval
822-
self.testing = True
814+
self.evaluating = stage
823815
self.model = model
824816
results = self.fit(model)
825-
self.testing = False
817+
self.evaluating = None
826818

827819
# teardown
828820
if self.is_function_implemented('teardown'):
829-
model.teardown('test')
821+
model.teardown(stage)
830822

831823
return results
832824

825+
@property
826+
def testing(self):
827+
warnings.warn(
828+
'Trainer.testing has been deprecated in v1.1 and will be removed '
829+
'in v1.3, use Trainer.evaluating instead.',
830+
DeprecationWarning, stacklevel=2
831+
)
832+
return bool(self.evaluating)
833+
834+
@property
835+
def tested_ckpt_path(self):
836+
warnings.warn(
837+
'Trainer.tested_ckpt_path has been renamed Trainer.evaluated_ckpt_path '
838+
'in v1.1 and will be removed in v1.3.',
839+
DeprecationWarning, stacklevel=2
840+
)
841+
return self.evaluated_ckpt_path
842+
833843
def tune(
834844
self,
835845
model: LightningModule,
@@ -856,11 +866,17 @@ def tune(
856866

857867
def call_setup_hook(self, model):
858868
# call setup after the ddp process has connected
859-
stage_name = 'test' if self.testing else 'fit'
869+
stage_name = self.evaluating or 'fit'
870+
860871
if self.datamodule is not None:
861-
called = self.datamodule.has_setup_test if self.testing else self.datamodule.has_setup_fit
872+
called = {
873+
None: self.datamodule.has_setup_fit,
874+
'test': self.datamodule.has_setup_test,
875+
}[self.evaluating]
876+
862877
if not called:
863878
self.datamodule.setup(stage_name)
879+
864880
self.setup(model, stage_name)
865881
model.setup(stage_name)
866882

pytorch_lightning/trainer/training_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def setup_training(self, model: LightningModule):
161161
ref_model.on_pretrain_routine_start()
162162

163163
# print model summary
164-
if self.trainer.is_global_zero and self.trainer.weights_summary is not None and not self.trainer.testing:
164+
if self.trainer.is_global_zero and self.trainer.weights_summary is not None and not self.trainer.evaluating:
165165
if self.trainer.weights_summary in ModelSummary.MODES:
166166
ref_model.summarize(mode=self.trainer.weights_summary)
167167
else:

tests/trainer/test_trainer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -728,12 +728,12 @@ def test_test_checkpoint_path(tmpdir, ckpt_path, save_top_k):
728728
trainer.test(ckpt_path=ckpt_path)
729729
else:
730730
trainer.test(ckpt_path=ckpt_path)
731-
assert trainer.tested_ckpt_path == trainer.checkpoint_callback.best_model_path
731+
assert trainer.evaluated_ckpt_path == trainer.checkpoint_callback.best_model_path
732732
elif ckpt_path is None:
733733
# ckpt_path is None, meaning we don't load any checkpoints and
734734
# use the weights from the end of training
735735
trainer.test(ckpt_path=ckpt_path)
736-
assert trainer.tested_ckpt_path is None
736+
assert trainer.evaluated_ckpt_path is None
737737
else:
738738
# specific checkpoint, pick one from saved ones
739739
if save_top_k == 0:
@@ -746,7 +746,7 @@ def test_test_checkpoint_path(tmpdir, ckpt_path, save_top_k):
746746
].absolute()
747747
)
748748
trainer.test(ckpt_path=ckpt_path)
749-
assert trainer.tested_ckpt_path == ckpt_path
749+
assert trainer.evaluated_ckpt_path == ckpt_path
750750

751751

752752
def test_disabled_training(tmpdir):

0 commit comments

Comments
 (0)