Skip to content

ggml : full ALiBi support #7192

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
May 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,6 +1013,18 @@ def set_gguf_parameters(self):
class RefactModel(Model):
model_arch = gguf.MODEL_ARCH.REFACT

def set_vocab(self):
super().set_vocab()

# TODO: how to determine special FIM tokens automatically?
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False,
special_token_types = ['prefix', 'suffix', 'middle', 'fsep', 'eot'])
special_vocab._set_special_token("prefix", 1)
special_vocab._set_special_token("suffix", 3)
special_vocab._set_special_token("middle", 2)
special_vocab._set_special_token("fsep", 4) # is this correct?
special_vocab.add_to_gguf(self.gguf_writer)

def set_gguf_parameters(self):
hidden_dim = self.hparams["n_embd"]
inner_dim = 4 * hidden_dim
Expand Down
5 changes: 0 additions & 5 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

#include "ggml-cuda/common.cuh"
#include "ggml-cuda/acc.cuh"
#include "ggml-cuda/alibi.cuh"
#include "ggml-cuda/arange.cuh"
#include "ggml-cuda/argsort.cuh"
#include "ggml-cuda/binbcast.cuh"
Expand Down Expand Up @@ -2277,9 +2276,6 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_ROPE:
ggml_cuda_op_rope(ctx, dst);
break;
case GGML_OP_ALIBI:
ggml_cuda_op_alibi(ctx, dst);
break;
case GGML_OP_IM2COL:
ggml_cuda_op_im2col(ctx, dst);
break;
Expand Down Expand Up @@ -2829,7 +2825,6 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
case GGML_OP_ROPE:
case GGML_OP_ALIBI:
case GGML_OP_IM2COL:
case GGML_OP_POOL_2D:
case GGML_OP_SUM_ROWS:
Expand Down
63 changes: 0 additions & 63 deletions ggml-cuda/alibi.cu

This file was deleted.

5 changes: 0 additions & 5 deletions ggml-cuda/alibi.cuh

This file was deleted.

72 changes: 62 additions & 10 deletions ggml-cuda/fattn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ static __global__ void flash_attn_vec_ext_f16(
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
const float max_bias,
const float m0,
const float m1,
const uint32_t n_head_log2,
const int ne00,
const int ne01,
const int ne02,
Expand Down Expand Up @@ -58,6 +62,18 @@ static __global__ void flash_attn_vec_ext_f16(
const int stride_KV = nb11 / sizeof(half);
const int stride_KV2 = nb11 / sizeof(half2);

half slopeh = __float2half(1.0f);

// ALiBi
if (max_bias > 0.0f) {
const int h = blockIdx.y;

const float base = h < n_head_log2 ? m0 : m1;
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;

slopeh = __float2half(powf(base, exph));
}

static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
constexpr int nwarps = D / WARP_SIZE;
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
Expand Down Expand Up @@ -141,7 +157,7 @@ static __global__ void flash_attn_vec_ext_f16(
for (int j = 0; j < ncols; ++j) {
sum2[j] = warp_reduce_sum(sum2[j]);
half sum = __low2half(sum2[j]) + __high2half(sum2[j]);
sum += mask ? maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);

if (ncols == 1) {
kqmax_new = ggml_cuda_hmax(kqmax_new, sum);
Expand Down Expand Up @@ -249,6 +265,10 @@ static __global__ void flash_attn_ext_f16(
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
const float max_bias,
const float m0,
const float m1,
const uint32_t n_head_log2,
const int ne00,
const int ne01,
const int ne02,
Expand Down Expand Up @@ -305,6 +325,20 @@ static __global__ void flash_attn_ext_f16(
const int stride_Q = nb01 / sizeof(float);
const int stride_KV = nb11 / sizeof(half);

half slopeh = __float2half(1.0f);
half2 slope2 = make_half2(1.0f, 1.0f);

// ALiBi
if (max_bias > 0.0f) {
const int h = blockIdx.y;

const float base = h < n_head_log2 ? m0 : m1;
const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;

slopeh = __float2half(powf(base, exph));
slope2 = make_half2(slopeh, slopeh);
}

frag_b Q_b[D/16][ncols/frag_n];

// A single buffer for temporarily holding tiles of KQ and VKQ parts:
Expand Down Expand Up @@ -421,7 +455,7 @@ static __global__ void flash_attn_ext_f16(
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
const int k = k0 + threadIdx.x;

KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]);
}
KQ_max_new = warp_reduce_max(KQ_max_new);
Expand Down Expand Up @@ -464,7 +498,7 @@ static __global__ void flash_attn_ext_f16(
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
const int k = k0 + threadIdx.x;

KQ2_tmp[k0/WARP_SIZE] += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
}
KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
Expand Down Expand Up @@ -710,8 +744,17 @@ template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_
const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
const int shmem = 0;

float scale;
memcpy(&scale, KQV->op_params, sizeof(float));
float scale = 1.0f;
float max_bias = 0.0f;

memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));

const uint32_t n_head = Q->ne[2];
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));

const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);

flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>
<<<blocks_num, block_dim, shmem, main_stream>>> (
Expand All @@ -720,7 +763,7 @@ template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_
(const char *) V->data,
mask ? ((const char *) mask->data) : nullptr,
parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
scale,
scale, max_bias, m0, m1, n_head_log2,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Expand Down Expand Up @@ -761,8 +804,17 @@ template <int D, int cols_per_block, int nwarps, int parallel_blocks, typename K
const dim3 blocks_num(parallel_blocks*(Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]);
const int shmem = 0;

float scale;
memcpy(&scale, KQV->op_params, sizeof(float));
float scale = 1.0f;
float max_bias = 0.0f;

memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));

const uint32_t n_head = Q->ne[2];
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));

const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);

flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>
<<<blocks_num, block_dim, shmem, main_stream>>> (
Expand All @@ -771,7 +823,7 @@ template <int D, int cols_per_block, int nwarps, int parallel_blocks, typename K
(const char *) V->data,
mask ? ((const char *) mask->data) : nullptr,
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
scale,
scale, max_bias, m0, m1, n_head_log2,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
Expand Down Expand Up @@ -837,7 +889,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;

const int32_t precision = KQV->op_params[1];
const int32_t precision = KQV->op_params[2];

if (!fp16_mma_available(cc)) {
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
Expand Down
Loading
Loading