Skip to content

Commit ba4d12a

Browse files
committed
ggml : ggml_flash_attn_ext() support ALiBi (Metal)
1 parent 166e60b commit ba4d12a

File tree

4 files changed

+105
-47
lines changed

4 files changed

+105
-47
lines changed

ggml-metal.m

Lines changed: 44 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1390,8 +1390,8 @@ static enum ggml_status ggml_metal_graph_compute(
13901390
const int64_t nrows_x = ggml_nrows(src0);
13911391
const int64_t nrows_y = src0->ne[1];
13921392

1393-
const uint32_t n_head_kv = nrows_x/nrows_y;
1394-
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
1393+
const uint32_t n_head = nrows_x/nrows_y;
1394+
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
13951395

13961396
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
13971397
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
@@ -2513,7 +2513,7 @@ static enum ggml_status ggml_metal_graph_compute(
25132513
"the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
25142514

25152515
const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
2516-
const int64_t ne31 = src3 ? src3->ne[1] : 0;
2516+
//const int64_t ne31 = src3 ? src3->ne[1] : 0;
25172517
const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
25182518
const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33);
25192519

@@ -2525,7 +2525,16 @@ static enum ggml_status ggml_metal_graph_compute(
25252525
const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
25262526

25272527
float scale;
2528-
memcpy(&scale, dst->op_params, sizeof(float));
2528+
float max_bias;
2529+
2530+
memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
2531+
memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
2532+
2533+
const uint32_t n_head = src0->ne[2];
2534+
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
2535+
2536+
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
2537+
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
25292538

25302539
id<MTLComputePipelineState> pipeline = nil;
25312540

@@ -2562,34 +2571,37 @@ static enum ggml_status ggml_metal_graph_compute(
25622571
}
25632572

25642573
[encoder setComputePipelineState:pipeline];
2565-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2566-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2567-
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
2568-
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
2569-
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
2570-
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:5];
2571-
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:6];
2572-
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:7];
2573-
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:8];
2574-
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:9];
2575-
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10];
2576-
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11];
2577-
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12];
2578-
[encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:13];
2579-
[encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:14];
2580-
[encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:15];
2581-
[encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:16];
2582-
[encoder setBytes:&nb10 length:sizeof(uint64_t) atIndex:17];
2583-
[encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:18];
2584-
[encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:19];
2585-
[encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:20];
2586-
[encoder setBytes:&ne31 length:sizeof( int64_t) atIndex:21];
2587-
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:22];
2588-
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:23];
2589-
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:24];
2590-
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:25];
2591-
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
2592-
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
2574+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2575+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2576+
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
2577+
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
2578+
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
2579+
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:5];
2580+
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:6];
2581+
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:7];
2582+
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:8];
2583+
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:9];
2584+
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10];
2585+
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11];
2586+
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12];
2587+
[encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:13];
2588+
[encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:14];
2589+
[encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:15];
2590+
[encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:16];
2591+
[encoder setBytes:&nb10 length:sizeof(uint64_t) atIndex:17];
2592+
[encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:18];
2593+
[encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:19];
2594+
[encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:20];
2595+
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:21];
2596+
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:22];
2597+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:23];
2598+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:24];
2599+
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:25];
2600+
[encoder setBytes:&scale length:sizeof( float) atIndex:26];
2601+
[encoder setBytes:&max_bias length:sizeof( float) atIndex:27];
2602+
[encoder setBytes:&m0 length:sizeof(m0) atIndex:28];
2603+
[encoder setBytes:&m1 length:sizeof(m1) atIndex:29];
2604+
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:30];
25932605

