Skip to content

YaRN : store rope scaling type as int32_t in memory #5285

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@ struct gpt_params {
float yarn_beta_fast = 32.0f; // YaRN low correction dim
float yarn_beta_slow = 1.0f; // YaRN high correction dim
int32_t yarn_orig_ctx = 0; // YaRN original context length
int8_t rope_scaling_type = LLAMA_ROPE_SCALING_UNSPECIFIED; // TODO: better to be int32_t for alignment
// pinging @cebtenzzre
int32_t rope_scaling_type = LLAMA_ROPE_SCALING_UNSPECIFIED;

// // sampling parameters
struct llama_sampling_params sparams;
Expand Down
24 changes: 12 additions & 12 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ enum llm_arch {
LLM_ARCH_UNKNOWN,
};

static std::map<llm_arch, std::string> LLM_ARCH_NAMES = {
static std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_LLAMA, "llama" },
{ LLM_ARCH_FALCON, "falcon" },
{ LLM_ARCH_GPT2, "gpt2" },
Expand Down Expand Up @@ -285,7 +285,7 @@ enum llm_kv {
LLM_KV_TOKENIZER_RWKV,
};

static std::map<llm_kv, std::string> LLM_KV_NAMES = {
static std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" },
{ LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" },
{ LLM_KV_GENERAL_ALIGNMENT, "general.alignment" },
Expand Down Expand Up @@ -346,7 +346,7 @@ struct LLM_KV {
llm_arch arch;

std::string operator()(llm_kv kv) const {
return ::format(LLM_KV_NAMES[kv].c_str(), LLM_ARCH_NAMES[arch].c_str());
return ::format(LLM_KV_NAMES[kv], LLM_ARCH_NAMES[arch]);
}
};

Expand Down Expand Up @@ -747,13 +747,13 @@ struct LLM_TN {
// gguf helpers
//

static std::map<int8_t, std::string> LLAMA_ROPE_SCALING_TYPES = {
static std::map<int32_t, const char *> LLAMA_ROPE_SCALING_TYPES = {
{ LLAMA_ROPE_SCALING_NONE, "none" },
{ LLAMA_ROPE_SCALING_LINEAR, "linear" },
{ LLAMA_ROPE_SCALING_YARN, "yarn" },
};

static int8_t llama_rope_scaling_type_from_string(const std::string & name) {
static int32_t llama_rope_scaling_type_from_string(const std::string & name) {
for (const auto & kv : LLAMA_ROPE_SCALING_TYPES) {
if (kv.second == name) {
return kv.first;
Expand Down Expand Up @@ -1415,6 +1415,7 @@ static const size_t GiB = 1024*MiB;

struct llama_hparams {
bool vocab_only;
bool rope_finetuned;
uint32_t n_vocab;
uint32_t n_ctx_train; // context size the model was trained on
uint32_t n_embd;
Expand All @@ -1434,8 +1435,7 @@ struct llama_hparams {
float rope_freq_base_train;
float rope_freq_scale_train;
uint32_t n_yarn_orig_ctx;
int8_t rope_scaling_type_train : 3;
bool rope_finetuned : 1;
int32_t rope_scaling_type_train;

float f_clamp_kqv;
float f_max_alibi_bias;
Expand Down Expand Up @@ -2701,7 +2701,7 @@ struct llama_model_loader {
// load LLaMA models
//

static std::string llama_model_arch_name(llm_arch arch) {
static const char * llama_model_arch_name(llm_arch arch) {
auto it = LLM_ARCH_NAMES.find(arch);
if (it == LLM_ARCH_NAMES.end()) {
return "unknown";
Expand Down Expand Up @@ -3310,11 +3310,11 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
const auto & hparams = model.hparams;
const auto & vocab = model.vocab;

const auto rope_scaling_type = LLAMA_ROPE_SCALING_TYPES.at(hparams.rope_scaling_type_train);
const char * rope_scaling_type = LLAMA_ROPE_SCALING_TYPES.at(hparams.rope_scaling_type_train);

// hparams
LLAMA_LOG_INFO("%s: format = %s\n", __func__, llama_file_version_name(ml.fver));
LLAMA_LOG_INFO("%s: arch = %s\n", __func__, LLM_ARCH_NAMES.at(model.arch).c_str());
LLAMA_LOG_INFO("%s: arch = %s\n", __func__, LLM_ARCH_NAMES.at(model.arch));
LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, llama_model_vocab_type_name(vocab.type));
LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, hparams.n_vocab);
LLAMA_LOG_INFO("%s: n_merges = %u\n", __func__, (int) vocab.bpe_ranks.size());
Expand All @@ -3336,7 +3336,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff);
LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert);
LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used);
LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type.c_str());
LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type);
LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train);
LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train);
LLAMA_LOG_INFO("%s: n_yarn_orig_ctx = %u\n", __func__, hparams.n_yarn_orig_ctx);
Expand Down Expand Up @@ -10735,7 +10735,7 @@ int32_t llama_model_meta_val_str_by_index(const struct llama_model * model, int3

int32_t llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size) {
return snprintf(buf, buf_size, "%s %s %s",
llama_model_arch_name(model->arch).c_str(),
llama_model_arch_name(model->arch),
llama_model_type_name(model->type),
llama_model_ftype_name(model->ftype).c_str());
}
Expand Down
2 changes: 1 addition & 1 deletion llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ extern "C" {
uint32_t n_batch; // prompt processing maximum batch size
uint32_t n_threads; // number of threads to use for generation
uint32_t n_threads_batch; // number of threads to use for batch processing
int8_t rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
int32_t rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`

// ref: https://github.com/ggerganov/llama.cpp/pull/2054
float rope_freq_base; // RoPE base frequency, 0 = from model
Expand Down