Skip to content

Commit d0cff71

Browse files
committed
metal : use F16 precision in FA kernel
1 parent ed8de1e commit d0cff71

File tree

2 files changed

+30
-13
lines changed

2 files changed

+30
-13
lines changed

ggml/src/ggml-metal.m

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
#define MIN(a, b) ((a) < (b) ? (a) : (b))
1313
#define MAX(a, b) ((a) > (b) ? (a) : (b))
1414

15+
// TODO: for now, always use F32 for flash attention to avoid compiling 2 sets of kernels
16+
#define GGML_METAL_FORCE_FATTN_PREC_F32
17+
1518
// max memory buffers that can be mapped to the device
1619
#define GGML_METAL_MAX_BUFFERS 64
1720

@@ -480,6 +483,11 @@ @implementation GGMLMetalClass
480483
// dictionary of preprocessor macros
481484
NSMutableDictionary * prep = [NSMutableDictionary dictionary];
482485

486+
// add GGML_METAL_FORCE_FATTN_PREC_F32
487+
#if defined(GGML_METAL_FORCE_FATTN_PREC_F32)
488+
[prep setObject:@"1" forKey:@"GGML_METAL_FORCE_FATTN_PREC_F32"];
489+
#endif
490+
483491
MTLCompileOptions* options = [MTLCompileOptions new];
484492
options.preprocessorMacros = prep;
485493

@@ -3145,11 +3153,19 @@ static void ggml_metal_encode_node(
31453153
GGML_ASSERT(nqptg % 8 == 0);
31463154
GGML_ASSERT(ncpsg % 32 == 0);
31473155

3156+
#ifdef GGML_METAL_FORCE_FATTN_PREC_F32
3157+
const enum ggml_prec prec = GGML_PREC_DEFAULT;
3158+
#else
3159+
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(dst);
3160+
#endif
3161+
3162+
const int64_t nhalfs = prec == GGML_PREC_DEFAULT ? 1 : 2;
3163+
31483164
// 16*32*(nsg)
31493165
// the shared memory needed for the simdgroups to load the KV cache
31503166
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
31513167
//
3152-
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
3168+
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + nhalfs*(ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
31533169

31543170
int64_t nsgmax = 2;
31553171

ggml/src/ggml-metal.metal

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2774,13 +2774,13 @@ kernel void kernel_flash_attn_ext(
27742774
const short NW = N_SIMDWIDTH;
27752775
const short SH = (C + Q); // shared memory per simdgroup in (half)
27762776

2777-
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
2778-
const short TF = T/2; // shared memory size per query in (float)
2777+
const short T = D + nsg*SH; // shared memory size per query in (half)
2778+
const short TF = T; // shared memory size per query in (float)
27792779
const short T4 = T/4; // shared memory size per query in (half4)
27802780

2781-
threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
2782-
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
2783-
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
2781+
threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
2782+
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
2783+
threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
27842784

27852785
threadgroup half * skv = (threadgroup half *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K and V in shared memory
27862786
threadgroup half4x4 * skv4 = (threadgroup half4x4 *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in half4x4
@@ -2809,7 +2809,7 @@ kernel void kernel_flash_attn_ext(
28092809
// zero out shared memory SH
28102810
for (short j = 0; j < Q; ++j) {
28112811
for (short i = tiisg; i < SH; i += NW) {
2812-
ss[j*TF + i] = 0.0f;
2812+
ss[j*TF + i] = 0.0h;
28132813
}
28142814
}
28152815

@@ -2874,7 +2874,7 @@ kernel void kernel_flash_attn_ext(
28742874
// Q*K^T
28752875
{
28762876
for (short cc = 0; cc < C/8; ++cc) {
2877-
simdgroup_float8x8 mqk = make_filled_simdgroup_matrix<float, 8>(0.h);
2877+
simdgroup_half8x8 mqk = make_filled_simdgroup_matrix<half, 8>(0.h);
28782878

28792879
// this is compile-time check, so it does not have runtime overhead
28802880
if (is_same<block_q, half4x4>::value) {
@@ -2946,7 +2946,7 @@ kernel void kernel_flash_attn_ext(
29462946
const float m = M[j];
29472947

29482948
// scale and apply the logitcap / mask
2949-
float s = ss[j*TF + tiisg]*scale;
2949+
float s = ((float)(ss[j*TF + tiisg]))*scale;
29502950

29512951
if (logit_softcap != 0.0f) {
29522952
s = logit_softcap*precise::tanh(s);
@@ -2982,7 +2982,7 @@ kernel void kernel_flash_attn_ext(
29822982

29832983
// O = diag(ms)*O
29842984
{
2985-
simdgroup_float8x8 mm;
2985+
simdgroup_half8x8 mm;
29862986
simdgroup_load(mm, ss + C, TF, 0, false);
29872987

29882988
for (short i = 0; i < D8; ++i) {
@@ -2993,7 +2993,7 @@ kernel void kernel_flash_attn_ext(
29932993
// O = O + (Q*K^T)*V
29942994
{
29952995
for (short cc = 0; cc < C/8; ++cc) {
2996-
simdgroup_float8x8 ms;
2996+
simdgroup_half8x8 ms;
29972997
simdgroup_load(ms, ss + 8*cc, TF, 0, false);
29982998

29992999
if (is_same<block_q, half4x4>::value) {
@@ -3106,8 +3106,8 @@ kernel void kernel_flash_attn_ext(
31063106
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
31073107
{
31083108
simdgroup_half8x8 t;
3109-
simdgroup_float8x8 ms0;
3110-
simdgroup_float8x8 ms1;
3109+
simdgroup_half8x8 ms0;
3110+
simdgroup_half8x8 ms1;
31113111

31123112
simdgroup_load(ms0, ss + C, TF, 0, false);
31133113
simdgroup_load(ms1, ss + C + sg*SH, TF, 0, false);
@@ -3188,6 +3188,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_
31883188
template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 256>;
31893189

31903190
// NOTE: can use half instead of float precision for some extra perf
3191+
// however, by default use F32 since the op should be mostly memory bandwidth bound
31913192
// D - head size, Q - queries per threadgroup, C - cache items per threadgroup
31923193
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &), short D, short Q = 1, short C = 32>
31933194
kernel void kernel_flash_attn_ext_vec(

0 commit comments

Comments
 (0)