Skip to content

Remove remaining __torch_function__ calls from regular transform pipeline #6781

@pmeier

Description

@pmeier

On our journey to micro-optimize transforms v2, we recently merged #6681. This significantly reduced the number of __torch_function__ calls inside the augmentation pipeline. However, we still have some unnecessary calls to it that originate from this idiom that we have on the feature methods:

def horizontal_flip(self) -> Image:
output = self._F.horizontal_flip_image_tensor(self)

The kernel, here horizontal_flip_image_tensor, is called with self, i.e. a features.Image. Internally, the first operation will thus go through __torch_function__ and unwrap it. Meaning, we are going through the whole protocol for no reason, since the kernel per definition works with plain tensors.

We could simply unwrap, i.e. self.as_subclass(torch.Tensor) before we call the kernel. That would make our implementation a little more verbose, but would prevent any __torch_function__ calls when using our transformations whatsoever. Of course we could also have a self.unwrap() method to reduce the boilerplate. Another option is to move the kernel call under the DisableTorchFunction context manager, similar to what we did in #6681.

from time import perf_counter_ns

import torch
from torch._C import DisableTorchFunction
from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F


tensor = torch.rand(3, 256, 256)
image = features.Image(tensor)

arbitrary_unary_image_tensor_kernel = F.autocontrast_image_tensor


def minimum():
    arbitrary_unary_image_tensor_kernel(tensor)


def unwrap_boilerplate():
    arbitrary_unary_image_tensor_kernel(image.as_subclass(torch.Tensor))


def _unwrap_helper(feature):
    return feature.as_subclass(torch.Tensor)


def unwrap_with_helper_method():
    arbitrary_unary_image_tensor_kernel(_unwrap_helper(image))


def context_manager():
    with DisableTorchFunction():
        arbitrary_unary_image_tensor_kernel(image)


def baseline():
    arbitrary_unary_image_tensor_kernel(image)


for fn in [minimum, unwrap_boilerplate, unwrap_with_helper_method, context_manager, baseline]:
    for _ in range(1_000):
        fn()

    time_diffs_ns = []
    for _ in range(100_000):
        start = perf_counter_ns()
        fn()
        stop = perf_counter_ns()
        time_diffs_ns.append(stop - start)

    time_diffs = torch.tensor(time_diffs_ns, dtype=torch.float64) * 1e-9
    print(f"{fn.__name__}: median={time_diffs.median() * 1e6:.2f} µs, std={time_diffs.std() * 1e6:.2f} µs")
minimum: median=69.79 µs, std=3.61 µs
unwrap_boilerplate: median=74.35 µs, std=5.30 µs
unwrap_with_helper_method: median=74.65 µs, std=5.57 µs
context_manager: median=76.62 µs, std=3.58 µs
baseline: median=94.84 µs, std=4.42 µs

Benchmarking shows that unwrapping is slightly faster, but the difference to the context manager is small enough to also opt for consistency with #6681 and use it instead.

cc @vfdev-5 @datumbox @bjuncek

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions