-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Description
On uint8
tensors, Resize()
currently converts the input image to and from torch.float
to pass it to interpolate()
, because interpolate()
didn't support native uint8
inputs in the past. This is suboptimal.
@vfdev-5 and I have recently implemented native uint8
support for interpolate(mode="bilinear")
in pytorch/pytorch#90771 and pytorch/pytorch#96848.
We should integrate this native uint8
support into torchvision's Resize()
. Benchmarks below show that such integration could lead to at least 3X improvement on Resize()
's time, which saves 1ms per image and a 30% improvement of the total pipeline time for a typical classification pipeline (including auto-augment, which is the next bottleneck). This would make the Tensor / DataPoint backend significantly faster than PIL.
Some current challenges before integrations are:
- improvements for native uint8 are mostly for AVX2 archs. Compared to current
Resize()
implem (float), is the perf still OK on archs that don’t support AVX2? First: need to identify whether those non-AVX2 targets are critical or not. - BC: Although more strictly correct, the uint8 native path may have 1-off differences with current float path. Mitigation: only integrate native uint8 into V2 Resize(), where BC commitments are looser.
Benchmarks made with @pmeier's pmeier/detection-reference-benchmark@0ae9027 and with the following patch
+class ResizeUint8(torch.nn.Module):
+ def __init__(self, force_channels_last):
+ super().__init__()
+ self.force_channels_last = force_channels_last
+
+ def forward(self, img):
+ img = img.unsqueeze(0)
+ if self.force_channels_last:
+ img = img.contiguous(memory_format=torch.channels_last)
+ return torch.nn.functional.interpolate(img, size=[223, 223], mode="bilinear", antialias=True, align_corners=None).squeeze(0)
+
def classification_complex_pipeline_builder(*, input_type, api_version):
if input_type == "Datapoint" and api_version == "v1":
return None
@@ -94,9 +106,15 @@ def classification_complex_pipeline_builder(*, input_type, api_version):
if api_version == "v1":
transforms = transforms_v1
RandomResizedCropWithoutResize = RandomResizedCropWithoutResizeV1
+ resize = transforms.Resize(223, antialias=True)
elif api_version == "v2":
transforms = transforms_v2
RandomResizedCropWithoutResize = RandomResizedCropWithoutResizeV2
+ if input_type in ("Datapoint", "Tensor"):
+ # resize = ResizeUint8(force_channels_last=False)
+ resize = transforms.Resize(223, antialias=True)
+ else:
+ resize = transforms.Resize(223, antialias=True)
else:
raise RuntimeError(f"Got {api_version=}")
@@ -106,11 +124,14 @@ def classification_complex_pipeline_builder(*, input_type, api_version):
pipeline.append(transforms.PILToTensor())
elif input_type == "Datapoint":
pipeline.append(transforms.ToImageTensor())
+
+
pipeline.extend(
[
RandomResizedCropWithoutResize(224),
- transforms.Resize(224, antialias=True),
+ # transforms.Resize(223, antialias=True),
+ resize,
transforms.RandomHorizontalFlip(p=0.5),
transforms.AutoAugment(transforms.AutoAugmentPolicy.IMAGENET),
]
Without uint8
native support:
############################################################
classification-complex
############################################################
input_type='Tensor', api_version='v1'
Results computed for 1_000 samples
median std
PILToTensor 258 µs +- 24 µs
RandomResizedCropWithoutResizeV1 111 µs +- 22 µs
Resize 1238 µs +- 311 µs
RandomHorizontalFlip 53 µs +- 21 µs
AutoAugment 1281 µs +- 840 µs
RandomErasing 31 µs +- 66 µs
ConvertImageDtype 120 µs +- 13 µs
Normalize 186 µs +- 23 µs
total 3278 µs
------------------------------------------------------------
input_type='Tensor', api_version='v2'
Results computed for 1_000 samples
median std
PILToTensor 271 µs +- 21 µs
RandomResizedCropWithoutResizeV2 113 µs +- 17 µs
Resize 1226 µs +- 304 µs
RandomHorizontalFlip 64 µs +- 24 µs
AutoAugment 1099 µs +- 738 µs
RandomErasing 39 µs +- 68 µs
ConvertDtype 96 µs +- 12 µs
Normalize 150 µs +- 17 µs
total 3057 µs
------------------------------------------------------------
input_type='PIL', api_version='v1'
Results computed for 1_000 samples
median std
RandomResizedCropWithoutResizeV1 162 µs +- 27 µs
Resize 787 µs +- 186 µs
RandomHorizontalFlip 53 µs +- 29 µs
AutoAugment 585 µs +- 342 µs
PILToTensor 96 µs +- 9 µs
RandomErasing 32 µs +- 65 µs
ConvertImageDtype 125 µs +- 14 µs
Normalize 850 µs +- 83 µs
total 2688 µs
------------------------------------------------------------
input_type='PIL', api_version='v2'
Results computed for 1_000 samples
median std
RandomResizedCropWithoutResizeV2 166 µs +- 26 µs
Resize 783 µs +- 185 µs
RandomHorizontalFlip 61 µs +- 33 µs
AutoAugment 489 µs +- 355 µs
PILToTensor 115 µs +- 9 µs
RandomErasing 37 µs +- 65 µs
ConvertDtype 101 µs +- 11 µs
Normalize 825 µs +- 84 µs
total 2577 µs
------------------------------------------------------------
input_type='Datapoint', api_version='v2'
Results computed for 1_000 samples
median std
ToImageTensor 284 µs +- 22 µs
RandomResizedCropWithoutResizeV2 119 µs +- 17 µs
Resize 1223 µs +- 302 µs
RandomHorizontalFlip 62 µs +- 29 µs
AutoAugment 1100 µs +- 625 µs
RandomErasing 39 µs +- 72 µs
ConvertDtype 106 µs +- 13 µs
Normalize 155 µs +- 16 µs
total 3089 µs
------------------------------------------------------------
Summaries
v2 / v1
Tensor 0.93
PIL 0.96
[a] [b] [c] [d] [e]
Tensor, v1, [a] 1.00 1.07 1.22 1.27 1.06
Tensor, v2, [b] 0.93 1.00 1.14 1.19 0.99
PIL, v1, [c] 0.82 0.88 1.00 1.04 0.87
PIL, v2, [d] 0.79 0.84 0.96 1.00 0.83
Datapoint, v2, [e] 0.94 1.01 1.15 1.20 1.00
Slowdown as row / col
With uint8
native support for TensorV2 and DatapointV2:
############################################################
classification-complex
############################################################
input_type='Tensor', api_version='v1'
Results computed for 1_000 samples
median std
PILToTensor 255 µs +- 21 µs
RandomResizedCropWithoutResizeV1 110 µs +- 22 µs
Resize 1230 µs +- 315 µs
RandomHorizontalFlip 47 µs +- 24 µs
AutoAugment 1269 µs +- 870 µs
RandomErasing 31 µs +- 66 µs
ConvertImageDtype 121 µs +- 13 µs
Normalize 186 µs +- 23 µs
total 3249 µs
------------------------------------------------------------
input_type='Tensor', api_version='v2'
Results computed for 1_000 samples
median std
PILToTensor 270 µs +- 20 µs
RandomResizedCropWithoutResizeV2 110 µs +- 17 µs
ResizeUint8 402 µs +- 109 µs
RandomHorizontalFlip 66 µs +- 24 µs
AutoAugment 996 µs +- 539 µs
RandomErasing 39 µs +- 64 µs
ConvertDtype 81 µs +- 10 µs
Normalize 134 µs +- 14 µs
total 2099 µs
------------------------------------------------------------
input_type='PIL', api_version='v1'
Results computed for 1_000 samples
median std
RandomResizedCropWithoutResizeV1 161 µs +- 28 µs
Resize 779 µs +- 186 µs
RandomHorizontalFlip 53 µs +- 29 µs
AutoAugment 576 µs +- 339 µs
PILToTensor 93 µs +- 8 µs
RandomErasing 31 µs +- 64 µs
ConvertImageDtype 123 µs +- 13 µs
Normalize 843 µs +- 82 µs
total 2661 µs
------------------------------------------------------------
input_type='PIL', api_version='v2'
Results computed for 1_000 samples
median std
RandomResizedCropWithoutResizeV2 163 µs +- 26 µs
Resize 788 µs +- 180 µs
RandomHorizontalFlip 62 µs +- 33 µs
AutoAugment 492 µs +- 355 µs
PILToTensor 112 µs +- 9 µs
RandomErasing 37 µs +- 64 µs
ConvertDtype 100 µs +- 11 µs
Normalize 826 µs +- 86 µs
total 2580 µs
------------------------------------------------------------
input_type='Datapoint', api_version='v2'
Results computed for 1_000 samples
median std
ToImageTensor 284 µs +- 22 µs
RandomResizedCropWithoutResizeV2 118 µs +- 17 µs
ResizeUint8 410 µs +- 109 µs
RandomHorizontalFlip 68 µs +- 23 µs
AutoAugment 994 µs +- 542 µs
RandomErasing 38 µs +- 63 µs
ConvertDtype 81 µs +- 10 µs
Normalize 133 µs +- 14 µs
total 2127 µs
------------------------------------------------------------
Summaries
v2 / v1
Tensor 0.65
PIL 0.97
[a] [b] [c] [d] [e]
Tensor, v1, [a] 1.00 1.55 1.22 1.26 1.53
Tensor, v2, [b] 0.65 1.00 0.79 0.81 0.99
PIL, v1, [c] 0.82 1.27 1.00 1.03 1.25
PIL, v2, [d] 0.79 1.23 0.97 1.00 1.21
Datapoint, v2, [e] 0.65 1.01 0.80 0.82 1.00
Slowdown as row / col