@@ -13,9 +13,9 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
13
13
const ggml_tensor * KQV = dst;
14
14
const ggml_tensor * Q = dst->src [0 ];
15
15
16
- const int32_t precision = KQV-> op_params [ 3 ] ;
16
+ const enum ggml_prec prec = ggml_flash_attn_ext_get_prec ( KQV) ;
17
17
18
- if (precision != GGML_PREC_DEFAULT) {
18
+ if (prec != GGML_PREC_DEFAULT) {
19
19
if (Q->ne [1 ] <= 32 || Q->ne [0 ] > 128 ) {
20
20
constexpr int cols_per_block = 16 ;
21
21
switch (Q->ne [0 ]) {
@@ -301,11 +301,11 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
301
301
302
302
ggml_cuda_set_device (ctx.device );
303
303
const int cc = ggml_cuda_info ().devices [ggml_cuda_get_device ()].cc ;
304
- const int32_t precision = KQV-> op_params [ 3 ] ;
304
+ const enum ggml_prec prec = ggml_flash_attn_ext_get_prec ( KQV) ;
305
305
306
306
// On AMD the tile kernels perform poorly, use the vec kernel instead:
307
307
if (cc >= CC_OFFSET_AMD) {
308
- if (precision == GGML_PREC_DEFAULT && fast_fp16_available (cc)) {
308
+ if (prec == GGML_PREC_DEFAULT && fast_fp16_available (cc)) {
309
309
ggml_cuda_flash_attn_ext_vec_f16 (ctx, dst);
310
310
} else {
311
311
ggml_cuda_flash_attn_ext_vec_f32 (ctx, dst);
@@ -332,7 +332,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
332
332
}
333
333
334
334
if (Q->ne [1 ] == 1 && Q->ne [0 ] % (2 *WARP_SIZE) == 0 ) {
335
- if (precision == GGML_PREC_DEFAULT) {
335
+ if (prec == GGML_PREC_DEFAULT) {
336
336
ggml_cuda_flash_attn_ext_vec_f16 (ctx, dst);
337
337
return ;
338
338
} else if (Q->ne [0 ] <= 128 ) {
0 commit comments