Skip to content

Commit 8a052c1

Browse files
slarenggerganov
andauthored
ggml-cuda : support stablelm rope (#4156)
* ggml-cuda : support stablelm rope * remove unused freq_base kernel parameter * add n_dims parameter to llm_build_k_shift, default to n_rot via overload * llama : fix llm_build_k_shift args --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 189d684 commit 8a052c1

File tree

2 files changed

+24
-17
lines changed

2 files changed

+24
-17
lines changed

ggml-cuda.cu

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4610,8 +4610,8 @@ static __global__ void rope(
46104610

46114611
template<typename T, bool has_pos>
46124612
static __global__ void rope_neox(
4613-
const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
4614-
float ext_factor, float attn_factor, rope_corr_dims corr_dims
4613+
const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
4614+
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims
46154615
) {
46164616
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
46174617

@@ -4620,23 +4620,25 @@ static __global__ void rope_neox(
46204620
}
46214621

46224622
const int row = blockDim.x*blockIdx.x + threadIdx.x;
4623-
const int i = row*ncols + col/2;
4623+
const int ib = col / n_dims;
4624+
const int ic = col % n_dims;
4625+
4626+
const int i = row*ncols + ib*n_dims + ic/2;
46244627
const int i2 = row/p_delta_rows;
46254628

4626-
// simplified from `(ib * ncols + col) * (-1 / ncols)`, where ib is assumed to be zero
4627-
const float cur_rot = -float(col)/ncols;
4629+
float cur_rot = inv_ndims * ic - ib;
46284630

46294631
const int p = has_pos ? pos[i2] : 0;
4630-
const float theta_base = p*powf(freq_base, cur_rot);
4632+
const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f);
46314633

46324634
float cos_theta, sin_theta;
46334635
rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
46344636

46354637
const float x0 = x[i + 0];
4636-
const float x1 = x[i + ncols/2];
4638+
const float x1 = x[i + n_dims/2];
46374639

4638-
dst[i + 0] = x0*cos_theta - x1*sin_theta;
4639-
dst[i + ncols/2] = x0*sin_theta + x1*cos_theta;
4640+
dst[i + 0] = x0*cos_theta - x1*sin_theta;
4641+
dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
46404642
}
46414643

46424644
static __global__ void rope_glm_f32(
@@ -5739,20 +5741,26 @@ static void rope_cuda(
57395741

57405742
template<typename T>
57415743
static void rope_neox_cuda(
5742-
const T * x, T * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
5744+
const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
57435745
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
57445746
) {
57455747
GGML_ASSERT(ncols % 2 == 0);
57465748
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
57475749
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
57485750
const dim3 block_nums(nrows, num_blocks_x, 1);
5751+
5752+
const float theta_scale = powf(freq_base, -2.0f/n_dims);
5753+
const float inv_ndims = -1.0f / n_dims;
5754+
57495755
if (pos == nullptr) {
57505756
rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(
5751-
x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
5757+
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
5758+
theta_scale, inv_ndims
57525759
);
57535760
} else {
57545761
rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(
5755-
x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
5762+
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
5763+
theta_scale, inv_ndims
57565764
);
57575765
}
57585766
}
@@ -6707,15 +6715,14 @@ inline void ggml_cuda_op_rope(
67076715
GGML_ASSERT(false);
67086716
rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, n_ctx, main_stream);
67096717
} else if (is_neox) {
6710-
GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet");
67116718
if (src0->type == GGML_TYPE_F32) {
67126719
rope_neox_cuda(
6713-
(const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
6720+
(const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
67146721
attn_factor, corr_dims, main_stream
67156722
);
67166723
} else if (src0->type == GGML_TYPE_F16) {
67176724
rope_neox_cuda(
6718-
(const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
6725+
(const half *)src0_dd, (half *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
67196726
attn_factor, corr_dims, main_stream
67206727
);
67216728
} else {

llama.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3469,7 +3469,7 @@ static void llm_build_k_shift(
34693469
struct ggml_cgraph * graph,
34703470
llm_rope_type type,
34713471
int64_t n_ctx,
3472-
int64_t n_rot,
3472+
int n_rot,
34733473
float freq_base,
34743474
float freq_scale,
34753475
const llm_build_cb & cb) {
@@ -3501,7 +3501,7 @@ static void llm_build_k_shift(
35013501
// we rotate only the first n_rot dimensions
35023502
ggml_rope_custom_inplace(ctx,
35033503
ggml_view_3d(ctx, kv.k,
3504-
n_rot, n_head_kv, n_ctx,
3504+
n_embd_head, n_head_kv, n_ctx,
35053505
ggml_element_size(kv.k)*n_embd_head,
35063506
ggml_element_size(kv.k)*n_embd_gqa,
35073507
ggml_element_size(kv.k)*n_embd_gqa*n_ctx*il),

0 commit comments

Comments
 (0)