-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathgenerate_c2w_data.py
74 lines (61 loc) · 2.68 KB
/
generate_c2w_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os, sys
import collections
import numpy as np
import re, string
MAX_LINE_SIZE = 80
MAX_WORDS_IN_LINE = 20
all_chars = ""
with open('pride-and-prejudice.txt') as f:
all_chars = f.read().replace('\n', ' ')
all_words = re.findall('[a-z]{2,}', all_chars.lower())
words = list(set(all_words))
def generate_pair():
# Grab a slice of the input file of size MAX_LINE_SIZE
index = np.random.randint(0, len(all_chars) - MAX_LINE_SIZE)
cquery = ' ' + all_chars[index:index+MAX_LINE_SIZE - 2] + ' '
# Replace unknown words with known ones
wquery = set(re.findall('[a-z]{2,}', cquery.lower()))
for w in wquery:
if w not in words[:VOCAB_SIZE]:
# Replace ALL occurrences in query with the same replacement word
other = words[np.random.randint(0, VOCAB_SIZE/2)]
exp = '[^a-z]' + w + '[^a-z]'
indices = [(m.start()+1, m.end()-1) for m in re.finditer(exp, cquery.lower())]
for b, e in reversed(indices):
cquery = cquery[0:b] + other + cquery[e:]
# Make sure the size of all chars is less than MAX_LINE_SIZE
if len(cquery) >= MAX_LINE_SIZE:
last_sp = cquery[:MAX_LINE_SIZE].rfind(' ')
cquery = cquery[:last_sp] + ' ' * (MAX_LINE_SIZE - last_sp)
# OK, now that we have the sequence of chars, find its sequence of words
# [TODO] Remember to remove stop words
list_of_words = re.findall('[a-z]{2,}', cquery.lower())
return cquery.strip(), list_of_words
def generate_data(ntrain, nval, vocab_size, data_folder, train_x, train_y, val_x, val_y):
if not os.path.exists(data_folder):
os.makedirs(data_folder)
global VOCAB_SIZE
VOCAB_SIZE = vocab_size
with open(train_x, 'w') as fx, open(train_y, 'w') as fy:
for _ in range(0, ntrain):
query, ans = generate_pair()
fx.write(query + '\n')
fy.write(','.join(ans) + '\n')
with open(val_x, 'w') as fx, open(val_y, 'w') as fy:
for _ in range(0, nval):
query, ans = generate_pair()
fx.write(query + '\n')
fy.write(','.join(ans) + '\n')
def main():
# [1]: number of samples in training set
# [2]: number of samples in validation set
# [3]: vocabulary size
data_folder = 'c2w_data'
if len(sys.argv) > 3: data_folder = data_folder + "_" + sys.argv[3]
train_x = os.path.join(data_folder, 'train_x.txt')
train_y = os.path.join(data_folder, 'train_y.txt')
val_x = os.path.join(data_folder, 'val_x.txt')
val_y = os.path.join(data_folder, 'val_y.txt')
generate_data(int(sys.argv[1]), int(sys.argv[2]), int(sys.argv[3]), data_folder, train_x, train_y, val_x, val_y)
if __name__ == "__main__":
main()