2
2
3
3
#extension GL_EXT_control_flow_attributes : enable
4
4
#extension GL_EXT_shader_16bit_storage : require
5
+ #extension GL_KHR_cooperative_matrix : require
6
+ #extension GL_KHR_memory_scope_semantics : require
7
+ #extension GL_EXT_shader_explicit_arithmetic_types : require
5
8
6
9
#ifdef FLOAT16
7
10
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
@@ -152,12 +155,10 @@ void main() {
152
155
uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B;
153
156
#endif
154
157
155
- float sums[WMITER * TM * WNITER * TN];
156
- FLOAT_TYPE cache_a[WMITER * TM];
157
- FLOAT_TYPE cache_b[WNITER * TN];
158
+ coopmat<float, gl_ScopeSubgroup, 16, 16, gl_MatrixUseAccumulator> sums[WM * WN / 16 / 16];
158
159
159
- [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN ; i++) {
160
- sums[i] = 0.0f ;
160
+ [[unroll]] for (uint i = 0; i < WM * WN / 16 / 16 ; i++) {
161
+ sums[i] = coopmat<float, gl_ScopeSubgroup, 16, 16, gl_MatrixUseAccumulator>(0.0) ;
161
162
}
162
163
163
164
[[unroll]] for (uint block = start_k; block < end_k; block += BK) {
@@ -446,27 +447,14 @@ void main() {
446
447
pos_a += BK / LOAD_VEC_A;
447
448
pos_b += BK / LOAD_VEC_B;
448
449
449
- for (uint i = 0; i < BK; i++) {
450
- // Load from shared into cache
451
- [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
452
- [[unroll]] for (uint j = 0; j < TM; j++) {
453
- cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i];
454
- }
455
- }
456
- [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
457
- [[unroll]] for (uint j = 0; j < TN; j++) {
458
- cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i];
459
- }
460
- }
461
-
462
- [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
463
- [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
464
- [[unroll]] for (uint cc = 0; cc < TN; cc++) {
465
- [[unroll]] for (uint cr = 0; cr < TM; cr++) {
466
- const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
467
- sums[sums_idx] = fma(float(cache_a[wsir * TM + cr]), float(cache_b[wsic * TN + cc]), sums[sums_idx]);
468
- }
469
- }
450
+ [[unroll]] for (uint i = 0; i < WM; i += 16) {
451
+ [[unroll]] for (uint j = 0; j < WN; j += 16) {
452
+ [[unroll]] for (uint k = 0; k < BK; k += 16) {
453
+ coopmat<float16_t, gl_ScopeSubgroup, 16, 16, gl_MatrixUseA> matA;
454
+ coopmat<float16_t, gl_ScopeSubgroup, 16, 16, gl_MatrixUseB> matB;
455
+ coopMatLoad(matA, buf_a, (warp_r * WM + i) * (BK+1) + k, BK+1, gl_CooperativeMatrixLayoutRowMajor);
456
+ coopMatLoad(matB, buf_b, (warp_c * WN + j) * (BK+1) + k, BK+1, gl_CooperativeMatrixLayoutColumnMajor);
457
+ sums[(i / 16) * (WN / 16) + (j / 16)] = coopMatMulAdd(matA, matB, sums[(i / 16) * (WN / 16) + (j / 16)]);
470
458
}
471
459
}
472
460
}
@@ -481,6 +469,19 @@ void main() {
481
469
const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
482
470
#endif
483
471
472
+ #if 1
473
+ #ifndef MUL_MAT_ID
474
+ // XXX TODO this is missing bounds checking against p.M and p.N,
475
+ // which probably requires spilling to shared memory and doing scalar stores.
476
+ // But sums[] may not all fit in shared memory...
477
+ [[unroll]] for (uint i = 0; i < WM; i += 16) {
478
+ [[unroll]] for (uint j = 0; j < WN; j += 16) {
479
+ coopMatStore(sums[(i / 16) * (WN / 16) + (j / 16)], data_d, offsets + (dc + j) * p.stride_d + dr + i, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor);
480
+ }
481
+ }
482
+ #endif
483
+ #else
484
+
484
485
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
485
486
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
486
487
@@ -505,4 +506,5 @@ void main() {
505
506
}
506
507
}
507
508
}
509
+ #endif
508
510
}
0 commit comments