@@ -605,7 +605,7 @@ static float rope_ntkv2(
605
605
const float theta_base,
606
606
const float theta_linear,
607
607
const float theta_ntk,
608
- device const float corr_factors[4 ],
608
+ const float corr_factors[4 ],
609
609
const int64_t i0,
610
610
const float ntk_factor,
611
611
const float ext_factor) {
@@ -620,6 +620,29 @@ static float rope_ntkv2(
620
620
return theta;
621
621
}
622
622
623
+ // Interpolation constants found experimentally for LLaMA (might not be totally optimal though)
624
+ // Do not change unless there is a good reason for doing so!
625
+ constant float BETA_0 = 1 .75f ;
626
+ constant float BETA_1 = 1 .25f ;
627
+ constant float GAMMA_0 = 16 .0f ;
628
+ constant float GAMMA_1 = 2 .0f ;
629
+
630
+ constant float max_pos_emb = 2048 ;
631
+
632
+ // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
633
+ // `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
634
+ static float rope_ntkv2_corr_factor (const int n_dims, const float n_rot, const float base) {
635
+ return n_dims * log (max_pos_emb / (n_rot * 2 * M_PI_F)) / (2 * log (base));
636
+ }
637
+
638
+ static void rope_ntkv2_corr_factors (int n_dims, const float freq_base, float factors[4 ]) {
639
+ // start and end correction factors
640
+ factors[0 ] = max (0 .0f , floor (rope_ntkv2_corr_factor (n_dims, BETA_0, freq_base)));
641
+ factors[1 ] = min (n_dims - 1 .0f , ceil (rope_ntkv2_corr_factor (n_dims, BETA_1, freq_base)));
642
+ factors[2 ] = max (0 .0f , floor (rope_ntkv2_corr_factor (n_dims, GAMMA_0, freq_base)));
643
+ factors[3 ] = min (n_dims - 1 .0f , ceil (rope_ntkv2_corr_factor (n_dims, GAMMA_1, freq_base)));
644
+ }
645
+
623
646
kernel void kernel_rope (
624
647
device const void * src0,
625
648
device float * dst,
@@ -651,10 +674,10 @@ kernel void kernel_rope(
651
674
const int64_t i2 = tpig[1 ];
652
675
const int64_t i1 = tpig[0 ];
653
676
654
- const float theta_scale = powf (freq_base, -2 .0f /n_dims);
655
- const float theta_ntk_scale = powf (freq_base * powf (freq_scale, (n_dims / (n_dims - 2 .0f ))), -2 .0f /n_dims);
656
- device float corr_factors[4 ];
657
- ggml_rope_ntkv2_corr_factors (n_dims, freq_base, corr_factors);
677
+ const float theta_scale = pow (freq_base, -2 .0f /n_dims);
678
+ const float theta_ntk_scale = pow (freq_base * pow (freq_scale, (n_dims / (n_dims - 2 .0f ))), -2 .0f /n_dims);
679
+ float corr_factors[4 ];
680
+ rope_ntkv2_corr_factors (n_dims, freq_base, corr_factors);
658
681
659
682
float theta_base = (mode & 1 ) == 0 ? n_past + i2 : i2;
660
683
float theta_ntk = theta_base;
@@ -666,8 +689,8 @@ kernel void kernel_rope(
666
689
const float theta_linear = freq_scale * theta_base;
667
690
const float theta = rope_ntkv2 (theta_base, theta_linear, theta_ntk, corr_factors,
668
691
i0, ntk_factor, ext_factor);
669
- const float cos_theta = cosf (theta);
670
- const float sin_theta = sinf (theta);
692
+ const float cos_theta = cos (theta);
693
+ const float sin_theta = sin (theta);
671
694
672
695
theta_base *= theta_scale;
673
696
theta_ntk *= theta_ntk_scale;
0 commit comments