Skip to content

Let Resize() handle uint8 images natively for bilinear mode #7497

@NicolasHug

Description

@NicolasHug

On uint8 tensors, Resize() currently converts the input image to and from torch.float to pass it to interpolate(), because interpolate() didn't support native uint8 inputs in the past. This is suboptimal.

@vfdev-5 and I have recently implemented native uint8 support for interpolate(mode="bilinear") in pytorch/pytorch#90771 and pytorch/pytorch#96848.

We should integrate this native uint8 support into torchvision's Resize(). Benchmarks below show that such integration could lead to at least 3X improvement on Resize()'s time, which saves 1ms per image and a 30% improvement of the total pipeline time for a typical classification pipeline (including auto-augment, which is the next bottleneck). This would make the Tensor / DataPoint backend significantly faster than PIL.

Some current challenges before integrations are:

  • improvements for native uint8 are mostly for AVX2 archs. Compared to current Resize() implem (float), is the perf still OK on archs that don’t support AVX2? First: need to identify whether those non-AVX2 targets are critical or not.
  • BC: Although more strictly correct, the uint8 native path may have 1-off differences with current float path. Mitigation: only integrate native uint8 into V2 Resize(), where BC commitments are looser.

Benchmarks made with @pmeier's pmeier/detection-reference-benchmark@0ae9027 and with the following patch

+class ResizeUint8(torch.nn.Module):
+    def __init__(self, force_channels_last):
+        super().__init__()
+        self.force_channels_last = force_channels_last
+
+    def forward(self, img):
+        img = img.unsqueeze(0)
+        if self.force_channels_last:
+            img = img.contiguous(memory_format=torch.channels_last)
+        return torch.nn.functional.interpolate(img, size=[223, 223], mode="bilinear", antialias=True, align_corners=None).squeeze(0)
+
 def classification_complex_pipeline_builder(*, input_type, api_version):
     if input_type == "Datapoint" and api_version == "v1":
         return None
@@ -94,9 +106,15 @@ def classification_complex_pipeline_builder(*, input_type, api_version):
     if api_version == "v1":
         transforms = transforms_v1
         RandomResizedCropWithoutResize = RandomResizedCropWithoutResizeV1
+        resize = transforms.Resize(223, antialias=True)
     elif api_version == "v2":
         transforms = transforms_v2
         RandomResizedCropWithoutResize = RandomResizedCropWithoutResizeV2
+        if input_type in ("Datapoint", "Tensor"):
+            # resize = ResizeUint8(force_channels_last=False)
+            resize = transforms.Resize(223, antialias=True)
+        else:
+            resize = transforms.Resize(223, antialias=True)
     else:
         raise RuntimeError(f"Got {api_version=}")
 
@@ -106,11 +124,14 @@ def classification_complex_pipeline_builder(*, input_type, api_version):
         pipeline.append(transforms.PILToTensor())
     elif input_type == "Datapoint":
         pipeline.append(transforms.ToImageTensor())
