Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break;
}
params.rope_freq_scale = 1.0f/std::stof(argv[i]);
} else if (arg == "--yarn-orig-ctx") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.yarn_orig_ctx = std::stoi(argv[i]);
} else if (arg == "--yarn-ext-factor") {
if (++i >= argc) {
invalid_param = true;
Expand Down Expand Up @@ -737,6 +743,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --rope-scale N RoPE context scaling factor, expands context by a factor of N\n");
printf(" --rope-freq-base N RoPE base frequency, used by NTK-aware scaling (default: loaded from model)\n");
printf(" --rope-freq-scale N RoPE frequency scaling factor, expands context by a factor of 1/N\n");
printf(" --yarn-orig-ctx N YaRN: original context size of model (default: 0 = model training context size)\n");
printf(" --yarn-ext-factor N YaRN: extrapolation mix factor (default: 1.0, 0.0 = full interpolation)\n");
printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n");
printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow);
Expand Down Expand Up @@ -861,6 +868,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.yarn_attn_factor = params.yarn_attn_factor;
cparams.yarn_beta_fast = params.yarn_beta_fast;
cparams.yarn_beta_slow = params.yarn_beta_slow;
cparams.yarn_orig_ctx = params.yarn_orig_ctx;

return cparams;
}
Expand Down
5 changes: 3 additions & 2 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@ struct gpt_params {
float rope_freq_scale = 0.0f; // RoPE frequency scaling factor
float yarn_ext_factor = NAN; // YaRN extrapolation mix factor
float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor
float yarn_beta_fast = 32.0f; // YaRN low correction dim
float yarn_beta_slow = 1.0f; // YaRN high correction dim
float yarn_beta_fast = 32.0f;// YaRN low correction dim
float yarn_beta_slow = 1.0f; // YaRN high correction dim
int32_t yarn_orig_ctx = 0; // YaRN original context length
int8_t rope_scaling_type = LLAMA_ROPE_SCALING_UNSPECIFIED;

// // sampling parameters
Expand Down
7 changes: 3 additions & 4 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4406,7 +4406,7 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
}

static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) {
const float y = (i0 / 2 - low) / min(0.001f, high - low);
const float y = (i0 / 2 - low) / max(0.001f, high - low);
return 1.0f - min(1.0f, max(0.0f, y));
}

Expand All @@ -4426,11 +4426,10 @@ static __device__ void rope_yarn(
if (ext_factor != 0.0f) {
float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
}

// Get n-d magnitude scaling corrected for interpolation
if (freq_scale < 1.0f)
// Get n-d magnitude scaling corrected for interpolation
mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
}
*cos_theta = cosf(theta) * mscale;
*sin_theta = sinf(theta) * mscale;
}
Expand Down
7 changes: 3 additions & 4 deletions ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -880,7 +880,7 @@ kernel void kernel_alibi_f32(
}

static float rope_yarn_ramp(const float low, const float high, const int i0) {
const float y = (i0 / 2 - low) / min(0.001f, high - low);
const float y = (i0 / 2 - low) / max(0.001f, high - low);
return 1.0f - min(1.0f, max(0.0f, y));
}

