diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index c6b76c8a797..80c274f4616 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -14,7 +14,12 @@ pil_to_tensor, to_pil_image, ) -from torchvision.transforms.functional_tensor import _parse_pad_padding +from torchvision.transforms.functional_tensor import ( + _cast_squeeze_in, + _cast_squeeze_out, + _parse_pad_padding, + interpolate, +) from ._meta import convert_format_bounding_box, get_dimensions_image_pil, get_dimensions_image_tensor @@ -104,12 +109,34 @@ def resize_image_tensor( extra_dims = image.shape[:-3] if image.numel() > 0: - image = _FT.resize( - image.view(-1, num_channels, old_height, old_width), - size=[new_height, new_width], - interpolation=interpolation.value, - antialias=antialias, - ) + image = image.view(-1, num_channels, old_height, old_width) + + # This is a perf hack to avoid slow channels_last upsample code path + # Related issue: https://github.com/pytorch/pytorch/issues/83840 + # We are transforming (N, 1, H, W) into (N, 2, H, W) to force to take channels_first path + if image.shape[1] == 1 and interpolation == InterpolationMode.NEAREST: + # Below code is copied from _FT.resize + # This is due to the fact that we need to apply the hack on casted image and not before + # Otherwise, image will be copied while cast to float and interpolate will work on twice more data + image, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(image, [torch.float32, torch.float64]) + + shape = (image.shape[0], 2, image.shape[2], image.shape[3]) + image = image.expand(shape) + + image = interpolate( + image, size=[new_height, new_width], mode=interpolation.value, align_corners=None, antialias=False + ) + + image = image[:, 0, ...] + image = _cast_squeeze_out(image, need_cast=need_cast, need_squeeze=need_squeeze, out_dtype=out_dtype) + + else: + image = _FT.resize( + image, + size=[new_height, new_width], + interpolation=interpolation.value, + antialias=antialias, + ) return image.view(extra_dims + (num_channels, new_height, new_width))