Skip to content

Commit e50653d

Browse files
Deprecate agg_key_funcs, agg_default_func, and update_agg_funcs from LightningLoggerBase (#11871)
Co-authored-by: Danielle Pintz <[email protected]>
1 parent 4c57155 commit e50653d

File tree

3 files changed

+74
-4
lines changed

3 files changed

+74
-4
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
408408
- Deprecated `pytorch_lightning.utilities.warnings.LightningDeprecationWarning` in favor of `pytorch_lightning.utilities.rank_zero.LightningDeprecationWarning`
409409

410410

411+
- Deprecated `agg_key_funcs` and `agg_default_func` parameters from `LightningLoggerBase` ([#11871](https://github.com/PyTorchLightning/pytorch-lightning/pull/11871))
412+
413+
414+
- Deprecated `LightningLoggerBase.update_agg_funcs` ([#11871](https://github.com/PyTorchLightning/pytorch-lightning/pull/11871))
415+
416+
411417
- Deprecated `LightningLoggerBase.agg_and_log_metrics` in favor of `LightningLoggerBase.log_metrics` ([#11832](https://github.com/PyTorchLightning/pytorch-lightning/pull/11832))
412418

413419

pytorch_lightning/loggers/base.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ class LightningLoggerBase(ABC):
5555
is not presented in the `agg_key_funcs` dictionary, then the
5656
`agg_default_func` will be used for aggregation.
5757
58+
.. deprecated:: v1.6
59+
The parameters `agg_key_funcs` and `agg_default_func` are deprecated
60+
in v1.6 and will be removed in v1.8.
61+
5862
Note:
5963
The `agg_key_funcs` and `agg_default_func` arguments are used only when
6064
one logs metrics with the :meth:`~LightningLoggerBase.agg_and_log_metrics` method.
@@ -63,12 +67,26 @@ class LightningLoggerBase(ABC):
6367
def __init__(
6468
self,
6569
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
66-
agg_default_func: Callable[[Sequence[float]], float] = np.mean,
70+
agg_default_func: Optional[Callable[[Sequence[float]], float]] = None,
6771
):
6872
self._prev_step: int = -1
6973
self._metrics_to_agg: List[Dict[str, float]] = []
70-
self._agg_key_funcs = agg_key_funcs if agg_key_funcs else {}
71-
self._agg_default_func = agg_default_func
74+
if agg_key_funcs:
75+
self._agg_key_funcs = agg_key_funcs
76+
rank_zero_deprecation(
77+
"The `agg_key_funcs` parameter for `LightningLoggerBase` was deprecated in v1.6"
78+
" and will be removed in v1.8."
79+
)
80+
else:
81+
self._agg_key_funcs = {}
82+
if agg_default_func:
83+
self._agg_default_func = agg_default_func
84+
rank_zero_deprecation(
85+
"The `agg_default_func` parameter for `LightningLoggerBase` was deprecated in v1.6"
86+
" and will be removed in v1.8."
87+
)
88+
else:
89+
self._agg_default_func = np.mean
7290

7391
def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> None:
7492
"""Called after model checkpoint callback saves a new checkpoint.
@@ -85,6 +103,9 @@ def update_agg_funcs(
85103
):
86104
"""Update aggregation methods.
87105
106+
.. deprecated:: v1.6
107+
`update_agg_funcs` is deprecated in v1.6 and will be removed in v1.8.
108+
88109
Args:
89110
agg_key_funcs:
90111
Dictionary which maps a metric name to a function, which will
@@ -98,6 +119,9 @@ def update_agg_funcs(
98119
self._agg_key_funcs.update(agg_key_funcs)
99120
if agg_default_func:
100121
self._agg_default_func = agg_default_func
122+
rank_zero_deprecation(
123+
"`LightningLoggerBase.update_agg_funcs` was deprecated in v1.6 and will be removed in v1.8."
124+
)
101125

102126
def _aggregate_metrics(
103127
self, metrics: Dict[str, float], step: Optional[int] = None

tests/deprecated_api/test_remove_1-8.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@
1414
"""Test deprecated functionality which will be removed in v1.8.0."""
1515
from unittest.mock import Mock
1616

17+
import numpy as np
1718
import pytest
1819
import torch
1920
from torch import optim
2021

2122
from pytorch_lightning import Callback, Trainer
22-
from pytorch_lightning.loggers import CSVLogger
23+
from pytorch_lightning.loggers import CSVLogger, LightningLoggerBase
2324
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
2425
from pytorch_lightning.plugins.training_type.ddp2 import DDP2Plugin
2526
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
@@ -503,6 +504,45 @@ def on_before_accelerator_backend_setup(self, *args, **kwargs):
503504
trainer.fit(model)
504505

505506

507+
def test_v1_8_0_logger_agg_parameters():
508+
class CustomLogger(LightningLoggerBase):
509+
@rank_zero_only
510+
def log_hyperparams(self, params):
511+
pass
512+
513+
@rank_zero_only
514+
def log_metrics(self, metrics, step):
515+
pass
516+
517+
@property
518+
def name(self):
519+
pass
520+
521+
@property
522+
def version(self):
523+
pass
524+
525+
with pytest.deprecated_call(
526+
match="The `agg_key_funcs` parameter for `LightningLoggerBase` was deprecated in v1.6"
527+
" and will be removed in v1.8."
528+
):
529+
CustomLogger(agg_key_funcs={"mean", np.mean})
530+
531+
with pytest.deprecated_call(
532+
match="The `agg_default_func` parameter for `LightningLoggerBase` was deprecated in v1.6"
533+
" and will be removed in v1.8."
534+
):
535+
CustomLogger(agg_default_func=np.mean)
536+
537+
# Should have no deprecation warning
538+
logger = CustomLogger()
539+
540+
with pytest.deprecated_call(
541+
match="`LightningLoggerBase.update_agg_funcs` was deprecated in v1.6 and will be removed in v1.8."
542+
):
543+
logger.update_agg_funcs()
544+
545+
506546
def test_v1_8_0_deprecated_agg_and_log_metrics_override(tmpdir):
507547
class AggregationOverrideLogger(CSVLogger):
508548
@rank_zero_only

0 commit comments

Comments
 (0)