@@ -57,31 +57,37 @@ static __global__ void mul_mat_f(
5757 T * tile_xy = (T *) compute_base + threadIdx .y *(tile_A::I * tile_k_padded);
5858
5959 if constexpr (has_ids) {
60- __shared__ int has_any;
61- if (threadIdx .y == 0 ) {
62- int local_has_any = 0 ;
63- for (int j = threadIdx .x ; j < cols_per_block; j += warp_size) {
64- int slot = -1 ;
65- for (int k = 0 ; k < nchannels_dst; ++k) {
66- const int idv = ids[j*stride_row_id + k*stride_col_id];
67- if (idv == expert_idx) {
68- slot = k;
69- break ;
60+ int found = 0 ;
61+
62+ for (int j = threadIdx .y ; j < cols_per_block; j += nwarps) {
63+ const int32_t * __restrict__ id_row = ids + j*stride_row_id;
64+
65+ if (threadIdx .x == 0 ) {
66+ slot_map[j] = -1 ;
67+ }
68+
69+ for (int k_base = 0 ; k_base < nchannels_dst; k_base += warp_size) {
70+ int k = k_base + threadIdx .x ;
71+ int match = (k < nchannels_dst) && (id_row[k*stride_col_id] == expert_idx);
72+
73+ unsigned mask = __ballot_sync (0xffffffff , match);
74+ if (mask) {
75+ int leader = __ffs (mask) - 1 ;
76+ if (threadIdx .x == leader) {
77+ slot_map[j] = k_base + leader;
7078 }
71- }
72- if (j < cols_per_block) {
73- local_has_any |= (slot >= 0 );
74- slot_map[j] = slot;
79+ found = 1 ;
80+ break ;
7581 }
7682 }
77- has_any = warp_reduce_any (local_has_any);
7883 }
79- __syncthreads ();
80- if (has_any == 0 ) {
84+
85+ if (! __syncthreads_or (found) ) {
8186 return ;
8287 }
8388 }
8489
90+
8591 for (int col = threadIdx .y *warp_size + threadIdx .x ; col < ncols; col += nwarps*warp_size) {
8692 tile_A A[ntA][warp_size / tile_A::J];
8793#pragma unroll
@@ -106,14 +112,7 @@ static __global__ void mul_mat_f(
106112 if constexpr (!has_ids) {
107113 tile_xy[j0*tile_k_padded + threadIdx .x ] = j < cols_per_block ? y[j*stride_col_y + col] : 0 .0f ;
108114 } else {
109- float val = 0 .0f ;
110- if (j < cols_per_block) {
111- const int slot = slot_map[j];
112- if (slot >= 0 ) {
113- val = y[slot*stride_channel_y + j*stride_col_y + col];
114- }
115- }
116- tile_xy[j0*tile_k_padded + threadIdx .x ] = val;
115+ tile_xy[j0*tile_k_padded + threadIdx .x ] = j < cols_per_block ? y[slot_map[j]*stride_channel_y + j*stride_col_y + col] : 0 .0f ;
117116 }
118117 }
119118 } else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
@@ -125,14 +124,7 @@ static __global__ void mul_mat_f(
125124 const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2 (0 .0f , 0 .0f );
126125 tile_xy[j0*tile_k_padded + threadIdx .x ] = {tmp.x , tmp.y };
127126 } else {
128- float2 tmp = make_float2 (0 .0f , 0 .0f );
129- if (j < cols_per_block) {
130- const int slot = slot_map[j];
131- if (slot >= 0 ) {
132- const float2 * y2_slot = (const float2 *)(y + slot*stride_channel_y);
133- tmp = y2_slot[j*stride_col_y + col];
134- }
135- }
127+ float2 tmp = j < cols_per_block && slot_map[j] >= 0 ? *(const float2 *) &y[slot_map[j]*stride_channel_y + 2 *(j*stride_col_y + col)] : make_float2 (0 .0f , 0 .0f );
136128 tile_xy[j0*tile_k_padded + threadIdx .x ] = {tmp.x , tmp.y };
137129 }
138130 }
0 commit comments