Skip to content

Commit 3416010

Browse files
committed
hacky KHR cooperative matrix prototype
1 parent 60e17ce commit 3416010

File tree

1 file changed

+28
-26
lines changed

1 file changed

+28
-26
lines changed

ggml/src/vulkan-shaders/mul_mm.comp

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
#extension GL_EXT_control_flow_attributes : enable
44
#extension GL_EXT_shader_16bit_storage : require
5+
#extension GL_KHR_cooperative_matrix : require
6+
#extension GL_KHR_memory_scope_semantics : require
7+
#extension GL_EXT_shader_explicit_arithmetic_types : require
58

69
#ifdef FLOAT16
710
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
@@ -152,12 +155,10 @@ void main() {
152155
uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B;
153156
#endif
154157

155-
float sums[WMITER * TM * WNITER * TN];
156-
FLOAT_TYPE cache_a[WMITER * TM];
157-
FLOAT_TYPE cache_b[WNITER * TN];
158+
coopmat<float, gl_ScopeSubgroup, 16, 16, gl_MatrixUseAccumulator> sums[WM * WN / 16 / 16];
158159

159-
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
160-
sums[i] = 0.0f;
160+
[[unroll]] for (uint i = 0; i < WM * WN / 16 / 16; i++) {
161+
sums[i] = coopmat<float, gl_ScopeSubgroup, 16, 16, gl_MatrixUseAccumulator>(0.0);
161162
}
162163

163164
[[unroll]] for (uint block = start_k; block < end_k; block += BK) {
@@ -446,27 +447,14 @@ void main() {
446447
pos_a += BK / LOAD_VEC_A;
447448
pos_b += BK / LOAD_VEC_B;
448449

449-
for (uint i = 0; i < BK; i++) {
450-
// Load from shared into cache
451-
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
452-
[[unroll]] for (uint j = 0; j < TM; j++) {
453-
cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i];
454-
}
455-
}
456-
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
457-
[[unroll]] for (uint j = 0; j < TN; j++) {
458-
cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i];
459-
}
460-
}
461-
462-
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
463-
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
464-
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
465-
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
466-
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
467-
sums[sums_idx] = fma(float(cache_a[wsir * TM + cr]), float(cache_b[wsic * TN + cc]), sums[sums_idx]);
468-
}
469-
}
450+
[[unroll]] for (uint i = 0; i < WM; i += 16) {
451+
[[unroll]] for (uint j = 0; j < WN; j += 16) {
452+
[[unroll]] for (uint k = 0; k < BK; k += 16) {
453+
coopmat<float16_t, gl_ScopeSubgroup, 16, 16, gl_MatrixUseA> matA;
454+
coopmat<float16_t, gl_ScopeSubgroup, 16, 16, gl_MatrixUseB> matB;
455+
coopMatLoad(matA, buf_a, (warp_r * WM + i) * (BK+1) + k, BK+1, gl_CooperativeMatrixLayoutRowMajor);
456+
coopMatLoad(matB, buf_b, (warp_c * WN + j) * (BK+1) + k, BK+1, gl_CooperativeMatrixLayoutColumnMajor);
457+
sums[(i / 16) * (WN / 16) + (j / 16)] = coopMatMulAdd(matA, matB, sums[(i / 16) * (WN / 16) + (j / 16)]);
470458
}
471459
}
472460
}
@@ -481,6 +469,19 @@ void main() {
481469
const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
482470
#endif
483471

472+
#if 1
473+
#ifndef MUL_MAT_ID
474+
// XXX TODO this is missing bounds checking against p.M and p.N,
475+
// which probably requires spilling to shared memory and doing scalar stores.
476+
// But sums[] may not all fit in shared memory...
477+
[[unroll]] for (uint i = 0; i < WM; i += 16) {
478+
[[unroll]] for (uint j = 0; j < WN; j += 16) {
479+
coopMatStore(sums[(i / 16) * (WN / 16) + (j / 16)], data_d, offsets + (dc + j) * p.stride_d + dr + i, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor);
480+
}
481+
}
482+
#endif
483+
#else
484+
484485
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
485486
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
486487

@@ -505,4 +506,5 @@ void main() {
505506
}
506507
}
507508
}
509+
#endif
508510
}

0 commit comments

Comments
 (0)