Skip to content

Conversation

pmeier
Copy link
Collaborator

@pmeier pmeier commented Apr 20, 2019

Edit: I only now became aware of #855 and especially a comment from @fmassa about the future of the torchvision.transform: if this PR does not fit into the grand scheme of the upcoming release, feel free to close it.

TL,DR

This PR does not add any new functionality, but streamlines the __repr__ methods. In order to do so it introduces the superclasses Transform and TransformContainer.

Details

  • The __repr__ methods of the Transform and TransformContainer objects are based the idea of torch.nn.Module: each subclass has the possibility to add information by overriding the extra_repr method.
  • Compose, RandomTransforms, RandomApply, RandomOrder, RandomChoice are moved to container.py
  • RandomTransforms is superseded by TransformContainer, since its only tasks were to handle the __repr__ method and make sure the subclasses implement the __call__ method.
  • Some tests in test_transforms.py fail, but the same tests also fail on the master branch

@codecov-io
Copy link

codecov-io commented Apr 20, 2019

Codecov Report

Merging #861 into master will decrease coverage by 0.11%.
The diff coverage is 90.9%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #861      +/-   ##
==========================================
- Coverage   54.49%   54.37%   -0.12%     
==========================================
  Files          36       37       +1     
  Lines        3307     3292      -15     
  Branches      542      539       -3     
==========================================
- Hits         1802     1790      -12     
+ Misses       1372     1369       -3     
  Partials      133      133
Impacted Files Coverage Δ
torchvision/transforms/__init__.py 100% <100%> (ø) ⬆️
torchvision/transforms/transforms.py 82.35% <89.02%> (-0.34%) ⬇️
torchvision/transforms/container.py 93.87% <93.87%> (ø)
torchvision/transforms/functional.py 69.2% <0%> (-1.27%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update ccbb322...bf1b6f1. Read the comment docs.

@fmassa
Copy link
Member

fmassa commented Apr 23, 2019

Hi,

Thanks a lot for the PR!

While I definitely agree that adding some more structure is a good thing, I'm still not convinced that the current way the transforms are implemented is very adapted when dealing with more types of data.

In particular, I'm more and more inclined that for more complex use-cases the user should just write their composite transform themselves by using the functional interface, in a similar way that we don't provide nn.Concat / nn.Split / nn.Select layers, but instead the user write the operations in the forward of the model.

That being said, I have no objections in merging this PR as is, but I might expect that a few things will change in the future, which might make some of the changes in this PR obsolete. Also, while the new classes do add some more structure, we don't significantly save much typing, precisely because the scope of what should be in the Transform is not very well defined.

Thoughts?

@pmeier
Copy link
Collaborator Author

pmeier commented Apr 23, 2019

I'm most certainly don't have the whole picture, so please correct me if I make some wrong assumptions.

That being said, I have no objections in merging this PR as is, but I might expect that a few things will change in the future, which might make some of the changes in this PR obsolete.

I don't mind if it gets replaced by something more fitting in the future.

[...] adding some more structure is a good thing [...]

I was wondering if I should create a separate module for all Random* transforms. Is this something I should add?

[...] I'm more and more inclined that for more complex use-cases the user should just write their composite transform themselves [...]

I agree with you here. In this case I think having a superclass, from which the user can inherit from, is IMHO even more needed. By mimicking the structure of nn.Module the learning curve is low, for users familiar with PyTorch.

[...] we don't significantly save much typing, precisely because the scope of what should be in the Transform is not very well defined.

I'm not sure if I got your point here. Can you elaborate?

@fmassa
Copy link
Member

fmassa commented Apr 25, 2019

@pmeier

I was wondering if I should create a separate module for all Random* transforms. Is this something I should add?

What would be the restrictions / requirements that the Random transforms should have?

Also, before the first release of PyTorch, we used to have a nn.Module and a nn.Container. The separation was that nn.Module would perform "single" operations, and the nn.Container would fuse them all together (like nn.Sequential). Because the distinction between both didn't really add much value, we just fused them both together, and have everything be a nn.Module.
In this vein, I'm wondering what would the extra RandomTransform module provide that is not already inside Transform / object? The get_params?

I'd love to hear your thoughts on this

@pmeier
Copy link
Collaborator Author

pmeier commented Apr 26, 2019

What would be the restrictions / requirements that the Random transforms should have?

None, other than providing more structure. On a second thought this makes little sense, since

  1. they all have Random in their name making their intention pretty clear and
  2. they are in fact TransformContainers which might lead to more confusion.
    In summary: scratch that thought ;)

We could mimic the structure of nn.Module even more in the future:

  • Transform is an abstract class and handles __repr__ and makes sure __call__ is implemented.
  • TransformContainer is also abstract and inherits from Transform. It makes sure __getitem__ etc. is implemented. This intermediate class could be left out as is done in torch.nn
  • Compose and the Random containers inherit from TransformContainer.

Internally the Transform could store all transforms inside an OrderedDict for easy access.

If that is something we want to pursue, I will open a discussion as issue summarising all ideas I can find within the issues and PRs about this topic.

@fmassa
Copy link
Member

fmassa commented Apr 26, 2019

If that is something we want to pursue, I will open a discussion as issue summarising all ideas I can find within the issues and PRs about this topic.

@pmeier I think this is something we want to do.
I've been thinking for some time to actually make Transform simply be a nn.Module and use torch.Tensor everywhere. This would enable the transforms to be traceable, and to be run on the GPU.

But there are some complications there. For example, we might have different image sizes that comes out from the dataset in this case. So we would need a way to efficiently represent a list of tensors of different sizes, something like a RaggedTensor. Plus, we will also want to perform random transformations independently and in parallel, which is not yet that simple to do and would require custom CUDA code.

Here are my thoughts:

  • the dataset returns torch.Tensor instead of a PIL image, potentially a uint8 tensor, instead of converting it to float.
  • the collate_fn returns a RaggedTensor instead of erroing on tensors with different sizes
  • in the training loop, the RaggedTensor gets moved efficiently to the GPU in a single go
  • the first layer of the network is the transformation, including resizing / cropping / converting to float and normalizing to 0-1 , mean / std subtraction.

We can nowadays sample random transformation parameters efficiently for each image. But performing the geometric transformations in the batch in one go is not easy to implement efficiently as of now. This is a blocker so that this idea could get implemented. I'm not yet 100% clear on what kind of API we would need for that.

And of course, we are not tied to using torch.Tensor. The current way of using PIL to perform data transformation will still work, but will not be traceable.

I'm ccing @cpuhrsch , as he is investigating RaggedTensor as a potential addition to PyTorch.

@pmeier
Copy link
Collaborator Author

pmeier commented Apr 26, 2019

I think this is something we want to do.

In that case I will get to it. I will CC you.

I've been thinking for some time to actually make Transform simply be a nn.Module and use torch.Tensor everywhere.

This would make a lot things easier. AFAIK we could use nn.Module without ever touching torch.Tensors. Is that what you mean by

And of course, we are not tied to using torch.Tensor. The current way of using PIL to perform data transformation will still work, but will not be traceable.

?

If that is the case we could simply let Transform inherit from nn.Module and replace __call__ by forward as a first step.

@fmassa
Copy link
Member

fmassa commented Apr 26, 2019

If that is the case we could simply let Transform inherit from nn.Module and replace call by forward as a first step.

Yes, but I'd rather not rush things just yet. There are not yet any real benefit for making the transforms a nn.Module, there are a few things that should be done beforehand so that making such a change would really add value.

@pmeier pmeier closed this Jul 3, 2019
@pmeier pmeier deleted the refactor_transforms branch July 3, 2019 09:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants