Skip to content

Commit e38592d

Browse files
committed
add unit tests
1 parent a4ad764 commit e38592d

File tree

6 files changed

+284
-0
lines changed

6 files changed

+284
-0
lines changed

CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ option(GGML_FMA "ggml: enable FMA"
2626
option(GGML_CUBLAS "ggml: use cuBLAS" OFF)
2727
option(GGML_METAL "ggml: use Metal" OFF)
2828

29+
option(BERT_BUILD_TESTS "bert: Build tests" ON)
30+
2931
#
3032
# Compile flags
3133
#
@@ -93,3 +95,8 @@ add_subdirectory(src)
9395

9496
# for shared library
9597
set_target_properties(ggml PROPERTIES POSITION_INDEPENDENT_CODE ON)
98+
99+
if (BERT_BUILD_TESTS)
100+
include(CTest)
101+
add_subdirectory(tests)
102+
endif ()

tests/CMakeLists.txt

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
include_directories(${CMAKE_SOURCE_DIR}/src)
2+
3+
# add_executable(test_tokenizer test_tokenizer.cpp)
4+
# target_link_libraries(test_tokenizer PRIVATE bert ggml)
5+
6+
set(TEST_MODEL_NAME "bge-large-zh-v1.5")
7+
8+
function(bert_build_executable source)
9+
get_filename_component(TEST_TARGET ${source} NAME_WE)
10+
add_executable(${TEST_TARGET} ${source})
11+
install(TARGETS ${TEST_TARGET} RUNTIME)
12+
target_link_libraries(${TEST_TARGET} PRIVATE bert ggml)
13+
endfunction()
14+
15+
function(bert_test_executable name source)
16+
get_filename_component(TEST_TARGET ${source} NAME_WE)
17+
add_test(NAME "Generate_HF_tokens" COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/test_hf_tokenizer.py ${TEST_MODEL_NAME})
18+
add_test(NAME ${name} COMMAND $<TARGET_FILE:${TEST_TARGET}> ${ARGN})
19+
set_property(TEST ${name} PROPERTY LABELS "main")
20+
endfunction()
21+
22+
23+
bert_build_executable(test_tokenizer.cpp)
24+
bert_test_executable (test_tokenizer test_tokenizer.cpp -m ${CMAKE_CURRENT_SOURCE_DIR}/../models/${TEST_MODEL_NAME}/bge-large-zh-v1.5-q4_1.gguf)

tests/test.sh

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#!/usr/bin/env bash
2+
3+
set -e
4+
5+
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
6+
MODEL_NAME=${1:-bge-large-zh-v1.5}
7+
MODEL_DIR=$(realpath "$SCRIPT_DIR/../models/$MODEL_NAME")
8+
9+
if [ ! -d "$MODEL_DIR" ]; then
10+
python3 $SCRIPT_DIR/../models/download-repo.py $MODEL_NAME
11+
fi
12+
13+
if [ ! -d "$MODEL_DIR/ggml-model-q4_1.gguf" ]; then
14+
$SCRIPT_DIR/../models/run_conversions.sh $MODEL_NAME q4_1
15+
fi
16+
17+
python3 $SCRIPT_DIR/test_hf_tokenizer.py $MODEL_DIR
18+
19+
$SCRIPT_DIR/../build/bin/test_tokenizer -m $MODEL_DIR/ggml-model-q4_1.gguf

tests/test_hf_tokenizer.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from ast import arg
2+
from transformers import AutoTokenizer, AutoModel
3+
import argparse
4+
import os
5+
6+
SCRIPT_PATH=os.path.dirname(os.path.realpath(__file__))
7+
8+
def main(args):
9+
# tokenizer_name = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1"
10+
if "/" in args.model_name:
11+
tokenizer_name = args.model_name
12+
elif "MiniLM" in args.model_name:
13+
tokenizer_name = f"sentence-transformers/{args.model_name}"
14+
elif "bge-" in args.model_name:
15+
tokenizer_name = f"BAAI/{args.model_name}"
16+
else:
17+
raise ValueError(f"Unknown model name: {args.model_name}")
18+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
19+
20+
with open(SCRIPT_PATH + "/test_prompts.txt", "r", encoding="utf-8") as f:
21+
inps = f.readlines()
22+
inps = list(map(lambda x: x.strip(), inps))
23+
24+
print("Using tokenizer:", tokenizer_name)
25+
output = []
26+
for inp in inps:
27+
oup = tokenizer(inp, return_tensors="pt").input_ids[0].tolist()
28+
output.append(",".join([str(x) for x in oup]))
29+
for token in oup:
30+
print(f"{token} <--> {tokenizer.decode([token])}")
31+
32+
with open(SCRIPT_PATH + "/hf_tokenized_ids.txt", "w", encoding="utf-8") as f:
33+
f.write("\n".join(output))
34+
35+
if __name__ == "__main__":
36+
parser = argparse.ArgumentParser(description='Download original repo files')
37+
parser.add_argument('model_name', type=str, help='Name of the repo')
38+
args = parser.parse_args()
39+
main(args)

tests/test_prompts.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
hello world
2+
i'm going to the store to buy 3 apples and a banana! you're welcome to come along if you'd like. the time is 2:30 p.m. and it's partly cloudy outside. i'll be back soon, so don't go anywhere.
3+
"5 2 + 3 * 4 -"; int stack[1000], top = -1; int calculate(int a, int b, char operator) { return operator == \'+\' ? a + b : operator == \'-\' ? a - b : operator == \'*\' ? a * b : a / b; } void push(int x) { stack[++top] = x; } int pop() { return stack[top--]; } int evaluatepostfix(char* expression) { for (int i = 0; expression[i]; i++) { if (isdigit(expression[i])) push(expression[i] - \'0\'); else { int a = pop(), b = pop(); push(calculate(b, a, expression[i])); } } return pop(); } int result = evaluatepostfix(input);
4+
你好,世界!
5+
こんにちは、世界!
6+
1231 2431431
7+
你好我是gpt
8+
然而,分音符号(diaeresis)和变音符号(umlaut)在一些情况下也可以被泛称为 "accent",这是因为它们都是附加在字母上的符号,用于改变字母的原始发音。

tests/test_tokenizer.cpp

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
#include "bert.h"
2+
#include "ggml.h"
3+
4+
#include <unistd.h>
5+
#include <map>
6+
#include <algorithm>
7+
#include <stdio.h>
8+
#include <string>
9+
#include <vector>
10+
#include <fstream>
11+
#include <sstream>
12+
#define ANSI_COLOR_RED "\x1b[31m"
13+
#define ANSI_COLOR_RESET "\x1b[0m"
14+
#define ANSI_COLOR_GREEN "\x1b[32m"
15+
16+
17+
std::vector<std::string> txt2list(const std::string& filename) {
18+
std::ifstream file(filename);
19+
std::vector<std::string> all_lines;
20+
21+
if (!file.is_open()) {
22+
printf("can not open file: %s\n", filename.c_str());
23+
return all_lines;
24+
}
25+
26+
std::string line;
27+
while (std::getline(file, line)) {
28+
all_lines.push_back(line);
29+
}
30+
31+
file.close();
32+
return all_lines;
33+
}
34+
35+
std::vector<std::vector<int>> read_expected_tokenids(const std::string& filename) {
36+
std::ifstream file(filename);
37+
std::vector<std::vector<int>> all_numbers;
38+
39+
if (!file.is_open()) {
40+
printf("can not open file: %s\n", filename.c_str());
41+
return all_numbers;
42+
}
43+
44+
45+
std::string line;
46+
while (std::getline(file, line)) {
47+
std::vector<int> line_numbers;
48+
std::istringstream iss(line);
49+
std::string number_str;
50+
51+
while (std::getline(iss, number_str, ',')) {
52+
line_numbers.push_back(std::stoi(number_str));
53+
}
54+
55+
all_numbers.push_back(line_numbers);
56+
}
57+
58+
file.close();
59+
return all_numbers;
60+
}
61+
62+
void tokenizer_test(bert_ctx * ctx, const std::string& input, const bert_tokens& expected) {
63+
int N = bert_n_max_tokens(ctx);
64+
bert_tokens result = bert_tokenize(ctx, input, N);
65+
int n_tokens;
66+
67+
if (result != expected) {
68+
printf("tokenizer test failed: '%.*s'\n", 16000, input.data());
69+
70+
printf("[");
71+
for (auto& tok : result) {
72+
printf("%d, ", tok);
73+
}
74+
printf("]\n");
75+
76+
for (size_t i = 0; i < result.size(); i++) {
77+
bert_token a = expected[std::min(i, expected.size()-1)];
78+
bert_token b = result[i];
79+
const char *color_start = (a == b) ? ANSI_COLOR_GREEN : ANSI_COLOR_RED;
80+
const char *color_end = ANSI_COLOR_RESET;
81+
82+
printf("%s%d -> %s : %d -> %s%s\n", color_start, a, bert_vocab_id_to_token(ctx, a), b, bert_vocab_id_to_token(ctx, b), color_end);
83+
}
84+
} else {
85+
printf("Success '%.*s...'\n", 16, input.data());
86+
}
87+
}
88+
89+
90+
struct bert_params
91+
{
92+
int32_t n_threads = 6;
93+
const char* model = "models/all-MiniLM-L6-v2/ggml-model-q4_0.bin";
94+
const char* prompt = "test prompt";
95+
int32_t batch_size = 32;
96+
bool use_cpu = false;
97+
};
98+
99+
void bert_print_usage(char **argv, const bert_params &params) {
100+
fprintf(stderr, "usage: %s [options]\n", argv[0]);
101+
fprintf(stderr, "\n");
102+
fprintf(stderr, "options:\n");
103+
fprintf(stderr, " -h, --help show this help message and exit\n");
104+
fprintf(stderr, " -m FNAME, --model FNAME\n");
105+
fprintf(stderr, " model path (default: %s)\n", params.model);
106+
fprintf(stderr, " batch size to use when executing model\n");
107+
fprintf(stderr, " -c, --cpu use CPU backend (default: use CUDA if available)\n");
108+
fprintf(stderr, "\n");
109+
}
110+
111+
bool bert_params_parse(int argc, char **argv, bert_params &params) {
112+
for (int i = 1; i < argc; i++)
113+
{
114+
std::string arg = argv[i];
115+
116+
if (arg == "-m" || arg == "--model") {
117+
params.model = argv[++i];
118+
} else if (arg == "-c" || arg == "--cpu") {
119+
params.use_cpu = true;
120+
} else if (arg == "-h" || arg == "--help") {
121+
bert_print_usage(argv, params);
122+
exit(0);
123+
} else {
124+
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
125+
bert_print_usage(argv, params);
126+
exit(0);
127+
}
128+
}
129+
130+
return true;
131+
}
132+
133+
int main(int argc, char ** argv) {
134+
135+
bert_params params;
136+
params.model = "models/all-MiniLM-L6-v2/ggml-model-q4_0.bin";
137+
138+
if (bert_params_parse(argc, argv, params) == false) {
139+
return 1;
140+
}
141+
142+
143+
bert_ctx * bctx;
144+
145+
// load the model
146+
{
147+
if ((bctx = bert_load_from_file(params.model, params.use_cpu)) == nullptr) {
148+
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model);
149+
return 1;
150+
}
151+
}
152+
std::string dir = params.model;
153+
std::size_t i = dir.rfind("/models/");
154+
if (i != std::string::npos) {
155+
dir.resize(i);
156+
} else {
157+
dir = ".";
158+
}
159+
160+
auto expected = read_expected_tokenids(dir + "/tests/hf_tokenized_ids.txt");
161+
auto prompts = txt2list(dir + "/tests/test_prompts.txt");
162+
163+
if (expected.size() == 0 || prompts.size() == 0) {
164+
printf("failed to read test data\n");
165+
return 1;
166+
}
167+
168+
if (expected.size() != prompts.size()) {
169+
printf("test data size mismatch\n");
170+
return 1;
171+
}
172+
173+
// tokenizer tests:
174+
for (size_t i = 0; i < prompts.size(); i++) {
175+
tokenizer_test(bctx, prompts[i], expected[i]);
176+
}
177+
178+
// tokenizer_test(bctx, "1231 2431431", {101, 13138, 2487, 22884, 16932, 21486, 102});
179+
// tokenizer_test(bctx, "Québec", {101, 5447, 102});
180+
// tokenizer_test(bctx, "syömme \t täällä tänään", {101, 25353, 5358, 4168, 11937, 25425, 9092, 14634, 102});
181+
// tokenizer_test(bctx, "I'm going to the store to buy 3 apples and a banana! You're welcome to come along if you'd like. The time is 2:30 p.m. and it's partly cloudy outside. I'll be back soon, so don't go anywhere.", {101, 1045, 1005, 1049, 2183, 2000, 1996, 3573, 2000, 4965, 1017, 18108, 1998, 1037, 15212, 999, 2017, 1005, 2128, 6160, 2000, 2272, 2247, 2065, 2017, 1005, 1040, 2066, 1012, 1996, 2051, 2003, 1016, 1024, 2382, 1052, 1012, 1049, 1012, 1998, 2009, 1005, 1055, 6576, 24706, 2648, 1012, 1045, 1005, 2222, 2022, 2067, 2574, 1010, 2061, 2123, 1005, 1056, 2175, 5973, 1012, 102});
182+
// tokenizer_test(bctx, "\"5 2 + 3 * 4 -\"; int stack[1000], top = -1; int calculate(int a, int b, char operator) { return operator == '+' ? a + b : operator == '-' ? a - b : operator == '*' ? a * b : a / b; } void push(int x) { stack[++top] = x; } int pop() { return stack[top--]; } int evaluatePostfix(char* expression) { for (int i = 0; expression[i]; i++) { if (isdigit(expression[i])) push(expression[i] - '0'); else { int a = pop(), b = pop(); push(calculate(b, a, expression[i])); } } return pop(); } int result = evaluatePostfix(input);", {101, 1000, 1019, 1016, 1009, 1017, 1008, 1018, 1011, 1000, 1025, 20014, 9991, 1031, 6694, 1033, 1010, 2327, 1027, 1011, 1015, 1025, 20014, 18422, 1006, 20014, 1037, 1010, 20014, 1038, 1010, 25869, 6872, 1007, 1063, 2709, 6872, 1027, 1027, 1005, 1009, 1005, 1029, 1037, 1009, 1038, 1024, 6872, 1027, 1027, 1005, 1011, 1005, 1029, 1037, 1011, 1038, 1024, 6872, 1027, 1027, 1005, 1008, 1005, 1029, 1037, 1008, 1038, 1024, 1037, 1013, 1038, 1025, 1065, 11675, 5245, 1006, 20014, 1060, 1007, 1063, 9991, 1031, 1009, 1009, 2327, 1033, 1027, 1060, 1025, 1065, 20014, 3769, 1006, 1007, 1063, 2709, 9991, 1031, 2327, 1011, 1011, 1033, 1025, 1065, 20014, 16157, 19894, 8873, 2595, 1006, 25869, 1008, 3670, 1007, 1063, 2005, 1006, 20014, 1045, 1027, 1014, 1025, 3670, 1031, 1045, 1033, 1025, 1045, 1009, 1009, 1007, 1063, 2065, 1006, 2003, 4305, 23806, 1006, 3670, 1031, 1045, 1033, 1007, 1007, 5245, 1006, 3670, 1031, 1045, 1033, 1011, 1005, 1014, 1005, 1007, 1025, 2842, 1063, 20014, 1037, 1027, 3769, 1006, 1007, 1010, 1038, 1027, 3769, 1006, 1007, 1025, 5245, 1006, 18422, 1006, 1038, 1010, 1037, 1010, 3670, 1031, 1045, 1033, 1007, 1007, 1025, 1065, 1065, 2709, 3769, 1006, 1007, 1025, 1065, 20014, 2765, 1027, 16157, 19894, 8873, 2595, 1006, 7953, 1007, 1025, 102});
183+
184+
// tokenizer_test(bctx, "Hello world!", {101, 7592, 2088, 999, 102});
185+
// tokenizer_test(bctx, "你好,世界!", {101, 100, 100, 1989, 1745, 100, 1986, 102});
186+
// tokenizer_test(bctx, "こんにちは、世界!", {101, 1655, 30217, 30194, 30188, 30198, 1635, 1745, 100, 1986, 102});
187+
}

0 commit comments

Comments
 (0)