Skip to content

Commit f93414d

Browse files
authored
Prune metyrics: regression 9/n (#6637)
* psnr * r2score * ssim * chlog
1 parent efce2b7 commit f93414d

File tree

11 files changed

+80
-989
lines changed

11 files changed

+80
-989
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
9292

9393
[#6636](https://github.com/PyTorchLightning/pytorch-lightning/pull/6636),
9494

95+
[#6637](https://github.com/PyTorchLightning/pytorch-lightning/pull/6637),
96+
9597
)
9698

9799

pytorch_lightning/metrics/functional/psnr.py

Lines changed: 5 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -14,46 +14,12 @@
1414
from typing import Optional, Tuple, Union
1515

1616
import torch
17-
from torchmetrics.utilities import reduce
17+
from torchmetrics.functional import psnr as _psnr
1818

19-
from pytorch_lightning.utilities import rank_zero_warn
20-
21-
22-
def _psnr_compute(
23-
sum_squared_error: torch.Tensor,
24-
n_obs: torch.Tensor,
25-
data_range: torch.Tensor,
26-
base: float = 10.0,
27-
reduction: str = 'elementwise_mean',
28-
) -> torch.Tensor:
29-
psnr_base_e = 2 * torch.log(data_range) - torch.log(sum_squared_error / n_obs)
30-
psnr = psnr_base_e * (10 / torch.log(torch.tensor(base)))
31-
return reduce(psnr, reduction=reduction)
32-
33-
34-
def _psnr_update(preds: torch.Tensor,
35-
target: torch.Tensor,
36-
dim: Optional[Union[int, Tuple[int, ...]]] = None) -> Tuple[torch.Tensor, torch.Tensor]:
37-
if dim is None:
38-
sum_squared_error = torch.sum(torch.pow(preds - target, 2))
39-
n_obs = torch.tensor(target.numel(), device=target.device)
40-
return sum_squared_error, n_obs
41-
42-
sum_squared_error = torch.sum(torch.pow(preds - target, 2), dim=dim)
43-
44-
if isinstance(dim, int):
45-
dim_list = [dim]
46-
else:
47-
dim_list = list(dim)
48-
if not dim_list:
49-
n_obs = torch.tensor(target.numel(), device=target.device)
50-
else:
51-
n_obs = torch.tensor(target.size(), device=target.device)[dim_list].prod()
52-
n_obs = n_obs.expand_as(sum_squared_error)
53-
54-
return sum_squared_error, n_obs
19+
from pytorch_lightning.utilities.deprecation import deprecated
5520

5621

22+
@deprecated(target=_psnr, ver_deprecate="1.3.0", ver_remove="1.5.0")
5723
def psnr(
5824
preds: torch.Tensor,
5925
target: torch.Tensor,
@@ -63,50 +29,6 @@ def psnr(
6329
dim: Optional[Union[int, Tuple[int, ...]]] = None,
6430
) -> torch.Tensor:
6531
"""
66-
Computes the peak signal-to-noise ratio
67-
68-
Args:
69-
preds: estimated signal
70-
target: groun truth signal
71-
data_range:
72-
the range of the data. If None, it is determined from the data (max - min). ``data_range`` must be given
73-
when ``dim`` is not None.
74-
base: a base of a logarithm to use (default: 10)
75-
reduction: a method to reduce metric score over labels.
76-
77-
- ``'elementwise_mean'``: takes the mean (default)
78-
- ``'sum'``: takes the sum
79-
- ``'none'``: no reduction will be applied
80-
81-
dim:
82-
Dimensions to reduce PSNR scores over provided as either an integer or a list of integers. Default is
83-
None meaning scores will be reduced across all dimensions.
84-
Return:
85-
Tensor with PSNR score
86-
87-
Raises:
88-
ValueError:
89-
If ``dim`` is not ``None`` and ``data_range`` is not provided.
90-
91-
Example:
92-
>>> from pytorch_lightning.metrics.functional import psnr
93-
>>> pred = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
94-
>>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]])
95-
>>> psnr(pred, target)
96-
tensor(2.5527)
97-
32+
.. deprecated::
33+
Use :func:`torchmetrics.functional.psnr`. Will be removed in v1.5.0.
9834
"""
99-
if dim is None and reduction != 'elementwise_mean':
100-
rank_zero_warn(f'The `reduction={reduction}` will not have any effect when `dim` is None.')
101-
102-
if data_range is None:
103-
if dim is not None:
104-
# Maybe we could use `torch.amax(target, dim=dim) - torch.amin(target, dim=dim)` in PyTorch 1.7 to calculate
105-
# `data_range` in the future.
106-
raise ValueError("The `data_range` must be given when `dim` is not None.")
107-
108-
data_range = target.max() - target.min()
109-
else:
110-
data_range = torch.tensor(float(data_range))
111-
sum_squared_error, n_obs = _psnr_update(preds, target, dim=dim)
112-
return _psnr_compute(sum_squared_error, n_obs, data_range, base=base, reduction=reduction)

pytorch_lightning/metrics/functional/r2score.py

Lines changed: 6 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -11,133 +11,21 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Tuple
1514

1615
import torch
17-
from torchmetrics.utilities.checks import _check_same_shape
16+
from torchmetrics.functional import r2score as _r2score
1817

19-
from pytorch_lightning.utilities import rank_zero_warn
20-
21-
22-
def _r2score_update(
23-
preds: torch.tensor,
24-
target: torch.Tensor,
25-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
26-
_check_same_shape(preds, target)
27-
if preds.ndim > 2:
28-
raise ValueError(
29-
'Expected both prediction and target to be 1D or 2D tensors,'
30-
f' but recevied tensors with dimension {preds.shape}'
31-
)
32-
if len(preds) < 2:
33-
raise ValueError('Needs atleast two samples to calculate r2 score.')
34-
35-
sum_error = torch.sum(target, dim=0)
36-
sum_squared_error = torch.sum(torch.pow(target, 2.0), dim=0)
37-
residual = torch.sum(torch.pow(target - preds, 2.0), dim=0)
38-
total = target.size(0)
39-
40-
return sum_squared_error, sum_error, residual, total
41-
42-
43-
def _r2score_compute(
44-
sum_squared_error: torch.Tensor,
45-
sum_error: torch.Tensor,
46-
residual: torch.Tensor,
47-
total: torch.Tensor,
48-
adjusted: int = 0,
49-
multioutput: str = "uniform_average"
50-
) -> torch.Tensor:
51-
mean_error = sum_error / total
52-
diff = sum_squared_error - sum_error * mean_error
53-
raw_scores = 1 - (residual / diff)
54-
55-
if multioutput == "raw_values":
56-
r2score = raw_scores
57-
elif multioutput == "uniform_average":
58-
r2score = torch.mean(raw_scores)
59-
elif multioutput == "variance_weighted":
60-
diff_sum = torch.sum(diff)
61-
r2score = torch.sum(diff / diff_sum * raw_scores)
62-
else:
63-
raise ValueError(
64-
'Argument `multioutput` must be either `raw_values`,'
65-
f' `uniform_average` or `variance_weighted`. Received {multioutput}.'
66-
)
67-
68-
if adjusted < 0 or not isinstance(adjusted, int):
69-
raise ValueError('`adjusted` parameter should be an integer larger or' ' equal to 0.')
70-
71-
if adjusted != 0:
72-
if adjusted > total - 1:
73-
rank_zero_warn(
74-
"More independent regressions than datapoints in"
75-
" adjusted r2 score. Falls back to standard r2 score.", UserWarning
76-
)
77-
elif adjusted == total - 1:
78-
rank_zero_warn("Division by zero in adjusted r2 score. Falls back to" " standard r2 score.", UserWarning)
79-
else:
80-
r2score = 1 - (1 - r2score) * (total - 1) / (total - adjusted - 1)
81-
return r2score
18+
from pytorch_lightning.utilities.deprecation import deprecated
8219

8320

21+
@deprecated(target=_r2score, ver_deprecate="1.3.0", ver_remove="1.5.0")
8422
def r2score(
8523
preds: torch.Tensor,
8624
target: torch.Tensor,
8725
adjusted: int = 0,
8826
multioutput: str = "uniform_average",
8927
) -> torch.Tensor:
90-
r"""
91-
Computes r2 score also known as `coefficient of determination
92-
<https://en.wikipedia.org/wiki/Coefficient_of_determination>`_:
93-
94-
.. math:: R^2 = 1 - \frac{SS_res}{SS_tot}
95-
96-
where :math:`SS_res=\sum_i (y_i - f(x_i))^2` is the sum of residual squares, and
97-
:math:`SS_tot=\sum_i (y_i - \bar{y})^2` is total sum of squares. Can also calculate
98-
adjusted r2 score given by
99-
100-
.. math:: R^2_adj = 1 - \frac{(1-R^2)(n-1)}{n-k-1}
101-
102-
where the parameter :math:`k` (the number of independent regressors) should
103-
be provided as the ``adjusted`` argument.
104-
105-
Args:
106-
preds: estimated labels
107-
target: ground truth labels
108-
adjusted: number of independent regressors for calculating adjusted r2 score.
109-
Default 0 (standard r2 score).
110-
multioutput: Defines aggregation in the case of multiple output scores. Can be one
111-
of the following strings (default is ``'uniform_average'``.):
112-
113-
* ``'raw_values'`` returns full set of scores
114-
* ``'uniform_average'`` scores are uniformly averaged
115-
* ``'variance_weighted'`` scores are weighted by their individual variances
116-
117-
Raises:
118-
ValueError:
119-
If both ``preds`` and ``targets`` are not ``1D`` or ``2D`` tensors.
120-
ValueError:
121-
If ``len(preds)`` is less than ``2``
122-
since at least ``2`` sampels are needed to calculate r2 score.
123-
ValueError:
124-
If ``multioutput`` is not one of ``raw_values``,
125-
``uniform_average`` or ``variance_weighted``.
126-
ValueError:
127-
If ``adjusted`` is not an ``integer`` greater than ``0``.
128-
129-
Example:
130-
131-
>>> from pytorch_lightning.metrics.functional import r2score
132-
>>> target = torch.tensor([3, -0.5, 2, 7])
133-
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
134-
>>> r2score(preds, target)
135-
tensor(0.9486)
136-
137-
>>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]])
138-
>>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]])
139-
>>> r2score(preds, target, multioutput='raw_values')
140-
tensor([0.9654, 0.9082])
14128
"""
142-
sum_squared_error, sum_error, residual, total = _r2score_update(preds, target)
143-
return _r2score_compute(sum_squared_error, sum_error, residual, total, adjusted, multioutput)
29+
.. deprecated::
30+
Use :func:`torchmetrics.functional.r2score`. Will be removed in v1.5.0.
31+
"""

0 commit comments

Comments
 (0)