44
55#include  < initializer_list> 
66
7+ //  Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
8+ template  <int  experts_per_thread, bool  use_limit>
9+ __device__  void  softmax_warp_inplace (float  (&vals)[experts_per_thread], const  int limit, const  int lane) {
10+     float  max_val = -INFINITY;
11+ 
12+ #pragma  unroll
13+     for  (int  i = 0 ; i < experts_per_thread; i++) {
14+         const  int   idx    = lane + i * WARP_SIZE;
15+         const  bool  active = !use_limit || (idx < limit);
16+         if  (active) {
17+             max_val = max (max_val, vals[i]);
18+         }
19+     }
20+ 
21+     max_val = warp_reduce_max (max_val);
22+ 
23+     float  sum = 0 .f ;
24+ 
25+ #pragma  unroll
26+     for  (int  i = 0 ; i < experts_per_thread; i++) {
27+         const  int   idx    = lane + i * WARP_SIZE;
28+         const  bool  active = !use_limit || (idx < limit);
29+         if  (active) {
30+             const  float  val = expf (vals[i] - max_val);
31+             vals[i]         = val;
32+             sum += val;
33+         } else  {
34+             vals[i] = 0 .f ;
35+         }
36+     }
37+ 
38+     sum = warp_reduce_sum (sum);
39+ 
40+     const  float  inv_sum = 1 .0f  / sum;
41+ 
42+ #pragma  unroll
43+     for  (int  i = 0 ; i < experts_per_thread; i++) {
44+         const  int   idx    = lane + i * WARP_SIZE;
45+         const  bool  active = !use_limit || (idx < limit);
46+         if  (active) {
47+             vals[i] *= inv_sum;
48+         }
49+     }
50+ }
51+ 
752/* 
853    This kernel does the following: 
9-     1. softmax over the logits per token [n_experts, n_tokens] 
54+     1. optionally  softmax over the logits per token [n_experts, n_tokens] 
1055    2. argmax reduce over the top-k (n_experts_used) logits 
1156    3. write weights + ids to global memory 
12-     4. optionally normalize the weights 
57+     4. optionally normalize the weights or apply softmax over the selected logits  
1358
1459    It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models 
1560*/ 
16- template  <int  n_experts, bool  with_norm>
61+ template  <int  n_experts, bool  with_norm,  bool  delayed_softmax =  false >
1762__launch_bounds__ (4  * WARP_SIZE, 1 ) __global__ void topk_moe_cuda(const  float  * logits,
1863                                                                  float  *       weights,
1964                                                                  int32_t  *     ids,
@@ -30,51 +75,31 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
3075
3176    constexpr  int  experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1 ;
3277
33-     float  logits_r [experts_per_thread];
78+     float  wt [experts_per_thread];
3479
3580#pragma  unroll
3681    for  (int  i = 0 ; i < n_experts; i += WARP_SIZE) {
37-         const  int  expert         = i + threadIdx .x ;
38-         logits_r [i / WARP_SIZE] = n_experts % WARP_SIZE == 0  || expert < n_experts ? logits[expert] : -INFINITY;
82+         const  int  expert  = i + threadIdx .x ;
83+         wt [i / WARP_SIZE] = ( n_experts % WARP_SIZE == 0  || expert < n_experts)  ? logits[expert] : -INFINITY;
3984    }
4085
41-     float  max_val = logits_r[0 ];
42- 
43- #pragma  unroll
44-     for  (int  i = 1 ; i < experts_per_thread; i++) {
45-         const  float  val = logits_r[i];
46-         max_val         = max (val, max_val);
86+     if  constexpr  (!delayed_softmax) {
87+         softmax_warp_inplace<experts_per_thread, false >(wt, n_experts, threadIdx .x );
4788    }
4889
49-     max_val = warp_reduce_max (max_val);
50- 
51-     float  wt[experts_per_thread];
52-     float  tmp = 0 .f ;
53- 
54- #pragma  unroll
55-     for  (int  i = 0 ; i < experts_per_thread; i++) {
56-         const  float  val = logits_r[i];
57-         wt[i]           = expf (val - max_val);
58-         tmp += wt[i];
59-     }
90+     // at this point, each thread holds either a portion of the softmax distribution
91+     // or the raw logits. We do the argmax reduce over n_expert_used, each time marking
92+     // the expert weight as -inf to exclude from the next iteration
6093
61-     tmp =  warp_reduce_sum (tmp) ;
94+     float  wt_sum =  0 . f ;
6295
63-     const   float  inv_sum =  1 . 0f  / tmp ;
96+     float  output_weights[experts_per_thread] ;
6497
6598#pragma  unroll
6699    for  (int  i = 0 ; i < experts_per_thread; i++) {
67-         wt [i] = wt[i] * inv_sum ;
100+         output_weights [i] = 0 . f ;
68101    }
69102
70-     // at this point, each thread holds a portion of softmax,
71-     // we do the argmax reduce over n_expert_used, each time marking
72-     // the expert weight as -inf to exclude from the next iteration
73- 
74-     float  wt_sum = 0 .f ;
75- 
76-     float  output_weights[experts_per_thread];
77- 
78103    for  (int  k = 0 ; k < n_expert_used; k++) {
79104        float  max_val    = wt[0 ];
80105        int    max_expert = threadIdx .x ;
@@ -121,6 +146,10 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
121146        }
122147    }
123148
149+     if  constexpr  (delayed_softmax) {
150+         softmax_warp_inplace<experts_per_thread, true >(output_weights, n_expert_used, threadIdx .x );
151+     }
152+ 
124153#pragma  unroll
125154    for  (int  i = 0 ; i < experts_per_thread; i++) {
126155        const  int  idx = i * WARP_SIZE + threadIdx .x ;
@@ -130,58 +159,60 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
130159    }
131160}
132161
133- template  <bool  with_norm>
162+ template  <bool  with_norm,  bool  delayed_softmax =  false >
134163static  void  launch_topk_moe_cuda (ggml_backend_cuda_context & ctx,
135164                                 const  float  *               logits,
136165                                 float  *                     weights,
137166                                 int32_t  *                   ids,
138167                                 const  int                    n_rows,
139168                                 const  int                    n_expert,
140169                                 const  int                    n_expert_used) {
170+     static_assert (!(with_norm && delayed_softmax), " delayed softmax is not supported with weight normalization" 
171+ 
141172    const  int     rows_per_block = 4 ;
142173    dim3          grid_dims ((n_rows + rows_per_block - 1 ) / rows_per_block, 1 , 1 );
143174    dim3          block_dims (WARP_SIZE, rows_per_block, 1 );
144175    cudaStream_t stream = ctx.stream ();
145176
146177    switch  (n_expert) {
147178        case  1 :
148-             topk_moe_cuda<1 , with_norm>
179+             topk_moe_cuda<1 , with_norm, delayed_softmax >
149180                <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
150181            break ;
151182        case  2 :
152-             topk_moe_cuda<2 , with_norm>
183+             topk_moe_cuda<2 , with_norm, delayed_softmax >
153184                <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
154185            break ;
155186        case  4 :
156-             topk_moe_cuda<4 , with_norm>
187+             topk_moe_cuda<4 , with_norm, delayed_softmax >
157188                <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
158189            break ;
159190        case  8 :
160-             topk_moe_cuda<8 , with_norm>
191+             topk_moe_cuda<8 , with_norm, delayed_softmax >
161192                <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
162193            break ;
163194        case  16 :
164-             topk_moe_cuda<16 , with_norm>
195+             topk_moe_cuda<16 , with_norm, delayed_softmax >
165196                <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
166197            break ;
167198        case  32 :
168-             topk_moe_cuda<32 , with_norm>
199+             topk_moe_cuda<32 , with_norm, delayed_softmax >
169200                <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
170201            break ;
171202        case  64 :
172-             topk_moe_cuda<64 , with_norm>
203+             topk_moe_cuda<64 , with_norm, delayed_softmax >
173204                <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
174205            break ;
175206        case  128 :
176-             topk_moe_cuda<128 , with_norm>
207+             topk_moe_cuda<128 , with_norm, delayed_softmax >
177208                <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
178209            break ;
179210        case  256 :
180-             topk_moe_cuda<256 , with_norm>
211+             topk_moe_cuda<256 , with_norm, delayed_softmax >
181212                <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
182213            break ;
183214        case  512 :
184-             topk_moe_cuda<512 , with_norm>
215+             topk_moe_cuda<512 , with_norm, delayed_softmax >
185216                <<<grid_dims, block_dims, 0 , stream>>> (logits, weights, ids, n_rows, n_expert_used);
186217            break ;
187218        default :
@@ -194,15 +225,16 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
194225                           const  ggml_tensor *         logits,
195226                           ggml_tensor *               weights,
196227                           ggml_tensor *               ids,
197-                            const  bool                   with_norm) {
228+                            const  bool                   with_norm,
229+                            const  bool                   delayed_softmax) {
198230    GGML_ASSERT (logits->type  == GGML_TYPE_F32);
199231    GGML_ASSERT (weights->type  == GGML_TYPE_F32);
200232    GGML_ASSERT (ids->type  == GGML_TYPE_I32);
201233
202234    const  int  n_experts = logits->ne [0 ];
203235    const  int  n_rows    = logits->ne [1 ];
204236
205-     const  float  * logits_d  = (const  float  *) logits->src [ 0 ]-> data ;
237+     const  float  * logits_d  = (const  float  *) logits->data ;
206238    float  *       weights_d = (float  *) weights->data ;
207239    int32_t  *     ids_d     = (int32_t  *) ids->data ;
208240
@@ -213,7 +245,11 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
213245    if  (with_norm) {
214246        launch_topk_moe_cuda<true >(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
215247    } else  {
216-         launch_topk_moe_cuda<false >(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
248+         if  (delayed_softmax) {
249+             launch_topk_moe_cuda<false , true >(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
250+         } else  {
251+             launch_topk_moe_cuda<false , false >(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used);
252+         }
217253    }
218254}
219255
@@ -246,16 +282,27 @@ bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tenso
246282    return  true ;
247283}
248284
249- std::initializer_list<enum  ggml_op> ggml_cuda_topk_moe_ops (bool  norm) {
285+ std::initializer_list<enum  ggml_op> ggml_cuda_topk_moe_ops (bool  norm,  bool  delayed_softmax ) {
250286    static  std::initializer_list<enum  ggml_op> norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE,  GGML_OP_ARGSORT,
251287                                                            GGML_OP_VIEW,     GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
252288                                                            GGML_OP_SUM_ROWS, GGML_OP_DIV,      GGML_OP_RESHAPE };
253289
254290    static  std::initializer_list<enum  ggml_op> no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
255291                                                               GGML_OP_VIEW, GGML_OP_GET_ROWS };
256292
293+     static  std::initializer_list<enum  ggml_op> delayed_softmax_ops = { GGML_OP_ARGSORT,  GGML_OP_VIEW,
294+                                                                        GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
295+                                                                        GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
296+ 
297+     GGML_ASSERT (!norm || !delayed_softmax);
298+ 
299+     if  (delayed_softmax) {
300+         return  delayed_softmax_ops;
301+     }
302+ 
257303    if  (norm) {
258304        return  norm_ops;
259305    }
306+ 
260307    return  no_norm_ops;
261308}
0 commit comments