Skip to content

Commit 5844493

Browse files
committed
ggml-cuda : support stablelm rope
1 parent 8e672ef commit 5844493

File tree

1 file changed

+22
-15
lines changed

1 file changed

+22
-15
lines changed

ggml-cuda.cu

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4609,8 +4609,8 @@ static __global__ void rope(
46094609

46104610
template<typename T, bool has_pos>
46114611
static __global__ void rope_neox(
4612-
const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
4613-
float ext_factor, float attn_factor, rope_corr_dims corr_dims
4612+
const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
4613+
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims
46144614
) {
46154615
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
46164616

@@ -4619,23 +4619,25 @@ static __global__ void rope_neox(
46194619
}
46204620

46214621
const int row = blockDim.x*blockIdx.x + threadIdx.x;
4622-
const int i = row*ncols + col/2;
4622+
const int ib = col / n_dims;
4623+
const int ic = col % n_dims;
4624+
4625+
const int i = row*ncols + ib*n_dims + ic/2;
46234626
const int i2 = row/p_delta_rows;
46244627

4625-
// simplified from `(ib * ncols + col) * (-1 / ncols)`, where ib is assumed to be zero
4626-
const float cur_rot = -float(col)/ncols;
4628+
float cur_rot = inv_ndims * ic - ib;
46274629

46284630
const int p = has_pos ? pos[i2] : 0;
4629-
const float theta_base = p*powf(freq_base, cur_rot);
4631+
const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f);
46304632

46314633
float cos_theta, sin_theta;
46324634
rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
46334635

46344636
const float x0 = x[i + 0];
4635-
const float x1 = x[i + ncols/2];
4637+
const float x1 = x[i + n_dims/2];
46364638

4637-
dst[i + 0] = x0*cos_theta - x1*sin_theta;
4638-
dst[i + ncols/2] = x0*sin_theta + x1*cos_theta;
4639+
dst[i + 0] = x0*cos_theta - x1*sin_theta;
4640+
dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
46394641
}
46404642

46414643
static __global__ void rope_glm_f32(
@@ -5738,20 +5740,26 @@ static void rope_cuda(
57385740

57395741
template<typename T>
57405742
static void rope_neox_cuda(
5741-
const T * x, T * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
5743+
const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
57425744
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
57435745
) {
57445746
GGML_ASSERT(ncols % 2 == 0);
57455747
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
57465748
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
57475749
const dim3 block_nums(nrows, num_blocks_x, 1);
5750+
5751+
const float theta_scale = powf(freq_base, -2.0f/n_dims);
5752+
const float inv_ndims = -1.0f / n_dims;
5753+
57485754
if (pos == nullptr) {
57495755
rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(
5750-
x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
5756+
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims,
5757+
theta_scale, inv_ndims
57515758
);
57525759
} else {
57535760
rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(
5754-
x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
5761+
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims,
5762+
theta_scale, inv_ndims
57555763
);
57565764
}
57575765
}
@@ -6706,15 +6714,14 @@ inline void ggml_cuda_op_rope(
67066714
GGML_ASSERT(false);
67076715
rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, n_ctx, main_stream);
67086716
} else if (is_neox) {
6709-
GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet");
67106717
if (src0->type == GGML_TYPE_F32) {
67116718
rope_neox_cuda(
6712-
(const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
6719+
(const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
67136720
attn_factor, corr_dims, main_stream
67146721
);
67156722
} else if (src0->type == GGML_TYPE_F16) {
67166723
rope_neox_cuda(
6717-
(const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
6724+
(const half *)src0_dd, (half *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
67186725
attn_factor, corr_dims, main_stream
67196726
);
67206727
} else {

0 commit comments

Comments
 (0)