Skip to content

Commit 41f5f3a

Browse files
fixup! CUDA: add FP32 FlashAttention vector kernel
1 parent bbeb952 commit 41f5f3a

File tree

2 files changed

+58
-8
lines changed

2 files changed

+58
-8
lines changed

ggml-cuda/fattn-vec-f16.cu

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ static __global__ void flash_attn_vec_ext_f16(
1414
float * __restrict__ dst,
1515
float2 * __restrict__ dst_meta,
1616
const float scale,
17+
const float max_bias,
18+
const float m0,
19+
const float m1,
20+
const uint32_t n_head_log2,
1721
const int ne00,
1822
const int ne01,
1923
const int ne02,
@@ -49,6 +53,18 @@ static __global__ void flash_attn_vec_ext_f16(
4953
const int stride_KV = nb11 / sizeof(half);
5054
const int stride_KV2 = nb11 / sizeof(half2);
5155

56+
half slopeh = __float2half(1.0f);
57+
58+
// ALiBi
59+
if (max_bias > 0.0f) {
60+
const int h = blockIdx.y;
61+
62+
const float base = h < n_head_log2 ? m0 : m1;
63+
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
64+
65+
slopeh = __float2half(powf(base, exph));
66+
}
67+
5268
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
5369
constexpr int nwarps = D / WARP_SIZE;
5470
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
@@ -132,7 +148,7 @@ static __global__ void flash_attn_vec_ext_f16(
132148
for (int j = 0; j < ncols; ++j) {
133149
sum2[j] = warp_reduce_sum(sum2[j]);
134150
half sum = __low2half(sum2[j]) + __high2half(sum2[j]);
135-
sum += mask ? maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
151+
sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
136152

137153
if (ncols == 1) {
138154
kqmax_new = ggml_cuda_hmax(kqmax_new, sum);
@@ -244,8 +260,17 @@ template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_
244260
const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
245261
const int shmem = 0;
246262

247-
float scale;
248-
memcpy(&scale, KQV->op_params, sizeof(float));
263+
float scale = 1.0f;
264+
float max_bias = 0.0f;
265+
266+
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
267+
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
268+
269+
const uint32_t n_head = Q->ne[2];
270+
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
271+
272+
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
273+
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
249274

250275
flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>
251276
<<<blocks_num, block_dim, shmem, main_stream>>> (
@@ -254,7 +279,7 @@ template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_
254279
(const char *) V->data,
255280
mask ? ((const char *) mask->data) : nullptr,
256281
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
257-
scale,
282+
scale, max_bias, m0, m1, n_head_log2,
258283
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
259284
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
260285
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,

ggml-cuda/fattn-vec-f32.cu

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ static __global__ void flash_attn_vec_ext_f32(
1414
float * __restrict__ dst,
1515
float2 * __restrict__ dst_meta,
1616
const float scale,
17+
const float max_bias,
18+
const float m0,
19+
const float m1,
20+
const uint32_t n_head_log2,
1721
const int ne00,
1822
const int ne01,
1923
const int ne02,
@@ -48,6 +52,18 @@ static __global__ void flash_attn_vec_ext_f32(
4852
const int stride_KV = nb11 / sizeof(half);
4953
const int stride_KV2 = nb11 / sizeof(half2);
5054

55+
float slope = 1.0f;
56+
57+
// ALiBi
58+
if (max_bias > 0.0f) {
59+
const int h = blockIdx.y;
60+
61+
const float base = h < n_head_log2 ? m0 : m1;
62+
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
63+
64+
slope = powf(base, exph);
65+
}
66+
5167
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
5268
constexpr int nwarps = D / WARP_SIZE;
5369
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
@@ -127,7 +143,7 @@ static __global__ void flash_attn_vec_ext_f32(
127143
#pragma unroll
128144
for (int j = 0; j < ncols; ++j) {
129145
sum[j] = warp_reduce_sum(sum[j]);
130-
sum[j] += mask ? __half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
146+
sum[j] += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
131147

132148
kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum[j]);
133149

@@ -230,8 +246,17 @@ template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_
230246
const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
231247
const int shmem = 0;
232248

233-
float scale;
234-
memcpy(&scale, KQV->op_params, sizeof(float));
249+
float scale = 1.0f;
250+
float max_bias = 0.0f;
251+
252+
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
253+
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
254+
255+
const uint32_t n_head = Q->ne[2];
256+
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
257+
258+
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
259+
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
235260

236261
flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks>
237262
<<<blocks_num, block_dim, shmem, main_stream>>> (
@@ -240,7 +265,7 @@ template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_
240265
(const char *) V->data,
241266
mask ? ((const char *) mask->data) : nullptr,
242267
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
243-
scale,
268+
scale, max_bias, m0, m1, n_head_log2,
244269
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
245270
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
246271
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,

0 commit comments

Comments
 (0)