@@ -4872,9 +4872,29 @@ static void llm_load_vocab(
4872
4872
//NOTE: Each model customizes per token attributes.
4873
4873
//NOTE: Per token attributes are missing from the GGUF file.
4874
4874
//TODO: Merge llama_token_type and llama_token_attrib.
4875
+ //TODO: Extract attribs from GGUF file.
4875
4876
{
4877
+ auto _contains_any = [] (const std::string &str, const std::vector<std::string> &substrs) -> bool {
4878
+ for (auto substr : substrs) {
4879
+ if (str.find(substr) < std::string::npos) {
4880
+ return true;
4881
+ }
4882
+ }
4883
+ return false;
4884
+ };
4885
+
4886
+ auto _set_tokenid_attrib = [&] (const llama_vocab::id id, llama_token_attrib attrib, bool value) {
4887
+ uint32_t attribs = vocab.id_to_token.at(id).attribs;
4888
+ attribs = value ? (attribs | attrib) : (attribs & ~attrib);
4889
+ vocab.id_to_token[id].attribs = (llama_token_attrib) attribs;
4890
+ };
4891
+
4892
+ auto _set_token_attrib = [&] (const std::string & token, llama_token_attrib attrib, bool value) {
4893
+ _set_tokenid_attrib(vocab.token_to_id.at(token), attrib, value);
4894
+ };
4895
+
4876
4896
// convert token type as an attribute
4877
- for (auto data : vocab.id_to_token) {
4897
+ for (auto & data : vocab.id_to_token) {
4878
4898
uint32_t attrib = LLAMA_TOKEN_ATTRIB_UNDEFINED;
4879
4899
attrib |= LLAMA_TOKEN_ATTRIB_UNKNOWN * (data.type == LLAMA_TOKEN_TYPE_UNKNOWN);
4880
4900
attrib |= LLAMA_TOKEN_ATTRIB_UNUSED * (data.type == LLAMA_TOKEN_TYPE_UNUSED);
@@ -4885,44 +4905,31 @@ static void llm_load_vocab(
4885
4905
data.attribs = (llama_token_attrib) attrib;
4886
4906
}
4887
4907
4888
- // set attributes by model name
4889
4908
std::string model_name;
4890
- if (ml.get_key(LLM_KV_GENERAL_NAME, model_name, false)) {
4891
- std::transform(model_name.begin(), model_name.end(), model_name.begin(),
4892
- [] (const std::string::value_type x) {
4893
- return std::tolower(x);
4894
- }
4895
- );
4896
-
4897
- auto _contains_any = [&model_name] (const std::vector<std::string> &substrs) -> bool {
4898
- for (auto substr : substrs) {
4899
- if (model_name.find(substr) < std::string::npos) {
4900
- return true;
4901
- }
4902
- }
4903
- return false;
4904
- };
4909
+ std::string tokenizer_pre;
4905
4910
4906
- auto _set_tokenid_attrib = [&] (const llama_vocab::id id, llama_token_attrib attrib, bool value) {
4907
- uint32_t attribs = vocab.id_to_token[id].attribs;
4908
- attribs = value ? (attribs | attrib) : (attribs & ~attrib);
4909
- vocab.id_to_token[id].attribs = (llama_token_attrib) attribs;
4910
- };
4911
+ ml.get_key(LLM_KV_GENERAL_NAME, model_name, false);
4912
+ ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false);
4911
4913
4912
- auto _set_token_attrib = [&] (const std::string & token, llama_token_attrib attrib, bool value) {
4913
- _set_tokenid_attrib(vocab.token_to_id.at(token), attrib, value);
4914
- };
4914
+ // model name to lowercase
4915
+ std::transform(model_name.begin(), model_name.end(), model_name.begin(),
4916
+ [] (const std::string::value_type x) {
4917
+ return std::tolower(x);
4918
+ }
4919
+ );
4915
4920
4916
- if (_contains_any({"phi-3", "phi3"})) {
4917
- for (auto id : vocab.cache_special_tokens) {
4918
- _set_tokenid_attrib(id, LLAMA_TOKEN_ATTRIB_RSTRIP, true);
4919
- }
4920
- for (auto token : {"</s>"}) {
4921
- _set_token_attrib(token, LLAMA_TOKEN_ATTRIB_RSTRIP, true);
4922
- }
4923
- for (auto token : {"<unk>", "<s>", "<|endoftext|>"}) {
4924
- _set_token_attrib(token, LLAMA_TOKEN_ATTRIB_RSTRIP, false);
4925
- }
4921
+ // set attributes by model/tokenizer name
4922
+ if (_contains_any(tokenizer_pre, {"jina-v2-es", "jina-v2-de"})) {
4923
+ _set_token_attrib("<mask>", LLAMA_TOKEN_ATTRIB_LSTRIP, true);
4924
+ } else if (_contains_any(model_name, {"phi-3", "phi3"})) {
4925
+ for (auto id : vocab.cache_special_tokens) {
4926
+ _set_tokenid_attrib(id, LLAMA_TOKEN_ATTRIB_RSTRIP, true);
4927
+ }
4928
+ for (auto token : {"</s>"}) {
4929
+ _set_token_attrib(token, LLAMA_TOKEN_ATTRIB_RSTRIP, true);
4930
+ }
4931
+ for (auto token : {"<unk>", "<s>", "<|endoftext|>"}) {
4932
+ _set_token_attrib(token, LLAMA_TOKEN_ATTRIB_RSTRIP, false);
4926
4933
}
4927
4934
}
4928
4935
}
0 commit comments