Skip to content

Commit 9cb317f

Browse files
authored
ggml : full ALiBi support (#7192)
* ggml : full ALiBi support * ggml : update ggml_soft_max_ext() CUDA, SYCL * ggml : ggml_flash_attn_ext() support ALiBi (CPU) * ggml : ggml_flash_attn_ext() support ALiBi (Metal) * ggml : fix warning * ggml : ggml_flash_attn_ext() support ALiBi (CUDA) ggml-ci * ggml : fix assert message * vulkan : add dev notes * ggml : require mask when using ALiBi ggml-ci * convert : fix convert for refact models
1 parent e849648 commit 9cb317f

16 files changed

+350
-825
lines changed

convert-hf-to-gguf.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1013,6 +1013,18 @@ def set_gguf_parameters(self):
10131013
class RefactModel(Model):
10141014
model_arch = gguf.MODEL_ARCH.REFACT
10151015

1016+
def set_vocab(self):
1017+
super().set_vocab()
1018+
1019+
# TODO: how to determine special FIM tokens automatically?
1020+
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False,
1021+
special_token_types = ['prefix', 'suffix', 'middle', 'fsep', 'eot'])
1022+
special_vocab._set_special_token("prefix", 1)
1023+
special_vocab._set_special_token("suffix", 3)
1024+
special_vocab._set_special_token("middle", 2)
1025+
special_vocab._set_special_token("fsep", 4) # is this correct?
1026+
special_vocab.add_to_gguf(self.gguf_writer)
1027+
10161028
def set_gguf_parameters(self):
10171029
hidden_dim = self.hparams["n_embd"]
10181030
inner_dim = 4 * hidden_dim

ggml-cuda.cu

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
#include "ggml-cuda/common.cuh"
66
#include "ggml-cuda/acc.cuh"
7-
#include "ggml-cuda/alibi.cuh"
87
#include "ggml-cuda/arange.cuh"
98
#include "ggml-cuda/argsort.cuh"
109
#include "ggml-cuda/binbcast.cuh"
@@ -2277,9 +2276,6 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
22772276
case GGML_OP_ROPE:
22782277
ggml_cuda_op_rope(ctx, dst);
22792278
break;
2280-
case GGML_OP_ALIBI:
2281-
ggml_cuda_op_alibi(ctx, dst);
2282-
break;
22832279
case GGML_OP_IM2COL:
22842280
ggml_cuda_op_im2col(ctx, dst);
22852281
break;
@@ -2829,7 +2825,6 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
28292825
case GGML_OP_DIAG_MASK_INF:
28302826
case GGML_OP_SOFT_MAX:
28312827
case GGML_OP_ROPE:
2832-
case GGML_OP_ALIBI:
28332828
case GGML_OP_IM2COL:
28342829
case GGML_OP_POOL_2D:
28352830
case GGML_OP_SUM_ROWS:

ggml-cuda/alibi.cu

Lines changed: 0 additions & 63 deletions
This file was deleted.

ggml-cuda/alibi.cuh

Lines changed: 0 additions & 5 deletions
This file was deleted.

ggml-cuda/fattn.cu

