-
Notifications
You must be signed in to change notification settings - Fork 180
/
Copy pathBertTokenizer.swift
184 lines (160 loc) · 5.46 KB
/
BertTokenizer.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
//
// BertTokenizer.swift
// CoreMLBert
//
// Created by Julien Chaumond on 27/06/2019.
// Copyright © 2019 Hugging Face. All rights reserved.
//
import Foundation
enum TokenizerError: Error {
case tooLong(String)
}
class BertTokenizer {
private let basicTokenizer = BasicTokenizer()
private let wordpieceTokenizer: WordpieceTokenizer
private let maxLen = 512
private let vocab: [String: Int]
private let ids_to_tokens: [Int: String]
init() {
let url = Bundle.main.url(forResource: "vocab", withExtension: "txt")!
let vocabTxt = try! String(contentsOf: url)
let tokens = vocabTxt.split(separator: "\n").map { String($0) }
var vocab: [String: Int] = [:]
var ids_to_tokens: [Int: String] = [:]
for (i, token) in tokens.enumerated() {
vocab[token] = i
ids_to_tokens[i] = token
}
self.vocab = vocab
self.ids_to_tokens = ids_to_tokens
self.wordpieceTokenizer = WordpieceTokenizer(vocab: self.vocab)
}
func tokenize(text: String) -> [String] {
var tokens: [String] = []
for token in basicTokenizer.tokenize(text: text) {
for subToken in wordpieceTokenizer.tokenize(word: token) {
tokens.append(subToken)
}
}
return tokens
}
private func convertTokensToIds(tokens: [String]) throws -> [Int] {
if tokens.count > maxLen {
throw TokenizerError.tooLong(
"""
Token indices sequence length is longer than the specified maximum
sequence length for this BERT model (\(tokens.count) > \(maxLen). Running this
sequence through BERT will result in indexing errors".format(len(ids), self.max_len)
"""
)
}
return tokens.map { vocab[$0]! }
}
/// Main entry point
func tokenizeToIds(text: String) -> [Int] {
return try! convertTokensToIds(tokens: tokenize(text: text))
}
func tokenToId(token: String) -> Int {
return vocab[token]!
}
/// Un-tokenization: get tokens from tokenIds
func unTokenize(tokens: [Int]) -> [String] {
return tokens.map { ids_to_tokens[$0]! }
}
/// Un-tokenization:
func convertWordpieceToBasicTokenList(_ wordpieceTokenList: [String]) -> String {
var tokenList: [String] = []
var individualToken: String = ""
for token in wordpieceTokenList {
if token.starts(with: "##") {
individualToken += String(token.suffix(token.count - 2))
} else {
if individualToken.count > 0 {
tokenList.append(individualToken)
}
individualToken = token
}
}
tokenList.append(individualToken)
return tokenList.joined(separator: " ")
}
}
class BasicTokenizer {
let neverSplit = [
"[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"
]
func tokenize(text: String) -> [String] {
let splitTokens = text.folding(options: .diacriticInsensitive, locale: nil)
.components(separatedBy: NSCharacterSet.whitespaces)
let tokens = splitTokens.flatMap({ (token: String) -> [String] in
if neverSplit.contains(token) {
return [token]
}
var toks: [String] = []
var currentTok = ""
for c in token.lowercased() {
if c.isLetter || c.isNumber || c == "°" {
currentTok += String(c)
} else if currentTok.count > 0 {
toks.append(currentTok)
toks.append(String(c))
currentTok = ""
} else {
toks.append(String(c))
}
}
if currentTok.count > 0 {
toks.append(currentTok)
}
return toks
})
return tokens
}
}
class WordpieceTokenizer {
private let unkToken = "[UNK]"
private let maxInputCharsPerWord = 100
private let vocab: [String: Int]
init(vocab: [String: Int]) {
self.vocab = vocab
}
/// `word`: A single token.
/// Warning: this differs from the `pytorch-transformers` implementation.
/// This should have already been passed through `BasicTokenizer`.
func tokenize(word: String) -> [String] {
if word.count > maxInputCharsPerWord {
return [unkToken]
}
var outputTokens: [String] = []
var isBad = false
var start = 0
var subTokens: [String] = []
while start < word.count {
var end = word.count
var cur_substr: String? = nil
while start < end {
var substr = Utils.substr(word, start..<end)!
if start > 0 {
substr = "##\(substr)"
}
if vocab[substr] != nil {
cur_substr = substr
break
}
end -= 1
}
if cur_substr == nil {
isBad = true
break
}
subTokens.append(cur_substr!)
start = end
}
if isBad {
outputTokens.append(unkToken)
} else {
outputTokens.append(contentsOf: subTokens)
}
return outputTokens
}
}