@@ -3247,12 +3247,14 @@ template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_
3247
3247
template [[host_name(" kernel_flash_attn_ext_f16_h128" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 128 >;
3248
3248
template [[host_name(" kernel_flash_attn_ext_f16_h256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 256 >;
3249
3249
3250
+ #if !defined(GGML_METAL_NO_BFLOAT)
3250
3251
template [[host_name(" kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 64 >;
3251
3252
template [[host_name(" kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 80 >;
3252
3253
template [[host_name(" kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 96 >;
3253
3254
template [[host_name(" kernel_flash_attn_ext_bf16_h112" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 112 >;
3254
3255
template [[host_name(" kernel_flash_attn_ext_bf16_h128" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 128 >;
3255
3256
template [[host_name(" kernel_flash_attn_ext_bf16_h256" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 256 >;
3257
+ #endif
3256
3258
3257
3259
template [[host_name(" kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2 , dequantize_q4_0, block_q4_0, 2 , dequantize_q4_0, 64 >;
3258
3260
template [[host_name(" kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2 , dequantize_q4_0, block_q4_0, 2 , dequantize_q4_0, 80 >;
@@ -3630,15 +3632,19 @@ kernel void kernel_flash_attn_ext_vec(
3630
3632
typedef decltype (kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 64 >) flash_attn_ext_vec_t;
3631
3633
3632
3634
template [[host_name(" kernel_flash_attn_ext_vec_f16_h128" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 128 >;
3635
+ #if !defined(GGML_METAL_NO_BFLOAT)
3633
3636
template [[host_name(" kernel_flash_attn_ext_vec_bf16_h128" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 128 >;
3637
+ #endif
3634
3638
template [[host_name(" kernel_flash_attn_ext_vec_q4_0_h128" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 2 , dequantize_q4_0, block_q4_0, 2 , dequantize_q4_0, 128 >;
3635
3639
template [[host_name(" kernel_flash_attn_ext_vec_q4_1_h128" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 2 , dequantize_q4_1, block_q4_1, 2 , dequantize_q4_1, 128 >;
3636
3640
template [[host_name(" kernel_flash_attn_ext_vec_q5_0_h128" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 2 , dequantize_q5_0, block_q5_0, 2 , dequantize_q5_0, 128 >;
3637
3641
template [[host_name(" kernel_flash_attn_ext_vec_q5_1_h128" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 2 , dequantize_q5_1, block_q5_1, 2 , dequantize_q5_1, 128 >;
3638
3642
template [[host_name(" kernel_flash_attn_ext_vec_q8_0_h128" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 2 , dequantize_q8_0, block_q8_0, 2 , dequantize_q8_0, 128 >;
3639
3643
3640
3644
template [[host_name(" kernel_flash_attn_ext_vec_f16_h256" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1 , dequantize_f16, half4x4, 1 , dequantize_f16, 256 >;
3645
+ #if !defined(GGML_METAL_NO_BFLOAT)
3641
3646
template [[host_name(" kernel_flash_attn_ext_vec_bf16_h256" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4x4, 1 , dequantize_bf16, bfloat4x4, 1 , dequantize_bf16, 256 >;
3647
+ #endif
3642
3648
template [[host_name(" kernel_flash_attn_ext_vec_q4_0_h256" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 2 , dequantize_q4_0, block_q4_0, 2 , dequantize_q4_0, 256 >;
3643
3649
template [[host_name(" kernel_flash_attn_ext_vec_q4_1_h256" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 2 , dequantize_q4_1, block_q4_1, 2 , dequantize_q4_1, 256 >;
3644
3650
template [[host_name(" kernel_flash_attn_ext_vec_q5_0_h256" )]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 2 , dequantize_q5_0, block_q5_0, 2 , dequantize_q5_0, 256 >;
0 commit comments