Skip to content

Commit 1d76f48

Browse files
authored
Fix potential bugs in internode dispatch kernel (#25)
When I review the code, I think there are two issues in current internode dispatch kernel: 1. The size of shared memory allocation `sharedMemoryRecv` during kernel launch is set to `sizeof(uint32_t) * expertsPerBlock`. However during recv in the kernel, both `sharedExpert` and `sharedToken` uses shared memory, with each an `uint32_t` array of `expertsPerBlock`. So we need to double the allocation size to `sizeof(uint32_t) * expertsPerBlock * 2` 2. When synchronizing the number of tokens received, the for loop should increment with `blockDim.x`, instead of `gridDim.x * expertsPerBlock`
1 parent c336faf commit 1d76f48

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

csrc/all_to_all/internode_dispatch.cu

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,7 @@ __global__ __launch_bounds__(NUM_WARPS * 32, 1) void dispatchKernel(
183183
unsigned firstGroup = blockIdx.x * expertsPerBlock;
184184
unsigned lastGroup = std::min(firstGroup + expertsPerBlock, numExpertsAndGroups);
185185

186-
for (unsigned group = firstGroup + threadIdx.x; group < lastGroup;
187-
group += gridDim.x * expertsPerBlock) {
186+
for (unsigned group = firstGroup + threadIdx.x; group < lastGroup; group += blockDim.x) {
188187
const uint32_t expert = group / numDPGroups;
189188

190189
// Fetch the token count per DP, which is non-zero to indicate receipt.
@@ -276,7 +275,7 @@ void AllToAllInterNode::dispatch(
276275

277276
const size_t expertsPerBlock = ceil_div<size_t>(numLocalExperts * numDPGroups, numBlocks);
278277
const size_t sharedMemorySend = sizeof(uint32_t) * numExperts;
279-
const size_t sharedMemoryRecv = sizeof(uint32_t) * expertsPerBlock;
278+
const size_t sharedMemoryRecv = sizeof(uint32_t) * expertsPerBlock * 2;
280279

281280
void *args[] = {
282281
const_cast<int32_t **>(&outNumTokensPerExpert.data),

0 commit comments

Comments
 (0)