-
Notifications
You must be signed in to change notification settings - Fork 100
/
Copy pathretrofit.py
92 lines (80 loc) · 3.15 KB
/
retrofit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import argparse
import gzip
import math
import numpy
import re
import sys
from copy import deepcopy
isNumber = re.compile(r'\d+.*')
def norm_word(word):
if isNumber.search(word.lower()):
return '---num---'
elif re.sub(r'\W+', '', word) == '':
return '---punc---'
else:
return word.lower()
''' Read all the word vectors and normalize them '''
def read_word_vecs(filename):
wordVectors = {}
if filename.endswith('.gz'): fileObject = gzip.open(filename, 'r')
else: fileObject = open(filename, 'r')
for line in fileObject:
line = line.strip().lower()
word = line.split()[0]
wordVectors[word] = numpy.zeros(len(line.split())-1, dtype=float)
for index, vecVal in enumerate(line.split()[1:]):
wordVectors[word][index] = float(vecVal)
''' normalize weight vector '''
wordVectors[word] /= math.sqrt((wordVectors[word]**2).sum() + 1e-6)
sys.stderr.write("Vectors read from: "+filename+" \n")
return wordVectors
''' Write word vectors to file '''
def print_word_vecs(wordVectors, outFileName):
sys.stderr.write('\nWriting down the vectors in '+outFileName+'\n')
outFile = open(outFileName, 'w')
for word, values in wordVectors.iteritems():
outFile.write(word+' ')
for val in wordVectors[word]:
outFile.write('%.4f' %(val)+' ')
outFile.write('\n')
outFile.close()
''' Read the PPDB word relations as a dictionary '''
def read_lexicon(filename):
lexicon = {}
for line in open(filename, 'r'):
words = line.lower().strip().split()
lexicon[norm_word(words[0])] = [norm_word(word) for word in words[1:]]
return lexicon
''' Retrofit word vectors to a lexicon '''
def retrofit(wordVecs, lexicon, numIters):
newWordVecs = deepcopy(wordVecs)
wvVocab = set(newWordVecs.keys())
loopVocab = wvVocab.intersection(set(lexicon.keys()))
for it in range(numIters):
# loop through every node also in ontology (else just use data estimate)
for word in loopVocab:
wordNeighbours = set(lexicon[word]).intersection(wvVocab)
numNeighbours = len(wordNeighbours)
#no neighbours, pass - use data estimate
if numNeighbours == 0:
continue
# the weight of the data estimate if the number of neighbours
newVec = numNeighbours * wordVecs[word]
# loop over neighbours and add to new vector (currently with weight 1)
for ppWord in wordNeighbours:
newVec += newWordVecs[ppWord]
newWordVecs[word] = newVec/(2*numNeighbours)
return newWordVecs
if __name__=='__main__':
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input", type=str, default=None, help="Input word vecs")
parser.add_argument("-l", "--lexicon", type=str, default=None, help="Lexicon file name")
parser.add_argument("-o", "--output", type=str, help="Output word vecs")
parser.add_argument("-n", "--numiter", type=int, default=10, help="Num iterations")
args = parser.parse_args()
wordVecs = read_word_vecs(args.input)
lexicon = read_lexicon(args.lexicon)
numIter = int(args.numiter)
outFileName = args.output
''' Enrich the word vectors using ppdb and print the enriched vectors '''
print_word_vecs(retrofit(wordVecs, lexicon, numIter), outFileName)