Skip to content

Commit d27cb7c

Browse files
authored
Ensure input type of normalize is float. (#3621)
1 parent 226126b commit d27cb7c

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

test/test_transforms_tensor.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -446,12 +446,15 @@ def test_to_grayscale(self):
446446
)
447447

448448
def test_normalize(self):
449+
fn = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
449450
tensor, _ = self._create_data(26, 34, device=self.device)
450-
batch_tensors = torch.rand(4, 3, 44, 56, device=self.device)
451451

452+
with self.assertRaisesRegex(TypeError, r"Input tensor should be a float tensor"):
453+
fn(tensor)
454+
455+
batch_tensors = torch.rand(4, 3, 44, 56, device=self.device)
452456
tensor = tensor.to(dtype=torch.float32) / 255.0
453457
# test for class interface
454-
fn = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
455458
scripted_fn = torch.jit.script(fn)
456459

457460
self._test_transform_vs_scripted(fn, scripted_fn, tensor)

torchvision/transforms/functional.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ def to_pil_image(pic, mode=None):
297297

298298

299299
def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool = False) -> Tensor:
300-
"""Normalize a tensor image with mean and standard deviation.
300+
"""Normalize a float tensor image with mean and standard deviation.
301301
This transform does not support PIL Image.
302302
303303
.. note::
@@ -306,7 +306,7 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool
306306
See :class:`~torchvision.transforms.Normalize` for more details.
307307
308308
Args:
309-
tensor (Tensor): Tensor image of size (C, H, W) or (B, C, H, W) to be normalized.
309+
tensor (Tensor): Float tensor image of size (C, H, W) or (B, C, H, W) to be normalized.
310310
mean (sequence): Sequence of means for each channel.
311311
std (sequence): Sequence of standard deviations for each channel.
312312
inplace(bool,optional): Bool to make this operation inplace.
@@ -317,6 +317,9 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool
317317
if not isinstance(tensor, torch.Tensor):
318318
raise TypeError('Input tensor should be a torch tensor. Got {}.'.format(type(tensor)))
319319

320+
if not tensor.is_floating_point():
321+
raise TypeError('Input tensor should be a float tensor. Got {}.'.format(tensor.dtype))
322+
320323
if tensor.ndim < 3:
321324
raise ValueError('Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = '
322325
'{}.'.format(tensor.size()))

0 commit comments

Comments
 (0)