You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[SymmetricMemory] introduce a binding for cuMemset32Async (pytorch#138755)
## This Stack
This stack does the following things to support `xformers`-style, comm-aware Triton kernels:
- Exposes `signal_pad`s as tensors in Python
- Adds a binding for `cuMemsetAsync`
These in combination aims to provide users with more flexibility to express custom signaling/synchronization patterns.
## This PR
Make `cuMemset32Async` available via `_SymmetricMemory.memset32`. We chose `cuMemset32Async` over `cudaMemsetAsync` because it allows for `uint32_t`-wise memset. This provides users with better flexibility.
To enable this, we also added the following cuda driver APIs in `c10::cuda::DriverAPI`:
- `cuDevicePrimaryCtxRetain` - for obtaining the primary context of a device in the form of `CUcontext`.
- `cuCtxGetCurrent`/`cuCtxSetCurrent` - for setting and restoring the context for cuda driver APIs such as `cuMemset32Async`.
Pull Request resolved: pytorch#138755
Approved by: https://github.com/weifengpy, https://github.com/eqy, https://github.com/lw
0 commit comments