Skip to content

Commit 61f2032

Browse files
authored
Hack to improve performance of resize op with nearest mode on 2D (#6661)
* Hack to improve performance of resize op with nearest mode on 2D * Moved hack to prototype * Moved hack into proto and reused code from stable resize * updates * More updates
1 parent 30b879f commit 61f2032

File tree

1 file changed

+34
-7
lines changed

1 file changed

+34
-7
lines changed

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@
1414
pil_to_tensor,
1515
to_pil_image,
1616
)
17-
from torchvision.transforms.functional_tensor import _parse_pad_padding
17+
from torchvision.transforms.functional_tensor import (
18+
_cast_squeeze_in,
19+
_cast_squeeze_out,
20+
_parse_pad_padding,
21+
interpolate,
22+
)
1823

1924
from ._meta import convert_format_bounding_box, get_dimensions_image_pil, get_dimensions_image_tensor
2025

@@ -104,12 +109,34 @@ def resize_image_tensor(
104109
extra_dims = image.shape[:-3]
105110

106111
if image.numel() > 0:
107-
image = _FT.resize(
108-
image.view(-1, num_channels, old_height, old_width),
109-
size=[new_height, new_width],
110-
interpolation=interpolation.value,
111-
antialias=antialias,
112-
)
112+
image = image.view(-1, num_channels, old_height, old_width)
113+
114+
# This is a perf hack to avoid slow channels_last upsample code path
115+
# Related issue: https://github.com/pytorch/pytorch/issues/83840
116+
# We are transforming (N, 1, H, W) into (N, 2, H, W) to force to take channels_first path
117+
if image.shape[1] == 1 and interpolation == InterpolationMode.NEAREST:
118+
# Below code is copied from _FT.resize
119+
# This is due to the fact that we need to apply the hack on casted image and not before
120+
# Otherwise, image will be copied while cast to float and interpolate will work on twice more data
121+
image, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(image, [torch.float32, torch.float64])
122+
123+
shape = (image.shape[0], 2, image.shape[2], image.shape[3])
124+
image = image.expand(shape)
125+
126+
image = interpolate(
127+
image, size=[new_height, new_width], mode=interpolation.value, align_corners=None, antialias=False
128+
)
129+
130+
image = image[:, 0, ...]
131+
image = _cast_squeeze_out(image, need_cast=need_cast, need_squeeze=need_squeeze, out_dtype=out_dtype)
132+
133+
else:
134+
image = _FT.resize(
135+
image,
136+
size=[new_height, new_width],
137+
interpolation=interpolation.value,
138+
antialias=antialias,
139+
)
113140

114141
return image.view(extra_dims + (num_channels, new_height, new_width))
115142

0 commit comments

Comments
 (0)