@@ -291,7 +291,7 @@ static __global__ void mul_mat_f_ids(
291291 extern __shared__ char data_mmv[];
292292 char * compute_base = data_mmv;
293293
294- const float2 * y2 = (const float2 *) y;
294+ // const float2 * y2 = (const float2 *) y;
295295
296296 tile_C C[ntA][ntB];
297297
@@ -311,13 +311,12 @@ static __global__ void mul_mat_f_ids(
311311 }
312312 }
313313
314- # pragma unroll
315- for ( int itB = 0 ; itB < ntB; ++itB) {
316- if constexpr (std::is_same_v<T , float > ) {
314+ if constexpr (std::is_same_v<T, float >) {
315+ float vals_buf[ 2 ][tile_B::I];
316+ auto gather_tile = [&]( int tile_idx_local , float *vals ) {
317317#pragma unroll
318318 for (int j0 = 0 ; j0 < tile_B::I; ++j0) {
319- const int j = j0 + itB*tile_B::I;
320-
319+ const int j = j0 + tile_idx_local*tile_B::I;
321320 const int global_j = col_base + j;
322321 float val = 0 .0f ;
323322 if (j < cols_per_block && global_j < ncols_expert) {
@@ -329,13 +328,48 @@ static __global__ void mul_mat_f_ids(
329328 val = y[channel*stride_channel_y + token*stride_col_y + col];
330329 }
331330 }
332- tile_xy [j0*tile_k_padded + threadIdx . x ] = val;
331+ vals [j0] = val;
333332 }
334- } else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
333+ };
334+
335+ if (ntB > 0 ) {
336+ gather_tile (0 , vals_buf[0 ]);
337+ }
338+
339+ int curr_buf = 0 ;
340+ int next_buf = 1 ;
341+ #pragma unroll
342+ for (int itB = 0 ; itB < ntB; ++itB) {
335343#pragma unroll
336344 for (int j0 = 0 ; j0 < tile_B::I; ++j0) {
337- const int j = j0 + itB*tile_B::I;
345+ tile_xy[j0*tile_k_padded + threadIdx .x ] = vals_buf[curr_buf][j0];
346+ }
347+
348+ if (itB + 1 < ntB) {
349+ gather_tile (itB + 1 , vals_buf[next_buf]);
350+ }
338351
352+ #pragma unroll
353+ for (int k0 = 0 ; k0 < warp_size; k0 += tile_B::J) {
354+ tile_B B;
355+ load_ldmatrix (B, tile_xy + k0, tile_k_padded);
356+ #pragma unroll
357+ for (int itA = 0 ; itA < ntA; ++itA) {
358+ mma (C[itA][itB], A[itA][k0/tile_B::J], B);
359+ }
360+ }
361+
362+ if (itB + 1 < ntB) {
363+ curr_buf ^= 1 ;
364+ next_buf ^= 1 ;
365+ }
366+ }
367+ } else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
368+ float2 vals_buf[2 ][tile_B::I];
369+ auto gather_tile = [&](int tile_idx_local, float2 *vals) {
370+ #pragma unroll
371+ for (int j0 = 0 ; j0 < tile_B::I; ++j0) {
372+ const int j = j0 + tile_idx_local*tile_B::I;
339373 const int global_j = col_base + j;
340374 float2 tmp = make_float2 (0 .0f , 0 .0f );
341375 if (j < cols_per_block && global_j < ncols_expert) {
@@ -344,23 +378,48 @@ static __global__ void mul_mat_f_ids(
344378 const int token = (int ) qrm.x ;
345379 const int channel = (int ) qrm.y ;
346380 if (token < ncols_dst_total) {
347- tmp = y2 [channel*stride_channel_y/ 2 + token*stride_col_y + col];
381+ tmp = *( const float2 *) &y [channel*stride_channel_y + 2 *( token*stride_col_y + col) ];
348382 }
349383 }
350- tile_xy [j0*tile_k_padded + threadIdx . x ] = { tmp. x , tmp. y } ;
384+ vals [j0] = tmp;
351385 }
352- } else {
353- static_assert (std::is_same_v<T, void >, " unsupported type" );
386+ };
387+
388+ if (ntB > 0 ) {
389+ gather_tile (0 , vals_buf[0 ]);
354390 }
391+
392+ int curr_buf = 0 ;
393+ int next_buf = 1 ;
355394#pragma unroll
356- for (int k0 = 0 ; k0 < warp_size; k0 += tile_B::J) {
357- tile_B B;
358- load_ldmatrix (B, tile_xy + k0, tile_k_padded);
395+ for (int itB = 0 ; itB < ntB; ++itB) {
359396#pragma unroll
360- for (int itA = 0 ; itA < ntA; ++itA) {
361- mma (C[itA][itB], A[itA][k0/tile_B::J], B);
397+ for (int j0 = 0 ; j0 < tile_B::I; ++j0) {
398+ const float2 tmp = vals_buf[curr_buf][j0];
399+ tile_xy[j0*tile_k_padded + threadIdx .x ] = {tmp.x , tmp.y };
400+ }
401+
402+ if (itB + 1 < ntB) {
403+ gather_tile (itB + 1 , vals_buf[next_buf]);
404+ }
405+
406+ #pragma unroll
407+ for (int k0 = 0 ; k0 < warp_size; k0 += tile_B::J) {
408+ tile_B B;
409+ load_ldmatrix (B, tile_xy + k0, tile_k_padded);
410+ #pragma unroll
411+ for (int itA = 0 ; itA < ntA; ++itA) {
412+ mma (C[itA][itB], A[itA][k0/tile_B::J], B);
413+ }
414+ }
415+
416+ if (itB + 1 < ntB) {
417+ curr_buf ^= 1 ;
418+ next_buf ^= 1 ;
362419 }
363420 }
421+ } else {
422+ static_assert (std::is_same_v<T, void >, " unsupported type" );
364423 }
365424 }
366425
0 commit comments