File tree Expand file tree Collapse file tree 2 files changed +15
-0
lines changed Expand file tree Collapse file tree 2 files changed +15
-0
lines changed Original file line number Diff line number Diff line change @@ -40,6 +40,9 @@ class AllToAll {
40
40
41
41
virtual ~AllToAll ();
42
42
43
+ // / @brief Returns the number of experts each token is routed to.
44
+ size_t getNumExpertsPerToken () const { return expertsPerToken; }
45
+
43
46
protected:
44
47
// / The maximum number of tokens per DP group.
45
48
const size_t maxNumTokens;
Original file line number Diff line number Diff line change 5
5
#include < ATen/ATen.h>
6
6
#include < ATen/core/Tensor.h>
7
7
#include < ATen/cuda/CUDAContext.h>
8
+ #include < c10/cuda/CUDAGuard.h>
8
9
#include < c10/util/Exception.h>
9
10
#include < torch/csrc/distributed/c10d/GroupRegistry.hpp>
10
11
#include < torch/csrc/distributed/c10d/ProcessGroup.hpp>
@@ -152,6 +153,15 @@ void dispatch(
152
153
}
153
154
154
155
auto *all_to_all = (Kernel *)ptr;
156
+
157
+ TORCH_CHECK (indices.size (0 ) == dpX.size (0 ), " indices.size(0) must be equal to dpX.size(0)" );
158
+ TORCH_CHECK (
159
+ indices.size (1 ) == all_to_all->getNumExpertsPerToken (),
160
+ " indices.size(1) must be equal to the experts per token"
161
+ );
162
+
163
+ at::cuda::OptionalCUDAGuard const device_guard (device_of (indices));
164
+
155
165
all_to_all->dispatch (
156
166
Strided1D<int32_t >(
157
167
outExpertNumTokens.data_ptr <int32_t >(), (size_t )outExpertNumTokens.stride (0 )
@@ -237,6 +247,8 @@ void combine(
237
247
238
248
auto *all_to_all = (Kernel *)ptr;
239
249
250
+ at::cuda::OptionalCUDAGuard const device_guard (device_of (indices));
251
+
240
252
switch (expertY.scalar_type ()) {
241
253
case at::kFloat : {
242
254
switch (outTokens.scalar_type ()) {
You can’t perform that action at this time.
0 commit comments