Skip to content

Commit

Permalink
update test tokenizer script
Browse files Browse the repository at this point in the history
  • Loading branch information
xyzhang626 committed Dec 21, 2023
1 parent 242a2f8 commit d1f919f
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 37 deletions.
56 changes: 28 additions & 28 deletions examples/test_hf_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,34 @@
from ast import arg
from transformers import AutoTokenizer, AutoModel
import tiktoken
import argparse

tokenizer_name = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1"
# tokenizer_name = "jinaai/jina-embeddings-v2-base-en"
# tokenizer_name = "mymusise/CPM-GPT2"
# tokenizer_name = "gpt2"
# tokenizer_name = "bert-base-chinese"
# tokenizer_name = "BAAI/llm-embedder"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
def main(args):
# tokenizer_name = "sentence-transformers/multi-qa-MiniLM-L6-cos-v1"
if "MiniLM" in args.model_name:
tokenizer_name = f"sentence-transformers/{args.model_name}"
elif "bge-" in args.model_name:
tokenizer_name = f"BAAI/{args.model_name}"
else:
raise ValueError(f"Unknown model name: {args.model_name}")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

inps = [
"Hell12o,, world! complexness123",
",大12家好gpt,我是GPT。你好。中國龍",
"你好,世界!",
"こんにちは、世界!",
"syömme \t täällä tänään",
"🙂🙂🙂😒😒😒🍍🍍🍑😗⚜️🕕⛄☃️",
"1231 2431431",
]
with open("examples/test_prompts.txt", "r", encoding="utf-8") as f:
inps = f.readlines()
inps = list(map(lambda x: x.strip(), inps))

print("Using tokenizer:", tokenizer_name)
for inp in inps:
oup = tokenizer(inp, return_tensors="pt").input_ids[0].tolist()
print(f"{oup} is {tokenizer.decode(oup)}")
for token in oup:
print(f"{token} <--> {tokenizer.decode([token])}")
print("\n\n")
# print(f"{oup} is {tokenizer.decode(oup)}")
print("Using tokenizer:", tokenizer_name)
output = []
for inp in inps:
oup = tokenizer(inp, return_tensors="pt").input_ids[0].tolist()
output.append(",".join([str(x) for x in oup]))
for token in oup:
print(f"{token} <--> {tokenizer.decode([token])}")

# print(f"{inp} is tokenized as {oup}")
# for token in oup:
# print(f"{token} is {tokenizer.decode([token])}")
with open("examples/hf_tokenized_ids.txt", "w", encoding="utf-8") as f:
f.write("\n".join(output))

if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Download original repo files')
parser.add_argument('model_name', help='Name of the repo')
args = parser.parse_args()
main(args)
7 changes: 7 additions & 0 deletions examples/test_prompts.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Hell12o,, world! complexness123 A a
,大12家好gpt,我是GPT。你好。中國龍
你好,世界!
こんにちは、世界!
syömme \t täällä tänään
1231 2431431
Hello world
81 changes: 72 additions & 9 deletions examples/test_tokenizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,58 @@
#include <stdio.h>
#include <string>
#include <vector>
#include <fstream>
#include <sstream>
#define ANSI_COLOR_RED "\x1b[31m"
#define ANSI_COLOR_RESET "\x1b[0m"
#define ANSI_COLOR_GREEN "\x1b[32m"


std::vector<std::string> txt2list(const std::string& filename) {
std::ifstream file(filename);
std::vector<std::string> all_lines;

if (!file.is_open()) {
printf("can not open file: %s\n", filename.c_str());
return all_lines;
}

std::string line;
while (std::getline(file, line)) {
all_lines.push_back(line);
}

file.close();
return all_lines;
}

std::vector<std::vector<int>> read_expected_tokenids(const std::string& filename) {
std::ifstream file(filename);
std::vector<std::vector<int>> all_numbers;

if (!file.is_open()) {
printf("can not open file: %s\n", filename.c_str());
return all_numbers;
}


std::string line;
while (std::getline(file, line)) {
std::vector<int> line_numbers;
std::istringstream iss(line);
std::string number_str;

while (std::getline(iss, number_str, ',')) {
line_numbers.push_back(static_cast<bert_vocab_id>(std::stoi(number_str)));
}

all_numbers.push_back(line_numbers);
}

file.close();
return all_numbers;
}

