Skip to content

Commit 3b38d48

Browse files
authored
Per token attributes (#7685)
* Add per token attributes enum * Using phi-3 for testing 'rstrip' * Using jina-v2 for testing 'lstrip' * Brute force test for 'lstrip' and 'rstrip' * Implement 'rstrip' and 'lstrip' * Update phi-3 GGUF file (obsolete since 917dc8c) * Replace llama_token_type with llama_token_attribs
1 parent 6d16169 commit 3b38d48

File tree

4 files changed

+155
-62
lines changed

4 files changed

+155
-62
lines changed

llama.cpp

Lines changed: 105 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2149,12 +2149,12 @@ struct llama_control_vector {
21492149
struct llama_vocab {
21502150
using id = int32_t;
21512151
using token = std::string;
2152-
using ttype = llama_token_type;
2152+
using tattr = llama_token_attr;
21532153

21542154
struct token_data {
21552155
token text;
21562156
float score;
2157-
ttype type;
2157+
tattr attr;
21582158
};
21592159

21602160
enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM;
@@ -4750,7 +4750,20 @@ static void llm_load_vocab(
47504750
auto & token_data = vocab.id_to_token[i];
47514751
token_data.text = std::move(word);
47524752
token_data.score = scores ? scores[i] : 0.0f;
4753-
token_data.type = toktypes ? (llama_token_type) toktypes[i] : LLAMA_TOKEN_TYPE_NORMAL;
4753+
token_data.attr = LLAMA_TOKEN_ATTR_NORMAL;
4754+
4755+
if (toktypes) { //TODO: remove, required until per token attributes are available from GGUF file
4756+
switch(toktypes[i]) {
4757+
case LLAMA_TOKEN_TYPE_UNKNOWN: token_data.attr = LLAMA_TOKEN_ATTR_UNKNOWN; break;
4758+
case LLAMA_TOKEN_TYPE_UNUSED: token_data.attr = LLAMA_TOKEN_ATTR_UNUSED; break;
4759+
case LLAMA_TOKEN_TYPE_NORMAL: token_data.attr = LLAMA_TOKEN_ATTR_NORMAL; break;
4760+
case LLAMA_TOKEN_TYPE_CONTROL: token_data.attr = LLAMA_TOKEN_ATTR_CONTROL; break;
4761+
case LLAMA_TOKEN_TYPE_USER_DEFINED: token_data.attr = LLAMA_TOKEN_ATTR_USER_DEFINED; break;
4762+
case LLAMA_TOKEN_TYPE_BYTE: token_data.attr = LLAMA_TOKEN_ATTR_BYTE; break;
4763+
case LLAMA_TOKEN_TYPE_UNDEFINED: token_data.attr = LLAMA_TOKEN_ATTR_UNDEFINED; break;
4764+
default: token_data.attr = LLAMA_TOKEN_ATTR_UNDEFINED; break;
4765+
}
4766+
}
47544767
}
47554768
GGML_ASSERT(vocab.id_to_token.size() == vocab.token_to_id.size());
47564769

@@ -4841,7 +4854,7 @@ static void llm_load_vocab(
48414854
// build special tokens cache
48424855
{
48434856
for (llama_vocab::id id = 0; id < (llama_vocab::id)n_vocab; ++id) {
4844-
if (vocab.id_to_token[id].type != LLAMA_TOKEN_TYPE_NORMAL) {
4857+
if (!(vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_NORMAL)) {
48454858
vocab.cache_special_tokens.push_back(id);
48464859
}
48474860
}
@@ -4871,6 +4884,59 @@ static void llm_load_vocab(
48714884

48724885
LLAMA_LOG_INFO("%s: token to piece cache size = %.4f MB\n", __func__, size_cache / 1024.0 / 1024.0);
48734886
}
4887+
4888+
// Handle per token attributes
4889+
//NOTE: Each model customizes per token attributes.
4890+
//NOTE: Per token attributes are missing from the GGUF file.
4891+
//TODO: Extract attributes from GGUF file.
4892+
{
4893+
auto _contains_any = [] (const std::string &str, const std::vector<std::string> &substrs) -> bool {
4894+
for (auto substr : substrs) {
4895+
if (str.find(substr) < std::string::npos) {
4896+
return true;
4897+
}
4898+
}
4899+
return false;
4900+
};
4901+
4902+
auto _set_tokenid_attr = [&] (const llama_vocab::id id, llama_token_attr attr, bool value) {
4903+
uint32_t current = vocab.id_to_token.at(id).attr;
4904+
current = value ? (current | attr) : (current & ~attr);
4905+
vocab.id_to_token[id].attr = (llama_token_attr) current;
4906+
};
4907+
4908+
auto _set_token_attr = [&] (const std::string & token, llama_token_attr attr, bool value) {
4909+
_set_tokenid_attr(vocab.token_to_id.at(token), attr, value);
4910+
};
4911+
4912+
std::string model_name;
4913+
std::string tokenizer_pre;
4914+
4915+
ml.get_key(LLM_KV_GENERAL_NAME, model_name, false);
4916+
ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false);
4917+
4918+
// model name to lowercase
4919+
std::transform(model_name.begin(), model_name.end(), model_name.begin(),
4920+
[] (const std::string::value_type x) {
4921+
return std::tolower(x);
4922+
}
4923+
);
4924+
4925+
// set attributes by model/tokenizer name
4926+
if (_contains_any(tokenizer_pre, {"jina-v2-es", "jina-v2-de"})) {
4927+
_set_token_attr("<mask>", LLAMA_TOKEN_ATTR_LSTRIP, true);
4928+
} else if (_contains_any(model_name, {"phi-3", "phi3"})) {
4929+
for (auto id : vocab.cache_special_tokens) {
4930+
_set_tokenid_attr(id, LLAMA_TOKEN_ATTR_RSTRIP, true);
4931+
}
4932+
for (auto token : {"</s>"}) {
4933+
_set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, true);
4934+
}
4935+
for (auto token : {"<unk>", "<s>", "<|endoftext|>"}) {
4936+
_set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, false);
4937+
}
4938+
}
4939+
}
48744940
}
48754941

48764942
static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
@@ -12620,27 +12686,27 @@ static enum llama_vocab_type llama_vocab_get_type(const llama_vocab & vocab) {
1262012686

1262112687
static bool llama_is_normal_token(const llama_vocab & vocab, llama_token id) {
1262212688
GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
12623-
return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_NORMAL;
12689+
return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_NORMAL;
1262412690
}
1262512691

1262612692
static bool llama_is_unknown_token(const llama_vocab & vocab, llama_token id) {
1262712693
GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
12628-
return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_UNKNOWN;
12694+
return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_UNKNOWN;
1262912695
}
1263012696

1263112697
static bool llama_is_control_token(const llama_vocab & vocab, llama_token id) {
1263212698
GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
12633-
return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_CONTROL;
12699+
return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_CONTROL;
1263412700
}
1263512701

1263612702
static bool llama_is_byte_token(const llama_vocab & vocab, llama_token id) {
1263712703
GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
12638-
return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_BYTE;
12704+
return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_BYTE;
1263912705
}
1264012706

1264112707
static bool llama_is_user_defined_token(const llama_vocab& vocab, llama_token id) {
1264212708
GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
12643-
return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_USER_DEFINED;
12709+
return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_USER_DEFINED;
1264412710
}
1264512711

1264612712
static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) {
@@ -13258,7 +13324,8 @@ struct fragment_buffer_variant {
1325813324
static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer) {
1325913325
// for each special token
1326013326
for (const llama_vocab::id special_id : vocab.cache_special_tokens) {
13261-
const auto & special_token = vocab.id_to_token[special_id].text;
13327+
const auto & data = vocab.id_to_token[special_id];
13328+
const auto & special_token = data.text;
1326213329

1326313330
// for each text fragment
1326413331
std::forward_list<fragment_buffer_variant>::iterator it = buffer.begin();
@@ -13295,13 +13362,22 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
1329513362
if (match > raw_text_base_offset) {
1329613363
// left
1329713364
const int64_t left_reminder_offset = raw_text_base_offset + 0;
13298-
const int64_t left_reminder_length = match - raw_text_base_offset;
13299-
buffer.emplace_after(it, raw_text, left_reminder_offset, left_reminder_length);
13365+
int64_t left_reminder_length = match - raw_text_base_offset;
13366+
13367+
if (data.attr & LLAMA_TOKEN_ATTR_LSTRIP) {
13368+
while (left_reminder_length > 0 && isspace(raw_text[left_reminder_offset + left_reminder_length - 1])) {
13369+
left_reminder_length--;
13370+
}
13371+
}
13372+
13373+
if (left_reminder_length > 0) {
13374+
buffer.emplace_after(it, raw_text, left_reminder_offset, left_reminder_length);
13375+
it++;
13376+
}
1330013377

1330113378
#ifdef PRETOKENIZERDEBUG
1330213379
LLAMA_LOG_WARN("FL: (%ld %ld) '%s'\n", left_reminder_offset, left_reminder_length, raw_text->substr(left_reminder_offset, left_reminder_length).c_str());
1330313380
#endif
13304-
it++;
1330513381
}
1330613382

1330713383
// special token
@@ -13310,16 +13386,25 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
1331013386

1331113387
// right
1331213388
if (match + special_token.length() < raw_text_base_offset + raw_text_base_length) {
13313-
const int64_t right_reminder_offset = match + special_token.length();
13314-
const int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + special_token.length());
13315-
buffer.emplace_after(it, raw_text, right_reminder_offset, right_reminder_length);
13389+
int64_t right_reminder_offset = match + special_token.length();
13390+
int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + special_token.length());
13391+
13392+
if (data.attr & LLAMA_TOKEN_ATTR_RSTRIP) {
13393+
while (right_reminder_length > 0 && isspace(raw_text[right_reminder_offset])) {
13394+
right_reminder_offset++;
13395+
right_reminder_length--;
13396+
}
13397+
}
13398+
13399+
if (right_reminder_length > 0) {
13400+
buffer.emplace_after(it, raw_text, right_reminder_offset, right_reminder_length);
13401+
it++;
13402+
}
1331613403

1331713404
#ifdef PRETOKENIZERDEBUG
1331813405
LLAMA_LOG_WARN("FR: (%ld %ld) '%s'\n", right_reminder_offset, right_reminder_length, raw_text->substr(right_reminder_offset, right_reminder_length).c_str());
1331913406
#endif
1332013407

13321-
it++;
13322-
1332313408
if (source == 0) {
1332413409
buffer.erase_after(buffer.before_begin());
1332513410
} else {
@@ -13365,9 +13450,7 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
1336513450
// tokenizer.encode('', add_special_tokens=True) returns [1]
1336613451
// tokenizer.encode('', add_special_tokens=False) returns []
1336713452

13368-
static const bool rtrim = true; //TODO: as param
1336913453
bool is_prev_special = false;
13370-
bool special_token_rtrim = false;
1337113454

1337213455
if (add_special && vocab.special_add_bos != 0) {
1337313456
GGML_ASSERT(vocab.special_bos_id != -1);
@@ -13377,25 +13460,8 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
1337713460

1337813461
for (const auto & fragment : fragment_buffer) {
1337913462
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
13380-
// without adding this leading whitespace, we do not get the same results as the original tokenizer
13381-
13382-
// TODO: It's likely possible to get rid of this string copy entirely
13383-
// by modifying llm_tokenizer_x to operate with string offsets like pre-tokenizer
13384-
// and passing 'add space prefix' as bool argument
13385-
//
1338613463
auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
1338713464

13388-
if (special_token_rtrim) {
13389-
size_t num_whitespaces = 0;
13390-
while (isspace(raw_text[num_whitespaces])) {
13391-
num_whitespaces++;
13392-
}
13393-
if (num_whitespaces == raw_text.size()) {
13394-
continue; // skip if all whitespaces
13395-
}
13396-
raw_text = raw_text.substr(num_whitespaces);
13397-
}
13398-
1339913465
if (vocab.add_space_prefix) {
1340013466
if (!output.size() || is_prev_special) { // prefix with space if first token
1340113467
raw_text = " " + raw_text;
@@ -13411,11 +13477,6 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
1341113477
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
1341213478
output.push_back(fragment.token);
1341313479
is_prev_special = true;
13414-
// phi-3 special tokens without rtrim, works fine for llama-spm too
13415-
special_token_rtrim = rtrim
13416-
&& fragment.token != vocab.special_bos_id
13417-
&& fragment.token != vocab.special_unk_id
13418-
&& fragment.token != vocab.special_eos_id;
1341913480
}
1342013481
}
1342113482

@@ -18221,9 +18282,9 @@ float llama_token_get_score(const struct llama_model * model, llama_token token)
1822118282
return model->vocab.id_to_token[token].score;
1822218283
}
1822318284

18224-
llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token) {
18285+
llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_token token) {
1822518286
GGML_ASSERT(model->vocab.type != LLAMA_VOCAB_TYPE_NONE);
18226-
return model->vocab.id_to_token[token].type;
18287+
return model->vocab.id_to_token[token].attr;
1822718288
}
1822818289

1822918290
bool llama_token_is_eog(const struct llama_model * model, llama_token token) {

llama.h

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ extern "C" {
9797
LLAMA_ROPE_TYPE_GLM = 4,
9898
};
9999

100-
enum llama_token_type {
100+
enum llama_token_type { //TODO: remove, required until per token attributes are available from GGUF file
101101
LLAMA_TOKEN_TYPE_UNDEFINED = 0,
102102
LLAMA_TOKEN_TYPE_NORMAL = 1,
103103
LLAMA_TOKEN_TYPE_UNKNOWN = 2,
@@ -107,6 +107,20 @@ extern "C" {
107107
LLAMA_TOKEN_TYPE_BYTE = 6,
108108
};
109109

110+
enum llama_token_attr {
111+
LLAMA_TOKEN_ATTR_UNDEFINED = 0,
112+
LLAMA_TOKEN_ATTR_UNKNOWN = 1 << 1,
113+
LLAMA_TOKEN_ATTR_UNUSED = 1 << 2,
114+
LLAMA_TOKEN_ATTR_NORMAL = 1 << 3,
115+
LLAMA_TOKEN_ATTR_CONTROL = 1 << 4, // SPECIAL?
116+
LLAMA_TOKEN_ATTR_USER_DEFINED = 1 << 5,
117+
LLAMA_TOKEN_ATTR_BYTE = 1 << 6,
118+
LLAMA_TOKEN_ATTR_NORMALIZED = 1 << 7,
119+
LLAMA_TOKEN_ATTR_LSTRIP = 1 << 8,
120+
LLAMA_TOKEN_ATTR_RSTRIP = 1 << 9,
121+
LLAMA_TOKEN_ATTR_SINGLE_WORD = 1 << 10,
122+
};
123+
110124
// model file types
111125
enum llama_ftype {
112126
LLAMA_FTYPE_ALL_F32 = 0,
@@ -821,7 +835,7 @@ extern "C" {
821835

822836
LLAMA_API float llama_token_get_score(const struct llama_model * model, llama_token token);
823837

824-
LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token);
838+
LLAMA_API enum llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_token token);
825839

826840
// Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.)
827841
LLAMA_API bool llama_token_is_eog(const struct llama_model * model, llama_token token);

models/ggml-vocab-phi-3.gguf

173 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)