-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Issue 2350 - support of all padding modes with tensors #2368
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report
@@ Coverage Diff @@
## master #2368 +/- ##
==========================================
+ Coverage 68.49% 68.52% +0.03%
==========================================
Files 93 93
Lines 7655 7670 +15
Branches 1177 1182 +5
==========================================
+ Hits 5243 5256 +13
Misses 2075 2075
- Partials 337 339 +2
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR Victor!
I have a couple of comments, let me know what you think
test/test_functional_tensor.py
Outdated
if pil_tensor.dtype != tensor.dtype: | ||
pil_tensor = pil_tensor.to(tensor.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Which case led you to have to perform a casting here, is it for the extra float32
and float64
types?
I think it might be preferable to avoid doing implicit casts in this function, and instead perform the cast directly in the caller -- it's more explicit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, I can cast explicitly before instead of this function 👍
|
||
out_dtype = img.dtype | ||
need_cast = False | ||
if img.dtype not in (torch.float32, torch.float64): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could probably be optimized because constant
padding supports uint8
types as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, will fix !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review ! I'll update it accordingly
test/test_functional_tensor.py
Outdated
if pil_tensor.dtype != tensor.dtype: | ||
pil_tensor = pil_tensor.to(tensor.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, I can cast explicitly before instead of this function 👍
|
||
out_dtype = img.dtype | ||
need_cast = False | ||
if img.dtype not in (torch.float32, torch.float64): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, will fix !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks a lot!
* [WIP] functional_tensor supports more padding modes * [WIP] Support all padding modes * Removed wip symmetric mode * Improvements according to the review
Fixes #2350
Description:
functional_tensor.pad