@@ -297,7 +297,7 @@ def to_pil_image(pic, mode=None):
297
297
298
298
299
299
def normalize (tensor : Tensor , mean : List [float ], std : List [float ], inplace : bool = False ) -> Tensor :
300
- """Normalize a tensor image with mean and standard deviation.
300
+ """Normalize a float tensor image with mean and standard deviation.
301
301
This transform does not support PIL Image.
302
302
303
303
.. note::
@@ -306,7 +306,7 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool
306
306
See :class:`~torchvision.transforms.Normalize` for more details.
307
307
308
308
Args:
309
- tensor (Tensor): Tensor image of size (C, H, W) or (B, C, H, W) to be normalized.
309
+ tensor (Tensor): Float tensor image of size (C, H, W) or (B, C, H, W) to be normalized.
310
310
mean (sequence): Sequence of means for each channel.
311
311
std (sequence): Sequence of standard deviations for each channel.
312
312
inplace(bool,optional): Bool to make this operation inplace.
@@ -317,6 +317,9 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool
317
317
if not isinstance (tensor , torch .Tensor ):
318
318
raise TypeError ('Input tensor should be a torch tensor. Got {}.' .format (type (tensor )))
319
319
320
+ if not tensor .is_floating_point ():
321
+ raise TypeError ('Input tensor should be a float tensor. Got {}.' .format (tensor .dtype ))
322
+
320
323
if tensor .ndim < 3 :
321
324
raise ValueError ('Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = '
322
325
'{}.' .format (tensor .size ()))
0 commit comments