Skip to content

Commit ea026ae

Browse files
SkafteNickiBordapre-commit-ci[bot]
authored
Make data_range required parameter (#3178)
* change psnr and psnrb implementations * fix tests * changelog * fix deprecated * typing ? --------- Co-authored-by: Jirka B <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 533fe95 commit ea026ae

File tree

9 files changed

+100
-91
lines changed

9 files changed

+100
-91
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3939

4040
### Changed
4141

42-
-
42+
- Changed `data_range` argument in `PSNR` metric to be a required argument ([3178](https://github.com/Lightning-AI/torchmetrics/pull/3178))
4343

4444

4545
### Deprecated

src/torchmetrics/functional/image/_deprecated.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def _image_gradients(img: Tensor) -> tuple[Tensor, Tensor]:
8080
def _peak_signal_noise_ratio(
8181
preds: Tensor,
8282
target: Tensor,
83-
data_range: Optional[Union[float, tuple[float, float]]] = None,
83+
data_range: Union[float, tuple[float, float]] = 3.0,
8484
base: float = 10.0,
8585
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
8686
dim: Optional[Union[int, tuple[int, ...]]] = None,

src/torchmetrics/functional/image/psnr.py

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def _psnr_update(
9595
def peak_signal_noise_ratio(
9696
preds: Tensor,
9797
target: Tensor,
98-
data_range: Optional[Union[float, tuple[float, float]]] = None,
98+
data_range: Union[float, tuple[float, float]],
9999
base: float = 10.0,
100100
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
101101
dim: Optional[Union[int, tuple[int, ...]]] = None,
@@ -106,9 +106,8 @@ def peak_signal_noise_ratio(
106106
preds: estimated signal
107107
target: groun truth signal
108108
data_range:
109-
the range of the data. If None, it is determined from the data (max - min). If a tuple is provided then
110-
the range is calculated as the difference and input is clamped between the values.
111-
The ``data_range`` must be given when ``dim`` is not None.
109+
the range of the data. If a tuple is provided then the range is calculated as the difference and
110+
input is clamped between the values.
112111
base: a base of a logarithm to use
113112
reduction: a method to reduce metric score over labels.
114113
@@ -123,15 +122,11 @@ def peak_signal_noise_ratio(
123122
Return:
124123
Tensor with PSNR score
125124
126-
Raises:
127-
ValueError:
128-
If ``dim`` is not ``None`` and ``data_range`` is not provided.
129-
130125
Example:
131126
>>> from torchmetrics.functional.image import peak_signal_noise_ratio
132127
>>> pred = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
133128
>>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]])
134-
>>> peak_signal_noise_ratio(pred, target)
129+
>>> peak_signal_noise_ratio(pred, target, data_range=3.0)
135130
tensor(2.5527)
136131
137132
.. attention::
@@ -141,19 +136,12 @@ def peak_signal_noise_ratio(
141136
if dim is None and reduction != "elementwise_mean":
142137
rank_zero_warn(f"The `reduction={reduction}` will not have any effect when `dim` is None.")
143138

144-
if data_range is None:
145-
if dim is not None:
146-
# Maybe we could use `torch.amax(target, dim=dim) - torch.amin(target, dim=dim)` in PyTorch 1.7 to calculate
147-
# `data_range` in the future.
148-
raise ValueError("The `data_range` must be given when `dim` is not None.")
149-
150-
data_range = target.max() - target.min() # type: ignore[assignment]
151-
elif isinstance(data_range, tuple):
139+
if isinstance(data_range, tuple):
152140
preds = torch.clamp(preds, min=data_range[0], max=data_range[1])
153141
target = torch.clamp(target, min=data_range[0], max=data_range[1])
154-
data_range = tensor(data_range[1] - data_range[0]) # type: ignore[assignment]
142+
data_range_val = tensor(data_range[1] - data_range[0])
155143
else:
156-
data_range = tensor(float(data_range)) # type: ignore[assignment]
144+
data_range_val = tensor(float(data_range))
157145

158146
sum_squared_error, num_obs = _psnr_update(preds, target, dim=dim)
159-
return _psnr_compute(sum_squared_error, num_obs, data_range, base=base, reduction=reduction) # type: ignore[arg-type]
147+
return _psnr_compute(sum_squared_error, num_obs, data_range_val, base=base, reduction=reduction)

src/torchmetrics/functional/image/psnrb.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import math
15+
from typing import Union
1516

1617
import torch
1718
from torch import Tensor, tensor
@@ -76,13 +77,11 @@ def _psnrb_compute(
7677
sum_squared_error: Sum of square of errors over all observations
7778
bef: block effect
7879
num_obs: Number of predictions or observations
79-
data_range: the range of the data. If None, it is determined from the data (max - min).
80+
data_range: the range of the data.
8081
8182
"""
8283
sum_squared_error = sum_squared_error / num_obs + bef
83-
if data_range > 2:
84-
return 10 * torch.log10(data_range**2 / sum_squared_error)
85-
return 10 * torch.log10(1.0 / sum_squared_error)
84+
return 10 * torch.log10(data_range**2 / sum_squared_error)
8685

8786

8887
def _psnrb_update(preds: Tensor, target: Tensor, block_size: int = 8) -> tuple[Tensor, Tensor, Tensor]:
@@ -103,6 +102,7 @@ def _psnrb_update(preds: Tensor, target: Tensor, block_size: int = 8) -> tuple[T
103102
def peak_signal_noise_ratio_with_blocked_effect(
104103
preds: Tensor,
105104
target: Tensor,
105+
data_range: Union[float, tuple[float, float]],
106106
block_size: int = 8,
107107
) -> Tensor:
108108
r"""Computes `Peak Signal to Noise Ratio With Blocked Effect` (PSNRB) metrics.
@@ -114,7 +114,9 @@ def peak_signal_noise_ratio_with_blocked_effect(
114114
115115
Args:
116116
preds: estimated signal
117-
target: groun truth signal
117+
target: ground truth signal
118+
data_range: the range of the data. If a tuple is provided then the range is calculated as the difference and
119+
input is clamped between the values.
118120
block_size: integer indication the block size
119121
120122
Return:
@@ -125,10 +127,16 @@ def peak_signal_noise_ratio_with_blocked_effect(
125127
>>> from torchmetrics.functional.image import peak_signal_noise_ratio_with_blocked_effect
126128
>>> preds = rand(1, 1, 28, 28)
127129
>>> target = rand(1, 1, 28, 28)
128-
>>> peak_signal_noise_ratio_with_blocked_effect(preds, target)
130+
>>> peak_signal_noise_ratio_with_blocked_effect(preds, target, data_range=1.0)
129131
tensor(7.8402)
130132
131133
"""
132-
data_range = target.max() - target.min()
134+
if isinstance(data_range, tuple):
135+
preds = torch.clamp(preds, min=data_range[0], max=data_range[1])
136+
target = torch.clamp(target, min=data_range[0], max=data_range[1])
137+
data_range_val = tensor(data_range[1] - data_range[0])
138+
else:
139+
data_range_val = tensor(float(data_range))
140+
133141
sum_squared_error, bef, num_obs = _psnrb_update(preds, target, block_size=block_size)
134-
return _psnrb_compute(sum_squared_error, bef, num_obs, data_range)
142+
return _psnrb_compute(sum_squared_error, bef, num_obs, data_range_val)

src/torchmetrics/image/_deprecated.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ class _PeakSignalNoiseRatio(PeakSignalNoiseRatio):
9191

9292
def __init__(
9393
self,
94-
data_range: Optional[Union[float, tuple[float, float]]] = None,
94+
data_range: Union[float, tuple[float, float]] = 3.0,
9595
base: float = 10.0,
9696
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
9797
dim: Optional[Union[int, tuple[int, ...]]] = None,

src/torchmetrics/image/psnr.py

Lines changed: 10 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,8 @@ class PeakSignalNoiseRatio(Metric):
4848
4949
Args:
5050
data_range:
51-
the range of the data. If None, it is determined from the data (max - min). If a tuple is provided then
52-
the range is calculated as the difference and input is clamped between the values.
53-
The ``data_range`` must be given when ``dim`` is not None.
51+
the range of the data. If a tuple is provided, then the range is calculated as the difference and
52+
input is clamped between the values.
5453
base: a base of a logarithm to use.
5554
reduction: a method to reduce metric score over labels.
5655
@@ -63,13 +62,9 @@ class PeakSignalNoiseRatio(Metric):
6362
None meaning scores will be reduced across all dimensions and all batches.
6463
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
6564
66-
Raises:
67-
ValueError:
68-
If ``dim`` is not ``None`` and ``data_range`` is not given.
69-
7065
Example:
7166
>>> from torchmetrics.image import PeakSignalNoiseRatio
72-
>>> psnr = PeakSignalNoiseRatio()
67+
>>> psnr = PeakSignalNoiseRatio(data_range=3.0)
7368
>>> preds = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
7469
>>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]])
7570
>>> psnr(preds, target)
@@ -81,13 +76,11 @@ class PeakSignalNoiseRatio(Metric):
8176
higher_is_better: bool = True
8277
full_state_update: bool = False
8378
plot_lower_bound: float = 0.0
84-
85-
min_target: Tensor
86-
max_target: Tensor
79+
data_range: Tensor
8780

8881
def __init__(
8982
self,
90-
data_range: Optional[Union[float, tuple[float, float]]] = None,
83+
data_range: Union[float, tuple[float, float]],
9184
base: float = 10.0,
9285
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
9386
dim: Optional[Union[int, tuple[int, ...]]] = None,
@@ -106,20 +99,12 @@ def __init__(
10699
self.add_state("total", default=[], dist_reduce_fx="cat")
107100

108101
self.clamping_fn = None
109-
if data_range is None:
110-
if dim is not None:
111-
# Maybe we could use `torch.amax(target, dim=dim) - torch.amin(target, dim=dim)` in PyTorch 1.7 to
112-
# calculate `data_range` in the future.
113-
raise ValueError("The `data_range` must be given when `dim` is not None.")
114-
115-
self.data_range = None
116-
self.add_state("min_target", default=tensor(0.0), dist_reduce_fx=torch.min)
117-
self.add_state("max_target", default=tensor(0.0), dist_reduce_fx=torch.max)
118-
elif isinstance(data_range, tuple):
102+
if isinstance(data_range, tuple):
119103
self.add_state("data_range", default=tensor(data_range[1] - data_range[0]), dist_reduce_fx="mean")
120104
self.clamping_fn = partial(torch.clamp, min=data_range[0], max=data_range[1])
121105
else:
122106
self.add_state("data_range", default=tensor(float(data_range)), dist_reduce_fx="mean")
107+
123108
self.base = base
124109
self.reduction = reduction
125110
self.dim = tuple(dim) if isinstance(dim, Sequence) else dim
@@ -132,11 +117,6 @@ def update(self, preds: Tensor, target: Tensor) -> None:
132117

133118
sum_squared_error, num_obs = _psnr_update(preds, target, dim=self.dim)
134119
if self.dim is None:
135-
if self.data_range is None:
136-
# keep track of min and max target values
137-
self.min_target = torch.minimum(target.min(), self.min_target)
138-
self.max_target = torch.maximum(target.max(), self.max_target)
139-
140120
if not isinstance(self.sum_squared_error, Tensor):
141121
raise TypeError(
142122
f"Expected `self.sum_squared_error` to be a Tensor, but got {type(self.sum_squared_error)}"
@@ -158,8 +138,6 @@ def update(self, preds: Tensor, target: Tensor) -> None:
158138

159139
def compute(self) -> Tensor:
160140
"""Compute peak signal-to-noise ratio over state."""
161-
data_range = self.data_range if self.data_range is not None else self.max_target - self.min_target
162-
163141
if isinstance(self.sum_squared_error, torch.Tensor):
164142
sum_squared_error = self.sum_squared_error
165143
elif isinstance(self.sum_squared_error, list):
@@ -174,7 +152,7 @@ def compute(self) -> Tensor:
174152
else:
175153
raise TypeError("Expected total to be a Tensor or a list of Tensors")
176154

177-
return _psnr_compute(sum_squared_error, total, data_range, base=self.base, reduction=self.reduction)
155+
return _psnr_compute(sum_squared_error, total, self.data_range, base=self.base, reduction=self.reduction)
178156

179157
def plot(
180158
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
@@ -199,7 +177,7 @@ def plot(
199177
>>> # Example plotting a single value
200178
>>> import torch
201179
>>> from torchmetrics.image import PeakSignalNoiseRatio
202-
>>> metric = PeakSignalNoiseRatio()
180+
>>> metric = PeakSignalNoiseRatio(data_range=1.0)
203181
>>> preds = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
204182
>>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]])
205183
>>> metric.update(preds, target)
@@ -211,7 +189,7 @@ def plot(
211189
>>> # Example plotting multiple values
212190
>>> import torch
213191
>>> from torchmetrics.image import PeakSignalNoiseRatio
214-
>>> metric = PeakSignalNoiseRatio()
192+
>>> metric = PeakSignalNoiseRatio(data_range=1.0)
215193
>>> preds = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
216194
>>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]])
217195
>>> values = [ ]

src/torchmetrics/image/psnrb.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,14 @@ class PeakSignalNoiseRatioWithBlockedEffect(Metric):
4848
- ``psnrb`` (:class:`~torch.Tensor`): float scalar tensor with aggregated PSNRB value
4949
5050
Args:
51+
data_range: the range of the data. If a tuple is provided then the range is calculated as the difference and
52+
input is clamped between the values.
5153
block_size: integer indication the block size
5254
kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
5355
5456
Example:
5557
>>> from torch import rand
56-
>>> metric = PeakSignalNoiseRatioWithBlockedEffect()
58+
>>> metric = PeakSignalNoiseRatioWithBlockedEffect(data_range=1.0)
5759
>>> preds = rand(2, 1, 10, 10)
5860
>>> target = rand(2, 1, 10, 10)
5961
>>> metric(preds, target)
@@ -72,6 +74,7 @@ class PeakSignalNoiseRatioWithBlockedEffect(Metric):
7274

7375
def __init__(
7476
self,
77+
data_range: Union[float, tuple[float, float]],
7578
block_size: int = 8,
7679
**kwargs: Any,
7780
) -> None:
@@ -83,15 +86,24 @@ def __init__(
8386
self.add_state("sum_squared_error", default=tensor(0.0), dist_reduce_fx="sum")
8487
self.add_state("total", default=tensor(0), dist_reduce_fx="sum")
8588
self.add_state("bef", default=tensor(0.0), dist_reduce_fx="sum")
86-
self.add_state("data_range", default=tensor(0), dist_reduce_fx="max")
89+
90+
if isinstance(data_range, tuple):
91+
self.add_state("data_range", default=tensor(data_range[1] - data_range[0]), dist_reduce_fx="mean")
92+
self.clamping_fn = lambda x: torch.clamp(x, min=data_range[0], max=data_range[1])
93+
else:
94+
self.add_state("data_range", default=tensor(float(data_range)), dist_reduce_fx="mean")
95+
self.clamping_fn = None # type: ignore[assignment]
8796

8897
def update(self, preds: Tensor, target: Tensor) -> None:
8998
"""Update state with predictions and targets."""
99+
if self.clamping_fn is not None:
100+
preds = self.clamping_fn(preds)
101+
target = self.clamping_fn(target)
102+
90103
sum_squared_error, bef, num_obs = _psnrb_update(preds, target, block_size=self.block_size)
91104
self.sum_squared_error += sum_squared_error
92105
self.bef += bef
93106
self.total += num_obs
94-
self.data_range = torch.maximum(self.data_range, torch.max(target) - torch.min(target))
95107

96108
def compute(self) -> Tensor:
97109
"""Compute peak signal-to-noise ratio over state."""
@@ -120,7 +132,7 @@ def plot(
120132
>>> # Example plotting a single value
121133
>>> import torch
122134
>>> from torchmetrics.image import PeakSignalNoiseRatioWithBlockedEffect
123-
>>> metric = PeakSignalNoiseRatioWithBlockedEffect()
135+
>>> metric = PeakSignalNoiseRatioWithBlockedEffect(data_range=1.0)
124136
>>> metric.update(torch.rand(2, 1, 10, 10), torch.rand(2, 1, 10, 10))
125137
>>> fig_, ax_ = metric.plot()
126138
@@ -130,7 +142,7 @@ def plot(
130142
>>> # Example plotting multiple values
131143
>>> import torch
132144
>>> from torchmetrics.image import PeakSignalNoiseRatioWithBlockedEffect
133-
>>> metric = PeakSignalNoiseRatioWithBlockedEffect()
145+
>>> metric = PeakSignalNoiseRatioWithBlockedEffect(data_range=1.0)
134146
>>> values = [ ]
135147
>>> for _ in range(10):
136148
... values.append(metric(torch.rand(2, 1, 10, 10), torch.rand(2, 1, 10, 10)))

tests/unittests/image/test_psnr.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -154,19 +154,10 @@ def test_reduction_for_dim_none(reduction):
154154
"""Test that warnings are raised when then reduction parameter is combined with no dim provided arg."""
155155
match = f"The `reduction={reduction}` will not have any effect when `dim` is None."
156156
with pytest.warns(UserWarning, match=match):
157-
PeakSignalNoiseRatio(reduction=reduction, dim=None)
157+
PeakSignalNoiseRatio(data_range=10.0, reduction=reduction, dim=None)
158158

159159
with pytest.warns(UserWarning, match=match):
160-
peak_signal_noise_ratio(_inputs[0].preds, _inputs[0].target, reduction=reduction, dim=None)
161-
162-
163-
def test_missing_data_range():
164-
"""Check that error is raised if data range is not provided."""
165-
with pytest.raises(ValueError, match="The `data_range` must be given when `dim` is not None."):
166-
PeakSignalNoiseRatio(data_range=None, dim=0)
167-
168-
with pytest.raises(ValueError, match="The `data_range` must be given when `dim` is not None."):
169-
peak_signal_noise_ratio(_inputs[0].preds, _inputs[0].target, data_range=None, dim=0)
160+
peak_signal_noise_ratio(_inputs[0].preds, _inputs[0].target, data_range=10.0, reduction=reduction, dim=None)
170161

171162

172163
def test_psnr_uint_dtype():
@@ -177,6 +168,6 @@ def test_psnr_uint_dtype():
177168
"""
178169
preds = torch.randint(0, 255, _input_size, dtype=torch.uint8)
179170
target = torch.randint(0, 255, _input_size, dtype=torch.uint8)
180-
psnr = peak_signal_noise_ratio(preds, target)
181-
prnr2 = peak_signal_noise_ratio(preds.float(), target.float())
171+
psnr = peak_signal_noise_ratio(preds, target, data_range=255.0)
172+
prnr2 = peak_signal_noise_ratio(preds.float(), target.float(), data_range=255.0)
182173
assert torch.allclose(psnr, prnr2)

0 commit comments

Comments
 (0)