Skip to content

Commit fd878f7

Browse files
committed
cuda: mask as fp16
1 parent 3df0b8d commit fd878f7

File tree

1 file changed

+9
-14
lines changed

1 file changed

+9
-14
lines changed

ggml-cuda.cu

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6529,7 +6529,7 @@ static __global__ void flash_attn_ext_f16(
65296529
const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1;
65306530

65316531
// pointer to the mask
6532-
const float * mp = mask ? (const float *) (mask + (ir%ne31)*nb31) : nullptr;
6532+
const half * mp = mask ? (const half *) (mask + iq1*nb31) : nullptr;
65336533

65346534
// loop over the KV cache
65356535
// each simdgroup handles blocks of Q rows and C columns
@@ -6555,12 +6555,9 @@ static __global__ void flash_attn_ext_f16(
65556555

65566556
// mqk = mqk*scale + mask
65576557
for (int64_t j = 0; j < Q16; ++j) {
6558-
// const float* msk_p = mp + 16*j*(nb31/sizeof(float)) + ic + 16*cc;
6559-
// int64_t msk_ne_row = nb31/sizeof(float);
65606558
for (uint32_t i = 0; i < mqk[j].num_elements; i++) {
6561-
// int msk_col = i % 16;
6562-
// int msk_row = i / 16;
6563-
mqk[j].x[i] = __float2half(scale) * mqk[j].x[i]; // __half2float() + msk_p[msk_col + msk_row*msk_ne_row]);
6559+
// TODO: process mask
6560+
mqk[j].x[i] = __float2half(scale) * mqk[j].x[i];
65646561
}
65656562
nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major);
65666563
}
@@ -9216,7 +9213,7 @@ static void ggml_cuda_op_dequantize_mul_mat_vec(
92169213
src1_dfloat = src1_dfloat_a.alloc(ne00);
92179214
ggml_cpy_f32_f16_cuda((const char *) src1_ddf_i, (char *) src1_dfloat, ne00,
92189215
ne00, 1, sizeof(float), 0, 0,
9219-
ne00, 1, sizeof(half), 0, 0, stream);
9216+
ne00, 1, sizeof(half), 0, 0, 0, 0, 0, 0, stream);
92209217
}
92219218
#else
92229219
const dfloat * src1_dfloat = (const dfloat *) src1_ddf_i; // dfloat == float, no conversion
@@ -10891,19 +10888,18 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
1089110888
GGML_ASSERT(Q->type == GGML_TYPE_F32);
1089210889
GGML_ASSERT(K->type == GGML_TYPE_F16);
1089310890
GGML_ASSERT(V->type == GGML_TYPE_F16);
10894-
if(mask) {
10895-
GGML_ASSERT(mask->type == GGML_TYPE_F32);
10896-
}
1089710891
GGML_ASSERT(KQV->type == GGML_TYPE_F32);
1089810892

1089910893
GGML_ASSERT(Q->backend == GGML_BACKEND_GPU);
1090010894
GGML_ASSERT(K->backend == GGML_BACKEND_GPU);
1090110895
GGML_ASSERT(V->backend == GGML_BACKEND_GPU);
10902-
if(mask) {
10903-
GGML_ASSERT(mask->backend == GGML_BACKEND_GPU);
10904-
}
1090510896
GGML_ASSERT(KQV->backend == GGML_BACKEND_GPU);
1090610897

10898+
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
10899+
GGML_ASSERT(!mask || mask->backend == GGML_BACKEND_GPU);
10900+
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 8) &&
10901+
"the Flash-Attention CUDA kernel requires the mask to be padded to 8 and at least n_queries big");
10902+
1090710903
ggml_cuda_set_device(g_main_device);
1090810904
const cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
1090910905

@@ -10925,7 +10921,6 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
1092510921
dim3 block_dim(32, nwarps, 1);
1092610922

1092710923
int shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2);
10928-
printf("shared memory: %d bytes [%i, %i, %i] scale = %f\n\n", shmem, Q->ne[0], Q->ne[1], Q->ne[2], scale);
1092910924
switch (Q->ne[0])
1093010925
{
1093110926
case 16:

0 commit comments

Comments
 (0)