@@ -597,6 +597,55 @@ kernel void kernel_alibi_f32(
597597 }
598598}
599599
600+ static float rope_ntkv2_ramp (const float low, const float high, const int i0) {
601+ const float y = (i0 / 2 - low) / min (0 .001f , high - low);
602+ return 1 .0f - min (1 .0f , max (0 .0f , y));
603+ }
604+
605+ // NTKv2 algorithm based on LlamaPartNTKScaledRotaryEmbedding.py from https://github.com/jquesnelle/scaled-rope
606+ // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
607+ static float rope_ntkv2 (
608+ const float theta_base,
609+ const float theta_linear,
610+ const float theta_ntk,
611+ const float corr_factors[4 ],
612+ const int64_t i0,
613+ const float ntk_factor,
614+ const float ext_factor) {
615+ float ramp_mix;
616+ float theta;
617+
618+ ramp_mix = rope_ntkv2_ramp (corr_factors[0 ], corr_factors[1 ], i0) * ntk_factor;
619+ theta = theta_linear * (1 - ramp_mix) + theta_ntk * ramp_mix;
620+
621+ ramp_mix = rope_ntkv2_ramp (corr_factors[2 ], corr_factors[3 ], i0) * ext_factor;
622+ theta = theta * (1 - ramp_mix) + theta_base * ramp_mix;
623+ return theta;
624+ }
625+
626+ // Interpolation constants found experimentally for LLaMA (might not be totally optimal though)
627+ // Do not change unless there is a good reason for doing so!
628+ constant float BETA_0 = 1 .75f ;
629+ constant float BETA_1 = 1 .25f ;
630+ constant float GAMMA_0 = 16 .0f ;
631+ constant float GAMMA_1 = 2 .0f ;
632+
633+ constant float max_pos_emb = 2048 ;
634+
635+ // Apparently solving `n_rot = 2pi * x * base^((2 * max_pos_emb) / n_dims)` for x, we get
636+ // `corr_fac(n_rot) = n_dims * log(max_pos_emb / (n_rot * 2pi)) / (2 * log(base))`
637+ static float rope_ntkv2_corr_factor (const int n_dims, const float n_rot, const float base) {
638+ return n_dims * log (max_pos_emb / (n_rot * 2 * M_PI_F)) / (2 * log (base));
639+ }
640+
641+ static void rope_ntkv2_corr_factors (int n_dims, const float freq_base, float factors[4 ]) {
642+ // start and end correction factors
643+ factors[0 ] = max (0 .0f , floor (rope_ntkv2_corr_factor (n_dims, BETA_0, freq_base)));
644+ factors[1 ] = min (n_dims - 1 .0f , ceil (rope_ntkv2_corr_factor (n_dims, BETA_1, freq_base)));
645+ factors[2 ] = max (0 .0f , floor (rope_ntkv2_corr_factor (n_dims, GAMMA_0, freq_base)));
646+ factors[3 ] = min (n_dims - 1 .0f , ceil (rope_ntkv2_corr_factor (n_dims, GAMMA_1, freq_base)));
647+ }
648+
600649kernel void kernel_rope (
601650 device const void * src0,
602651 device float * dst,
@@ -621,24 +670,33 @@ kernel void kernel_rope(
621670 constant int & mode,
622671 constant float & freq_base,
623672 constant float & freq_scale,
673+ constant float & ntk_factor,
674+ constant float & ext_factor,
624675 uint3 tpig[[thread_position_in_grid]]) {
625676 const int64_t i3 = tpig[2 ];
626677 const int64_t i2 = tpig[1 ];
627678 const int64_t i1 = tpig[0 ];
628679
629- const bool is_neox = mode & 2 ;
630680 const float theta_scale = pow (freq_base, -2 .0f /n_dims);
681+ const float theta_ntk_scale = pow (freq_base * pow (freq_scale, (n_dims / (n_dims - 2 .0f ))), -2 .0f /n_dims);
682+ float corr_factors[4 ];
683+ rope_ntkv2_corr_factors (n_dims, freq_base, corr_factors);
631684
632- const int64_t p = ((mode & 1 ) == 0 ? n_past + i2 : i2);
685+ float theta_base = (mode & 1 ) == 0 ? n_past + i2 : i2;
686+ float theta_ntk = theta_base;
633687
634- float theta = freq_scale * ( float )p ;
688+ const bool is_neox = mode & 2 ;
635689
636690 if (!is_neox) {
637691 for (int64_t i0 = 0 ; i0 < ne0; i0 += 2 ) {
692+ const float theta_linear = freq_scale * theta_base;
693+ const float theta = rope_ntkv2 (theta_base, theta_linear, theta_ntk, corr_factors,
694+ i0, ntk_factor, ext_factor);
638695 const float cos_theta = cos (theta);
639696 const float sin_theta = sin (theta);
640697
641- theta *= theta_scale;
698+ theta_base *= theta_scale;
699+ theta_ntk *= theta_ntk_scale;
642700
643701 device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
644702 device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@@ -650,6 +708,7 @@ kernel void kernel_rope(
650708 dst_data[1 ] = x0*sin_theta + x1*cos_theta;
651709 }
652710 } else {
711+ theta_base *= freq_scale;
653712 // TODO: implement
654713 }
655714}
0 commit comments