-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathbert_model.py
270 lines (237 loc) · 11.7 KB
/
bert_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
from turtle import forward
from requests import head
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss, MSELoss
from transformers import BertPreTrainedModel, BertModel
from util import masked_softmax, weighted_sum, sort_by_seq_lens
class BertForCL(BertPreTrainedModel):
def __init__(self, config):
super(BertForCL, self).__init__(config)
#xxx 768
self.num_labels = config.num_labels
self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
#768 128
self.fc = nn.Linear(config.hidden_size, config.num_labels)
self.init_weights()
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None,
labels=None):
"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for computing the sequence classification/regression loss.
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
)
sequence_output = outputs[1]
sequence_output = self.dropout(sequence_output)
logits = self.fc(sequence_output)
loss = None
if labels is not None:
if self.num_labels == 1:
# We are doing regression
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output
class LinearClassifier(nn.Module):
"""Linear classifier"""
def __init__(self, encoder, num_classes=3):
super(LinearClassifier, self).__init__()
dim_mlp = encoder.fc.weight.shape[1]
#768 3
self.fc = nn.Linear(dim_mlp, num_classes)
def forward(self, features):
return self.fc(features)
class PairSupConBert(nn.Module):
def __init__(self, encoder, dropout=0.5, is_train=True):
super(PairSupConBert, self).__init__()
self.encoder = encoder.bert
self.dim_mlp = encoder.fc.weight.shape[1]
self.dropout = dropout
self.is_train = is_train
self.attention = SoftmaxAttention()
self.projection = nn.Sequential(
nn.Linear(4*self.dim_mlp, self.dim_mlp),
nn.ReLU())
self.pooler = nn.Sequential(nn.Linear(4*self.dim_mlp,self.dim_mlp),
encoder.bert.pooler)
self.head = nn.Sequential(nn.Linear(self.dim_mlp,self.dim_mlp),
nn.ReLU(inplace=True))
self.fc_sup = encoder.fc
self.fc_ce = nn.Linear(self.dim_mlp, 3)
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None):
input_ids2 = input_ids * token_type_ids
input_ids1 = input_ids - input_ids2
feat1 = self.encoder(input_ids1,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds)
feat2 = self.encoder(input_ids2,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds)
encoded_premises = feat1[0]
encoded_hypotheses = feat2[0]
attended_premises, attended_hypotheses = self.attention(encoded_premises, attention_mask, encoded_hypotheses, attention_mask)
enhanced_premises = torch.cat([encoded_premises,
attended_premises,
encoded_premises - attended_premises,
encoded_premises * attended_premises], dim=-1)
enhanced_hypotheses = torch.cat([encoded_hypotheses, attended_hypotheses,
encoded_hypotheses - attended_hypotheses,
encoded_hypotheses * attended_hypotheses],dim=-1)
projected_premises = self.projection(enhanced_premises)
projected_hypotheses = self.projection(enhanced_hypotheses)
pair_embeds = torch.cat([projected_premises, projected_hypotheses, projected_premises - projected_hypotheses, projected_premises * projected_hypotheses], dim=-1)
pair_output = self.pooler(pair_embeds)
if self.is_train:
feat = self.head(pair_output)
return F.normalize(self.fc_ce(feat),dim=1), F.normalize(self.fc_sup(feat),dim=1)
else:
return pair_output
class SoftmaxAttention(nn.Module):
"""
Attention layer taking premises and hypotheses encoded by an RNN as input
and computing the soft attention between their elements.
The dot product of the encoded vectors in the premises and hypotheses is
first computed. The softmax of the result is then used in a weighted sum
of the vectors of the premises for each element of the hypotheses, and
conversely for the elements of the premises.
"""
def forward(self,
premise_batch,
premise_mask,
hypothesis_batch,
hypothesis_mask):
"""
Args:
premise_batch: A batch of sequences of vectors representing the
premises in some NLI task. The batch is assumed to have the
size (batch, sequences, vector_dim).
premise_mask: A mask for the sequences in the premise batch, to
ignore padding data in the sequences during the computation of
the attention.
hypothesis_batch: A batch of sequences of vectors representing the
hypotheses in some NLI task. The batch is assumed to have the
size (batch, sequences, vector_dim).
hypothesis_mask: A mask for the sequences in the hypotheses batch,
to ignore padding data in the sequences during the computation
of the attention.
Returns:
attended_premises: The sequences of attention vectors for the
premises in the input batch.
attended_hypotheses: The sequences of attention vectors for the
hypotheses in the input batch.
"""
# Dot product between premises and hypotheses in each sequence of
# the batch.
similarity_matrix = premise_batch.bmm(hypothesis_batch.transpose(2, 1)
.contiguous())
# Softmax attention weights.
prem_hyp_attn = masked_softmax(similarity_matrix, hypothesis_mask)
hyp_prem_attn = masked_softmax(similarity_matrix.transpose(1, 2)
.contiguous(),
premise_mask)
# Weighted sums of the hypotheses for the the premises attention,
# and vice-versa for the attention of the hypotheses.
attended_premises = weighted_sum(hypothesis_batch,
prem_hyp_attn,
premise_mask)
attended_hypotheses = weighted_sum(premise_batch,
hyp_prem_attn,
hypothesis_mask)
return attended_premises, attended_hypotheses
class Seq2SeqEncoder(nn.Module):
"""
RNN taking variable length padded sequences of vectors as input and
encoding them into padded sequences of vectors of the same length.
This module is useful to handle batches of padded sequences of vectors
that have different lengths and that need to be passed through a RNN.
The sequences are sorted in descending order of their lengths, packed,
passed through the RNN, and the resulting sequences are then padded and
permuted back to the original order of the input sequences.
"""
def __init__(self,
rnn_type,
input_size,
hidden_size,
num_layers=1,
bias=True,
dropout=0.0,
bidirectional=False):
"""
Args:
rnn_type: The type of RNN to use as encoder in the module.
Must be a class inheriting from torch.nn.RNNBase
(such as torch.nn.LSTM for example).
input_size: The number of expected features in the input of the
module.
hidden_size: The number of features in the hidden state of the RNN
used as encoder by the module.
num_layers: The number of recurrent layers in the encoder of the
module. Defaults to 1.
bias: If False, the encoder does not use bias weights b_ih and
b_hh. Defaults to True.
dropout: If non-zero, introduces a dropout layer on the outputs
of each layer of the encoder except the last one, with dropout
probability equal to 'dropout'. Defaults to 0.0.
bidirectional: If True, the encoder of the module is bidirectional.
Defaults to False.
"""
assert issubclass(rnn_type, nn.RNNBase),\
"rnn_type must be a class inheriting from torch.nn.RNNBase"
super(Seq2SeqEncoder, self).__init__()
self.rnn_type = rnn_type
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.bias = bias
self.dropout = dropout
self.bidirectional = bidirectional
self._encoder = rnn_type(input_size,
hidden_size,
num_layers=num_layers,
bias=bias,
batch_first=True,
dropout=dropout,
bidirectional=bidirectional)
def forward(self, sequences_batch, sequences_lengths):
"""
Args:
sequences_batch: A batch of variable length sequences of vectors.
The batch is assumed to be of size
(batch, sequence, vector_dim).
sequences_lengths: A 1D tensor containing the sizes of the
sequences in the input batch.
Returns:
reordered_outputs: The outputs (hidden states) of the encoder for
the sequences in the input batch, in the same order.
"""
sorted_batch, sorted_lengths, _, restoration_idx =\
sort_by_seq_lens(sequences_batch, sequences_lengths)
packed_batch = nn.utils.rnn.pack_padded_sequence(sorted_batch,
sorted_lengths,
batch_first=True)
outputs, _ = self._encoder(packed_batch, None)
outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs,
batch_first=True)
reordered_outputs = outputs.index_select(0, restoration_idx)
return reordered_outputs