Lines changed: 62 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ static __global__ void flash_attn_vec_ext_f16(
2323
float * __restrict__ dst,
2424
float2 * __restrict__ dst_meta,
2525
const float scale,
26+
const float max_bias,
27+
const float m0,
28+
const float m1,
29+
const uint32_t n_head_log2,
2630
const int ne00,
2731
const int ne01,
2832
const int ne02,
@@ -58,6 +62,18 @@ static __global__ void flash_attn_vec_ext_f16(
5862
const int stride_KV = nb11 / sizeof(half);
5963
const int stride_KV2 = nb11 / sizeof(half2);
6064

65+
half slopeh = __float2half(1.0f);
66+
67+
// ALiBi
68+
if (max_bias > 0.0f) {
69+
const int h = blockIdx.y;
70+
71+
const float base = h < n_head_log2 ? m0 : m1;
72+
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
73+
74+
slopeh = __float2half(powf(base, exph));
75+
}
76+
6177
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
6278
constexpr int nwarps = D / WARP_SIZE;
6379
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
@@ -141,7 +157,7 @@ static __global__ void flash_attn_vec_ext_f16(
141157
for (int j = 0; j < ncols; ++j) {
142158
sum2[j] = warp_reduce_sum(sum2[j]);
143159
half sum = __low2half(sum2[j]) + __high2half(sum2[j]);
144-
sum += mask ? maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
160+
sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
145161

146162
if (ncols == 1) {
147163
kqmax_new = ggml_cuda_hmax(kqmax_new, sum);
@@ -249,6 +265,10 @@ static __global__ void flash_attn_ext_f16(
249265
float * __restrict__ dst,
250266
float2 * __restrict__ dst_meta,
251267
const float scale,
268+
const float max_bias,
269+
const float m0,
270+
const float m1,
271+
const uint32_t n_head_log2,
252272
const int ne00,
253273
const int ne01,
254274
const int ne02,
@@ -305,6 +325,20 @@ static __global__ void flash_attn_ext_f16(
305325
const int stride_Q = nb01 / sizeof(float);
306326
const int stride_KV = nb11 / sizeof(half);
307327

328+
half slopeh = __float2half(1.0f);
329+
half2 slope2 = make_half2(1.0f, 1.0f);
330+
331+
// ALiBi
332+
if (max_bias > 0.0f) {
333+
const int h = blockIdx.y;
334+
335+
const float base = h < n_head_log2 ? m0 : m1;
336+
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
337+
338+
slopeh = __float2half(powf(base, exph));
339+
slope2 = make_half2(slopeh, slopeh);
340+
}
341+
308342
frag_b Q_b[D/16][ncols/frag_n];
309343

310344
// A single buffer for temporarily holding tiles of KQ and VKQ parts:
@@ -421,7 +455,7 @@ static __global__ void flash_attn_ext_f16(
421455
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
422456
const int k = k0 + threadIdx.x;
423457

424-
KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
458+
KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
425459
KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]);
426460
}
427461
KQ_max_new = warp_reduce_max(KQ_max_new);
@@ -464,7 +498,7 @@ static __global__ void flash_attn_ext_f16(
464498
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
465499
const int k = k0 + threadIdx.x;
466500

467-
KQ2_tmp[k0/WARP_SIZE] += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
501+
KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
468502
KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
469503
}
470504
KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
@@ -710,8 +744,17 @@ template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_
710744
const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
711745
const int shmem = 0;
712746

713-
float scale;
714-
memcpy(&scale, KQV->op_params, sizeof(float));
747+
float scale = 1.0f;
748+
float max_bias = 0.0f;
749+
750+
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
751+
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
752+
753+
const uint32_t n_head = Q->ne[2];
754+
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
755+
756+
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
757+
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
715758

716759
flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>
717760
<<<blocks_num, block_dim, shmem, main_stream>>> (
@@ -720,7 +763,7 @@ template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_
720763
(const char *) V->data,
721764
mask ? ((const char *) mask->data) : nullptr,
722765
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
723-
scale,
766+
scale, max_bias, m0, m1, n_head_log2,
724767
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
725768
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
726769
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
@@ -761,8 +804,17 @@ template <int D, int cols_per_block, int nwarps, int parallel_blocks, typename K
761804
const dim3 blocks_num(parallel_blocks*(Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]);
762805
const int shmem = 0;
763806

764-
float scale;
765-
memcpy(&scale, KQV->op_params, sizeof(float));
807+
float scale = 1.0f;
808+
float max_bias = 0.0f;
809+
810+
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
811+
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
812+
813+
const uint32_t n_head = Q->ne[2];
814+
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
815+
816+
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
817+
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
766818

767819
flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>
768820
<<<blocks_num, block_dim, shmem, main_stream>>> (
@@ -771,7 +823,7 @@ template <int D, int cols_per_block, int nwarps, int parallel_blocks, typename K
771823
(const char *) V->data,
772824
mask ? ((const char *) mask->data) : nullptr,
773825
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
774-
scale,
826+
scale, max_bias, m0, m1, n_head_log2,
775827
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
776828
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
777829
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
@@ -837,7 +889,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
837889
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
838890
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
839891

840-
const int32_t precision = KQV->op_params[1];
892+
const int32_t precision = KQV->op_params[2];
841893

842894
if (!fp16_mma_available(cc)) {
843895
GGML_ASSERT(precision == GGML_PREC_DEFAULT);

0 commit comments

Comments
 (0)