Skip to content

Commit 1744362

Browse files
committed
Resize relies on interpolate's native uint8 handling
Description: - Now that pytorch/pytorch#90771 is merged, let Resize() rely on interpolate()'s native uint8 handling instead of converting to and from float. - uint8 input is not casted to f32 for nearest mode and bilinear mode if the latter has AVX2. Context: pytorch#7217 Benchmarks: ``` [----------- Resize cpu torch.uint8 InterpolationMode.NEAREST -----------] | resize v2 | resize stable | resize nightly 1 threads: --------------------------------------------------------------- (3, 400, 400) | 457 | 461 | 480 (16, 3, 400, 400) | 6870 | 6850 | 10100 Times are in microseconds (us). [---------- Resize cpu torch.uint8 InterpolationMode.BILINEAR -----------] | resize v2 | resize stable | resize nightly 1 threads: --------------------------------------------------------------- (3, 400, 400) | 326 | 329 | 844 (16, 3, 400, 400) | 4380 | 4390 | 14800 Times are in microseconds (us). ``` [Source](https://gist.github.com/vfdev-5/a2e30ed50b5996807c9b09d5d33d8bc2)
1 parent 7325517 commit 1744362

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

torchvision/transforms/_functional_tensor.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,17 @@ def resize(
459459
# now we don't as True is the default.
460460
antialias = False
461461

462-
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [torch.float32, torch.float64])
462+
acceptable_dtypes = [torch.float32, torch.float64]
463+
if interpolation in ["nearest", "nearest-exact"]:
464+
# uint8 dtype can be included for cpu and cuda input if nearest mode
465+
acceptable_dtypes.append(torch.uint8)
466+
elif interpolation == "bilinear" and img.device.type == "cpu":
467+
# uint8 dtype support for bilinear mode is limited to cpu and
468+
# according to our benchmarks non-AVX CPUs should prefer u8->f32->interpolate->u8 path
469+
if "AVX2" in torch.backends.cpu.get_cpu_capability():
470+
acceptable_dtypes.append(torch.uint8)
471+
472+
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, acceptable_dtypes)
463473

464474
# Define align_corners to avoid warnings
465475
align_corners = False if interpolation in ["bilinear", "bicubic"] else None

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,17 @@ def resize_image_tensor(
185185
image = image.reshape(-1, num_channels, old_height, old_width)
186186

187187
dtype = image.dtype
188-
need_cast = dtype not in (torch.float32, torch.float64)
188+
acceptable_dtypes = [torch.float32, torch.float64]
189+
if interpolation in [InterpolationMode.NEAREST, InterpolationMode.NEAREST_EXACT]:
190+
# uint8 dtype can be included for cpu and cuda input if nearest mode
191+
acceptable_dtypes.append(torch.uint8)
192+
elif interpolation == InterpolationMode.BILINEAR and image.device.type == "cpu":
193+
# uint8 dtype support for bilinear mode is limited to cpu and
194+
# according to our benchmarks non-AVX CPUs should prefer u8->f32->interpolate->u8 path
195+
if "AVX2" in torch.backends.cpu.get_cpu_capability():
196+
acceptable_dtypes.append(torch.uint8)
197+
198+
need_cast = dtype not in acceptable_dtypes
189199
if need_cast:
190200
image = image.to(dtype=torch.float32)
191201

0 commit comments

Comments
 (0)