Skip to content

Commit 0dd2dee

Browse files
authored
Remove legacy support for the magic log/progress_bar keys in dict returns (#6734)
1 parent f9bb7c6 commit 0dd2dee

24 files changed

+141
-943
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
165165
- Removed legacy code to include `step` dictionary returns in `callback_metrics`. Use `self.log_dict` instead. ([#6682](https://github.com/PyTorchLightning/pytorch-lightning/pull/6682))
166166

167167

168+
- Removed legacy code to log or include metrics in the progress bar by returning them in a dict with the `"log"/"progress_bar"` magic keys. Use `self.log` instead ([#6734](https://github.com/PyTorchLightning/pytorch-lightning/pull/6734))
169+
170+
168171
- Removed `optimizer_idx` argument from `training_step` in manual optimization ([#6093](https://github.com/PyTorchLightning/pytorch-lightning/pull/6093))
169172

170173

docs/source/ecosystem/asr_nlp_tts.rst

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -751,13 +751,8 @@ be customized with PyTorch Lightning since every NeMo model is a LightningModule
751751
752752
l_mle, l_length, logdet, loss, _ = self.step(y, y_lengths, x, x_lengths)
753753
754-
output = {
755-
"loss": loss, # required
756-
"progress_bar": {"l_mle": l_mle, "l_length": l_length, "logdet": logdet},
757-
"log": {"loss": loss, "l_mle": l_mle, "l_length": l_length, "logdet": logdet},
758-
}
759-
760-
return output
754+
self.log_dict({"l_mle": l_mle, "l_length": l_length, "logdet": logdet}, prog_bar=True)
755+
return loss
761756
...
762757
763758
Neural Types in NeMo TTS

pytorch_lightning/core/step_result.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -526,10 +526,6 @@ def reduce_across_time(cls, time_outputs):
526526
# auto-reduce across time for tbptt
527527
meta = time_outputs[0]['meta']
528528

529-
# in 1.0 the results have 'extra'. Once we deprecate 0.10.0 we may not need this
530-
if 'extra' in time_outputs[0]:
531-
[x.pop('extra', None) for x in time_outputs]
532-
533529
result = cls()
534530
result = recursive_gather(time_outputs, result)
535531
recursive_stack(result)

pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -394,12 +394,10 @@ def run_batch_from_func_name(self, func_name) -> Dict:
394394

395395
def get_latest_batch_log_metrics(self) -> Dict:
396396
batch_log_metrics = self.run_batch_from_func_name("get_batch_log_metrics")
397-
batch_log_metrics.update(self.legacy_batch_log_metrics)
398397
return batch_log_metrics
399398

400399
def get_latest_batch_pbar_metrics(self) -> Dict:
401400
batch_pbar_metrics = self.run_batch_from_func_name("get_batch_pbar_metrics")
402-
batch_pbar_metrics.update(self.legacy_batch_pbar_metrics)
403401
return batch_pbar_metrics
404402

405403
@property
@@ -451,8 +449,6 @@ def reset(self):
451449
self._opt_idx: Optional[int] = None
452450
self._batch_size: Optional[int] = None
453451
self._has_batch_loop_finished = False
454-
self.legacy_batch_log_metrics = {}
455-
self.legacy_batch_pbar_metrics = {}
456452

457453
def __call__(
458454
self,

pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,7 @@ def cache_training_step_metrics(self, opt_closure_result):
191191
self.add_progress_bar_metrics(pbar_metrics_tmp)
192192

193193
self._callback_metrics.update(callback_metrics_tmp)
194-
195-
# save legacy log metrics
196194
self._logged_metrics.update(logged_metrics_tmp)
197-
self.cached_results.legacy_batch_log_metrics.update(logged_metrics_tmp)
198195

199196
def log_metrics(self, metrics, grad_norm_dic, step=None):
200197
"""Logs the metric dict passed in.

pytorch_lightning/trainer/logging.py

Lines changed: 0 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -12,25 +12,15 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import inspect
1615
from abc import ABC
17-
from collections import Mapping
1816

1917
import torch
2018

21-
from pytorch_lightning.utilities import DistributedType
22-
from pytorch_lightning.utilities.distributed import rank_zero_warn
2319
from pytorch_lightning.utilities.exceptions import MisconfigurationException
24-
from pytorch_lightning.utilities.memory import recursive_detach
2520

2621

2722
class TrainerLoggingMixin(ABC):
2823

29-
# this is just a summary on variables used in this abstract class,
30-
# the proper values/initialisation should be done in child class
31-
_distrib_type: DistributedType
32-
num_gpus: int
33-
3424
def metrics_to_scalars(self, metrics):
3525
new_metrics = {}
3626
# TODO: this is duplicated in MetricsHolder. should be unified
@@ -49,128 +39,3 @@ def metrics_to_scalars(self, metrics):
4939
new_metrics[k] = v
5040

5141
return new_metrics
52-
53-
def process_dict_result(self, output, train=False):
54-
"""Reduces output according to the training mode.
55-
56-
Separates loss from logging and progress bar metrics
57-
"""
58-
# --------------------
59-
# WARN DEPRECATED KEYS
60-
# --------------------
61-
# TODO: 1.0.0 remove
62-
if isinstance(output, dict):
63-
for k, v in output.items():
64-
if k in ['log', 'progress_bar']:
65-
m = inspect.cleandoc(
66-
f"The {{{k}:dict keyword}} was deprecated in 0.9.1 and will be removed in 1.0.0\n"
67-
" Please use self.log(...) inside the lightningModule instead.\n"
68-
" # log on a step or aggregate epoch metric to the logger and/or progress bar"
69-
" (inside LightningModule)\n"
70-
" self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)"
71-
)
72-
rank_zero_warn(m)
73-
74-
# --------------------------
75-
# handle single scalar only
76-
# --------------------------
77-
# single scalar returned from a xx_step
78-
if isinstance(output, torch.Tensor):
79-
return output, {}, {}, None
80-
81-
# ---------------
82-
# EXTRACT PROGRESS BAR KEYS
83-
# ---------------
84-
try:
85-
progress_output = output['progress_bar']
86-
87-
# reduce progress metrics for progress bar when using dp
88-
if train and self._distrib_type in (DistributedType.DP, DistributedType.DDP2):
89-
num_gpus = self.num_gpus
90-
progress_output = self.reduce_distributed_output(progress_output, num_gpus)
91-
92-
progress_bar_metrics = progress_output
93-
# todo: specify the possible exception
94-
except Exception:
95-
progress_bar_metrics = {}
96-
97-
# ---------------
98-
# EXTRACT LOGGING KEYS
99-
# ---------------
100-
# extract metrics to log to experiment
101-
try:
102-
log_output = output['log']
103-
104-
# reduce progress metrics for progress bar when using dp
105-
if train and self._distrib_type in (DistributedType.DP, DistributedType.DDP2):
106-
num_gpus = self.num_gpus
107-
log_output = self.reduce_distributed_output(log_output, num_gpus)
108-
109-
log_metrics = log_output
110-
# todo: specify the possible exception
111-
except Exception:
112-
log_metrics = {}
113-
114-
# ---------------
115-
# EXTRACT LOSS
116-
# ---------------
117-
# if output dict doesn't have the keyword loss
118-
# then assume the output=loss if scalar
119-
loss = None
120-
if train:
121-
try:
122-
loss = output['loss']
123-
# todo: specify the possible exception
124-
except Exception as exp:
125-
if isinstance(output, torch.Tensor):
126-
loss = output
127-
else:
128-
raise RuntimeError(
129-
'No `loss` value in the dictionary returned from `model.training_step()`.'
130-
) from exp
131-
132-
# when using dp need to reduce the loss
133-
if self._distrib_type in (DistributedType.DP, DistributedType.DDP2):
134-
loss = self.reduce_distributed_output(loss, self.num_gpus)
135-
136-
# ---------------
137-
# EXTRACT HIDDEN
138-
# ---------------
139-
hiddens = output.get('hiddens', None) if isinstance(output, Mapping) else None
140-
if hiddens is not None:
141-
hiddens = hiddens.detach()
142-
143-
# detach all metrics for callbacks to prevent memory leaks
144-
# no .item() because it will slow things down
145-
progress_bar_metrics = recursive_detach(progress_bar_metrics)
146-
log_metrics = recursive_detach(log_metrics)
147-
148-
return loss, progress_bar_metrics, log_metrics, hiddens
149-
150-
def reduce_distributed_output(self, output, num_gpus):
151-
if num_gpus <= 1:
152-
return output
153-
154-
# when using DP, we get one output per gpu
155-
# average outputs and return
156-
if isinstance(output, torch.Tensor):
157-
return output.mean()
158-
159-
for k, v in output.items():
160-
# recurse on nested dics
161-
if isinstance(output[k], dict):
162-
output[k] = self.reduce_distributed_output(output[k], num_gpus)
163-
164-
# compute the average of scalars
165-
elif isinstance(output[k], list):
166-
output[k] = sum(output[k]) / len(output[k])
167-
168-
# do nothing when there's a scalar
169-
elif isinstance(output[k], torch.Tensor) and output[k].dim() == 0:
170-
pass
171-
172-
# do not reduce metrics that have batch size > num gpus
173-
elif output[k].size(0) <= num_gpus:
174-
output[k] = torch.mean(output[k])
175-
176-
return output

pytorch_lightning/trainer/training_loop.py

Lines changed: 7 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from pytorch_lightning.utilities import _TPU_AVAILABLE, AMPType, DeviceType, parsing
2929
from pytorch_lightning.utilities.distributed import rank_zero_info
3030
from pytorch_lightning.utilities.exceptions import MisconfigurationException
31-
from pytorch_lightning.utilities.memory import recursive_detach
3231
from pytorch_lightning.utilities.model_helpers import is_overridden
3332
from pytorch_lightning.utilities.parsing import AttributeDict
3433
from pytorch_lightning.utilities.warnings import WarningCache
@@ -242,12 +241,7 @@ def get_optimizers_iterable(self):
242241
return [[opt_idx, self.trainer.optimizers[opt_idx]]]
243242

244243
def on_after_backward(self, training_step_output, batch_idx, untouched_loss):
245-
is_result_obj = isinstance(training_step_output, Result)
246-
247-
if is_result_obj:
248-
training_step_output = training_step_output.detach()
249-
else:
250-
training_step_output.batch_loss = training_step_output.batch_loss.detach()
244+
training_step_output.detach()
251245

252246
# insert after step hook
253247
self.trainer.call_hook("on_after_backward")
@@ -284,24 +278,16 @@ def training_step(self, split_batch, batch_idx, opt_idx, hiddens):
284278
training_step_output_for_epoch_end, training_step_output = self._process_training_step_output(
285279
training_step_output, split_batch
286280
)
287-
is_result_obj = isinstance(training_step_output, Result)
288-
289281
if training_step_output_for_epoch_end is None:
290-
return None
282+
return
291283

292284
# enable empty loss when using manual opt
293285
closure_loss = None
294286
untouched_loss = None
295287

296288
if self.automatic_optimization:
297-
# accumulate loss
298-
# (if accumulate_grad_batches = 1 no effect)
299-
if is_result_obj:
300-
closure_loss = training_step_output.minimize
301-
else:
302-
closure_loss = training_step_output.batch_loss
303-
304-
closure_loss = closure_loss / self.trainer.accumulate_grad_batches
289+
# accumulate loss. if accumulate_grad_batches==1, no effect
290+
closure_loss = training_step_output.minimize / self.trainer.accumulate_grad_batches
305291

306292
# the loss will get scaled for amp. avoid any modifications to it
307293
untouched_loss = closure_loss.detach().clone()
@@ -322,35 +308,6 @@ def _process_training_step_output(self, training_step_output, split_batch):
322308
if training_step_output_for_epoch_end is None:
323309
return None, None
324310

325-
# -----------------------------------------
326-
# process hybrid (1.0)
327-
# -----------------------------------------
328-
# no need for these checks in 1.0.0
329-
# TODO: remove checks in 1.0.0
330-
is_tensor = isinstance(training_step_output_for_epoch_end, torch.Tensor)
331-
is_1_0_output = is_tensor or ("log" not in training_step_output and "progress_bar" not in training_step_output)
332-
if is_1_0_output:
333-
return self._process_training_step_output_1_0(training_step_output, split_batch)
334-
335-
# -----------------------------------------
336-
# process old dict (deprecate 1.0)
337-
# -----------------------------------------
338-
training_step_output = self.trainer.process_dict_result(training_step_output, train=True)
339-
340-
training_step_output = AttributeDict(
341-
batch_loss=training_step_output[0],
342-
pbar_on_batch_end=training_step_output[1],
343-
log_metrics=training_step_output[2],
344-
)
345-
# if the user decides to finally reduce things in epoch_end, save raw output without graphs
346-
if isinstance(training_step_output_for_epoch_end, torch.Tensor):
347-
training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach()
348-
else:
349-
training_step_output_for_epoch_end = recursive_detach(training_step_output_for_epoch_end)
350-
351-
return training_step_output_for_epoch_end, training_step_output
352-
353-
def _process_training_step_output_1_0(self, training_step_output, split_batch):
354311
result = self.trainer.lightning_module._results
355312

356313
loss = None
@@ -361,6 +318,8 @@ def _process_training_step_output_1_0(self, training_step_output, split_batch):
361318
if isinstance(training_step_output, dict):
362319
loss = training_step_output.pop("loss", None)
363320
hiddens = training_step_output.pop("hiddens", None)
321+
if hiddens is not None:
322+
hiddens = hiddens.detach()
364323
result["extra"] = training_step_output
365324

366325
# handle scalar return
@@ -380,10 +339,7 @@ def _process_training_step_output_1_0(self, training_step_output, split_batch):
380339
if self.trainer.move_metrics_to_cpu:
381340
training_step_output_for_epoch_end = training_step_output_for_epoch_end.cpu()
382341

383-
# what flows back into the system
384-
training_step_output = result
385-
386-
return training_step_output_for_epoch_end, training_step_output
342+
return training_step_output_for_epoch_end, result
387343

388344
def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure):
389345
model_ref = self.trainer.lightning_module

tests/checkpointing/test_model_checkpoint.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -876,7 +876,6 @@ def validation_epoch_end(self, outputs):
876876
assert trainer.dev_debugger.checkpoint_callback_history[-1]['epoch'] == len(monitor) - 1
877877

878878

879-
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
880879
def test_checkpoint_repeated_strategy(tmpdir):
881880
"""
882881
This test validates that the checkpoint can be called when provided to callbacks list
@@ -923,7 +922,6 @@ def validation_step(self, batch, batch_idx):
923922
assert set(os.listdir(tmpdir.join("lightning_logs"))) == {f'version_{i}' for i in range(4)}
924923

925924

926-
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
927925
def test_checkpoint_repeated_strategy_extended(tmpdir):
928926
"""
929927
This test validates checkpoint can be called several times without

0 commit comments

Comments
 (0)