-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathmodels.py
98 lines (89 loc) · 3.36 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#! -*- coding: utf-8 -*-
# GAU-α 模型实现
import bert4keras
from bert4keras.models import *
class GAU_alpha(RoFormerV2):
"""GAU-α
改动:基本模块换成GAU
链接:https://kexue.fm/archives/9052
"""
def initializer(self, shape, dtype=None, order=3, gain=1.0):
return super(GAU_alpha, self).initializer(shape, dtype, order, gain)
def apply_main_layers(self, inputs, index):
"""GAU-α 的主体是基于Gated Attention Unit的模块
顺序:GAU --> Add --> LN
"""
x = inputs
attention_name = 'Transformer-%d-GatedAttentionUnit' % index
attention_mask = self.compute_attention_bias(index)
position_bias = self.compute_position_bias(x)
# Self Attention
xi = x
x = [x, position_bias]
arguments = {'a_bias': None, 'p_bias': 'rotary'}
if attention_mask is not None:
arguments['a_bias'] = True
x.insert(1, attention_mask)
x = self.apply(
inputs=x,
layer=GatedAttentionUnit,
arguments=arguments,
units=self.intermediate_size,
key_size=self.attention_key_size,
activation=self.hidden_act,
use_bias=False,
normalization='softmax_plus',
attention_dropout=self.attention_dropout_rate,
kernel_initializer=self.initializer,
name=attention_name
)
x = self.apply(
inputs=x,
layer=Dropout,
rate=self.dropout_rate,
name='%s-Dropout' % attention_name
)
x = self.apply(
inputs=[xi, x], layer=Add, name='%s-Add' % attention_name
)
x = self.apply(
inputs=x,
layer=LayerNormalization,
zero_mean=False,
scale=False,
offset=False,
name='%s-Norm' % attention_name
)
return x
def variable_mapping(self):
"""重新定义权重映射
"""
mapping = {
'Embedding-Token': ['bert/embeddings/word_embeddings'],
'Embedding-Segment': ['bert/embeddings/token_type_embeddings'],
}
for i in range(self.num_hidden_layers):
prefix = 'GAU_alpha/encoder/layer_%d/' % i
if bert4keras.__version__ >= '0.10.4':
mapping['Transformer-%d-GatedAttentionUnit' % i] = [
prefix + 'gau/i_dense/kernel',
# prefix + 'gau/i_dense/bias',
# prefix + 'gau/q_scaleoffset/beta',
prefix + 'gau/q_scaleoffset/gamma',
# prefix + 'gau/k_scaleoffset/beta',
prefix + 'gau/k_scaleoffset/gamma',
prefix + 'gau/o_dense/kernel',
# prefix + 'gau/o_dense/bias',
]
else:
mapping['Transformer-%d-GatedAttentionUnit' % i] = [
prefix + 'gau/i_dense/kernel',
# prefix + 'gau/i_dense/bias',
prefix + 'gau/o_dense/kernel',
# prefix + 'gau/o_dense/bias',
# prefix + 'gau/q_scaleoffset/beta',
prefix + 'gau/q_scaleoffset/gamma',
# prefix + 'gau/k_scaleoffset/beta',
prefix + 'gau/k_scaleoffset/gamma',
]
return mapping