-
Notifications
You must be signed in to change notification settings - Fork 7.1k
[proto] Ported LinearTransformation #6458
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
shape = inpt.shape | ||
n = shape[-3] * shape[-2] * shape[-1] | ||
if n != self.transformation_matrix.shape[0]: | ||
raise ValueError( | ||
"Input tensor and transformation matrix have incompatible shape." | ||
+ f"[{shape[-3]} x {shape[-2]} x {shape[-1]}] != " | ||
+ f"{self.transformation_matrix.shape[0]}" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this happen in _get_params
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are no params to generate in this transform. I think moving this code to _get_params
and returning empty dict will look a bit strange, IMO
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, the n
is parameter, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, n
is not a parameter, we only check if input is compatible with self.transformation_matrix
and that's it. n
wont change for the inputs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
n
is the product of the number of channels, height, and width. Since it
wont change for the inputs.
we should have it in _get_params
with the
image = query_image(sample)
num_channels, height, width = F.get_image_dimensions(image)
n = num_channels * height * width
....
return dict(n=n)
idiom.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vfdev-5 I think that's worth adding here to be consistent with other implementations. For tiny changes like this and the TODO below (#6458 (comment)) it might be worth doing them on the same PR. Otherwise we just increasing the TODO list which makes it harder to track and easier to miss.
if inpt.device.type != self.mean_vector.device.type: | ||
raise ValueError( | ||
"Input tensor should be on the same device as transformation matrix and mean vector. " | ||
f"Got {inpt.device} vs {self.mean_vector.device}" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given that this is easy to fix, should we maybe only warn and move mean
and std
to the same device?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is original implementation. Yes, we can move class attributes to provided device. Let's do that in a follow-up PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM on my side, thanks!
Hey @vfdev-5! You merged this PR, but no labels were added. The list of valid labels is available at https://github.com/pytorch/vision/blob/main/.github/process_commit.py |
elif isinstance(inpt, PIL.Image.Image): | ||
raise TypeError("Unsupported input type") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Idiomatic this should happen in forward
with a has_any
check.
@pmeier I do not agree with your |
Summary: * WIP * Fixed dtype correction and tests * Removed PIL Image support and output always Tensor Reviewed By: datumbox Differential Revision: D39013682 fbshipit-source-id: e98d9ba0f2ea703cfb9807dbc9abcd34e5db3331
Fixes #5538