25942606
if (!use_vec_kernel) {
25952607
// half8x8 kernel

ggml-metal.metal

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2058,13 +2058,16 @@ typedef void (flash_attn_ext_f16_t)(
20582058
constant uint64_t & nb11,
20592059
constant uint64_t & nb12,
20602060
constant uint64_t & nb13,
2061-
constant int64_t & ne31,
20622061
constant uint64_t & nb31,
20632062
constant int64_t & ne0,
20642063
constant int64_t & ne1,
20652064
constant int64_t & ne2,
20662065
constant int64_t & ne3,
20672066
constant float & scale,
2067+
constant float & max_bias,
2068+
constant float & m0,
2069+
constant float & m1,
2070+
constant uint32_t & n_head_log2,
20682071
threadgroup half * shared,
20692072
uint3 tgpig[[threadgroup_position_in_grid]],
20702073
uint3 tpitg[[thread_position_in_threadgroup]],
@@ -2096,13 +2099,16 @@ kernel void kernel_flash_attn_ext_f16(
20962099
constant uint64_t & nb11,
20972100
constant uint64_t & nb12,
20982101
constant uint64_t & nb13,
2099-
constant int64_t & ne31,
21002102
constant uint64_t & nb31,
21012103
constant int64_t & ne0,
21022104
constant int64_t & ne1,
21032105
constant int64_t & ne2,
21042106
constant int64_t & ne3,
21052107
constant float & scale,
2108+
constant float & max_bias,
2109+
constant float & m0,
2110+
constant float & m1,
2111+
constant uint32_t & n_head_log2,
21062112
threadgroup half * shared [[threadgroup(0)]],
21072113
uint3 tgpig[[threadgroup_position_in_grid]],
21082114
uint3 tpitg[[thread_position_in_threadgroup]],
@@ -2199,6 +2205,19 @@ kernel void kernel_flash_attn_ext_f16(
21992205
// prepare diagonal scale matrix
22002206
simdgroup_float8x8 mscale(scale);
22012207

2208+
// prepare diagonal slope matrix
2209+
simdgroup_float8x8 mslope(1.0f);
2210+
2211+
// ALiBi
2212+
if (max_bias > 0.0f) {
2213+
const int64_t h = iq2;
2214+
2215+
const float base = h < n_head_log2 ? m0 : m1;
2216+
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
2217+
2218+
mslope = simdgroup_float8x8(pow(base, exph));
2219+
}
2220+
22022221
// loop over the KV cache
22032222
// each simdgroup handles blocks of Q rows and C columns
22042223
for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
@@ -2221,9 +2240,10 @@ kernel void kernel_flash_attn_ext_f16(
22212240
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
22222241
}
22232242

2224-
// mqk = mqk*scale + mask
2243+
// mqk = mqk*scale + mask*slope
22252244
simdgroup_half8x8 mm;
22262245
simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
2246+
simdgroup_multiply(mm, mslope, mm);
22272247
simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
22282248

22292249
simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
@@ -2414,13 +2434,16 @@ kernel void kernel_flash_attn_ext_vec_f16(
24142434
constant uint64_t & nb11,
24152435
constant uint64_t & nb12,
24162436
constant uint64_t & nb13,
2417-
constant int64_t & ne31,
24182437
constant uint64_t & nb31,
24192438
constant int64_t & ne0,
24202439
constant int64_t & ne1,
24212440
constant int64_t & ne2,
24222441
constant int64_t & ne3,
24232442
constant float & scale,
2443+
constant float & max_bias,
2444+
constant float & m0,
2445+
constant float & m1,
2446+
constant uint32_t & n_head_log2,
24242447
threadgroup half * shared [[threadgroup(0)]],
24252448
uint3 tgpig[[threadgroup_position_in_grid]],
24262449
uint3 tpitg[[thread_position_in_threadgroup]],
@@ -2439,6 +2462,18 @@ kernel void kernel_flash_attn_ext_vec_f16(
24392462

24402463
const short T = D + 2*nsg*SH; // shared memory size per query in (half)
24412464

2465+
float slope = 1.0f;
2466+
2467+
// ALiBi
2468+
if (max_bias > 0.0f) {
2469+
const int64_t h = iq2;
2470+
2471+
const float base = h < n_head_log2 ? m0 : m1;
2472+
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
2473+
2474+
slope = pow(base, exp);
2475+
}
2476+
24422477
//threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
24432478
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
24442479
threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
@@ -2545,10 +2580,10 @@ kernel void kernel_flash_attn_ext_vec_f16(
25452580
mqk += simd_shuffle_down(mqk, 2);
25462581
mqk += simd_shuffle_down(mqk, 1);
25472582

2548-
// mqk = mqk*scale + mask
2583+
// mqk = mqk*scale + mask*slope
25492584
if (tiisg == 0) {
25502585
float4 mm = (float4) mp4[ic/4 + cc];
2551-
mqk = mqk*scale + mm;
2586+
mqk = mqk*scale + mm*slope;
25522587

25532588
ss4[cc] = mqk;
25542589
}
@@ -2782,7 +2817,8 @@ kernel void kernel_cpy_f32_f16(
27822817
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
27832818
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
27842819

2785-
dst_data[i00] = src[0];
2820+
// TODO: is there a better way to handle -INFINITY?
2821+
dst_data[i00] = src[0] == -INFINITY ? -MAXHALF : src[0];
27862822
}
27872823
}
27882824

llama.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10985,6 +10985,12 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
1098510985
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
1098610986
}
1098710987
}
10988+
10989+
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
10990+
for (int j = 0; j < n_kv; ++j) {
10991+
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
10992+
}
10993+
}
1098810994
}
1098910995
} else {
1099010996
// when using kv cache, the mask needs to match the kv cache size

tests/test-backend-ops.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1486,23 +1486,25 @@ struct test_flash_attn_ext : public test_case {
14861486
const int64_t kv; // kv size
14871487
const int64_t nb; // batch size
14881488

1489+
const float max_bias; // ALiBi
1490+
14891491
std::string vars() override {
1490-
return VARS_TO_STR4(hs, nh, kv, nb);
1492+
return VARS_TO_STR5(hs, nh, kv, nb, max_bias);
14911493
}
14921494

14931495
double max_nmse_err() override {
14941496
return 5e-4;
14951497
}
14961498

1497-
test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8)
1498-
: hs(hs), nh(nh), kv(kv), nb(nb) {}
1499+
test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8, float max_bias = 0.0f)
1500+
: hs(hs), nh(nh), kv(kv), nb(nb), max_bias(max_bias) {}
14991501

15001502
ggml_tensor * build_graph(ggml_context * ctx) override {
15011503
ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1);
15021504
ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
15031505
ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
15041506
ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1);
1505-
ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs));
1507+
ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs), max_bias);
15061508
return out;
15071509
}
15081510
};
@@ -2176,10 +2178,12 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
21762178
#else
21772179
for (int hs : { 64, 80, 128, 256, }) {
21782180
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
2179-
for (int nh : { 32, }) {
2180-
for (int kv : { 512, 1024, }) {
2181-
for (int nb : { 1, 2, 4, 8, }) {
2182-
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb));
2181+
for (float max_bias : {0.0f, 8.0f}) {
2182+
for (int nh : { 32, }) {
2183+
for (int kv : { 512, 1024, }) {
2184+
for (int nb : { 1, 2, 4, 8, }) {
2185+
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, max_bias));
2186+
}
21832187
}
21842188
}
21852189
}

0 commit comments

Comments
 (0)