-
Notifications
You must be signed in to change notification settings - Fork 689
[ROCm] group_index_select_or_add_2d_kernel forward pass optimization #5078
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ROCm] group_index_select_or_add_2d_kernel forward pass optimization #5078
Conversation
…ices for group_index_select_or_add_2d_kernel
✅ Deploy Preview for pytorch-fbgemm-docs ready!
To edit notification comments on pull requests, go to your Netlify project configuration. |
|
|
||
| // The wave size is forced to be 32 on ROCm devices in favor | ||
| // of granularity losses reduction. | ||
| constexpr int EMULATED_WARP_SIZE = 32; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we ensure that EMULATED_WARP_SIZE = kWarpSize for CUDA?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated this in the internal diff.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in 799dad0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
merged.
…h#5078) Summary: This PR introduces optimization for `group_index_select_or_add_2d_kernel` (`USE_INDEX_SELECT==true`) kernel with primary focus on `float` type and relatively small embedding dimensions. 2 things are implemented: 1) Extracted the common variables out of the loop to omit unnecessary synchronizations on memory load (compiler won't do that automatically) 2) Switch to 32 threads logical wave sizes to reduce granularity losses. Differential Revision: D86135611 Pulled By: q10
…h#5080) Summary: Pull Request resolved: pytorch#5080 X-link: https://github.com/facebookresearch/FBGEMM/pull/2087 This PR introduces optimization for `group_index_select_or_add_2d_kernel` (`USE_INDEX_SELECT==true`) kernel with primary focus on `float` type and relatively small embedding dimensions. 2 things are implemented: 1) Extracted the common variables out of the loop to omit unnecessary synchronizations on memory load (compiler won't do that automatically) 2) Switch to 32 threads logical wave sizes to reduce granularity losses. Pull Request resolved: pytorch#5078 Reviewed By: spcyppt, haoyuz Differential Revision: D86135611 Pulled By: q10 fbshipit-source-id: f4fb9966f5f5180c4dde2aed92ca726c260b7743
This PR introduces optimization for
group_index_select_or_add_2d_kernel(USE_INDEX_SELECT==true) kernel with primary focus onfloattype and relatively small embedding dimensions. 2 things are implemented: