-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Description
On our journey to micro-optimize transforms v2, we recently merged #6681. This significantly reduced the number of __torch_function__
calls inside the augmentation pipeline. However, we still have some unnecessary calls to it that originate from this idiom that we have on the feature methods:
vision/torchvision/prototype/features/_image.py
Lines 125 to 126 in 149edda
def horizontal_flip(self) -> Image: | |
output = self._F.horizontal_flip_image_tensor(self) |
The kernel, here horizontal_flip_image_tensor
, is called with self
, i.e. a features.Image
. Internally, the first operation will thus go through __torch_function__
and unwrap it. Meaning, we are going through the whole protocol for no reason, since the kernel per definition works with plain tensors.
We could simply unwrap, i.e. self.as_subclass(torch.Tensor)
before we call the kernel. That would make our implementation a little more verbose, but would prevent any __torch_function__
calls when using our transformations whatsoever. Of course we could also have a self.unwrap()
method to reduce the boilerplate. Another option is to move the kernel call under the DisableTorchFunction
context manager, similar to what we did in #6681.
from time import perf_counter_ns
import torch
from torch._C import DisableTorchFunction
from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F
tensor = torch.rand(3, 256, 256)
image = features.Image(tensor)
arbitrary_unary_image_tensor_kernel = F.autocontrast_image_tensor
def minimum():
arbitrary_unary_image_tensor_kernel(tensor)
def unwrap_boilerplate():
arbitrary_unary_image_tensor_kernel(image.as_subclass(torch.Tensor))
def _unwrap_helper(feature):
return feature.as_subclass(torch.Tensor)
def unwrap_with_helper_method():
arbitrary_unary_image_tensor_kernel(_unwrap_helper(image))
def context_manager():
with DisableTorchFunction():
arbitrary_unary_image_tensor_kernel(image)
def baseline():
arbitrary_unary_image_tensor_kernel(image)
for fn in [minimum, unwrap_boilerplate, unwrap_with_helper_method, context_manager, baseline]:
for _ in range(1_000):
fn()
time_diffs_ns = []
for _ in range(100_000):
start = perf_counter_ns()
fn()
stop = perf_counter_ns()
time_diffs_ns.append(stop - start)
time_diffs = torch.tensor(time_diffs_ns, dtype=torch.float64) * 1e-9
print(f"{fn.__name__}: median={time_diffs.median() * 1e6:.2f} µs, std={time_diffs.std() * 1e6:.2f} µs")
minimum: median=69.79 µs, std=3.61 µs
unwrap_boilerplate: median=74.35 µs, std=5.30 µs
unwrap_with_helper_method: median=74.65 µs, std=5.57 µs
context_manager: median=76.62 µs, std=3.58 µs
baseline: median=94.84 µs, std=4.42 µs
Benchmarking shows that unwrapping is slightly faster, but the difference to the context manager is small enough to also opt for consistency with #6681 and use it instead.