Skip to content

Commit ee42a99

Browse files
Yifu Wangpytorchmergebot
authored andcommitted
[SymmetricMemory] introduce a binding for cuMemset32Async (pytorch#138755)
## This Stack This stack does the following things to support `xformers`-style, comm-aware Triton kernels: - Exposes `signal_pad`s as tensors in Python - Adds a binding for `cuMemsetAsync` These in combination aims to provide users with more flexibility to express custom signaling/synchronization patterns. ## This PR Make `cuMemset32Async` available via `_SymmetricMemory.memset32`. We chose `cuMemset32Async` over `cudaMemsetAsync` because it allows for `uint32_t`-wise memset. This provides users with better flexibility. To enable this, we also added the following cuda driver APIs in `c10::cuda::DriverAPI`: - `cuDevicePrimaryCtxRetain` - for obtaining the primary context of a device in the form of `CUcontext`. - `cuCtxGetCurrent`/`cuCtxSetCurrent` - for setting and restoring the context for cuda driver APIs such as `cuMemset32Async`. Pull Request resolved: pytorch#138755 Approved by: https://github.com/weifengpy, https://github.com/eqy, https://github.com/lw
1 parent 87059d4 commit ee42a99

File tree

6 files changed

+121
-13
lines changed

6 files changed

+121
-13
lines changed

.lintrunner.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,6 @@ include_patterns = [
7070
'aten/src/ATen/native/cudnn/*.cpp',
7171
'c10/**/*.h',
7272
'c10/**/*.cpp',
73-
'distributed/c10d/*DMAConnectivity.*',
74-
'distributed/c10d/*SymmetricMemory.*',
7573
'torch/csrc/**/*.h',
7674
'torch/csrc/**/*.hpp',
7775
'torch/csrc/**/*.cpp',

c10/cuda/driver_api.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
_(cuMemGetAllocationGranularity) \
3131
_(cuMemExportToShareableHandle) \
3232
_(cuMemImportFromShareableHandle) \
33+
_(cuMemsetD32Async) \
3334
_(cuStreamWriteValue32) \
3435
_(cuGetErrorString)
3536

caffe2/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,7 @@ if(USE_CUDA)
562562
${TORCH_SRC_DIR}/csrc/distributed/c10d/intra_node_comm.cpp
563563
${TORCH_SRC_DIR}/csrc/distributed/c10d/CudaDMAConnectivity.cpp
564564
${TORCH_SRC_DIR}/csrc/distributed/c10d/CUDASymmetricMemory.cu
565+
${TORCH_SRC_DIR}/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu
565566
${TORCH_SRC_DIR}/csrc/distributed/c10d/ProcessGroupNCCL.cpp
566567
PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1"
567568
)

test/distributed/test_symmetric_memory.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,11 @@
2424
from torch.testing._internal.common_utils import (
2525
instantiate_parametrized_tests,
2626
parametrize,
27+
requires_cuda,
2728
run_tests,
2829
skip_but_pass_in_sandcastle_if,
2930
skipIfRocm,
31+
TestCase,
3032
)
3133

3234

@@ -849,5 +851,32 @@ def func_3(x):
849851
self.assertNotIn("return (buf0", code_3)
850852

851853

854+
class SymmMemSingleProcTest(TestCase):
855+
@skipIfRocm
856+
@requires_cuda
857+
def test_memset32(self):
858+
t = _SymmetricMemory.empty_strided_p2p(
859+
(64,),
860+
(1,),
861+
dtype=torch.uint32,
862+
device=torch.device("cuda:0"),
863+
group_name="0",
864+
).fill_(0)
865+
866+
_SymmetricMemory.memset32(t, offset=32, val=1, count=16)
867+
self.assertTrue(t[:32].eq(0).all())
868+
self.assertTrue(t[32:48].eq(1).all())
869+
self.assertTrue(t[48:].eq(0).all())
870+
871+
with self.assertRaises(RuntimeError):
872+
_SymmetricMemory.memset32(t, offset=-1, val=1, count=16)
873+
874+
with self.assertRaises(RuntimeError):
875+
_SymmetricMemory.memset32(t, offset=32, val=4294967296, count=16)
876+
877+
with self.assertRaises(RuntimeError):
878+
_SymmetricMemory.memset32(t, offset=32, val=1, count=-1)
879+
880+
852881
if __name__ == "__main__":
853882
run_tests()

torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu

Lines changed: 73 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
1-
#if defined(CUDART_VERSION) && CUDART_VERSION >= 12030
2-
31
#include <ATen/ATen.h>
42
#include <ATen/ceil_div.h>
53
#include <ATen/cuda/CUDAContext.h>
4+
#include <c10/cuda/CUDAGuard.h>
5+
#include <torch/library.h>
6+
7+
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
8+
#include <c10/cuda/driver_api.h>
9+
#endif
10+
11+
#if defined(CUDART_VERSION) && CUDART_VERSION >= 12030
612

713
#ifndef AT_PER_OPERATOR_HEADERS
814
#include <ATen/Functions.h>
@@ -11,8 +17,6 @@
1117
#include <ATen/ops/empty_like.h>
1218
#endif
1319

14-
#include <torch/library.h>
15-
1620
#include <torch/csrc/distributed/c10d/CUDASymmetricMemory-inl.h>
1721
#include <torch/csrc/distributed/c10d/CUDASymmetricMemory.hpp>
1822

@@ -491,7 +495,61 @@ at::Tensor two_shot_all_reduce_(
491495
return input;
492496
}
493497
498+
} // namespace
499+
#endif // #if defined(CUDART_VERSION) && CUDART_VERSION >= 12030
500+
501+
namespace {
502+
503+
at::Tensor memset32_(
504+
at::Tensor& input,
505+
int64_t offset,
506+
int64_t val,
507+
int64_t count) {
508+
#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
509+
TORCH_CHECK(
510+
input.dim() == 1 && input.is_contiguous() &&
511+
input.scalar_type() == c10::ScalarType::UInt32,
512+
"symm_mem::memset32_: input must be a flat, contiguous uint32 tensor.");
513+
514+
TORCH_CHECK(
515+
offset > 0 && count > 0,
516+
"symm_mem::memset32_: offset and count must be positive integers.");
517+
518+
TORCH_CHECK(
519+
val >= 0 &&
520+
static_cast<size_t>(val) <= std::numeric_limits<uint32_t>::max(),
521+
"symm_mem::memset32_: val must be in the range of "
522+
"[0, 4294967295] (uint32_t).")
523+
524+
auto element_size = c10::elementSize(input.scalar_type());
525+
TORCH_CHECK(
526+
offset + count < input.numel(),
527+
"symm_mem::memset32_: offset + count (",
528+
offset + count,
529+
") exceeded the numel of the input (",
530+
input.numel(),
531+
")");
532+
533+
auto addr = reinterpret_cast<uint32_t*>(input.data_ptr()) + offset;
534+
535+
c10::cuda::CUDAGuard guard(input.device());
536+
auto driver_api = c10::cuda::DriverAPI::get();
537+
C10_CUDA_DRIVER_CHECK(driver_api->cuMemsetD32Async_(
538+
reinterpret_cast<CUdeviceptr>(addr),
539+
val,
540+
count,
541+
at::cuda::getCurrentCUDAStream()));
542+
#else
543+
TORCH_CHECK(
544+
false, "CUDASymmetricMemory requires PYTORCH_C10_DRIVER_API_SUPPORTED");
545+
#endif
546+
return input;
547+
}
548+
549+
} // namespace
550+
494551
TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
552+
#if defined(CUDART_VERSION) && CUDART_VERSION >= 12030
495553
m.def(
496554
"multimem_all_reduce_(Tensor(a!) input, str reduce_op, str group_name) -> Tensor(a!)",
497555
torch::dispatch(c10::DispatchKey::CUDA, ::multimem_all_reduce_),
@@ -519,8 +577,12 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
519577
"one_shot_all_reduce(Tensor input, str reduce_op, str group_name) -> Tensor",
520578
{at::Tag::pt2_compliant_tag});
521579
522-
m.impl("one_shot_all_reduce", torch::dispatch(c10::DispatchKey::Meta, ::one_shot_all_reduce_meta));
523-
m.impl("one_shot_all_reduce", torch::dispatch(c10::DispatchKey::CUDA, ::one_shot_all_reduce));
580+
m.impl(
581+
"one_shot_all_reduce",
582+
torch::dispatch(c10::DispatchKey::Meta, ::one_shot_all_reduce_meta));
583+
m.impl(
584+
"one_shot_all_reduce",
585+
torch::dispatch(c10::DispatchKey::CUDA, ::one_shot_all_reduce));
524586
525587
m.def(
526588
"one_shot_all_reduce_out(Tensor input, str reduce_op, str group_name, Tensor(a!) out) -> Tensor(a!)",
@@ -531,8 +593,9 @@ TORCH_LIBRARY_FRAGMENT(symm_mem, m) {
531593
"two_shot_all_reduce_(Tensor(a!) input, str reduce_op, str group_name) -> Tensor(a!)",
532594
torch::dispatch(c10::DispatchKey::CUDA, ::two_shot_all_reduce_),
533595
{at::Tag::pt2_compliant_tag});
534-
}
535-
536-
} // namespace
537-
538596
#endif
597+
m.def(
598+
"memset32_(Tensor(a!) input, int offset, int val, int count) -> Tensor(a!)",
599+
torch::dispatch(c10::DispatchKey::CUDA, ::memset32_),
600+
{at::Tag::pt2_compliant_tag});
601+
}

torch/csrc/distributed/c10d/init.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1122,7 +1122,23 @@ This class does not support ``__members__`` property.)");
11221122
"stream_write_value32",
11231123
&SymmetricMemory::stream_write_value32,
11241124
py::arg("addr"),
1125-
py::arg("val"));
1125+
py::arg("val"))
1126+
// Util functions that are often used together with symmetric memory but
1127+
// not necessarily directly on symmetric memory.
1128+
.def_static(
1129+
"memset32",
1130+
[](at::Tensor& input, int64_t offset, int64_t val, int64_t count) {
1131+
// The range of `val` is checked inside the op
1132+
auto op = c10::Dispatcher::singleton()
1133+
.findSchemaOrThrow("symm_mem::memset32_", "")
1134+
.typed<at::Tensor(
1135+
at::Tensor&, int64_t, int64_t, int64_t)>();
1136+
return op.call(input, offset, val, count);
1137+
},
1138+
py::arg("input"),
1139+
py::arg("offset"),
1140+
py::arg("val"),
1141+
py::arg("count") = 1);
11261142

11271143
auto store =
11281144
py::class_<::c10d::Store, c10::intrusive_ptr<::c10d::Store>, PythonStore>(

0 commit comments

Comments
 (0)