@@ -6529,7 +6529,7 @@ static __global__ void flash_attn_ext_f16(
6529
6529
const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1;
6530
6530
6531
6531
// pointer to the mask
6532
- const float * mp = mask ? (const float *) (mask + (ir%ne31) *nb31) : nullptr ;
6532
+ const half * mp = mask ? (const half *) (mask + iq1 *nb31) : nullptr ;
6533
6533
6534
6534
// loop over the KV cache
6535
6535
// each simdgroup handles blocks of Q rows and C columns
@@ -6555,12 +6555,9 @@ static __global__ void flash_attn_ext_f16(
6555
6555
6556
6556
// mqk = mqk*scale + mask
6557
6557
for (int64_t j = 0 ; j < Q16; ++j) {
6558
- // const float* msk_p = mp + 16*j*(nb31/sizeof(float)) + ic + 16*cc;
6559
- // int64_t msk_ne_row = nb31/sizeof(float);
6560
6558
for (uint32_t i = 0 ; i < mqk[j].num_elements ; i++) {
6561
- // int msk_col = i % 16;
6562
- // int msk_row = i / 16;
6563
- mqk[j].x [i] = __float2half (scale) * mqk[j].x [i]; // __half2float() + msk_p[msk_col + msk_row*msk_ne_row]);
6559
+ // TODO: process mask
6560
+ mqk[j].x [i] = __float2half (scale) * mqk[j].x [i];
6564
6561
}
6565
6562
nvcuda::wmma::store_matrix_sync (ss + 16 *j*T + 16 *cc, mqk[j], T, nvcuda::wmma::mem_row_major);
6566
6563
}
@@ -9216,7 +9213,7 @@ static void ggml_cuda_op_dequantize_mul_mat_vec(
9216
9213
src1_dfloat = src1_dfloat_a.alloc (ne00);
9217
9214
ggml_cpy_f32_f16_cuda ((const char *) src1_ddf_i, (char *) src1_dfloat, ne00,
9218
9215
ne00, 1 , sizeof (float ), 0 , 0 ,
9219
- ne00, 1 , sizeof (half), 0 , 0 , stream);
9216
+ ne00, 1 , sizeof (half), 0 , 0 , 0 , 0 , 0 , 0 , stream);
9220
9217
}
9221
9218
#else
9222
9219
const dfloat * src1_dfloat = (const dfloat *) src1_ddf_i; // dfloat == float, no conversion
@@ -10891,19 +10888,18 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
10891
10888
GGML_ASSERT (Q->type == GGML_TYPE_F32);
10892
10889
GGML_ASSERT (K->type == GGML_TYPE_F16);
10893
10890
GGML_ASSERT (V->type == GGML_TYPE_F16);
10894
- if (mask) {
10895
- GGML_ASSERT (mask->type == GGML_TYPE_F32);
10896
- }
10897
10891
GGML_ASSERT (KQV->type == GGML_TYPE_F32);
10898
10892
10899
10893
GGML_ASSERT (Q->backend == GGML_BACKEND_GPU);
10900
10894
GGML_ASSERT (K->backend == GGML_BACKEND_GPU);
10901
10895
GGML_ASSERT (V->backend == GGML_BACKEND_GPU);
10902
- if (mask) {
10903
- GGML_ASSERT (mask->backend == GGML_BACKEND_GPU);
10904
- }
10905
10896
GGML_ASSERT (KQV->backend == GGML_BACKEND_GPU);
10906
10897
10898
+ GGML_ASSERT (!mask || mask->type == GGML_TYPE_F16);
10899
+ GGML_ASSERT (!mask || mask->backend == GGML_BACKEND_GPU);
10900
+ GGML_ASSERT (!mask || mask->ne [1 ] >= GGML_PAD (Q->ne [1 ], 8 ) &&
10901
+ " the Flash-Attention CUDA kernel requires the mask to be padded to 8 and at least n_queries big" );
10902
+
10907
10903
ggml_cuda_set_device (g_main_device);
10908
10904
const cudaStream_t main_stream = g_cudaStreams[g_main_device][0 ];
10909
10905
@@ -10925,7 +10921,6 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
10925
10921
dim3 block_dim (32 , nwarps, 1 );
10926
10922
10927
10923
int shmem = nqpb*(Q->ne [0 ] + nwarps*(ncpw + nqpb))*(sizeof (float )/2 );
10928
- printf (" shared memory: %d bytes [%i, %i, %i] scale = %f\n\n " , shmem, Q->ne [0 ], Q->ne [1 ], Q->ne [2 ], scale);
10929
10924
switch (Q->ne [0 ])
10930
10925
{
10931
10926
case 16 :
0 commit comments