@@ -48,9 +48,8 @@ class PeakSignalNoiseRatio(Metric):
4848
4949 Args:
5050 data_range:
51- the range of the data. If None, it is determined from the data (max - min). If a tuple is provided then
52- the range is calculated as the difference and input is clamped between the values.
53- The ``data_range`` must be given when ``dim`` is not None.
51+ the range of the data. If a tuple is provided, then the range is calculated as the difference and
52+ input is clamped between the values.
5453 base: a base of a logarithm to use.
5554 reduction: a method to reduce metric score over labels.
5655
@@ -63,13 +62,9 @@ class PeakSignalNoiseRatio(Metric):
6362 None meaning scores will be reduced across all dimensions and all batches.
6463 kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
6564
66- Raises:
67- ValueError:
68- If ``dim`` is not ``None`` and ``data_range`` is not given.
69-
7065 Example:
7166 >>> from torchmetrics.image import PeakSignalNoiseRatio
72- >>> psnr = PeakSignalNoiseRatio()
67+ >>> psnr = PeakSignalNoiseRatio(data_range=3.0 )
7368 >>> preds = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
7469 >>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]])
7570 >>> psnr(preds, target)
@@ -81,13 +76,11 @@ class PeakSignalNoiseRatio(Metric):
8176 higher_is_better : bool = True
8277 full_state_update : bool = False
8378 plot_lower_bound : float = 0.0
84-
85- min_target : Tensor
86- max_target : Tensor
79+ data_range : Tensor
8780
8881 def __init__ (
8982 self ,
90- data_range : Optional [ Union [float , tuple [float , float ]]] = None ,
83+ data_range : Union [float , tuple [float , float ]],
9184 base : float = 10.0 ,
9285 reduction : Literal ["elementwise_mean" , "sum" , "none" , None ] = "elementwise_mean" ,
9386 dim : Optional [Union [int , tuple [int , ...]]] = None ,
@@ -106,20 +99,12 @@ def __init__(
10699 self .add_state ("total" , default = [], dist_reduce_fx = "cat" )
107100
108101 self .clamping_fn = None
109- if data_range is None :
110- if dim is not None :
111- # Maybe we could use `torch.amax(target, dim=dim) - torch.amin(target, dim=dim)` in PyTorch 1.7 to
112- # calculate `data_range` in the future.
113- raise ValueError ("The `data_range` must be given when `dim` is not None." )
114-
115- self .data_range = None
116- self .add_state ("min_target" , default = tensor (0.0 ), dist_reduce_fx = torch .min )
117- self .add_state ("max_target" , default = tensor (0.0 ), dist_reduce_fx = torch .max )
118- elif isinstance (data_range , tuple ):
102+ if isinstance (data_range , tuple ):
119103 self .add_state ("data_range" , default = tensor (data_range [1 ] - data_range [0 ]), dist_reduce_fx = "mean" )
120104 self .clamping_fn = partial (torch .clamp , min = data_range [0 ], max = data_range [1 ])
121105 else :
122106 self .add_state ("data_range" , default = tensor (float (data_range )), dist_reduce_fx = "mean" )
107+
123108 self .base = base
124109 self .reduction = reduction
125110 self .dim = tuple (dim ) if isinstance (dim , Sequence ) else dim
@@ -132,11 +117,6 @@ def update(self, preds: Tensor, target: Tensor) -> None:
132117
133118 sum_squared_error , num_obs = _psnr_update (preds , target , dim = self .dim )
134119 if self .dim is None :
135- if self .data_range is None :
136- # keep track of min and max target values
137- self .min_target = torch .minimum (target .min (), self .min_target )
138- self .max_target = torch .maximum (target .max (), self .max_target )
139-
140120 if not isinstance (self .sum_squared_error , Tensor ):
141121 raise TypeError (
142122 f"Expected `self.sum_squared_error` to be a Tensor, but got { type (self .sum_squared_error )} "
@@ -158,8 +138,6 @@ def update(self, preds: Tensor, target: Tensor) -> None:
158138
159139 def compute (self ) -> Tensor :
160140 """Compute peak signal-to-noise ratio over state."""
161- data_range = self .data_range if self .data_range is not None else self .max_target - self .min_target
162-
163141 if isinstance (self .sum_squared_error , torch .Tensor ):
164142 sum_squared_error = self .sum_squared_error
165143 elif isinstance (self .sum_squared_error , list ):
@@ -174,7 +152,7 @@ def compute(self) -> Tensor:
174152 else :
175153 raise TypeError ("Expected total to be a Tensor or a list of Tensors" )
176154
177- return _psnr_compute (sum_squared_error , total , data_range , base = self .base , reduction = self .reduction )
155+ return _psnr_compute (sum_squared_error , total , self . data_range , base = self .base , reduction = self .reduction )
178156
179157 def plot (
180158 self , val : Optional [Union [Tensor , Sequence [Tensor ]]] = None , ax : Optional [_AX_TYPE ] = None
@@ -199,7 +177,7 @@ def plot(
199177 >>> # Example plotting a single value
200178 >>> import torch
201179 >>> from torchmetrics.image import PeakSignalNoiseRatio
202- >>> metric = PeakSignalNoiseRatio()
180+ >>> metric = PeakSignalNoiseRatio(data_range=1.0 )
203181 >>> preds = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
204182 >>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]])
205183 >>> metric.update(preds, target)
@@ -211,7 +189,7 @@ def plot(
211189 >>> # Example plotting multiple values
212190 >>> import torch
213191 >>> from torchmetrics.image import PeakSignalNoiseRatio
214- >>> metric = PeakSignalNoiseRatio()
192+ >>> metric = PeakSignalNoiseRatio(data_range=1.0 )
215193 >>> preds = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
216194 >>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]])
217195 >>> values = [ ]
0 commit comments