Skip to content

Commit 3b0df03

Browse files
committed
ggml : fixes (hopefully)
ggml-ci
1 parent aa58468 commit 3b0df03

File tree

4 files changed

+28
-51
lines changed

4 files changed

+28
-51
lines changed

ggml-cuda/rope.cu

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ static __global__ void rope(
6161
template<typename T, bool has_pos, bool has_freq_facs>
6262
static __global__ void rope_neox(
6363
const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
64-
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims, const float * freq_factors
64+
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, const float * freq_factors
6565
) {
6666
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
6767

@@ -85,15 +85,13 @@ static __global__ void rope_neox(
8585
const int i = row*ncols + ib*n_dims + ic/2;
8686
const int i2 = row/p_delta_rows;
8787

88-
float cur_rot = inv_ndims * ic - ib;
89-
9088
const int p = has_pos ? pos[i2] : 0;
9189
const float freq_factor = has_freq_facs ? freq_factors[ic/2] : 1.0f;
9290

93-
const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f)/freq_factor;
91+
const float theta_base = p*powf(theta_scale, col/2.0f)/freq_factor;
9492

9593
float cos_theta, sin_theta;
96-
rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
94+
rope_yarn(theta_base, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
9795

9896
const float x0 = x[i + 0];
9997
const float x1 = x[i + n_dims/2];
@@ -174,30 +172,29 @@ static void rope_neox_cuda(
174172
const dim3 block_nums(nrows, num_blocks_x, 1);
175173

176174
const float theta_scale = powf(freq_base, -2.0f/n_dims);
177-
const float inv_ndims = -1.0f / n_dims;
178175

179176
if (pos == nullptr) {
180177
if (freq_factors == nullptr) {
181178
rope_neox<T, false, false><<<block_nums, block_dims, 0, stream>>>(
182179
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
183-
theta_scale, inv_ndims, freq_factors
180+
theta_scale, freq_factors
184181
);
185182
} else {
186183
rope_neox<T, false, true><<<block_nums, block_dims, 0, stream>>>(
187184
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
188-
theta_scale, inv_ndims, freq_factors
185+
theta_scale, freq_factors
189186
);
190187
}
191188
} else {
192189
if (freq_factors == nullptr) {
193190
rope_neox<T, true, false><<<block_nums, block_dims, 0, stream>>>(
194191
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
195-
theta_scale, inv_ndims, freq_factors
192+
theta_scale, freq_factors
196193
);
197194
} else {
198195
rope_neox<T, true, true><<<block_nums, block_dims, 0, stream>>>(
199196
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
200-
theta_scale, inv_ndims, freq_factors
197+
theta_scale, freq_factors
201198
);
202199
}
203200
}

ggml-metal.metal

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1767,13 +1767,13 @@ kernel void kernel_rope(
17671767

17681768
const int64_t p = pos[i2];
17691769

1770-
const float theta_0 = (float)p;
1770+
const float theta_base = (float)p;
17711771
const float inv_ndims = -1.f/n_dims;
17721772

17731773
if (!is_neox) {
17741774
for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
1775+
const float theta = theta_base * pow(freq_base, inv_ndims*i0);
17751776

1776-
const float theta = theta_0 * pow(freq_base, inv_ndims*i0);
17771777
float cos_theta, sin_theta;
17781778
rope_yarn(theta, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
17791779

@@ -1789,18 +1789,14 @@ kernel void kernel_rope(
17891789
} else {
17901790
for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
17911791
if (ic < n_dims) {
1792-
const int64_t ib = 0;
1792+
const int64_t i0 = ic/2;
17931793

1794-
// simplified from `(ib * n_dims + ic) * inv_ndims`
1795-
const float cur_rot = inv_ndims*ic - ib;
1796-
const float freq_factor = src2 != src0 ? src2[ic/2] : 1.0f;
1794+
const float freq_factor = src2 != src0 ? src2[i0] : 1.0f;
17971795

1798-
const float theta = theta_0 * pow(freq_base, cur_rot) / freq_factor;
1796+
const float theta = theta_base * pow(freq_base, inv_ndims*ic);
17991797

18001798
float cos_theta, sin_theta;
1801-
rope_yarn(theta, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
1802-
1803-
const int64_t i0 = ib*n_dims + ic/2;
1799+
rope_yarn(theta/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor, &cos_theta, &sin_theta);
18041800

18051801
device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
18061802
device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);

ggml.c

Lines changed: 14 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -14353,7 +14353,7 @@ static void ggml_compute_forward_rope_f32(
1435314353
int ir = 0;
1435414354

1435514355
const float theta_scale = powf(freq_base, -2.0f/n_dims);
14356-
const float inv_ndims = -1.f/n_dims;
14356+
1435714357
float corr_dims[2];
1435814358
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
1435914359

@@ -14402,7 +14402,7 @@ static void ggml_compute_forward_rope_f32(
1440214402
const float cos_block_theta = cosf(block_theta);
1440314403
const float sin_block_theta = sinf(block_theta) * sin_sign;
1440414404

14405-
theta_base *= theta_scale;
14405+
theta_base *= theta_scale;
1440614406
block_theta *= theta_scale;
1440714407

1440814408
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
@@ -14437,29 +14437,22 @@ static void ggml_compute_forward_rope_f32(
1443714437
dst_data[1] = x0*sin_theta*zeta + x1*cos_theta*zeta;
1443814438
}
1443914439
} else {
14440-
// TODO: this might be wrong for ne0 != n_dims - need double check
14441-
// it seems we have to rope just the first n_dims elements and do nothing with the rest
14442-
// ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
14443-
theta_base *= freq_scale;
14440+
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
1444414441
for (int64_t ic = 0; ic < ne0; ic += 2) {
1444514442
if (ic < n_dims) {
14446-
const int64_t ib = 0;
14443+
const int64_t i0 = ic/2;
1444714444

14448-
// simplified from `(ib * n_dims + ic) * inv_ndims`
14449-
float cur_rot = inv_ndims * ic - ib;
14450-
float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
14445+
const float freq_factor = freq_factors ? freq_factors[i0] : 1.0f;
1445114446

1445214447
float cos_theta, sin_theta;
1445314448
rope_yarn(
14454-
theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
14449+
theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor,
1445514450
&cos_theta, &sin_theta
1445614451
);
14457-
sin_theta *= sin_sign;
1445814452

14453+
sin_theta *= sin_sign;
1445914454
theta_base *= theta_scale;
1446014455

14461-
const int64_t i0 = ib*n_dims + ic/2;
14462-
1446314456
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1446414457
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1446514458

@@ -14538,7 +14531,7 @@ static void ggml_compute_forward_rope_f16(
1453814531
int ir = 0;
1453914532

1454014533
const float theta_scale = powf(freq_base, -2.0f/n_dims);
14541-
const float inv_ndims = -1.f/n_dims;
14534+
1454214535
float corr_dims[2];
1454314536
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
1454414537

@@ -14587,7 +14580,7 @@ static void ggml_compute_forward_rope_f16(
1458714580
const float cos_block_theta = cosf(block_theta);
1458814581
const float sin_block_theta = sinf(block_theta) * sin_sign;
1458914582

14590-
theta_base *= theta_scale;
14583+
theta_base *= theta_scale;
1459114584
block_theta *= theta_scale;
1459214585

1459314586
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
@@ -14618,29 +14611,22 @@ static void ggml_compute_forward_rope_f16(
1461814611
dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
1461914612
}
1462014613
} else {
14621-
// TODO: this might be wrong for ne0 != n_dims - need double check
14622-
// it seems we have to rope just the first n_dims elements and do nothing with the rest
14623-
// ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
14624-
theta_base *= freq_scale;
14614+
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
1462514615
for (int64_t ic = 0; ic < ne0; ic += 2) {
1462614616
if (ic < n_dims) {
14627-
const int64_t ib = 0;
14617+
const int64_t i0 = ic/2;
1462814618

14629-
// simplified from `(ib * n_dims + ic) * inv_ndims`
14630-
float cur_rot = inv_ndims * ic - ib;
14631-
float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
14619+
const float freq_factor = freq_factors ? freq_factors[i0] : 1.0f;
1463214620

1463314621
float cos_theta, sin_theta;
1463414622
rope_yarn(
14635-
theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
14623+
theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor,
1463614624
&cos_theta, &sin_theta
1463714625
);
14638-
sin_theta *= sin_sign;
1463914626

14627+
sin_theta *= sin_sign;
1464014628
theta_base *= theta_scale;
1464114629

14642-
const int64_t i0 = ib*n_dims + ic/2;
14643-
1464414630
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1464514631
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1464614632

ggml_vk_generate_shaders.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2670,14 +2670,12 @@
26702670
const uint i = row*p.ncols + ib*p.ndims + ic/2;
26712671
const uint i2 = row/p.p_delta_rows;
26722672
2673-
const float cur_rot = p.inv_ndims * ic - ib;
2674-
26752673
const int pos = data_b[i2];
26762674
const float freq_factor = p.has_freq_facs != 0 ? data_freq_factors[ic/2] : 1.0f;
26772675
const float theta_base = pos*p.freq_scale*pow(p.theta_scale, col/2.0f) / freq_factor;
26782676
26792677
float cos_theta, sin_theta;
2680-
rope_yarn(theta_base, uint(cur_rot), cos_theta, sin_theta);
2678+
rope_yarn(theta_base, ic, cos_theta, sin_theta);
26812679
26822680
const float x0 = float(data_a[i + 0]);
26832681
const float x1 = float(data_a[i + p.ndims/2]);

0 commit comments

Comments
 (0)