Skip to content

Commit 22a9311

Browse files
committed
ggml : add ggml_flash_attn_ext_get_prec
1 parent 5c333e0 commit 22a9311

File tree

3 files changed

+17
-5
lines changed

3 files changed

+17
-5
lines changed

ggml/include/ggml.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1746,6 +1746,9 @@ extern "C" {
17461746
struct ggml_tensor * a,
17471747
enum ggml_prec prec);
17481748

1749+
GGML_API enum ggml_prec ggml_flash_attn_ext_get_prec(
1750+
const struct ggml_tensor * a);
1751+
17491752
// TODO: needs to be adapted to ggml_flash_attn_ext
17501753
GGML_API struct ggml_tensor * ggml_flash_attn_back(
17511754
struct ggml_context * ctx,

ggml/src/ggml-cuda/fattn.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
1313
const ggml_tensor * KQV = dst;
1414
const ggml_tensor * Q = dst->src[0];
1515

16-
const int32_t precision = KQV->op_params[3];
16+
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
1717

18-
if (precision != GGML_PREC_DEFAULT) {
18+
if (prec != GGML_PREC_DEFAULT) {
1919
if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
2020
constexpr int cols_per_block = 16;
2121
switch (Q->ne[0]) {
@@ -301,11 +301,11 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
301301

302302
ggml_cuda_set_device(ctx.device);
303303
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
304-
const int32_t precision = KQV->op_params[3];
304+
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
305305

306306
// On AMD the tile kernels perform poorly, use the vec kernel instead:
307307
if (cc >= CC_OFFSET_AMD) {
308-
if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
308+
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
309309
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
310310
} else {
311311
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
@@ -332,7 +332,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
332332
}
333333

334334
if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
335-
if (precision == GGML_PREC_DEFAULT) {
335+
if (prec == GGML_PREC_DEFAULT) {
336336
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
337337
return;
338338
} else if(Q->ne[0] <= 128) {

ggml/src/ggml.c

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4228,6 +4228,15 @@ void ggml_flash_attn_ext_set_prec(
42284228
ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second
42294229
}
42304230

4231+
enum ggml_prec ggml_flash_attn_ext_get_prec(
4232+
const struct ggml_tensor * a) {
4233+
GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
4234+
4235+
const int32_t prec_i32 = ggml_get_op_params_i32(a, 3);
4236+
4237+
return (enum ggml_prec) prec_i32;
4238+
}
4239+
42314240
// ggml_flash_attn_back
42324241

42334242
struct ggml_tensor * ggml_flash_attn_back(

0 commit comments

Comments
 (0)