-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathtext_utils.py
217 lines (190 loc) · 7.29 KB
/
text_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
import re
import ftfy
import json
import spacy
from tqdm import tqdm
from typing import List, Dict
from spacy.tokens import Doc
def get_pairs(word):
"""
Return set of symbol pairs in a word.
word is represented as tuple of symbols (symbols being variable-length strings)
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
def text_standardize(text):
"""
fixes some issues the spacy tokenizer had on books corpus
also does some whitespace standardization
"""
text = text.replace('—', '-')
text = text.replace('–', '-')
text = text.replace('―', '-')
text = text.replace('…', '...')
text = text.replace('´', "'")
text = re.sub(r'''(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)''', r' \1 ', text)
text = re.sub(r'\s*\n\s*', ' \n ', text)
text = re.sub(r'[^\S\n]+', ' ', text)
return text.strip()
class TextEncoder(object):
"""
mostly a wrapper for a public python bpe tokenizer
"""
def __init__(self, encoder_path, bpe_path):
self.nlp = spacy.load('en', disable=['parser', 'tagger', 'ner', 'textcat'])
self.encoder = json.load(open(encoder_path))
self.decoder = {v:k for k,v in self.encoder.items()}
merges = open(bpe_path, encoding='utf-8').read().split('\n')[1:-1]
merges = [tuple(merge.split()) for merge in merges]
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {}
def bpe(self, token):
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
if token in self.cache:
return self.cache[token]
pairs = get_pairs(word)
if not pairs:
return token+'</w>'
while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
if word == '\n </w>':
word = '\n</w>'
self.cache[token] = word
return word
def encode(self, texts, verbose=True, use_tokenizer=False, special_tokens=None):
texts_tokens = []
if verbose:
for text in tqdm(texts, ncols=80, leave=False):
if use_tokenizer:
text = self.nlp(text_standardize(ftfy.fix_text(text)))
else:
words = []
for token in text:
if special_tokens is not None and token.lower() in special_tokens:
words.append(token)
else:
words.append(text_standardize(ftfy.fix_text(token)))
text = Doc(self.nlp.vocab, words=words)
text_tokens = []
for token in text:
if special_tokens is not None and token.text.lower() in special_tokens:
text_tokens.append(self.encoder.get(token.text.lower(), 0))
else:
text_tokens.extend([self.encoder.get(t, 0) for t in self.bpe(token.text.lower()).split(' ')])
texts_tokens.append(text_tokens)
else:
for text in texts:
if use_tokenizer:
text = self.nlp(text_standardize(ftfy.fix_text(text)))
else:
words = []
for token in text:
if special_tokens is not None and token.lower() in special_tokens:
words.append(token)
else:
words.append(text_standardize(ftfy.fix_text(token)))
text = Doc(self.nlp.vocab, words=words)
text_tokens = []
for token in text:
if special_tokens is not None and token.text.lower() in special_tokens:
text_tokens.append(self.encoder.get(token.text.lower(), 0))
else:
text_tokens.extend([self.encoder.get(t, 0) for t in self.bpe(token.text.lower()).split(' ')])
texts_tokens.append(text_tokens)
return texts_tokens
class Dictionary:
"""
This class holds a dictionary that maps strings to IDs, used to generate one-hot encodings of strings.
"""
def __init__(self, add_unk=True):
# init dictionaries
self.item2idx: Dict[str, int] = {}
self.idx2item: List[str] = []
# in order to deal with unknown tokens, add <unk>
if add_unk:
self.add_item('<unk>')
def add_item(self, item: str) -> int:
"""
add string - if already in dictionary returns its ID. if not in dictionary, it will get a new ID.
:param item: a string for which to assign an id
:return: ID of string
"""
item = item.encode('utf-8')
if item not in self.item2idx:
self.idx2item.append(item)
self.item2idx[item] = len(self.idx2item) - 1
return self.item2idx[item]
def get_idx_for_item(self, item: str) -> int:
"""
returns the ID of the string, otherwise 0
:param item: string for which ID is requested
:return: ID of string, otherwise 0
"""
item = item.encode('utf-8')
if item in self.item2idx.keys():
return self.item2idx[item]
else:
return 0
def get_items(self) -> List[str]:
items = []
for item in self.idx2item:
items.append(item.decode('UTF-8'))
return items
def __len__(self) -> int:
return len(self.idx2item)
def get_item_for_index(self, idx):
return self.idx2item[idx].decode('UTF-8')
def save(self, savefile):
import pickle
with open(savefile, 'wb') as f:
mappings = {
'idx2item': self.idx2item,
'item2idx': self.item2idx
}
pickle.dump(mappings, f)
@classmethod
def load_from_file(cls, filename: str):
import pickle
dictionary: Dictionary = Dictionary()
with open(filename, 'rb') as f:
mappings = pickle.load(f, encoding='latin1')
idx2item = mappings['idx2item']
item2idx = mappings['item2idx']
dictionary.item2idx = item2idx
dictionary.idx2item = idx2item
return dictionary
@classmethod
def load(cls, name: str):
return Dictionary.load_from_file(name)
class LabelEncoder(Dictionary):
pass