-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathgrowing_bart.py
166 lines (129 loc) · 6.81 KB
/
growing_bart.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import torch
import torch.nn as nn
from transformers import BartModel, RobertaModel
from transformers.activations import ACT2FN
def Linear(in_features, out_features, bias=True):
m = nn.Linear(in_features, out_features, bias)
nn.init.xavier_uniform_(m.weight, gain=0.0000001)
if bias:
nn.init.constant_(m.bias, 0.0)
return m
class SimpleGenerator(nn.Module):
# takes in a encoded task description and generates parameters of an adapter
def __init__(self, config):
super().__init__()
self.input_dim = 768 # config.d_model
self.hidden_dim = config.generator_hdim
self.output_dim = config.d_model * config.adapter_dim * 2 + config.d_model + config.adapter_dim
if config.adapt_layer_norm:
self.output_dim += 2 * config.d_model
self.linear1 = Linear(self.input_dim, self.hidden_dim)
self.activation_fn = ACT2FN[config.activation_function]
self.linear2 = Linear(self.hidden_dim, self.output_dim)
def forward(self, x):
x = self.linear1(x)
x = self.activation_fn(x)
x = self.linear2(x)
return x.view(-1)
class ParameterGenerator(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.encoder = RobertaModel.from_pretrained('roberta-base')
if self.config.unfreeze_hyper_encoder:
self.encoder.train()
else:
self.encoder.eval()
self.decoders = nn.ModuleList([
SimpleGenerator(config) for _ in range(config.encoder_layers + config.decoder_layers)
])
def encode(self, input_ids, attention_mask=None, encoder_outputs=None,
decoder_input_ids=None, decoder_attention_mask=None, decoder_cached_states=None,
use_cache=False, is_training=False):
# to save memory, the encoder (bart) here is frozen
if self.config.unfreeze_hyper_encoder:
outputs = self.encoder(
input_ids,
attention_mask=attention_mask,
)
else:
with torch.no_grad():
outputs = self.encoder(
input_ids,
attention_mask=attention_mask,
)
x = outputs[0] # last hidden state
x = x[:, 0, :] # take <s> token (equiv. to [CLS])
# eos_mask = input_ids.eq(self.config.eos_token_id)
# if len(torch.unique(eos_mask.sum(1))) > 1:
# raise ValueError("All examples must have the same number of <eos> tokens.")
# sentence_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :]
# return sentence_representation
return x
def decode(self, sr):
return [one_decoder(sr) for one_decoder in self.decoders]
def forward(self, input_ids, attention_mask=None, encoder_outputs=None,
decoder_input_ids=None, decoder_attention_mask=None, decoder_cached_states=None,
use_cache=False, is_training=False):
h = self.encode(input_ids, attention_mask, encoder_outputs, decoder_input_ids,
decoder_attention_mask, decoder_cached_states, use_cache, is_training)
params = self.decode(h)
return params
class GrowingBart(nn.Module):
def __init__(self, model, meta_model, config):
super().__init__()
self.config = config
self.model = model
self.meta_model = meta_model
def set_relation(self, rel_ids, rel_masks):
# generate adapter parameters using task descriptions
generated_params = self.meta_model(rel_ids, attention_mask=rel_masks)
# apply the parameters to the adapters
self.apply_params_to_adapters(generated_params)
def forward(self, rel_ids, rel_masks, input_ids, input_masks, output_ids, output_masks, is_training=False):
# generate adapter parameters using task descriptions
generated_params = self.meta_model(rel_ids, attention_mask=rel_masks)
# apply the parameters to the adapters
self.apply_params_to_adapters(generated_params)
# use the adapted model to make zero-shot inference
ret = self.model(input_ids, attention_mask=input_masks,
decoder_input_ids=output_ids,
decoder_attention_mask=output_masks,
is_training=is_training
)
return ret
def apply_params_to_adapters(self, generated_params):
encoder_params, decoder_params = generated_params[:self.config.encoder_layers], generated_params[self.config.encoder_layers:]
d_model = self.config.d_model
d_adapter = self.config.adapter_dim
for p, encoder_layer in zip(encoder_params, self.model.encoders()):
# dw, db: down weight, down bias
# uw, ub: up weight, up bias
dw, uw, db, ub = p[0:d_model*d_adapter], \
p[d_model*d_adapter:d_model*d_adapter*2], \
p[d_model*d_adapter*2:d_model*d_adapter*2+d_adapter], \
p[d_model*d_adapter*2+d_adapter:d_model*d_adapter*2+d_adapter+d_model]
encoder_layer.adapter_down_weight = dw.view(d_model, d_adapter)
encoder_layer.adapter_down_bias = db.view(d_adapter)
encoder_layer.adapter_up_weight = uw.view(d_adapter, d_model)
encoder_layer.adapter_up_bias = ub.view(d_model)
if self.config.adapt_layer_norm:
encoder_layer.self_attn_layer_norm.weight.data = encoder_layer.self_attn_layer_norm.weight.data + p[-2*d_model: -1*d_model]
encoder_layer.self_attn_layer_norm.bias.data = encoder_layer.self_attn_layer_norm.bias.data + p[-1*d_model:]
for p, decoder_layer in zip(decoder_params, self.model.decoders()):
dw, uw, db, ub = p[0:d_model*d_adapter], \
p[d_model*d_adapter:d_model*d_adapter*2], \
p[d_model*d_adapter*2:d_model*d_adapter*2+d_adapter], \
p[d_model*d_adapter*2+d_adapter:d_model*d_adapter*2+d_adapter+d_model]
decoder_layer.adapter_down_weight = dw.view(d_model, d_adapter)
decoder_layer.adapter_down_bias = db.view(d_adapter)
decoder_layer.adapter_up_weight = uw.view(d_adapter, d_model)
decoder_layer.adapter_up_bias = ub.view(d_model)
if self.config.adapt_layer_norm:
decoder_layer.self_attn_layer_norm.weight.data = decoder_layer.self_attn_layer_norm.weight.data + p[-2*d_model: -1*d_model]
decoder_layer.self_attn_layer_norm.bias.data = decoder_layer.self_attn_layer_norm.bias.data + p[-1*d_model:]
# a = self.model.decoders()[-4]
# print(a.adapter_down_weight)
# print(a.adapter_down_bias)
# print(a.adapter_up_weight)
# print(a.adapter_up_bias)