@@ -4610,8 +4610,8 @@ static __global__ void rope(
4610
4610
4611
4611
template <typename T, bool has_pos>
4612
4612
static __global__ void rope_neox (
4613
- const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base ,
4614
- float ext_factor, float attn_factor, rope_corr_dims corr_dims
4613
+ const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
4614
+ float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims
4615
4615
) {
4616
4616
const int col = 2 *(blockDim .y *blockIdx .y + threadIdx .y );
4617
4617
@@ -4620,23 +4620,25 @@ static __global__ void rope_neox(
4620
4620
}
4621
4621
4622
4622
const int row = blockDim .x *blockIdx .x + threadIdx .x ;
4623
- const int i = row*ncols + col/2 ;
4623
+ const int ib = col / n_dims;
4624
+ const int ic = col % n_dims;
4625
+
4626
+ const int i = row*ncols + ib*n_dims + ic/2 ;
4624
4627
const int i2 = row/p_delta_rows;
4625
4628
4626
- // simplified from `(ib * ncols + col) * (-1 / ncols)`, where ib is assumed to be zero
4627
- const float cur_rot = -float (col)/ncols;
4629
+ float cur_rot = inv_ndims * ic - ib;
4628
4630
4629
4631
const int p = has_pos ? pos[i2] : 0 ;
4630
- const float theta_base = p*powf (freq_base, cur_rot );
4632
+ const float theta_base = p*freq_scale* powf (theta_scale, col/ 2 . 0f );
4631
4633
4632
4634
float cos_theta, sin_theta;
4633
4635
rope_yarn (theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
4634
4636
4635
4637
const float x0 = x[i + 0 ];
4636
- const float x1 = x[i + ncols /2 ];
4638
+ const float x1 = x[i + n_dims /2 ];
4637
4639
4638
- dst[i + 0 ] = x0*cos_theta - x1*sin_theta;
4639
- dst[i + ncols /2 ] = x0*sin_theta + x1*cos_theta;
4640
+ dst[i + 0 ] = x0*cos_theta - x1*sin_theta;
4641
+ dst[i + n_dims /2 ] = x0*sin_theta + x1*cos_theta;
4640
4642
}
4641
4643
4642
4644
static __global__ void rope_glm_f32 (
@@ -5739,20 +5741,26 @@ static void rope_cuda(
5739
5741
5740
5742
template <typename T>
5741
5743
static void rope_neox_cuda (
5742
- const T * x, T * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
5744
+ const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
5743
5745
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
5744
5746
) {
5745
5747
GGML_ASSERT (ncols % 2 == 0 );
5746
5748
const dim3 block_dims (1 , CUDA_ROPE_BLOCK_SIZE, 1 );
5747
5749
const int num_blocks_x = (ncols + 2 *CUDA_ROPE_BLOCK_SIZE - 1 ) / (2 *CUDA_ROPE_BLOCK_SIZE);
5748
5750
const dim3 block_nums (nrows, num_blocks_x, 1 );
5751
+
5752
+ const float theta_scale = powf (freq_base, -2 .0f /n_dims);
5753
+ const float inv_ndims = -1 .0f / n_dims;
5754
+
5749
5755
if (pos == nullptr ) {
5750
5756
rope_neox<T, false ><<<block_nums, block_dims, 0 , stream>>> (
5751
- x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
5757
+ x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
5758
+ theta_scale, inv_ndims
5752
5759
);
5753
5760
} else {
5754
5761
rope_neox<T, true ><<<block_nums, block_dims, 0 , stream>>> (
5755
- x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
5762
+ x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
5763
+ theta_scale, inv_ndims
5756
5764
);
5757
5765
}
5758
5766
}
@@ -6707,15 +6715,14 @@ inline void ggml_cuda_op_rope(
6707
6715
GGML_ASSERT (false );
6708
6716
rope_glm_f32_cuda (src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, n_ctx, main_stream);
6709
6717
} else if (is_neox) {
6710
- GGML_ASSERT (ne00 == n_dims && " ne00 != n_dims is not implemented for CUDA yet" );
6711
6718
if (src0->type == GGML_TYPE_F32) {
6712
6719
rope_neox_cuda (
6713
- (const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
6720
+ (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
6714
6721
attn_factor, corr_dims, main_stream
6715
6722
);
6716
6723
} else if (src0->type == GGML_TYPE_F16) {
6717
6724
rope_neox_cuda (
6718
- (const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
6725
+ (const half *)src0_dd, (half *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
6719
6726
attn_factor, corr_dims, main_stream
6720
6727
);
6721
6728
} else {
0 commit comments