Skip to content

Commit bc61361

Browse files
authored
Do not add return dict items to callback_metrics (#6682)
1 parent 6b990f3 commit bc61361

File tree

16 files changed

+84
-341
lines changed

16 files changed

+84
-341
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
150150
- Removed deprecated `LightningModule` `hparams` setter ([#6207](https://github.com/PyTorchLightning/pytorch-lightning/pull/6207))
151151

152152

153+
- Removed legacy code to include `step` dictionary returns in `callback_metrics`. Use `self.log_dict` instead. ([#6682](https://github.com/PyTorchLightning/pytorch-lightning/pull/6682))
154+
155+
153156
- Removed `optimizer_idx` argument from `training_step` in manual optimization ([#6093](https://github.com/PyTorchLightning/pytorch-lightning/pull/6093))
154157

155158

docs/source/ecosystem/asr_nlp_tts.rst

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -270,12 +270,12 @@ with PyTorch Lightning since every NeMo model is a Lightning Module.
270270
log_probs=log_probs, targets=transcript, input_lengths=encoded_len, target_lengths=transcript_len
271271
)
272272
wer_num, wer_denom = self._wer(predictions, transcript, transcript_len)
273-
tensorboard_logs = {
273+
self.log_dict({
274274
'train_loss': loss_value,
275275
'training_batch_wer': wer_num / wer_denom,
276276
'learning_rate': self._optimizer.param_groups[0]['lr'],
277-
}
278-
return {'loss': loss_value, 'log': tensorboard_logs}
277+
})
278+
return loss_value
279279
280280
Neural Types in NeMo ASR
281281
------------------------
@@ -539,8 +539,8 @@ since every NeMo model is a Lightning Module.
539539
logits = self(input_ids=input_ids, token_type_ids=input_type_ids, attention_mask=input_mask)
540540
541541
loss = self.loss(logits=logits, labels=labels, loss_mask=loss_mask)
542-
tensorboard_logs = {'train_loss': loss, 'lr': self._optimizer.param_groups[0]['lr']}
543-
return {'loss': loss, 'log': tensorboard_logs}
542+
self.log_dict({'train_loss': loss, 'lr': self._optimizer.param_groups[0]['lr']})
543+
return loss
544544
...
545545
546546
Neural Types in NeMo NLP

docs/source/ecosystem/bolts.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ you can trust the implementations and use them to bootstrap your research much f
6868
6969
loss = self.criterion(logits.view(-1, logits.size(-1)), x.view(-1).long())
7070
71-
logs = {"loss": loss}
72-
return {"loss": loss, "log": logs}
71+
self.log("loss", loss)
72+
return loss
7373
7474
----------
7575

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -590,7 +590,7 @@ def _validate_monitor_key(self, trainer):
590590
m = (
591591
f"ModelCheckpoint(monitor='{self.monitor}') not found in the returned metrics:"
592592
f" {list(metrics.keys())}. "
593-
f"HINT: Did you call self.log('{self.monitor}', tensor) in the LightningModule?"
593+
f"HINT: Did you call self.log('{self.monitor}', value) in the LightningModule?"
594594
)
595595
raise MisconfigurationException(m)
596596

pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,6 @@ def update_logger_connector(self) -> Tuple[Dict, Dict]:
346346

347347
# update callback_metrics
348348
logger_connector._callback_metrics.update(callback_metrics)
349-
logger_connector._callback_metrics.pop("epoch", None)
350349

351350
batch_pbar_metrics.pop("debug_epoch", None)
352351
return batch_pbar_metrics, batch_log_metrics

pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py

Lines changed: 17 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def progress_bar_metrics(self, progress_bar_metrics: Dict) -> None:
7878

7979
@property
8080
def cached_results(self) -> Union[EpochResultStore, None]:
81-
return self._cached_results.get(self.trainer._running_stage) # type: ignore
81+
return self._cached_results.get(self.trainer._running_stage)
8282

8383
def get_metrics(self, key: str) -> Dict:
8484
metrics_holder: MetricsHolder = getattr(self, f"_{key}")
@@ -121,8 +121,6 @@ def cache_logged_metrics(self):
121121
def on_trainer_init(self, logger, flush_logs_every_n_steps: int, log_every_n_steps: int, move_metrics_to_cpu: bool):
122122
# logging
123123
self.configure_logger(logger)
124-
# todo: IDE is complaining, these shall be initialized in the Trainer init at leas as placeholders
125-
# and assign here the desired value
126124
self.trainer.flush_logs_every_n_steps = flush_logs_every_n_steps
127125
self.trainer.log_every_n_steps = log_every_n_steps
128126
self.trainer.move_metrics_to_cpu = move_metrics_to_cpu
@@ -185,9 +183,6 @@ def cache_training_step_metrics(self, opt_closure_result):
185183
batch_log_metrics = opt_closure_result.training_step_output.log_metrics
186184
logged_metrics_tmp.update(batch_log_metrics)
187185

188-
callback_metrics = opt_closure_result.training_step_output.callback_metrics
189-
callback_metrics_tmp.update(callback_metrics)
190-
191186
batch_pbar_metrics = opt_closure_result.training_step_output.pbar_on_batch_end
192187
pbar_metrics_tmp.update(batch_pbar_metrics)
193188

@@ -210,9 +205,6 @@ def log_metrics(self, metrics, grad_norm_dic, step=None):
210205
metrics (dict): Metric values
211206
grad_norm_dic (dict): Gradient norms
212207
step (int): Step for which metrics should be logged. Default value corresponds to `self.global_step`
213-
log_train_step_metrics (bool): Used to track if `log_metrics` function is being called in during training
214-
steps. In training steps, we will log metrics on step: `total_nb_idx` (for accumulated gradients)
215-
and global_step for the rest.
216208
"""
217209
# add gpu memory
218210
if self.trainer._device_type == DeviceType.GPU and self.log_gpu_memory:
@@ -348,27 +340,6 @@ def _track_callback_metrics(self, eval_results):
348340
if self.trainer.state in (TrainerState.TESTING, TrainerState.VALIDATING):
349341
self.trainer.logger_connector.evaluation_callback_metrics.update(flat)
350342

351-
def __process_eval_epoch_end_results_and_log_legacy_update(self, prog_bar_metrics, log_metrics, callback_metrics):
352-
# eval loop returns all metrics
353-
dataloader_result_metrics = {**prog_bar_metrics, **log_metrics, **callback_metrics}
354-
355-
# add metrics to prog bar
356-
self.trainer.logger_connector.add_progress_bar_metrics(prog_bar_metrics)
357-
358-
# log metrics
359-
if len(log_metrics) > 0:
360-
self.trainer.logger_connector.log_metrics(log_metrics, {})
361-
362-
# track metrics for callbacks (all prog bar, logged and callback metrics)
363-
callback_metrics.update(log_metrics)
364-
callback_metrics.update(prog_bar_metrics)
365-
self.trainer.logger_connector.callback_metrics.update(callback_metrics)
366-
if self.trainer.state in (TrainerState.TESTING, TrainerState.VALIDATING):
367-
self.trainer.logger_connector.evaluation_callback_metrics.update(callback_metrics)
368-
369-
if len(dataloader_result_metrics) > 0:
370-
self.eval_loop_results.append(dataloader_result_metrics)
371-
372343
def __process_eval_epoch_end_results_and_log_legacy(self, eval_results):
373344
if self.trainer.sanity_checking:
374345
return
@@ -379,21 +350,21 @@ def __process_eval_epoch_end_results_and_log_legacy(self, eval_results):
379350
if not isinstance(eval_results, list):
380351
eval_results = [eval_results]
381352

382-
num_loaders: int = self.trainer.evaluation_loop.num_dataloaders
383-
prog_bar_metrics, log_metrics, callback_metrics = {}, {}, {}
384-
385353
for result_idx, result in enumerate(eval_results):
386-
_, prog_bar_metrics, log_metrics, callback_metrics, _ = self.trainer.process_dict_result(result)
354+
_, prog_bar_metrics, log_metrics, _ = self.trainer.process_dict_result(result)
355+
356+
# eval loop returns all metrics
357+
dataloader_result_metrics = {**prog_bar_metrics, **log_metrics}
358+
359+
# add metrics to prog bar
360+
self.trainer.logger_connector.add_progress_bar_metrics(prog_bar_metrics)
387361

388-
if num_loaders > 1:
389-
self.__process_eval_epoch_end_results_and_log_legacy_update(
390-
prog_bar_metrics, log_metrics, callback_metrics
391-
)
362+
# log metrics
363+
if len(log_metrics) > 0:
364+
self.trainer.logger_connector.log_metrics(log_metrics, {})
392365

393-
if num_loaders == 1:
394-
self.__process_eval_epoch_end_results_and_log_legacy_update(
395-
prog_bar_metrics, log_metrics, callback_metrics
396-
)
366+
if len(dataloader_result_metrics) > 0:
367+
self.eval_loop_results.append(dataloader_result_metrics)
397368

398369
def on_train_epoch_end(self):
399370
# inform cached logger connector epoch finished
@@ -446,10 +417,9 @@ def log_train_epoch_end_metrics(
446417

447418
# TODO: deprecate 1.0
448419
else:
449-
out = self.__run_legacy_training_epoch_end(
450-
num_optimizers, epoch_output, model, is_result_obj, epoch_callback_metrics
420+
epoch_log_metrics, epoch_progress_bar_metrics = self.__run_legacy_training_epoch_end(
421+
num_optimizers, epoch_output, model, is_result_obj
451422
)
452-
epoch_log_metrics, epoch_progress_bar_metrics, epoch_callback_metrics = out
453423

454424
# it will perform reduction over epoch and return log metrics
455425
cached_epoch_log_metrics = self.cached_results.get_epoch_log_metrics()
@@ -501,9 +471,7 @@ def training_epoch_end(self, model, epoch_output, num_optimizers):
501471
# capture logging
502472
self.trainer.logger_connector.cache_logged_metrics()
503473

504-
def __run_legacy_training_epoch_end(
505-
self, num_optimizers, epoch_output, model, is_result_obj, epoch_callback_metrics
506-
):
474+
def __run_legacy_training_epoch_end(self, num_optimizers, epoch_output, model, is_result_obj):
507475

508476
epoch_log_metrics = {}
509477
epoch_progress_bar_metrics = {}
@@ -534,15 +502,14 @@ def __run_legacy_training_epoch_end(
534502
_processed_outputs = self.trainer.process_dict_result(epoch_output)
535503
epoch_progress_bar_metrics = _processed_outputs[1]
536504
epoch_log_metrics = _processed_outputs[2]
537-
epoch_callback_metrics = _processed_outputs[3]
538505

539506
# --------------------------
540507
# Structured Result (auto epoch end)
541508
# --------------------------
542509
elif is_result_obj:
543510
epoch_log_metrics, epoch_progress_bar_metrics = self.__auto_reduce_results_on_epoch_end(epoch_output)
544511

545-
return epoch_log_metrics, epoch_progress_bar_metrics, epoch_callback_metrics
512+
return epoch_log_metrics, epoch_progress_bar_metrics
546513

547514
def __auto_reduce_results_on_epoch_end(self, epoch_output):
548515
epoch_log_metrics = {}

pytorch_lightning/trainer/logging.py

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from pytorch_lightning.utilities import DistributedType
2222
from pytorch_lightning.utilities.distributed import rank_zero_warn
23+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2324
from pytorch_lightning.utilities.memory import recursive_detach
2425

2526

@@ -32,8 +33,14 @@ class TrainerLoggingMixin(ABC):
3233

3334
def metrics_to_scalars(self, metrics):
3435
new_metrics = {}
36+
# TODO: this is duplicated in MetricsHolder. should be unified
3537
for k, v in metrics.items():
3638
if isinstance(v, torch.Tensor):
39+
if v.numel() != 1:
40+
raise MisconfigurationException(
41+
f"The metric `{k}` does not contain a single element"
42+
f" thus it cannot be converted to float. Found `{v}`"
43+
)
3744
v = v.item()
3845

3946
if isinstance(v, dict):
@@ -71,23 +78,8 @@ def process_dict_result(self, output, train=False):
7178
if isinstance(output, torch.Tensor):
7279
progress_bar_metrics = {}
7380
log_metrics = {}
74-
callback_metrics = {}
7581
hiddens = None
76-
return output, progress_bar_metrics, log_metrics, callback_metrics, hiddens
77-
78-
# ---------------
79-
# EXTRACT CALLBACK KEYS
80-
# ---------------
81-
# all keys not progress_bar or log are candidates for callbacks
82-
callback_metrics = {}
83-
if isinstance(output, Mapping):
84-
for k, v in output.items():
85-
if k not in ['progress_bar', 'log', 'hiddens']:
86-
callback_metrics[k] = v
87-
88-
if train and self._distrib_type in (DistributedType.DP, DistributedType.DDP2):
89-
num_gpus = self.num_gpus
90-
callback_metrics = self.reduce_distributed_output(callback_metrics, num_gpus)
82+
return output, progress_bar_metrics, log_metrics, hiddens
9183

9284
# ---------------
9385
# EXTRACT PROGRESS BAR KEYS
@@ -149,17 +141,12 @@ def process_dict_result(self, output, train=False):
149141
# ---------------
150142
hiddens = output.get('hiddens', None) if isinstance(output, Mapping) else None
151143

152-
# use every metric passed in as a candidate for callback
153-
callback_metrics.update(progress_bar_metrics)
154-
callback_metrics.update(log_metrics)
155-
156144
# detach all metrics for callbacks to prevent memory leaks
157145
# no .item() because it will slow things down
158-
callback_metrics = recursive_detach(callback_metrics)
159146
progress_bar_metrics = recursive_detach(progress_bar_metrics)
160147
log_metrics = recursive_detach(log_metrics)
161148

162-
return loss, progress_bar_metrics, log_metrics, callback_metrics, hiddens
149+
return loss, progress_bar_metrics, log_metrics, hiddens
163150

164151
def reduce_distributed_output(self, output, num_gpus):
165152
if num_gpus <= 1:

pytorch_lightning/trainer/trainer.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -823,15 +823,6 @@ def run_sanity_check(self, ref_model):
823823
# run eval step
824824
_, eval_results = self.run_evaluation()
825825

826-
# allow no returns from eval
827-
if eval_results is not None and len(eval_results) > 0:
828-
# when we get a list back, used only the last item
829-
if isinstance(eval_results, list):
830-
eval_results = eval_results[-1]
831-
832-
_, _, _, callback_metrics, _ = self.process_dict_result(eval_results)
833-
self.logger_connector.callback_metrics = callback_metrics
834-
835826
self.on_sanity_check_end()
836827

837828
self._running_stage = stage

pytorch_lightning/trainer/training_loop.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,8 +348,7 @@ def _process_training_step_output(self, training_step_output, split_batch):
348348
batch_loss=training_step_output[0],
349349
pbar_on_batch_end=training_step_output[1],
350350
log_metrics=training_step_output[2],
351-
callback_metrics=training_step_output[3],
352-
hiddens=training_step_output[4],
351+
hiddens=training_step_output[3],
353352
)
354353
# if the user decides to finally reduce things in epoch_end, save raw output without graphs
355354
if isinstance(training_step_output_for_epoch_end, torch.Tensor):

tests/base/model_valid_epoch_ends.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,8 @@ def _mean(res, key):
4343
val_loss_mean = val_loss_mean.item()
4444
val_acc_mean = val_acc_mean.item()
4545

46-
metrics_dict = {'early_stop_on': val_loss_mean, 'val_acc': val_acc_mean}
47-
results = {'progress_bar': metrics_dict, 'log': metrics_dict}
48-
return results
46+
self.log('early_stop_on', val_loss_mean, prog_bar=True)
47+
self.log('val_acc', val_acc_mean, prog_bar=True)
4948

5049
def validation_epoch_end__multiple_dataloaders(self, outputs):
5150
"""

0 commit comments

Comments
 (0)