Skip to content

Commit 20d70c6

Browse files
committed
wip
1 parent aa58468 commit 20d70c6

File tree

4 files changed

+17
-36
lines changed

4 files changed

+17
-36
lines changed

ggml-cuda/rope.cu

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ static __global__ void rope(
6161
template<typename T, bool has_pos, bool has_freq_facs>
6262
static __global__ void rope_neox(
6363
const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
64-
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims, const float * freq_factors
64+
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors
6565
) {
6666
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
6767

@@ -85,15 +85,13 @@ static __global__ void rope_neox(
8585
const int i = row*ncols + ib*n_dims + ic/2;
8686
const int i2 = row/p_delta_rows;
8787

88-
float cur_rot = inv_ndims * ic - ib;
89-
9088
const int p = has_pos ? pos[i2] : 0;
9189
const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f;
9290

93-
const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f)/freq_factor;
91+
const float theta_base = p*powf(theta_scale, col/2.0f)/freq_factor;
9492

9593
float cos_theta, sin_theta;
96-
rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
94+
rope_yarn(theta_base, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
9795

9896
const float x0 = x[i + 0];
9997
const float x1 = x[i + n_dims/2];
@@ -174,30 +172,29 @@ static void rope_neox_cuda(
174172
const dim3 block_nums(nrows, num_blocks_x, 1);
175173

176174
const float theta_scale = powf(freq_base, -2.0f/n_dims);
177-
const float inv_ndims = -1.0f / n_dims;
178175

179176
if (pos == nullptr) {
180177
if (freq_factors == nullptr) {
181178
rope_neox<T, false, false><<<block_nums, block_dims, 0, stream>>>(
182179
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
183-
theta_scale, inv_ndims, freq_factors
180+
theta_scale, freq_factors
184181
);
185182
} else {
186183
rope_neox<T, false, true><<<block_nums, block_dims, 0, stream>>>(
187184
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
188-
theta_scale, inv_ndims, freq_factors
185+
theta_scale, freq_factors
189186
);
190187
}
191188
} else {
192189
if (freq_factors == nullptr) {
193190
rope_neox<T, true, false><<<block_nums, block_dims, 0, stream>>>(
194191
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
195-
theta_scale, inv_ndims, freq_factors
192+
theta_scale, freq_factors
196193
);
197194
} else {
198195
rope_neox<T, true, true><<<block_nums, block_dims, 0, stream>>>(
199196
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
200-
theta_scale, inv_ndims, freq_factors
197+
theta_scale, freq_factors
201198
);
202199
}
203200
}

ggml-metal.metal

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1791,14 +1791,10 @@ kernel void kernel_rope(
17911791
if (ic < n_dims) {
17921792
const int64_t ib = 0;
17931793

1794-
// simplified from `(ib * n_dims + ic) * inv_ndims`
1795-
const float cur_rot = inv_ndims*ic - ib;
17961794
const float freq_factor = src2 != src0 ? src2[ic/2] : 1.0f;
17971795

1798-
const float theta = theta_0 * pow(freq_base, cur_rot) / freq_factor;
1799-
18001796
float cos_theta, sin_theta;
1801-
rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
1797+
rope_yarn(theta_0/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
18021798

18031799
const int64_t i0 = ib*n_dims + ic/2;
18041800

ggml.c

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14353,7 +14353,7 @@ static void ggml_compute_forward_rope_f32(
1435314353
int ir = 0;
1435414354

1435514355
const float theta_scale = powf(freq_base, -2.0f/n_dims);
14356-
const float inv_ndims = -1.f/n_dims;
14356+
1435714357
float corr_dims[2];
1435814358
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
1435914359

@@ -14437,25 +14437,20 @@ static void ggml_compute_forward_rope_f32(
1443714437
dst_data[1] = x0*sin_theta*zeta + x1*cos_theta*zeta;
1443814438
}
1443914439
} else {
14440-
// TODO: this might be wrong for ne0 != n_dims - need double check
14441-
// it seems we have to rope just the first n_dims elements and do nothing with the rest
14442-
// ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
14443-
theta_base *= freq_scale;
14440+
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
1444414441
for (int64_t ic = 0; ic < ne0; ic += 2) {
1444514442
if (ic < n_dims) {
1444614443
const int64_t ib = 0;
1444714444

14448-
// simplified from `(ib * n_dims + ic) * inv_ndims`
14449-
float cur_rot = inv_ndims * ic - ib;
1445014445
float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
1445114446

1445214447
float cos_theta, sin_theta;
1445314448
rope_yarn(
14454-
theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
14449+
theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor,
1445514450
&cos_theta, &sin_theta
1445614451
);
14457-
sin_theta *= sin_sign;
1445814452

14453+
sin_theta *= sin_sign;
1445914454
theta_base *= theta_scale;
1446014455

1446114456
const int64_t i0 = ib*n_dims + ic/2;
@@ -14538,7 +14533,7 @@ static void ggml_compute_forward_rope_f16(
1453814533
int ir = 0;
1453914534

1454014535
const float theta_scale = powf(freq_base, -2.0f/n_dims);
14541-
const float inv_ndims = -1.f/n_dims;
14536+
1454214537
float corr_dims[2];
1454314538
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
1454414539

@@ -14618,25 +14613,20 @@ static void ggml_compute_forward_rope_f16(
1461814613
dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
1461914614
}
1462014615
} else {
14621-
// TODO: this might be wrong for ne0 != n_dims - need double check
14622-
// it seems we have to rope just the first n_dims elements and do nothing with the rest
14623-
// ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
14624-
theta_base *= freq_scale;
14616+
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
1462514617
for (int64_t ic = 0; ic < ne0; ic += 2) {
1462614618
if (ic < n_dims) {
1462714619
const int64_t ib = 0;
1462814620

14629-
// simplified from `(ib * n_dims + ic) * inv_ndims`
14630-
float cur_rot = inv_ndims * ic - ib;
1463114621
float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
1463214622

1463314623
float cos_theta, sin_theta;
1463414624
rope_yarn(
14635-
theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
14625+
theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor,
1463614626
&cos_theta, &sin_theta
1463714627
);
14638-
sin_theta *= sin_sign;
1463914628

14629+
sin_theta *= sin_sign;
1464014630
theta_base *= theta_scale;
1464114631

1464214632
const int64_t i0 = ib*n_dims + ic/2;

ggml_vk_generate_shaders.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2670,14 +2670,12 @@
26702670
const uint i = row*p.ncols + ib*p.ndims + ic/2;
26712671
const uint i2 = row/p.p_delta_rows;
26722672
2673-
const float cur_rot = p.inv_ndims * ic - ib;
2674-
26752673
const int pos = data_b[i2];
26762674
const float freq_factor = p.has_freq_facs != 0 ? data_freq_factors[ic/2] : 1.0f;
26772675
const float theta_base = pos*p.freq_scale*pow(p.theta_scale, col/2.0f) / freq_factor;
26782676
26792677
float cos_theta, sin_theta;
2680-
rope_yarn(theta_base, uint(cur_rot), cos_theta, sin_theta);
2678+
rope_yarn(theta_base, ic, cos_theta, sin_theta);
26812679
26822680
const float x0 = float(data_a[i + 0]);
26832681
const float x1 = float(data_a[i + p.ndims/2]);

0 commit comments

Comments
 (0)