Skip to content

Commit 7e7031c

Browse files
authored
Wrap kernel calls with a device guard (#19)
1 parent ed71b87 commit 7e7031c

File tree

2 files changed

+15
-0
lines changed

2 files changed

+15
-0
lines changed

csrc/all_to_all/all_to_all.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ class AllToAll {
4040

4141
virtual ~AllToAll();
4242

43+
/// @brief Returns the number of experts each token is routed to.
44+
size_t getNumExpertsPerToken() const { return expertsPerToken; }
45+
4346
protected:
4447
/// The maximum number of tokens per DP group.
4548
const size_t maxNumTokens;

csrc/bindings/all_to_all_ops.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <ATen/ATen.h>
66
#include <ATen/core/Tensor.h>
77
#include <ATen/cuda/CUDAContext.h>
8+
#include <c10/cuda/CUDAGuard.h>
89
#include <c10/util/Exception.h>
910
#include <torch/csrc/distributed/c10d/GroupRegistry.hpp>
1011
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
@@ -152,6 +153,15 @@ void dispatch(
152153
}
153154

154155
auto *all_to_all = (Kernel *)ptr;
156+
157+
TORCH_CHECK(indices.size(0) == dpX.size(0), "indices.size(0) must be equal to dpX.size(0)");
158+
TORCH_CHECK(
159+
indices.size(1) == all_to_all->getNumExpertsPerToken(),
160+
"indices.size(1) must be equal to the experts per token"
161+
);
162+
163+
at::cuda::OptionalCUDAGuard const device_guard(device_of(indices));
164+
155165
all_to_all->dispatch(
156166
Strided1D<int32_t>(
157167
outExpertNumTokens.data_ptr<int32_t>(), (size_t)outExpertNumTokens.stride(0)
@@ -237,6 +247,8 @@ void combine(
237247

238248
auto *all_to_all = (Kernel *)ptr;
239249

250+
at::cuda::OptionalCUDAGuard const device_guard(device_of(indices));
251+
240252
switch (expertY.scalar_type()) {
241253
case at::kFloat: {
242254
switch (outTokens.scalar_type()) {

0 commit comments

Comments
 (0)