-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathcombineCreateFinalTrainData.py
94 lines (67 loc) · 3.45 KB
/
combineCreateFinalTrainData.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import os
import csv
import sys
from collections import defaultdict
sys.path.append( '../../Utils/' )
from utils import tokenise_idiom as idiom_tokeniser
from utils import match_idioms, create_idiom_word_dict
from utils import _load_csv as load_csv
def _get_sents( location ) :
headers, data = load_csv( location, delimiter="," )
correct = list()
incorrect = list()
for row in data :
correct .append( row[headers.index( 'correct' ) ] )
incorrect.append( row[headers.index( 'incorrect' ) ] )
return correct, incorrect
def create_train_data( sent_location, similatiries, is_not_idiom_info, out_location, tokenise_strategy ) :
sent_header, sents = load_csv( sent_location, "," )
row_0, sims_data = load_csv( similatiries, "," )
not_idiom_header, not_idiom_data = load_csv( is_not_idiom_info, "\t" )
sims_data = [ row_0 ] + sims_data
assert len( sents ) == len( sims_data ) == len( not_idiom_data )
train_data = [ [ 'sentence1', 'sentence2', 'similarity' ] ]
for index in range( len( sims_data ) ) :
this_sentence = sents[ index ][ sent_header.index( 'original' ) ]
this_correct = sents[ index ][ sent_header.index( 'correct' ) ]
this_incorrect = sents[ index ][ sent_header.index( 'incorrect' ) ]
this_idiom = sents[ index ][ sent_header.index( 'idiom' ) ]
this_sim = sims_data[ index ][0]
this_pred = not_idiom_data[ index ][ not_idiom_header.index( 'prediction' ) ]
assert tokenise_strategy.lower() in [ 'select', 'all', 'none' ]
if ( int( this_pred ) == 0 and tokenise_strategy.lower() == 'select' ) or ( tokenise_strategy.lower() == 'all' ) : ## 0 is idiomatic
idiom_word_dict = create_idiom_word_dict( [ this_idiom ] )
matched_idioms = match_idioms( idiom_word_dict, this_sentence )
if len( matched_idioms ) == 0 :
print( "NO IDIOM!" )
print( this_idiom )
print( this_sentence )
# import pdb; pdb.set_trace()
# matched_idioms = match_idioms( idiom_word_dict, this_sentence )
continue
this_sentence = this_sentence.replace( this_idiom, idiom_tokeniser( this_idiom ) )
positive_example = [ this_sentence, this_correct, 1 ]
negative_example = [ this_sentence, this_incorrect, this_sim ]
train_data.append( positive_example )
train_data.append( negative_example )
outfile = os.path.join( out_location, 'sts_train_data.csv' )
with open( outfile, 'w' ) as csvfile :
writer = csv.writer( csvfile )
writer.writerows( train_data )
print( "Wrote STS train data to: ", outfile )
return
if __name__ == '__main__' :
for tokenise_strategy, out_location in [
( 'none' , 'trainDataNotTokenised' ),
( 'all' , 'trainDataAllTokenised' ),
( 'select', 'trainDataSelectTokenised' )
] :
params = {
'sent_location' : 'trainData/trainToPredict.csv' ,
'similatiries' : 'trainData/similatiries.csv' ,
'is_not_idiom_info' : 'trainData/predictions/predict_results_None.txt' ,
'out_location' : out_location ,
'tokenise_strategy' : tokenise_strategy ,
}
os.makedirs( out_location )
create_train_data( **params )