+    
+
 
     pipeline.extend(
         [
             RandomResizedCropWithoutResize(224),
-            transforms.Resize(224, antialias=True),
+            # transforms.Resize(223, antialias=True),
+            resize,
             transforms.RandomHorizontalFlip(p=0.5),
             transforms.AutoAugment(transforms.AutoAugmentPolicy.IMAGENET),
         ]

Without uint8 native support:

############################################################
classification-complex
############################################################
input_type='Tensor', api_version='v1'

Results computed for 1_000 samples

                                  median          std   
PILToTensor                          258 µs +-     24 µs
RandomResizedCropWithoutResizeV1     111 µs +-     22 µs
Resize                              1238 µs +-    311 µs
RandomHorizontalFlip                  53 µs +-     21 µs
AutoAugment                         1281 µs +-    840 µs
RandomErasing                         31 µs +-     66 µs
ConvertImageDtype                    120 µs +-     13 µs
Normalize                            186 µs +-     23 µs

total                               3278 µs
------------------------------------------------------------
input_type='Tensor', api_version='v2'

Results computed for 1_000 samples

                                  median          std   
PILToTensor                          271 µs +-     21 µs
RandomResizedCropWithoutResizeV2     113 µs +-     17 µs
Resize                              1226 µs +-    304 µs
RandomHorizontalFlip                  64 µs +-     24 µs
AutoAugment                         1099 µs +-    738 µs
RandomErasing                         39 µs +-     68 µs
ConvertDtype                          96 µs +-     12 µs
Normalize                            150 µs +-     17 µs

total                               3057 µs
------------------------------------------------------------
input_type='PIL', api_version='v1'

Results computed for 1_000 samples

                                  median          std   
RandomResizedCropWithoutResizeV1     162 µs +-     27 µs
Resize                               787 µs +-    186 µs
RandomHorizontalFlip                  53 µs +-     29 µs
AutoAugment                          585 µs +-    342 µs
PILToTensor                           96 µs +-      9 µs
RandomErasing                         32 µs +-     65 µs
ConvertImageDtype                    125 µs +-     14 µs
Normalize                            850 µs +-     83 µs

total                               2688 µs
------------------------------------------------------------
input_type='PIL', api_version='v2'

Results computed for 1_000 samples

                                  median          std   
RandomResizedCropWithoutResizeV2     166 µs +-     26 µs
Resize                               783 µs +-    185 µs
RandomHorizontalFlip                  61 µs +-     33 µs
AutoAugment                          489 µs +-    355 µs
PILToTensor                          115 µs +-      9 µs
RandomErasing                         37 µs +-     65 µs
ConvertDtype                         101 µs +-     11 µs
Normalize                            825 µs +-     84 µs

total                               2577 µs
------------------------------------------------------------
input_type='Datapoint', api_version='v2'

Results computed for 1_000 samples

                                  median          std   
ToImageTensor                        284 µs +-     22 µs
RandomResizedCropWithoutResizeV2     119 µs +-     17 µs
Resize                              1223 µs +-    302 µs
RandomHorizontalFlip                  62 µs +-     29 µs
AutoAugment                         1100 µs +-    625 µs
RandomErasing                         39 µs +-     72 µs
ConvertDtype                         106 µs +-     13 µs
Normalize                            155 µs +-     16 µs

total                               3089 µs
------------------------------------------------------------

Summaries

           v2 / v1
Tensor        0.93
PIL           0.96

                     [a]   [b]   [c]   [d]   [e]
   Tensor, v1, [a]  1.00  1.07  1.22  1.27  1.06
   Tensor, v2, [b]  0.93  1.00  1.14  1.19  0.99
      PIL, v1, [c]  0.82  0.88  1.00  1.04  0.87
      PIL, v2, [d]  0.79  0.84  0.96  1.00  0.83
Datapoint, v2, [e]  0.94  1.01  1.15  1.20  1.00

Slowdown as row / col

With uint8 native support for TensorV2 and DatapointV2:

############################################################
classification-complex
############################################################
input_type='Tensor', api_version='v1'

Results computed for 1_000 samples

                                  median          std   
PILToTensor                          255 µs +-     21 µs
RandomResizedCropWithoutResizeV1     110 µs +-     22 µs
Resize                              1230 µs +-    315 µs
RandomHorizontalFlip                  47 µs +-     24 µs
AutoAugment                         1269 µs +-    870 µs
RandomErasing                         31 µs +-     66 µs
ConvertImageDtype                    121 µs +-     13 µs
Normalize                            186 µs +-     23 µs

total                               3249 µs
------------------------------------------------------------
input_type='Tensor', api_version='v2'

Results computed for 1_000 samples

                                  median          std   
PILToTensor                          270 µs +-     20 µs
RandomResizedCropWithoutResizeV2     110 µs +-     17 µs
ResizeUint8                          402 µs +-    109 µs
RandomHorizontalFlip                  66 µs +-     24 µs
AutoAugment                          996 µs +-    539 µs
RandomErasing                         39 µs +-     64 µs
ConvertDtype                          81 µs +-     10 µs
Normalize                            134 µs +-     14 µs

total                               2099 µs
------------------------------------------------------------
input_type='PIL', api_version='v1'

Results computed for 1_000 samples

                                  median          std   
RandomResizedCropWithoutResizeV1     161 µs +-     28 µs
Resize                               779 µs +-    186 µs
RandomHorizontalFlip                  53 µs +-     29 µs
AutoAugment                          576 µs +-    339 µs
PILToTensor                           93 µs +-      8 µs
RandomErasing                         31 µs +-     64 µs
ConvertImageDtype                    123 µs +-     13 µs
Normalize                            843 µs +-     82 µs

total                               2661 µs
------------------------------------------------------------
input_type='PIL', api_version='v2'

Results computed for 1_000 samples

                                  median          std   
RandomResizedCropWithoutResizeV2     163 µs +-     26 µs
Resize                               788 µs +-    180 µs
RandomHorizontalFlip                  62 µs +-     33 µs
AutoAugment                          492 µs +-    355 µs
PILToTensor                          112 µs +-      9 µs
RandomErasing                         37 µs +-     64 µs
ConvertDtype                         100 µs +-     11 µs
Normalize                            826 µs +-     86 µs

total                               2580 µs
------------------------------------------------------------
input_type='Datapoint', api_version='v2'

Results computed for 1_000 samples

                                  median          std   
ToImageTensor                        284 µs +-     22 µs
RandomResizedCropWithoutResizeV2     118 µs +-     17 µs
ResizeUint8                          410 µs +-    109 µs
RandomHorizontalFlip                  68 µs +-     23 µs
AutoAugment                          994 µs +-    542 µs
RandomErasing                         38 µs +-     63 µs
ConvertDtype                          81 µs +-     10 µs
Normalize                            133 µs +-     14 µs

total                               2127 µs
------------------------------------------------------------

Summaries

           v2 / v1
Tensor        0.65
PIL           0.97

                     [a]   [b]   [c]   [d]   [e]
   Tensor, v1, [a]  1.00  1.55  1.22  1.26  1.53
   Tensor, v2, [b]  0.65  1.00  0.79  0.81  0.99
      PIL, v1, [c]  0.82  1.27  1.00  1.03  1.25
      PIL, v2, [d]  0.79  1.23  0.97  1.00  1.21
Datapoint, v2, [e]  0.65  1.01  0.80  0.82  1.00

Slowdown as row / col

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions