diff --git a/examples/test_hf_tokenizer.py b/examples/test_hf_tokenizer.py index 3192758..09081cc 100644 --- a/examples/test_hf_tokenizer.py +++ b/examples/test_hf_tokenizer.py @@ -1,34 +1,34 @@ +from ast import arg from transformers import AutoTokenizer, AutoModel -import tiktoken +import argparse -tokenizer_name = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1" -# tokenizer_name = "jinaai/jina-embeddings-v2-base-en" -# tokenizer_name = "mymusise/CPM-GPT2" -# tokenizer_name = "gpt2" -# tokenizer_name = "bert-base-chinese" -# tokenizer_name = "BAAI/llm-embedder" -tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) +def main(args): + # tokenizer_name = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1" + if "MiniLM" in args.model_name: + tokenizer_name = f"sentence-transformers/{args.model_name}" + elif "bge-" in args.model_name: + tokenizer_name = f"BAAI/{args.model_name}" + else: + raise ValueError(f"Unknown model name: {args.model_name}") + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) -inps = [ - "Hell12o,, world! complexness123", - ",大12家好gpt,我是GPT。你好。中國龍", - "你好,世界!", - "こんにちは、世界!", - "syömme \t täällä tänään", - "🙂🙂🙂😒😒😒🍍🍍🍑😗⚜️🕕⛄☃️", - "1231 2431431", - ] + with open("examples/test_prompts.txt", "r", encoding="utf-8") as f: + inps = f.readlines() + inps = list(map(lambda x: x.strip(), inps)) -print("Using tokenizer:", tokenizer_name) -for inp in inps: - oup = tokenizer(inp, return_tensors="pt").input_ids[0].tolist() - print(f"{oup} is {tokenizer.decode(oup)}") - for token in oup: - print(f"{token} <--> {tokenizer.decode([token])}") - print("\n\n") - # print(f"{oup} is {tokenizer.decode(oup)}") + print("Using tokenizer:", tokenizer_name) + output = [] + for inp in inps: + oup = tokenizer(inp, return_tensors="pt").input_ids[0].tolist() + output.append(",".join([str(x) for x in oup])) + for token in oup: + print(f"{token} <--> {tokenizer.decode([token])}") - # print(f"{inp} is tokenized as {oup}") - # for token in oup: - # print(f"{token} is {tokenizer.decode([token])}") + with open("examples/hf_tokenized_ids.txt", "w", encoding="utf-8") as f: + f.write("\n".join(output)) +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Download original repo files') + parser.add_argument('model_name', help='Name of the repo') + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/examples/test_prompts.txt b/examples/test_prompts.txt new file mode 100644 index 0000000..0990ab4 --- /dev/null +++ b/examples/test_prompts.txt @@ -0,0 +1,7 @@ +Hell12o,, world! complexness123 A a +,大12家好gpt,我是GPT。你好。中國龍 +你好,世界! +こんにちは、世界! +syömme \t täällä tänään +1231 2431431 +Hello world \ No newline at end of file diff --git a/examples/test_tokenizer.cpp b/examples/test_tokenizer.cpp index 21c0ea4..cdd7ed7 100644 --- a/examples/test_tokenizer.cpp +++ b/examples/test_tokenizer.cpp @@ -7,10 +7,58 @@ #include #include #include +#include +#include #define ANSI_COLOR_RED "\x1b[31m" #define ANSI_COLOR_RESET "\x1b[0m" #define ANSI_COLOR_GREEN "\x1b[32m" + +std::vector txt2list(const std::string& filename) { + std::ifstream file(filename); + std::vector all_lines; + + if (!file.is_open()) { + printf("can not open file: %s\n", filename.c_str()); + return all_lines; + } + + std::string line; + while (std::getline(file, line)) { + all_lines.push_back(line); + } + + file.close(); + return all_lines; +} + +std::vector> read_expected_tokenids(const std::string& filename) { + std::ifstream file(filename); + std::vector> all_numbers; + + if (!file.is_open()) { + printf("can not open file: %s\n", filename.c_str()); + return all_numbers; + } + + + std::string line; + while (std::getline(file, line)) { + std::vector line_numbers; + std::istringstream iss(line); + std::string number_str; + + while (std::getline(iss, number_str, ',')) { + line_numbers.push_back(static_cast(std::stoi(number_str))); + } + + all_numbers.push_back(line_numbers); + } + + file.close(); + return all_numbers; +} + void tokenizer_test(bert_ctx * ctx, const std::string& input, const std::vector& expected) { int N = bert_n_max_tokens(ctx); std::vector result(N); @@ -44,7 +92,7 @@ void tokenizer_test(bert_ctx * ctx, const std::string& input, const std::vector< int main(int argc, char ** argv) { bert_params params; - params.model = "models/bge-small-en-v1.5/ggml-model-q4_0.bin"; + params.model = "models/all-MiniLM-L6-v2/ggml-model-q4_0.bin"; if (bert_params_parse(argc, argv, params) == false) { return 1; @@ -61,16 +109,31 @@ int main(int argc, char ** argv) { } } + auto expected = read_expected_tokenids("examples/hf_tokenized_ids.txt"); + auto prompts = txt2list("examples/test_prompts.txt"); + + if (expected.size() == 0 || prompts.size() == 0) { + printf("failed to read test data\n"); + return 1; + } + + if (expected.size() != prompts.size()) { + printf("test data size mismatch\n"); + return 1; + } // tokenizer tests: + for (size_t i = 0; i < prompts.size(); i++) { + tokenizer_test(bctx, prompts[i], expected[i]); + } - tokenizer_test(bctx, "1231 2431431", {101, 13138, 2487, 22884, 16932, 21486, 102}); - tokenizer_test(bctx, "Québec", {101, 5447, 102}); - tokenizer_test(bctx, "syömme \t täällä tänään", {101, 25353, 5358, 4168, 11937, 25425, 9092, 14634, 102}); - tokenizer_test(bctx, "I'm going to the store to buy 3 apples and a banana! You're welcome to come along if you'd like. The time is 2:30 p.m. and it's partly cloudy outside. I'll be back soon, so don't go anywhere.", {101, 1045, 1005, 1049, 2183, 2000, 1996, 3573, 2000, 4965, 1017, 18108, 1998, 1037, 15212, 999, 2017, 1005, 2128, 6160, 2000, 2272, 2247, 2065, 2017, 1005, 1040, 2066, 1012, 1996, 2051, 2003, 1016, 1024, 2382, 1052, 1012, 1049, 1012, 1998, 2009, 1005, 1055, 6576, 24706, 2648, 1012, 1045, 1005, 2222, 2022, 2067, 2574, 1010, 2061, 2123, 1005, 1056, 2175, 5973, 1012, 102}); - tokenizer_test(bctx, "\"5 2 + 3 * 4 -\"; int stack[1000], top = -1; int calculate(int a, int b, char operator) { return operator == '+' ? a + b : operator == '-' ? a - b : operator == '*' ? a * b : a / b; } void push(int x) { stack[++top] = x; } int pop() { return stack[top--]; } int evaluatePostfix(char* expression) { for (int i = 0; expression[i]; i++) { if (isdigit(expression[i])) push(expression[i] - '0'); else { int a = pop(), b = pop(); push(calculate(b, a, expression[i])); } } return pop(); } int result = evaluatePostfix(input);", {101, 1000, 1019, 1016, 1009, 1017, 1008, 1018, 1011, 1000, 1025, 20014, 9991, 1031, 6694, 1033, 1010, 2327, 1027, 1011, 1015, 1025, 20014, 18422, 1006, 20014, 1037, 1010, 20014, 1038, 1010, 25869, 6872, 1007, 1063, 2709, 6872, 1027, 1027, 1005, 1009, 1005, 1029, 1037, 1009, 1038, 1024, 6872, 1027, 1027, 1005, 1011, 1005, 1029, 1037, 1011, 1038, 1024, 6872, 1027, 1027, 1005, 1008, 1005, 1029, 1037, 1008, 1038, 1024, 1037, 1013, 1038, 1025, 1065, 11675, 5245, 1006, 20014, 1060, 1007, 1063, 9991, 1031, 1009, 1009, 2327, 1033, 1027, 1060, 1025, 1065, 20014, 3769, 1006, 1007, 1063, 2709, 9991, 1031, 2327, 1011, 1011, 1033, 1025, 1065, 20014, 16157, 19894, 8873, 2595, 1006, 25869, 1008, 3670, 1007, 1063, 2005, 1006, 20014, 1045, 1027, 1014, 1025, 3670, 1031, 1045, 1033, 1025, 1045, 1009, 1009, 1007, 1063, 2065, 1006, 2003, 4305, 23806, 1006, 3670, 1031, 1045, 1033, 1007, 1007, 5245, 1006, 3670, 1031, 1045, 1033, 1011, 1005, 1014, 1005, 1007, 1025, 2842, 1063, 20014, 1037, 1027, 3769, 1006, 1007, 1010, 1038, 1027, 3769, 1006, 1007, 1025, 5245, 1006, 18422, 1006, 1038, 1010, 1037, 1010, 3670, 1031, 1045, 1033, 1007, 1007, 1025, 1065, 1065, 2709, 3769, 1006, 1007, 1025, 1065, 20014, 2765, 1027, 16157, 19894, 8873, 2595, 1006, 7953, 1007, 1025, 102}); + // tokenizer_test(bctx, "1231 2431431", {101, 13138, 2487, 22884, 16932, 21486, 102}); + // tokenizer_test(bctx, "Québec", {101, 5447, 102}); + // tokenizer_test(bctx, "syömme \t täällä tänään", {101, 25353, 5358, 4168, 11937, 25425, 9092, 14634, 102}); + // tokenizer_test(bctx, "I'm going to the store to buy 3 apples and a banana! You're welcome to come along if you'd like. The time is 2:30 p.m. and it's partly cloudy outside. I'll be back soon, so don't go anywhere.", {101, 1045, 1005, 1049, 2183, 2000, 1996, 3573, 2000, 4965, 1017, 18108, 1998, 1037, 15212, 999, 2017, 1005, 2128, 6160, 2000, 2272, 2247, 2065, 2017, 1005, 1040, 2066, 1012, 1996, 2051, 2003, 1016, 1024, 2382, 1052, 1012, 1049, 1012, 1998, 2009, 1005, 1055, 6576, 24706, 2648, 1012, 1045, 1005, 2222, 2022, 2067, 2574, 1010, 2061, 2123, 1005, 1056, 2175, 5973, 1012, 102}); + // tokenizer_test(bctx, "\"5 2 + 3 * 4 -\"; int stack[1000], top = -1; int calculate(int a, int b, char operator) { return operator == '+' ? a + b : operator == '-' ? a - b : operator == '*' ? a * b : a / b; } void push(int x) { stack[++top] = x; } int pop() { return stack[top--]; } int evaluatePostfix(char* expression) { for (int i = 0; expression[i]; i++) { if (isdigit(expression[i])) push(expression[i] - '0'); else { int a = pop(), b = pop(); push(calculate(b, a, expression[i])); } } return pop(); } int result = evaluatePostfix(input);", {101, 1000, 1019, 1016, 1009, 1017, 1008, 1018, 1011, 1000, 1025, 20014, 9991, 1031, 6694, 1033, 1010, 2327, 1027, 1011, 1015, 1025, 20014, 18422, 1006, 20014, 1037, 1010, 20014, 1038, 1010, 25869, 6872, 1007, 1063, 2709, 6872, 1027, 1027, 1005, 1009, 1005, 1029, 1037, 1009, 1038, 1024, 6872, 1027, 1027, 1005, 1011, 1005, 1029, 1037, 1011, 1038, 1024, 6872, 1027, 1027, 1005, 1008, 1005, 1029, 1037, 1008, 1038, 1024, 1037, 1013, 1038, 1025, 1065, 11675, 5245, 1006, 20014, 1060, 1007, 1063, 9991, 1031, 1009, 1009, 2327, 1033, 1027, 1060, 1025, 1065, 20014, 3769, 1006, 1007, 1063, 2709, 9991, 1031, 2327, 1011, 1011, 1033, 1025, 1065, 20014, 16157, 19894, 8873, 2595, 1006, 25869, 1008, 3670, 1007, 1063, 2005, 1006, 20014, 1045, 1027, 1014, 1025, 3670, 1031, 1045, 1033, 1025, 1045, 1009, 1009, 1007, 1063, 2065, 1006, 2003, 4305, 23806, 1006, 3670, 1031, 1045, 1033, 1007, 1007, 5245, 1006, 3670, 1031, 1045, 1033, 1011, 1005, 1014, 1005, 1007, 1025, 2842, 1063, 20014, 1037, 1027, 3769, 1006, 1007, 1010, 1038, 1027, 3769, 1006, 1007, 1025, 5245, 1006, 18422, 1006, 1038, 1010, 1037, 1010, 3670, 1031, 1045, 1033, 1007, 1007, 1025, 1065, 1065, 2709, 3769, 1006, 1007, 1025, 1065, 20014, 2765, 1027, 16157, 19894, 8873, 2595, 1006, 7953, 1007, 1025, 102}); - tokenizer_test(bctx, "Hello world!", {101, 7592, 2088, 999, 102}); - tokenizer_test(bctx, "你好,世界!", {101, 100, 100, 1989, 1745, 100, 1986, 102}); - tokenizer_test(bctx, "こんにちは、世界!", {101, 1655, 30217, 30194, 30188, 30198, 1635, 1745, 100, 1986, 102}); + // tokenizer_test(bctx, "Hello world!", {101, 7592, 2088, 999, 102}); + // tokenizer_test(bctx, "你好,世界!", {101, 100, 100, 1989, 1745, 100, 1986, 102}); + // tokenizer_test(bctx, "こんにちは、世界!", {101, 1655, 30217, 30194, 30188, 30198, 1635, 1745, 100, 1986, 102}); } \ No newline at end of file diff --git a/test_tokenizer.sh b/test_tokenizer.sh new file mode 100644 index 0000000..b1051a5 --- /dev/null +++ b/test_tokenizer.sh @@ -0,0 +1,7 @@ +mkdir build +cd build +cmake .. -DBUILD_SHARED_LIBS=OFF -DCMAKE_BUILD_TYPE=Release +make +cd .. +python examples/test_hf_tokenizer.py $1 +build/bin/test_tokenizer -m models/$1/ggml-model-q4_0.bin \ No newline at end of file