68
68
#include <cstdio>
69
69
#include <cstring>
70
70
#include <ctime>
71
+ #include <cwctype>
71
72
#include <forward_list>
72
73
#include <fstream>
73
74
#include <functional>
74
75
#include <initializer_list>
76
+ #include <locale>
75
77
#include <map>
76
78
#include <memory>
77
79
#include <mutex>
@@ -8941,37 +8943,46 @@ struct llm_tokenizer_wpm {
8941
8943
}
8942
8944
8943
8945
std::vector<std::string> preprocess(const std::string & text) {
8944
- std::string ori_str = normalize(text);
8945
- uint64_t ori_size = ori_str.size();
8946
+ // normalalization form D
8947
+ std::vector<uint32_t> codepoints = codepoints_from_utf8(text);
8948
+ std::vector<uint32_t> nfd_codepoints;
8949
+ for (uint32_t code : codepoints) {
8950
+ auto it = nfd_map.find(code);
8951
+ if (it != nfd_map.end()) {
8952
+ for (uint32_t c : it->second) {
8953
+ nfd_codepoints.push_back(c);
8954
+ }
8955
+ } else {
8956
+ nfd_codepoints.push_back(code);
8957
+ }
8958
+ }
8946
8959
8947
- // single punct / single symbol / single digit
8948
- // baseline: add whitespace on the left and right of punct and chinese characters
8949
- std::vector<std::string> words;
8960
+ // strip accents, strip control, uniformize whitespace,
8961
+ // to lowercase, pad chinese characters, pad punctuation
8950
8962
std::string new_str = "";
8951
- uint64_t i = 0;
8952
- while (i < ori_size) {
8953
- int utf_char_len = utf8_len(ori_str[i]);
8954
- if ((utf_char_len == 1) && ispunct(ori_str[i])) {
8955
- new_str += " ";
8956
- new_str += ori_str[i];
8957
- new_str += " ";
8958
- i += 1;
8963
+ for (uint32_t code : nfd_codepoints) {
8964
+ int type = codepoint_type(code);
8965
+ if (type == CODEPOINT_TYPE_ACCENT_MARK || type == CODEPOINT_TYPE_CONTROL) {
8966
+ continue;
8959
8967
}
8960
- else if ((utf_char_len == 3) && is_chinese_char(ori_str.substr(i, 3))) {
8968
+ code = to_lower(code);
8969
+ if (type == CODEPOINT_TYPE_WHITESPACE) {
8970
+ code = ' ';
8971
+ }
8972
+ std::string s = codepoint_to_utf8(code);
8973
+ if (type == CODEPOINT_TYPE_PUNCTUATION || is_ascii_punct(code) || is_chinese_char(code)) {
8961
8974
new_str += " ";
8962
- new_str += ori_str.substr(i, 3) ;
8975
+ new_str += s ;
8963
8976
new_str += " ";
8964
- i += 3;
8965
- }
8966
- else {
8967
- new_str += ori_str[i];
8968
- i += 1;
8977
+ } else {
8978
+ new_str += s;
8969
8979
}
8970
8980
}
8971
8981
8972
8982
// split by whitespace
8973
8983
uint64_t l = 0;
8974
8984
uint64_t r = 0;
8985
+ std::vector<std::string> words;
8975
8986
while (r < new_str.size()) {
8976
8987
// if is whitespace
8977
8988
if (isspace(new_str[r])) {
@@ -8989,47 +9000,20 @@ struct llm_tokenizer_wpm {
8989
9000
return words;
8990
9001
}
8991
9002
8992
- std::string normalize(const std::string & text) {
8993
- // TODO: handle chinese characters? https://github.com/huggingface/tokenizers/blob/ef5f50605ddf9f8caef1598c0e4853862b9707a7/tokenizers/src/normalizers/bert.rs#L98
8994
- std::string text2 = strip_accents(text);
8995
- for (size_t i = 0; i < text2.size(); i += utf8_len(text2[i])) {
8996
- char c = text2[i];
8997
- if (c >= 'A' && c <= 'Z') {
8998
- text2[i] = c - 'A' + 'a';
8999
- }
9003
+ uint32_t to_lower(uint32_t code) {
9004
+ #if defined(_WIN32)
9005
+ if (code > 0xFFFF) {
9006
+ return code;
9000
9007
}
9001
- return text2;
9008
+ #endif
9009
+ return std::tolower(wchar_t(code), std::locale("en_US.UTF-8"));
9002
9010
}
9003
9011
9004
- bool is_chinese_char(const std::string & str) {
9005
- int len = str.length();
9006
- unsigned int codepoint = 0;
9007
- int num_bytes = 0;
9008
- int i = 0;
9009
- unsigned char ch = static_cast<unsigned char>(str[i]);
9010
- if (ch <= 0x7f) {
9011
- codepoint = ch;
9012
- num_bytes = 1;
9013
- } else if ((ch >> 5) == 0x06) {
9014
- codepoint = ch & 0x1f;
9015
- num_bytes = 2;
9016
- } else if ((ch >> 4) == 0x0e) {
9017
- codepoint = ch & 0x0f;
9018
- num_bytes = 3;
9019
- } else if ((ch >> 3) == 0x1e) {
9020
- codepoint = ch & 0x07;
9021
- num_bytes = 4;
9022
- }
9023
- for (int j = 1; j < num_bytes; ++j) {
9024
- if (i + j >= len) {
9025
- return false; // incomplete UTF-8 character
9026
- }
9027
- unsigned char next_ch = static_cast<unsigned char>(str[i + j]);
9028
- if ((next_ch >> 6) != 0x02) {
9029
- return false; // invalid trailing byte
9030
- }
9031
- codepoint = (codepoint << 6) | (next_ch & 0x3f);
9032
- }
9012
+ bool is_ascii_punct(uint32_t code) {
9013
+ return code < 256 && ispunct(code);
9014
+ }
9015
+
9016
+ bool is_chinese_char(uint32_t codepoint) {
9033
9017
if ((codepoint >= 0x4E00 && codepoint <= 0x9FFF) ||
9034
9018
(codepoint >= 0x3400 && codepoint <= 0x4DBF) ||
9035
9019
(codepoint >= 0x20000 && codepoint <= 0x2A6DF) ||
@@ -9045,41 +9029,6 @@ struct llm_tokenizer_wpm {
9045
9029
return false;
9046
9030
}
9047
9031
9048
- std::string strip_accents(const std::string & input_string) {
9049
- std::string resultString;
9050
- std::map<std::string, char> accent_map = {
9051
- {"À", 'A'}, {"Á", 'A'}, {"Â", 'A'}, {"Ã", 'A'}, {"Ä", 'A'}, {"Å", 'A'},
9052
- {"à", 'a'}, {"á", 'a'}, {"â", 'a'}, {"ã", 'a'}, {"ä", 'a'}, {"å", 'a'},
9053
- {"È", 'E'}, {"É", 'E'}, {"Ê", 'E'}, {"Ë", 'E'}, {"è", 'e'}, {"é", 'e'},
9054
- {"ê", 'e'}, {"ë", 'e'}, {"Ì", 'I'}, {"Í", 'I'}, {"Î", 'I'}, {"Ï", 'I'},
9055
- {"ì", 'i'}, {"í", 'i'}, {"î", 'i'}, {"ï", 'i'}, {"Ò", 'O'}, {"Ó", 'O'},
9056
- {"Ô", 'O'}, {"Õ", 'O'}, {"Ö", 'O'}, {"ò", 'o'}, {"ó", 'o'}, {"ô", 'o'},
9057
- {"õ", 'o'}, {"ö", 'o'}, {"Ù", 'U'}, {"Ú", 'U'}, {"Û", 'U'}, {"Ü", 'U'},
9058
- {"ù", 'u'}, {"ú", 'u'}, {"û", 'u'}, {"ü", 'u'}, {"Ý", 'Y'}, {"ý", 'y'},
9059
- {"Ç", 'C'}, {"ç", 'c'}, {"Ñ", 'N'}, {"ñ", 'n'},
9060
- };
9061
-
9062
- for (size_t i = 0; i < input_string.length();) {
9063
- int len = utf8_len(input_string[i]);
9064
- std::string curChar = input_string.substr(i, len);
9065
- auto iter = accent_map.find(curChar);
9066
- if (iter != accent_map.end()) {
9067
- resultString += iter->second;
9068
- } else {
9069
- resultString += curChar;
9070
- }
9071
- i += len;
9072
- }
9073
-
9074
- return resultString;
9075
- }
9076
-
9077
- static size_t utf8_len(char src) {
9078
- const size_t lookup[] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4};
9079
- uint8_t highbits = static_cast<uint8_t>(src) >> 4;
9080
- return lookup[highbits];
9081
- }
9082
-
9083
9032
const llama_vocab & vocab;
9084
9033
};
9085
9034
0 commit comments