Skip to content

[proto] Ported LinearTransformation #6458

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 22, 2022
Merged

Conversation

vfdev-5
Copy link
Collaborator

@vfdev-5 vfdev-5 commented Aug 22, 2022

  • Ported LinearTransformation
  • Added tests

Fixes #5538

Comment on lines +71 to +78
shape = inpt.shape
n = shape[-3] * shape[-2] * shape[-1]
if n != self.transformation_matrix.shape[0]:
raise ValueError(
"Input tensor and transformation matrix have incompatible shape."
+ f"[{shape[-3]} x {shape[-2]} x {shape[-1]}] != "
+ f"{self.transformation_matrix.shape[0]}"
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this happen in _get_params?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are no params to generate in this transform. I think moving this code to _get_params and returning empty dict will look a bit strange, IMO

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, the n is parameter, no?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, n is not a parameter, we only check if input is compatible with self.transformation_matrix and that's it. n wont change for the inputs.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n is the product of the number of channels, height, and width. Since it

wont change for the inputs.

we should have it in _get_params with the

image = query_image(sample)
num_channels, height, width = F.get_image_dimensions(image)
n = num_channels * height * width

....

return dict(n=n)

idiom.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vfdev-5 I think that's worth adding here to be consistent with other implementations. For tiny changes like this and the TODO below (#6458 (comment)) it might be worth doing them on the same PR. Otherwise we just increasing the TODO list which makes it harder to track and easier to miss.

Comment on lines +80 to +84
if inpt.device.type != self.mean_vector.device.type:
raise ValueError(
"Input tensor should be on the same device as transformation matrix and mean vector. "
f"Got {inpt.device} vs {self.mean_vector.device}"
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that this is easy to fix, should we maybe only warn and move mean and std to the same device?

Copy link
Collaborator Author

@vfdev-5 vfdev-5 Aug 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is original implementation. Yes, we can move class attributes to provided device. Let's do that in a follow-up PR.

Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM on my side, thanks!

@vfdev-5 vfdev-5 merged commit aea748b into pytorch:main Aug 22, 2022
@vfdev-5 vfdev-5 deleted the proto-linear-tf branch August 22, 2022 09:58
@github-actions
Copy link

Hey @vfdev-5!

You merged this PR, but no labels were added. The list of valid labels is available at https://github.com/pytorch/vision/blob/main/.github/process_commit.py

Comment on lines +65 to +66
elif isinstance(inpt, PIL.Image.Image):
raise TypeError("Unsupported input type")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Idiomatic this should happen in forward with a has_any check.

@vfdev-5
Copy link
Collaborator Author

vfdev-5 commented Aug 22, 2022

@pmeier I do not agree with your _get_params comment. Your second idiomatic usage of has_any comment seems like an overhead to me. I let you send a follow-up PR with these suggestions if these changes are important. Thanks

facebook-github-bot pushed a commit that referenced this pull request Aug 25, 2022
Summary:
* WIP

* Fixed dtype correction and tests

* Removed PIL Image support and output always Tensor

Reviewed By: datumbox

Differential Revision: D39013682

fbshipit-source-id: e98d9ba0f2ea703cfb9807dbc9abcd34e5db3331
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Port transforms.LinearTransformation to prototype.transforms
4 participants