Skip to content

Hack to improve performance of resize op with nearest mode on 2D #6661

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Sep 29, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 34 additions & 7 deletions torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@
pil_to_tensor,
to_pil_image,
)
from torchvision.transforms.functional_tensor import _parse_pad_padding
from torchvision.transforms.functional_tensor import (
_cast_squeeze_in,
_cast_squeeze_out,
_parse_pad_padding,
interpolate,
)

from ._meta import convert_format_bounding_box, get_dimensions_image_pil, get_dimensions_image_tensor

Expand Down Expand Up @@ -104,12 +109,34 @@ def resize_image_tensor(
extra_dims = image.shape[:-3]

if image.numel() > 0:
image = _FT.resize(
image.view(-1, num_channels, old_height, old_width),
size=[new_height, new_width],
interpolation=interpolation.value,
antialias=antialias,
)
image = image.view(-1, num_channels, old_height, old_width)

# This is a perf hack to avoid slow channels_last upsample code path
# Related issue: https://github.com/pytorch/pytorch/issues/83840
# We are transforming (N, 1, H, W) into (N, 2, H, W) to force to take channels_first path
if image.shape[1] == 1 and interpolation == InterpolationMode.NEAREST:
# Below code is copied from _FT.resize
# This is due to the fact that we need to apply the hack on casted image and not before
# Otherwise, image will be copied while cast to float and interpolate will work on twice more data
image, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(image, [torch.float32, torch.float64])

shape = (image.shape[0], 2, image.shape[2], image.shape[3])
image = image.expand(shape)

image = interpolate(
image, size=[new_height, new_width], mode=interpolation.value, align_corners=None, antialias=False
)

image = image[:, 0, ...]
image = _cast_squeeze_out(image, need_cast=need_cast, need_squeeze=need_squeeze, out_dtype=out_dtype)

else:
image = _FT.resize(
image,
size=[new_height, new_width],
interpolation=interpolation.value,
antialias=antialias,
)

return image.view(extra_dims + (num_channels, new_height, new_width))

Expand Down