Skip to content

Commit 8543a62

Browse files
Yosua Michael Maranathafacebook-github-bot
Yosua Michael Maranatha
authored andcommitted
[fbsync] Remove performance workaround for mask resize (#6729)
Summary: * Remove performance workaround for mask resize * Fix linter * bug fixes * remove unnecessary import * Fixing linter Reviewed By: NicolasHug Differential Revision: D40427473 fbshipit-source-id: e7e632069c62cc4e53c31571f0c177d68ee61c3e
1 parent 85feefc commit 8543a62

File tree

2 files changed

+7
-43
lines changed

2 files changed

+7
-43
lines changed

test/prototype_transforms_kernel_infos.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
make_video_loaders,
2020
mark_framework_limitation,
2121
TestMark,
22-
VALID_EXTRA_DIMS,
2322
)
2423
from torchvision.prototype import features
2524
from torchvision.transforms.functional_tensor import _max_value as get_max_value
@@ -215,16 +214,6 @@ def sample_inputs_resize_image_tensor():
215214
):
216215
yield ArgsKwargs(image_loader, size=[min(image_loader.image_size) + 1], interpolation=interpolation)
217216

218-
# We have a speed hack in place for nearest interpolation and single channel images (grayscale)
219-
for image_loader in make_image_loaders(
220-
sizes=["random"],
221-
color_spaces=[features.ColorSpace.GRAY],
222-
extra_dims=VALID_EXTRA_DIMS,
223-
):
224-
yield ArgsKwargs(
225-
image_loader, size=[min(image_loader.image_size) + 1], interpolation=F.InterpolationMode.NEAREST
226-
)
227-
228217
yield ArgsKwargs(make_image_loader(size=(11, 17)), size=20, max_size=25)
229218

230219

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 7 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,7 @@
1414
pil_to_tensor,
1515
to_pil_image,
1616
)
17-
from torchvision.transforms.functional_tensor import (
18-
_cast_squeeze_in,
19-
_cast_squeeze_out,
20-
_parse_pad_padding,
21-
interpolate,
22-
)
17+
from torchvision.transforms.functional_tensor import _parse_pad_padding
2318

2419
from ._meta import (
2520
convert_format_bounding_box,
@@ -130,32 +125,12 @@ def resize_image_tensor(
130125
if image.numel() > 0:
131126
image = image.view(-1, num_channels, old_height, old_width)
132127

133-
# This is a perf hack to avoid slow channels_last upsample code path
134-
# Related issue: https://github.com/pytorch/pytorch/issues/83840
135-
# We are transforming (N, 1, H, W) into (N, 2, H, W) to force to take channels_first path
136-
if image.shape[1] == 1 and interpolation == InterpolationMode.NEAREST:
137-
# Below code is copied from _FT.resize
138-
# This is due to the fact that we need to apply the hack on casted image and not before
139-
# Otherwise, image will be copied while cast to float and interpolate will work on twice more data
140-
image, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(image, [torch.float32, torch.float64])
141-
142-
shape = (image.shape[0], 2, image.shape[2], image.shape[3])
143-
image = image.expand(shape)
144-
145-
image = interpolate(
146-
image, size=[new_height, new_width], mode=interpolation.value, align_corners=None, antialias=False
147-
)
148-
149-
image = image[:, 0, ...]
150-
image = _cast_squeeze_out(image, need_cast=need_cast, need_squeeze=need_squeeze, out_dtype=out_dtype)
151-
152-
else:
153-
image = _FT.resize(
154-
image,
155-
size=[new_height, new_width],
156-
interpolation=interpolation.value,
157-
antialias=antialias,
158-
)
128+
image = _FT.resize(
129+
image,
130+
size=[new_height, new_width],
131+
interpolation=interpolation.value,
132+
antialias=antialias,
133+
)
159134

160135
return image.view(extra_dims + (num_channels, new_height, new_width))
161136

0 commit comments

Comments
 (0)