Expand All @@ -896,11 +896,10 @@ static void rope_yarn(
if (ext_factor != 0.0f) {
ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
}

// Get n-d magnitude scaling corrected for interpolation
if (freq_scale < 1.0f)
// Get n-d magnitude scaling corrected for interpolation
mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
}
*cos_theta = cosf(theta) * mscale;
*sin_theta = sinf(theta) * mscale;
}
Expand Down
7 changes: 3 additions & 4 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -13345,7 +13345,7 @@ static void ggml_compute_forward_clamp(
// ggml_compute_forward_rope

static float rope_yarn_ramp(const float low, const float high, const int i0) {
const float y = (i0 / 2 - low) / MIN(0.001f, high - low);
const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
return 1 - MIN(1, MAX(0, y));
}

Expand All @@ -13361,11 +13361,10 @@ static void rope_yarn(
if (ext_factor != 0.0f) {
float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor;
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
}

// Get n-d magnitude scaling corrected for interpolation
if (freq_scale < 1.0f)
// Get n-d magnitude scaling corrected for interpolation
mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
}
*cos_theta = cosf(theta) * mscale;
*sin_theta = sinf(theta) * mscale;
}
Expand Down
10 changes: 6 additions & 4 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1113,6 +1113,7 @@ struct llama_cparams {
float rope_freq_base;
float rope_freq_scale;

uint32_t n_yarn_orig_ctx;
// These hyperparameters are not exposed in GGUF, because all
// existing YaRN models use the same values for them.
float yarn_ext_factor;
Expand Down Expand Up @@ -3028,7 +3029,7 @@ static struct ggml_cgraph * llm_build_llama(
const int32_t n_embd = hparams.n_embd;
const int32_t n_layer = hparams.n_layer;
const int32_t n_ctx = cparams.n_ctx;
const int32_t n_orig_ctx = hparams.n_yarn_orig_ctx;
const int32_t n_orig_ctx = cparams.n_yarn_orig_ctx;
const int32_t n_head = hparams.n_head;
const int32_t n_head_kv = hparams.n_head_kv;
const int32_t n_embd_head = hparams.n_embd_head();
Expand Down Expand Up @@ -3430,7 +3431,7 @@ static struct ggml_cgraph * llm_build_baichaun(
const int32_t n_embd = hparams.n_embd;
const int32_t n_layer = hparams.n_layer;
const int32_t n_ctx = cparams.n_ctx;
const int32_t n_orig_ctx = hparams.n_yarn_orig_ctx;
const int32_t n_orig_ctx = cparams.n_yarn_orig_ctx;
const int32_t n_head = hparams.n_head;
const int32_t n_head_kv = hparams.n_head_kv;
const int32_t n_embd_head = hparams.n_embd_head();
Expand Down Expand Up @@ -4194,7 +4195,7 @@ static struct ggml_cgraph * llm_build_falcon(
const int32_t n_embd = hparams.n_embd;
const int32_t n_layer = hparams.n_layer;
const int32_t n_ctx = cparams.n_ctx;
const int32_t n_orig_ctx = hparams.n_yarn_orig_ctx;
const int32_t n_orig_ctx = cparams.n_yarn_orig_ctx;
const int32_t n_head = hparams.n_head;
const int32_t n_head_kv = hparams.n_head_kv;
const int32_t n_embd_head = hparams.n_embd_head();
Expand Down Expand Up @@ -4818,7 +4819,7 @@ static struct ggml_cgraph * llm_build_persimmon(
const int64_t n_embd = hparams.n_embd;
const int64_t n_layer = hparams.n_layer;
const int64_t n_ctx = cparams.n_ctx;
const int32_t n_orig_ctx = hparams.n_yarn_orig_ctx;
const int32_t n_orig_ctx = cparams.n_yarn_orig_ctx;
const int64_t n_head_kv = hparams.n_head_kv;
const int64_t n_head = hparams.n_head;
const int64_t n_embd_head = hparams.n_embd_head();
Expand Down Expand Up @@ -8676,6 +8677,7 @@ struct llama_context * llama_new_context_with_model(
cparams.mul_mat_q = params.mul_mat_q;

cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
cparams.n_yarn_orig_ctx = params.yarn_orig_ctx == 0 ? hparams.n_ctx_train : params.yarn_orig_ctx;
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;

Expand Down
13 changes: 7 additions & 6 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,13 @@ extern "C" {
int8_t rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`

// ref: https://github.com/ggerganov/llama.cpp/pull/2054
float rope_freq_base; // RoPE base frequency, 0 = from model
float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model
float yarn_ext_factor; // YaRN extrapolation mix factor, NaN = from model
float yarn_attn_factor; // YaRN magnitude scaling factor
float yarn_beta_fast; // YaRN low correction dim
float yarn_beta_slow; // YaRN high correction dim
float rope_freq_base; // RoPE base frequency, 0 = from model
float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model
float yarn_ext_factor; // YaRN extrapolation mix factor, NaN = from model
float yarn_attn_factor; // YaRN magnitude scaling factor
float yarn_beta_fast; // YaRN low correction dim
float yarn_beta_slow; // YaRN high correction dim
uint32_t yarn_orig_ctx; // YaRN original context size

// Keep the booleans together to avoid misalignment during copy-by-value.
bool mul_mat_q; // if true, use experimental mul_mat_q kernels
Expand Down