-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathattention_intervention_structural.py
182 lines (165 loc) · 8.45 KB
/
attention_intervention_structural.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
"""Performs attention intervention on Winobias samples and saves results to JSON file."""
import json
import os
import random
import sys
from pandas import DataFrame
from transformers import (
GPT2Tokenizer, TransfoXLTokenizer, XLNetTokenizer
)
from attention_utils import perform_interventions, get_odds_ratio
from experiment_num_agreement import Model, Intervention
from vocab_utils import get_nouns, get_nouns2, get_verbs, get_verbs2, get_prepositions, \
get_preposition_nouns, get_adv1s, get_adv2s
import vocab_utils as vocab
def construct_templates(attractor):
templates = []
if attractor in ['singular', 'plural']:
for p in get_prepositions():
for ppns, ppnp in get_preposition_nouns():
ppn = ppns if attractor == 'singular' else ppnp
template = ' '.join(['The', '{}', p, 'the', ppn])
templates.append(template)
elif attractor in ('rc_singular', 'rc_plural', 'rc_singular_no_that', 'rc_plural_no_that'):
for noun2s, noun2p in get_nouns2():
noun2 = noun2s if attractor.startswith('rc_singular') else noun2p
for verb2s, verb2p in get_verbs2():
verb2 = verb2s if attractor.startswith('rc_singular') else verb2p
if attractor.endswith('no_that'):
template = ' '.join(['The', '{}', 'the', noun2, verb2])
else:
template = ' '.join(['The', '{}', 'that', 'the', noun2, verb2])
# templates.append(' '.join(['The', '{}', 'that', 'the', noun2s, verb2s]))
# templates.append(' '.join(['The', '{}', 'that', 'the', noun2p, verb2p]))
templates.append(template)
elif attractor in ('within_rc_singular', 'within_rc_plural', 'within_rc_singular_no_that', 'within_rc_plural_no_that'):
for ns, np in vocab.get_nouns():
noun = ns if attractor.startswith('within_rc_singular') else np
if attractor.endswith('no_that'):
template = ' '.join(['The', noun, 'the', '{}'])
else:
template = ' '.join(['The', noun, 'that', 'the', '{}'])
# templates.append(' '.join(['The', ns, 'that', 'the', '{}']))
# templates.append(' '.join(['The', np, 'that', 'the', '{}']))
templates.append(template)
elif attractor == 'distractor':
for adv1 in get_adv1s():
for adv2 in get_adv2s():
templates.append(' '.join(['The', '{}', adv1, 'and', adv2]))
elif attractor == 'distractor_1':
for adv1 in get_adv1s():
templates.append(' '.join(['The', '{}', adv1]))
else:
templates = ['The {}']
return templates
def load_structural_interventions(tokenizer, device, attractor, seed, examples):
# build list of interventions
interventions = {}
all_word_count = 0
used_word_count = 0
templates = construct_templates(attractor)
for temp in templates:
if attractor.startswith('within_rc'):
for noun2s, noun2p in get_nouns2():
for v_singular, v_plural in vocab.get_verbs():
all_word_count += 1
try:
intervention_name = '_'.join([temp, noun2s, v_singular])
interventions[intervention_name] = Intervention(
tokenizer,
temp,
[noun2s, noun2p],
[v_singular, v_plural],
device=device)
used_word_count += 1
except Exception as e:
pass
else:
for ns, np in vocab.get_nouns():
for v_singular, v_plural in vocab.get_verbs():
all_word_count += 1
try:
intervention_name = '_'.join([temp, ns, v_singular])
interventions[intervention_name] = Intervention(
tokenizer,
temp,
[ns, np],
[v_singular, v_plural],
device=device)
used_word_count += 1
except Exception as e:
pass
print(f"\t Only used {used_word_count}/{all_word_count} nouns due to tokenizer")
if examples > 0 and len(interventions) >= examples:
random.seed(seed)
interventions = {k: v
for k, v in random.sample(interventions.items(), examples)}
return interventions
def get_interventions_structural(gpt2_version, do_filter, model, tokenizer,
device='cuda', filter_quantile=0.25, seed=3, attractor=None, examples=100):
interventions = load_structural_interventions(tokenizer, device, attractor, seed, examples)
intervention_list = [intervention for intervention in interventions.values()]
interventions = intervention_list
json_data = {'model_version': gpt2_version,
'do_filter': do_filter,
'num_examples_loaded': len(interventions)}
if do_filter:
df = DataFrame({'odds_ratio': [get_odds_ratio(intervention, model) for intervention in intervention_list]})
df_expected = df[df.odds_ratio > 1]
threshold = df_expected.odds_ratio.quantile(filter_quantile)
filtered_interventions = []
assert len(intervention_list) == len(df)
for i in range(len(intervention_list)):
intervention = intervention_list[i]
odds_ratio = df.iloc[i].odds_ratio
if odds_ratio > threshold:
filtered_interventions.append(intervention)
print(f'Num examples with odds ratio > 1: {len(df_expected)} / {len(intervention_list)}')
print(
f'Num examples with odds ratio > {threshold:.4f} ({filter_quantile} quantile): {len(filtered_interventions)} / {len(intervention_list)}')
json_data['num_examples_aligned'] = len(df_expected)
json_data['filter_quantile'] = filter_quantile
json_data['threshold'] = threshold
interventions = filtered_interventions
json_data['num_examples_analyzed'] = len(interventions)
return interventions, json_data
def intervene_attention(gpt2_version, do_filter, attractor, device='cuda', filter_quantile=0.25, examples=100,\
seed=3, random_weights=False):
model = Model(output_attentions=True, gpt2_version=gpt2_version,
device=device, random_weights=random_weights)
tokenizer = (GPT2Tokenizer if model.is_gpt2 else
TransfoXLTokenizer if model.is_txl else
# XLNetTokenizer if model.is_xlnet
XLNetTokenizer
).from_pretrained(gpt2_version)
interventions, json_data = get_interventions_structural(gpt2_version, do_filter,
model, tokenizer,
device, filter_quantile,
seed=seed, attractor=attractor,
examples=examples)
results = perform_interventions(interventions, model)
json_data['mean_total_effect'] = DataFrame(results).total_effect.mean()
json_data['mean_model_indirect_effect'] = DataFrame(results).indirect_effect_model.mean()
json_data['mean_model_direct_effect'] = DataFrame(results).direct_effect_model.mean()
filter_name = 'filtered' if do_filter else 'unfiltered'
if random_weights:
gpt2_version += '_random'
fname = f"attention_results/{attractor}/attention_intervention_{gpt2_version}_{filter_name}.json"
base_path = '/'.join(fname.split('/')[:-1])
if not os.path.exists(base_path):
os.makedirs(base_path)
json_data['results'] = results
with open(fname, 'w') as f:
json.dump(json_data, f)
if __name__ == "__main__":
model = sys.argv[1]
device = sys.argv[2]
filter_quantile = float(sys.argv[3])
random_weights = sys.argv[4] == 'random'
attractor = sys.argv[5]
seed = int(sys.argv[6])
examples = int(sys.argv[7])
#intervene_attention(model, True, attractor, device=device, filter_quantile=filter_quantile, examples=examples, \
# seed=seed, random_weights=random_weights)
intervene_attention(model, False, attractor, device=device, filter_quantile=0.0, examples=examples, seed=seed, \
random_weights=random_weights)