diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 21458b3bdab..69445e6a231 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -59,8 +59,7 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) - if image.dtype == dtype: return image - # TODO: replace with image.dtype.is_floating_point when torchscript supports it - if torch.empty(0, dtype=image.dtype).is_floating_point(): + if image.is_floating_point(): # TODO: replace with dtype.is_floating_point when torchscript supports it if torch.tensor(0, dtype=dtype).is_floating_point():