@@ -315,7 +315,7 @@ void main() {
315315#else
316316 ACC_TYPE_VEC2 sums[WMITER * TM * WNITER * TN/2];
317317 FLOAT_TYPE_VEC2 cache_a[WMITER * TM];
318- FLOAT_TYPE_VEC2 cache_b[TN] ;
318+ FLOAT_TYPE_VEC2 cache_b;
319319
320320 [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) {
321321 sums[i] = ACC_TYPE_VEC2(0.0f, 0.0f);
@@ -360,21 +360,40 @@ void main() {
360360 cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i];
361361 }
362362 }
363- [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
364- [[unroll]] for (uint j = 0; j < TN; j++) {
365- cache_b[j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i];
366- }
363+ // [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
364+ // [[unroll]] for (uint j = 0; j < TN; j++) {
365+ // cache_b[j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i];
366+ // }
367+
368+ // [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
369+ // [[unroll]] for (uint cc = 0; cc < TN; cc++) {
370+ // [[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
371+ // const uint sums_idx = (cr * WNITER + wsic) * (WMITER * TN) + cc * WMITER + wsir;
372+ // sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].x), ACC_TYPE(cache_b[cc].x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].y), ACC_TYPE(cache_b[cc].y), sums[sums_idx].x));
373+ // sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b[cc].x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b[cc].y), sums[sums_idx].y));
374+ // }
375+ // }
376+ // }
377+ // }
367378
368- [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
369- [[unroll]] for (uint cc = 0; cc < TN; cc++) {
379+ [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
380+ [[unroll]] for (uint cc = 0; cc < TN; cc++) {
381+ cache_b = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + cc) * SHMEM_STRIDE + i];
382+
383+ [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
370384 [[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
371- const uint sums_idx = (cr * WNITER + wsic) * (WMITER * TN) + cc * WMITER + wsir;
372- sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].x), ACC_TYPE(cache_b[cc].x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].y), ACC_TYPE(cache_b[cc].y), sums[sums_idx].x));
373- sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b[cc].x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b[cc].y), sums[sums_idx].y));
385+ // [TM / 2][WNITER][TN][WMITER]
386+ // const uint sums_idx = (cr * WNITER + wsic) * (WMITER * TN) + cc * WMITER + wsir;
387+
388+ // [WNITER][TN][WMITER][TM / 2] -> [wsic][]
389+ const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr;
390+ sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].y), ACC_TYPE(cache_b.y), sums[sums_idx].x));
391+ sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y), sums[sums_idx].y));
374392 }
375393 }
376394 }
377395 }
396+
378397 }
379398#endif
380399
@@ -466,7 +485,7 @@ void main() {
466485 const u16vec2 row_idx = row_ids[row_i - ic * BN];
467486#endif // MUL_MAT_ID
468487 [[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
469- const uint sums_idx = (cr * WNITER + wsic ) * ( WMITER * TN ) + cc * WMITER + wsir ;
488+ const uint sums_idx = (wsic * TN + cc ) * WMITER * (TM / 2 ) + wsir * (TM / 2) + cr ;
470489#ifdef MUL_MAT_ID
471490 if (dr_warp + 2 * cr < p.M) {
472491 data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x);
0 commit comments