@@ -113,6 +113,92 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids,
113113 }
114114}
115115
116+ // TODO(simon): this is temporarily adapted from
117+ // https://github.com/sgl-project/sglang/commit/31548116a8dc8c6df7e146e0587335a59fc5b9d7
118+ // we did this to unblock Deepseek V3 but there should be a better
119+ // implementation to manage shared memory.
120+ template <typename scalar_t >
121+ __global__ void moe_align_block_size_global_mem_kernel (
122+ scalar_t * __restrict__ topk_ids, int32_t * sorted_token_ids,
123+ int32_t * expert_ids, int32_t * total_tokens_post_pad, int32_t num_experts,
124+ int32_t block_size, size_t numel, int32_t * tokens_cnts, int32_t * cumsum) {
125+ const size_t tokens_per_thread = CEILDIV (numel, blockDim .x );
126+ const size_t start_idx = threadIdx .x * tokens_per_thread;
127+
128+ for (int i = 0 ; i < num_experts; ++i) {
129+ tokens_cnts[index (num_experts, threadIdx .x + 1 , i)] = 0 ;
130+ }
131+
132+ /* *
133+ * In the first step we compute token_cnts[thread_index + 1][expert_index],
134+ * which counts how many tokens in the token shard of thread_index are
135+ * assigned to expert expert_index.
136+ */
137+ for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
138+ ++tokens_cnts[index (num_experts, threadIdx .x + 1 , topk_ids[i])];
139+ }
140+
141+ __syncthreads ();
142+
143+ // For each expert we accumulate the token counts from the different threads.
144+ if (threadIdx .x < num_experts) {
145+ tokens_cnts[index (num_experts, 0 , threadIdx .x )] = 0 ;
146+ for (int i = 1 ; i <= blockDim .x ; ++i) {
147+ tokens_cnts[index (num_experts, i, threadIdx .x )] +=
148+ tokens_cnts[index (num_experts, i - 1 , threadIdx .x )];
149+ }
150+ }
151+
152+ __syncthreads ();
153+
154+ // We accumulate the token counts of all experts in thread 0.
155+ if (threadIdx .x == 0 ) {
156+ cumsum[0 ] = 0 ;
157+ for (int i = 1 ; i <= num_experts; ++i) {
158+ cumsum[i] = cumsum[i - 1 ] +
159+ CEILDIV (tokens_cnts[index (num_experts, blockDim .x , i - 1 )],
160+ block_size) *
161+ block_size;
162+ }
163+ *total_tokens_post_pad = cumsum[num_experts];
164+ }
165+
166+ __syncthreads ();
167+
168+ /* *
169+ * For each expert, each thread processes the tokens of the corresponding
170+ * blocks and stores the corresponding expert_id for each block.
171+ */
172+ if (threadIdx .x < num_experts) {
173+ for (int i = cumsum[threadIdx .x ]; i < cumsum[threadIdx .x + 1 ];
174+ i += block_size) {
175+ expert_ids[i / block_size] = threadIdx .x ;
176+ }
177+ }
178+
179+ /* *
180+ * Each thread processes a token shard, calculating the index of each token
181+ * after sorting by expert number. Given the example topk_ids =
182+ * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *,
183+ * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a
184+ * padding value(preset in python).
185+ */
186+ for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
187+ int32_t expert_id = topk_ids[i];
188+ /* * The cumsum[expert_id] stores the starting index of the tokens that the
189+ * expert with expert_id needs to process, and
190+ * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
191+ * processed by the expert with expert_id within the current thread's token
192+ * shard.
193+ */
194+ int32_t rank_post_pad =
195+ tokens_cnts[index (num_experts, threadIdx .x , expert_id)] +
196+ cumsum[expert_id];
197+ sorted_token_ids[rank_post_pad] = i;
198+ ++tokens_cnts[index (num_experts, threadIdx .x , expert_id)];
199+ }
200+ }
201+
116202template <typename scalar_t , int TOPK>
117203__global__ void moe_sum_kernel (
118204 scalar_t * __restrict__ out, // [..., d]
@@ -137,25 +223,61 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
137223 torch::Tensor experts_ids,
138224 torch::Tensor num_tokens_post_pad) {
139225 const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
140- VLLM_DISPATCH_INTEGRAL_TYPES (
141- topk_ids.scalar_type (), " moe_align_block_size_kernel" , [&] {
142- // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
143- // tensors
144- const int32_t num_thread = max ((int32_t )num_experts, WARP_SIZE);
145- const int32_t shared_mem =
146- ((num_thread + 1 ) * num_experts + (num_experts + 1 )) *
147- sizeof (int32_t );
148-
149- // set dynamic shared mem
150- auto kernel = vllm::moe::moe_align_block_size_kernel<scalar_t >;
151- AT_CUDA_CHECK (VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize (
152- (void *)kernel, shared_mem));
153- kernel<<<1 , num_thread, shared_mem, stream>>> (
154- topk_ids.data_ptr <scalar_t >(), sorted_token_ids.data_ptr <int32_t >(),
155- experts_ids.data_ptr <int32_t >(),
156- num_tokens_post_pad.data_ptr <int32_t >(), num_experts, block_size,
157- topk_ids.numel ());
158- });
226+
227+ // If we have very large number of experts, we can no longer use shared
228+ // memory.
229+ // TODO(simon): the right solution should be calculating the exact right
230+ // amount of shared memory and use that. The num_experts >= 256 is just a
231+ // temporary solution to unblock Deepseek V3.
232+ if (num_experts >= 256 ) {
233+ VLLM_DISPATCH_INTEGRAL_TYPES (
234+ topk_ids.scalar_type (), " moe_align_block_size_global_mem_kernel" , [&] {
235+ // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
236+ // tensors
237+ const int32_t num_thread = max ((int32_t )num_experts, WARP_SIZE);
238+
239+ const int32_t mem_tokens_cnts =
240+ ((num_experts + 1 ) * num_experts) * sizeof (int32_t );
241+ const int32_t mem_cumsum = (num_experts + 1 ) * sizeof (int32_t );
242+ // allocate global memory
243+ int32_t * tokens_cnts;
244+ int32_t * cumsum;
245+ cudaMalloc (&tokens_cnts, mem_tokens_cnts);
246+ cudaMalloc (&cumsum, mem_cumsum);
247+
248+ auto kernel =
249+ vllm::moe::moe_align_block_size_global_mem_kernel<scalar_t >;
250+ kernel<<<1 , num_thread, 0 , stream>>> (
251+ topk_ids.data_ptr <scalar_t >(),
252+ sorted_token_ids.data_ptr <int32_t >(),
253+ experts_ids.data_ptr <int32_t >(),
254+ num_tokens_post_pad.data_ptr <int32_t >(), num_experts, block_size,
255+ topk_ids.numel (), tokens_cnts, cumsum);
256+ cudaFree (tokens_cnts);
257+ cudaFree (cumsum);
258+ });
259+ } else {
260+ VLLM_DISPATCH_INTEGRAL_TYPES (
261+ topk_ids.scalar_type (), " moe_align_block_size_kernel" , [&] {
262+ // calc needed amount of shared mem for `tokens_cnts` and `cumsum`
263+ // tensors
264+ const int32_t num_thread = max ((int32_t )num_experts, WARP_SIZE);
265+ const int32_t shared_mem =
266+ ((num_thread + 1 ) * num_experts + (num_experts + 1 )) *
267+ sizeof (int32_t );
268+
269+ // set dynamic shared mem
270+ auto kernel = vllm::moe::moe_align_block_size_kernel<scalar_t >;
271+ AT_CUDA_CHECK (VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize (
272+ (void *)kernel, shared_mem));
273+ kernel<<<1 , num_thread, shared_mem, stream>>> (
274+ topk_ids.data_ptr <scalar_t >(),
275+ sorted_token_ids.data_ptr <int32_t >(),
276+ experts_ids.data_ptr <int32_t >(),
277+ num_tokens_post_pad.data_ptr <int32_t >(), num_experts, block_size,
278+ topk_ids.numel ());
279+ });
280+ }
159281}
160282
161283void moe_sum (torch::Tensor& input, // [num_tokens, topk, hidden_size]
0 commit comments