Skip to content

Unwrap features before passing them into a kernel #6807

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Oct 21, 2022
Merged

Conversation

pmeier
Copy link
Collaborator

@pmeier pmeier commented Oct 20, 2022

Closes #6781 by going with the self.as_subclass(torch.Tensor) proposal. After this PR, there are no more __torch_function__ calls in the whole pipeline.

In fact, it almost renders #6681 obsolete: ndim, device, and dtype are no longer accessed. Even shape is only accessed 4 times:

elif isinstance(image, (features.Image, features.Video)):
channels = image.num_channels
height, width = image.spatial_size
return [channels, height, width]

elif isinstance(image, (features.Image, features.Video)):
return image.num_channels

elif isinstance(inpt, (features.Image, features.Video, features.BoundingBox, features.Mask)):
return list(inpt.spatial_size)

elif isinstance(inpt, features.Video):
return inpt.num_frames

Since they all only indirectly access .shape through properties we have anyway

@property
def spatial_size(self) -> Tuple[int, int]:
return cast(Tuple[int, int], tuple(self.shape[-2:]))
@property
def num_channels(self) -> int:
return self.shape[-3]

@property
def spatial_size(self) -> Tuple[int, int]:
return cast(Tuple[int, int], tuple(self.shape[-2:]))
@property
def num_channels(self) -> int:
return self.shape[-3]
@property
def num_frames(self) -> int:
return self.shape[-4]

we could also opt to move the with DisableTorchFunction() context manager there and completely remove

# Add properties for common attributes like shape, dtype, device, ndim etc
# this way we return the result without passing into __torch_function__
@property
def shape(self) -> _size: # type: ignore[override]
with DisableTorchFunction():
return super().shape
@property
def ndim(self) -> int: # type: ignore[override]
with DisableTorchFunction():
return super().ndim
@property
def device(self, *args: Any, **kwargs: Any) -> _device: # type: ignore[override]
with DisableTorchFunction():
return super().device
@property
def dtype(self) -> _dtype: # type: ignore[override]
with DisableTorchFunction():
return super().dtype

cc @vfdev-5 @datumbox @bjuncek

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Oct 21, 2022

@pmeier can you provide time measurements on a datapipeline (e.g. image classification + AA + RandomErasing, #6681 (comment))

@pmeier
Copy link
Collaborator Author

pmeier commented Oct 21, 2022

$ python -u main.py cprofile_pil_vs_feature --n=2000
  • main

             1538052 function calls (1523449 primitive calls) in 2.556 seconds
    
       Ordered by: internal time
    
       ncalls  tottime  percall  cumtime  percall filename:lineno(function)
          309    0.276    0.001    0.276    0.001 {built-in method torch.grid_sampler}
         2000    0.268    0.000    0.268    0.000 {built-in method torch._C._nn._upsample_bilinear2d_aa}
         9909    0.211    0.000    0.263    0.000 {method 'to' of 'torch._C._TensorBase' objects}
        16747    0.206    0.000    0.255    0.000 /home/philip/git/pytorch/torchvision/torchvision/prototype/features/_feature.py:63(__torch_function__)
         2348    0.091    0.000    0.091    0.000 {built-in method torch.round}
    1499/1199    0.082    0.000    0.083    0.000 {built-in method torch.where}
         4039    0.055    0.000    0.055    0.000 {method 'clone' of 'torch._C._TensorBase' objects}
         2000    0.048    0.000    1.008    0.001 /home/philip/git/pytorch/torchvision/torchvision/prototype/transforms/_auto_augment.py:284(forward)
         2000    0.044    0.000    0.069    0.000 /home/philip/git/pytorch/torchvision/torchvision/transforms/functional_tensor.py:938(erase)
          899    0.042    0.000    0.042    0.000 {method 'scatter_add_' of 'torch._C._TensorBase' objects}
    
  • PR

             1380923 function calls (1366923 primitive calls) in 2.432 seconds
    
     Ordered by: internal time
    
     ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        309    0.279    0.001    0.279    0.001 {built-in method torch.grid_sampler}
       2000    0.273    0.000    0.273    0.000 {built-in method torch._C._nn._upsample_bilinear2d_aa}
       9909    0.238    0.000    0.238    0.000 {method 'to' of 'torch._C._TensorBase' objects}
       2348    0.090    0.000    0.090    0.000 {built-in method torch.round}
       1199    0.084    0.000    0.084    0.000 {built-in method torch.where}
       4039    0.063    0.000    0.063    0.000 {method 'clone' of 'torch._C._TensorBase' objects}
       5798    0.047    0.000    0.047    0.000 {method 'div_' of 'torch._C._TensorBase' objects}
       2000    0.045    0.000    0.957    0.000 /home/philip/git/pytorch/torchvision/torchvision/prototype/transforms/_auto_augment.py:284(forward)
       2000    0.045    0.000    0.045    0.000 {method 'flip' of 'torch._C._TensorBase' objects}
       2000    0.043    0.000    0.072    0.000 /home/philip/git/pytorch/torchvision/torchvision/transforms/functional_tensor.py:938(erase)
    

The diff that we are interested in:

-     16747    0.206    0.000    0.255    0.000 /home/philip/git/pytorch/torchvision/torchvision/prototype/features/_feature.py:63(__torch_function__)
-     14487    0.022    0.000    0.022    0.000 {method 'as_subclass' of 'torch._C._TensorBase' objects}
+     18896    0.032    0.000    0.032    0.000 {method 'as_subclass' of 'torch._C._TensorBase' objects}

Meaning we are looking at a (0.206 + 0.022 - 0.032) / 2000 roughly 100µs improvement per sample.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Oct 21, 2022

@pmeier by time measurements I meant torch benchmarks and not cprofile (which may be less reliable)

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall. Only one comment below. Also mypy seems to be failing due to an unused ignore.

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@pmeier pmeier merged commit 3761855 into pytorch:main Oct 21, 2022
@pmeier pmeier deleted the unwrap branch October 21, 2022 11:34
facebook-github-bot pushed a commit that referenced this pull request Oct 21, 2022
Summary:
* unwrap features before calling the kernels

* revert double unwrapping

* cleanup

* remove debug raise

* more cleanup

Reviewed By: YosuaMichael

Differential Revision: D40588165

fbshipit-source-id: 3ce277bbe4f47124f572ca8cd185795b5917fe3e
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Remove remaining __torch_function__ calls from regular transform pipeline
4 participants