|
14 | 14 | pil_to_tensor,
|
15 | 15 | to_pil_image,
|
16 | 16 | )
|
17 |
| -from torchvision.transforms.functional_tensor import _parse_pad_padding |
| 17 | +from torchvision.transforms.functional_tensor import ( |
| 18 | + _cast_squeeze_in, |
| 19 | + _cast_squeeze_out, |
| 20 | + _parse_pad_padding, |
| 21 | + interpolate, |
| 22 | +) |
18 | 23 |
|
19 | 24 | from ._meta import convert_format_bounding_box, get_dimensions_image_pil, get_dimensions_image_tensor
|
20 | 25 |
|
@@ -104,12 +109,34 @@ def resize_image_tensor(
|
104 | 109 | extra_dims = image.shape[:-3]
|
105 | 110 |
|
106 | 111 | if image.numel() > 0:
|
107 |
| - image = _FT.resize( |
108 |
| - image.view(-1, num_channels, old_height, old_width), |
109 |
| - size=[new_height, new_width], |
110 |
| - interpolation=interpolation.value, |
111 |
| - antialias=antialias, |
112 |
| - ) |
| 112 | + image = image.view(-1, num_channels, old_height, old_width) |
| 113 | + |
| 114 | + # This is a perf hack to avoid slow channels_last upsample code path |
| 115 | + # Related issue: https://github.com/pytorch/pytorch/issues/83840 |
| 116 | + # We are transforming (N, 1, H, W) into (N, 2, H, W) to force to take channels_first path |
| 117 | + if image.shape[1] == 1 and interpolation == InterpolationMode.NEAREST: |
| 118 | + # Below code is copied from _FT.resize |
| 119 | + # This is due to the fact that we need to apply the hack on casted image and not before |
| 120 | + # Otherwise, image will be copied while cast to float and interpolate will work on twice more data |
| 121 | + image, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(image, [torch.float32, torch.float64]) |
| 122 | + |
| 123 | + shape = (image.shape[0], 2, image.shape[2], image.shape[3]) |
| 124 | + image = image.expand(shape) |
| 125 | + |
| 126 | + image = interpolate( |
| 127 | + image, size=[new_height, new_width], mode=interpolation.value, align_corners=None, antialias=False |
| 128 | + ) |
| 129 | + |
| 130 | + image = image[:, 0, ...] |
| 131 | + image = _cast_squeeze_out(image, need_cast=need_cast, need_squeeze=need_squeeze, out_dtype=out_dtype) |
| 132 | + |
| 133 | + else: |
| 134 | + image = _FT.resize( |
| 135 | + image, |
| 136 | + size=[new_height, new_width], |
| 137 | + interpolation=interpolation.value, |
| 138 | + antialias=antialias, |
| 139 | + ) |
113 | 140 |
|
114 | 141 | return image.view(extra_dims + (num_channels, new_height, new_width))
|
115 | 142 |
|
|
0 commit comments