void tokenizer_test(bert_ctx * ctx, const std::string& input, const std::vector<bert_vocab_id>& expected) {
int N = bert_n_max_tokens(ctx);
std::vector<bert_vocab_id> result(N);
Expand Down Expand Up @@ -44,7 +92,7 @@ void tokenizer_test(bert_ctx * ctx, const std::string& input, const std::vector<
int main(int argc, char ** argv) {

bert_params params;
params.model = "models/bge-small-en-v1.5/ggml-model-q4_0.bin";
params.model = "models/all-MiniLM-L6-v2/ggml-model-q4_0.bin";

if (bert_params_parse(argc, argv, params) == false) {
return 1;
Expand All @@ -61,16 +109,31 @@ int main(int argc, char ** argv) {
}
}

auto expected = read_expected_tokenids("examples/hf_tokenized_ids.txt");
auto prompts = txt2list("examples/test_prompts.txt");

if (expected.size() == 0 || prompts.size() == 0) {
printf("failed to read test data\n");
return 1;
}

if (expected.size() != prompts.size()) {
printf("test data size mismatch\n");
return 1;
}

// tokenizer tests:
for (size_t i = 0; i < prompts.size(); i++) {
tokenizer_test(bctx, prompts[i], expected[i]);
}

tokenizer_test(bctx, "1231 2431431", {101, 13138, 2487, 22884, 16932, 21486, 102});
tokenizer_test(bctx, "Québec", {101, 5447, 102});
tokenizer_test(bctx, "syömme \t täällä tänään", {101, 25353, 5358, 4168, 11937, 25425, 9092, 14634, 102});
tokenizer_test(bctx, "I'm going to the store to buy 3 apples and a banana! You're welcome to come along if you'd like. The time is 2:30 p.m. and it's partly cloudy outside. I'll be back soon, so don't go anywhere.", {101, 1045, 1005, 1049, 2183, 2000, 1996, 3573, 2000, 4965, 1017, 18108, 1998, 1037, 15212, 999, 2017, 1005, 2128, 6160, 2000, 2272, 2247, 2065, 2017, 1005, 1040, 2066, 1012, 1996, 2051, 2003, 1016, 1024, 2382, 1052, 1012, 1049, 1012, 1998, 2009, 1005, 1055, 6576, 24706, 2648, 1012, 1045, 1005, 2222, 2022, 2067, 2574, 1010, 2061, 2123, 1005, 1056, 2175, 5973, 1012, 102});
tokenizer_test(bctx, "\"5 2 + 3 * 4 -\"; int stack[1000], top = -1; int calculate(int a, int b, char operator) { return operator == '+' ? a + b : operator == '-' ? a - b : operator == '*' ? a * b : a / b; } void push(int x) { stack[++top] = x; } int pop() { return stack[top--]; } int evaluatePostfix(char* expression) { for (int i = 0; expression[i]; i++) { if (isdigit(expression[i])) push(expression[i] - '0'); else { int a = pop(), b = pop(); push(calculate(b, a, expression[i])); } } return pop(); } int result = evaluatePostfix(input);", {101, 1000, 1019, 1016, 1009, 1017, 1008, 1018, 1011, 1000, 1025, 20014, 9991, 1031, 6694, 1033, 1010, 2327, 1027, 1011, 1015, 1025, 20014, 18422, 1006, 20014, 1037, 1010, 20014, 1038, 1010, 25869, 6872, 1007, 1063, 2709, 6872, 1027, 1027, 1005, 1009, 1005, 1029, 1037, 1009, 1038, 1024, 6872, 1027, 1027, 1005, 1011, 1005, 1029, 1037, 1011, 1038, 1024, 6872, 1027, 1027, 1005, 1008, 1005, 1029, 1037, 1008, 1038, 1024, 1037, 1013, 1038, 1025, 1065, 11675, 5245, 1006, 20014, 1060, 1007, 1063, 9991, 1031, 1009, 1009, 2327, 1033, 1027, 1060, 1025, 1065, 20014, 3769, 1006, 1007, 1063, 2709, 9991, 1031, 2327, 1011, 1011, 1033, 1025, 1065, 20014, 16157, 19894, 8873, 2595, 1006, 25869, 1008, 3670, 1007, 1063, 2005, 1006, 20014, 1045, 1027, 1014, 1025, 3670, 1031, 1045, 1033, 1025, 1045, 1009, 1009, 1007, 1063, 2065, 1006, 2003, 4305, 23806, 1006, 3670, 1031, 1045, 1033, 1007, 1007, 5245, 1006, 3670, 1031, 1045, 1033, 1011, 1005, 1014, 1005, 1007, 1025, 2842, 1063, 20014, 1037, 1027, 3769, 1006, 1007, 1010, 1038, 1027, 3769, 1006, 1007, 1025, 5245, 1006, 18422, 1006, 1038, 1010, 1037, 1010, 3670, 1031, 1045, 1033, 1007, 1007, 1025, 1065, 1065, 2709, 3769, 1006, 1007, 1025, 1065, 20014, 2765, 1027, 16157, 19894, 8873, 2595, 1006, 7953, 1007, 1025, 102});
// tokenizer_test(bctx, "1231 2431431", {101, 13138, 2487, 22884, 16932, 21486, 102});
// tokenizer_test(bctx, "Québec", {101, 5447, 102});
// tokenizer_test(bctx, "syömme \t täällä tänään", {101, 25353, 5358, 4168, 11937, 25425, 9092, 14634, 102});
// tokenizer_test(bctx, "I'm going to the store to buy 3 apples and a banana! You're welcome to come along if you'd like. The time is 2:30 p.m. and it's partly cloudy outside. I'll be back soon, so don't go anywhere.", {101, 1045, 1005, 1049, 2183, 2000, 1996, 3573, 2000, 4965, 1017, 18108, 1998, 1037, 15212, 999, 2017, 1005, 2128, 6160, 2000, 2272, 2247, 2065, 2017, 1005, 1040, 2066, 1012, 1996, 2051, 2003, 1016, 1024, 2382, 1052, 1012, 1049, 1012, 1998, 2009, 1005, 1055, 6576, 24706, 2648, 1012, 1045, 1005, 2222, 2022, 2067, 2574, 1010, 2061, 2123, 1005, 1056, 2175, 5973, 1012, 102});
// tokenizer_test(bctx, "\"5 2 + 3 * 4 -\"; int stack[1000], top = -1; int calculate(int a, int b, char operator) { return operator == '+' ? a + b : operator == '-' ? a - b : operator == '*' ? a * b : a / b; } void push(int x) { stack[++top] = x; } int pop() { return stack[top--]; } int evaluatePostfix(char* expression) { for (int i = 0; expression[i]; i++) { if (isdigit(expression[i])) push(expression[i] - '0'); else { int a = pop(), b = pop(); push(calculate(b, a, expression[i])); } } return pop(); } int result = evaluatePostfix(input);", {101, 1000, 1019, 1016, 1009, 1017, 1008, 1018, 1011, 1000, 1025, 20014, 9991, 1031, 6694, 1033, 1010, 2327, 1027, 1011, 1015, 1025, 20014, 18422, 1006, 20014, 1037, 1010, 20014, 1038, 1010, 25869, 6872, 1007, 1063, 2709, 6872, 1027, 1027, 1005, 1009, 1005, 1029, 1037, 1009, 1038, 1024, 6872, 1027, 1027, 1005, 1011, 1005, 1029, 1037, 1011, 1038, 1024, 6872, 1027, 1027, 1005, 1008, 1005, 1029, 1037, 1008, 1038, 1024, 1037, 1013, 1038, 1025, 1065, 11675, 5245, 1006, 20014, 1060, 1007, 1063, 9991, 1031, 1009, 1009, 2327, 1033, 1027, 1060, 1025, 1065, 20014, 3769, 1006, 1007, 1063, 2709, 9991, 1031, 2327, 1011, 1011, 1033, 1025, 1065, 20014, 16157, 19894, 8873, 2595, 1006, 25869, 1008, 3670, 1007, 1063, 2005, 1006, 20014, 1045, 1027, 1014, 1025, 3670, 1031, 1045, 1033, 1025, 1045, 1009, 1009, 1007, 1063, 2065, 1006, 2003, 4305, 23806, 1006, 3670, 1031, 1045, 1033, 1007, 1007, 5245, 1006, 3670, 1031, 1045, 1033, 1011, 1005, 1014, 1005, 1007, 1025, 2842, 1063, 20014, 1037, 1027, 3769, 1006, 1007, 1010, 1038, 1027, 3769, 1006, 1007, 1025, 5245, 1006, 18422, 1006, 1038, 1010, 1037, 1010, 3670, 1031, 1045, 1033, 1007, 1007, 1025, 1065, 1065, 2709, 3769, 1006, 1007, 1025, 1065, 20014, 2765, 1027, 16157, 19894, 8873, 2595, 1006, 7953, 1007, 1025, 102});

tokenizer_test(bctx, "Hello world!", {101, 7592, 2088, 999, 102});
tokenizer_test(bctx, "你好,世界!", {101, 100, 100, 1989, 1745, 100, 1986, 102});
tokenizer_test(bctx, "こんにちは、世界!", {101, 1655, 30217, 30194, 30188, 30198, 1635, 1745, 100, 1986, 102});
// tokenizer_test(bctx, "Hello world!", {101, 7592, 2088, 999, 102});
// tokenizer_test(bctx, "你好,世界!", {101, 100, 100, 1989, 1745, 100, 1986, 102});
// tokenizer_test(bctx, "こんにちは、世界!", {101, 1655, 30217, 30194, 30188, 30198, 1635, 1745, 100, 1986, 102});
}
7 changes: 7 additions & 0 deletions test_tokenizer.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
mkdir build
cd build
cmake .. -DBUILD_SHARED_LIBS=OFF -DCMAKE_BUILD_TYPE=Release
make
cd ..
python examples/test_hf_tokenizer.py $1
build/bin/test_tokenizer -m models/$1/ggml-model-q4_0.bin

0 comments on commit d1f919f

Please sign in to comment.