-
Notifications
You must be signed in to change notification settings - Fork 4.2k
/
Copy pathsparse_self_attention.py
149 lines (118 loc) · 6.59 KB
/
sparse_self_attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch.nn as nn
import torch
from torch import distributed as dist
from deepspeed.ops.sparse_attention import SparsityConfig
class SparseSelfAttention(nn.Module):
"""Implements an efficient Sparse Self Attention of Transformer layer based on `Generative Modeling with Sparse Transformers`: https://arxiv.org/abs/1904.10509
For more information please see, TODO DeepSpeed Sparse Transformer.
For usage example please see, TODO DeepSpeed Sparse Transformer Tutorial.
"""
def __init__(
self,
# SparsityConfig parameters needs to be set accordingly
sparsity_config=SparsityConfig(num_heads=4),
key_padding_mask_mode='add',
attn_mask_mode='mul',
max_seq_length=2048):
"""Initialize the sparse self attention layer.
Arguments:
sparsity_config: optional: this parameter determines sparsity pattern configuration; it is based on SparsityConfig class.
key_padding_mask_mode: optional: a string determining if key padding mask needs to be added, `add`, or be multiplied, `mul`.
attn_mask_mode: optional: a string determining if attention mask needs to be added, `add`, or be multiplied, `mul`.
max_seq_length: optional: the maximum sequence length this sparse attention module will be applied to; it controls the size of the master_layout.
"""
super().__init__()
# sparsity information
self.sparsity_config = sparsity_config
# initialize sparse layout and register as buffer
master_layout = self.sparsity_config.make_layout(max_seq_length)
self.register_buffer("master_layout", master_layout)
self._need_layout_synchronization = True
# mask modes
self.key_padding_mask_mode = key_padding_mask_mode
self.attn_mask_mode = attn_mask_mode
ops = dict()
def get_layout(self, L):
# if layout is never synchronized across GPUs, broadcast the layout from global rank 0
if self._need_layout_synchronization and dist.is_initialized():
dist.broadcast(self.master_layout, src=0)
self._need_layout_synchronization = False
if (L % self.sparsity_config.block != 0):
raise ValueError(
f'Sequence Length, {L}, needs to be dividable by Block size {self.sparsity_config.block}!')
num_blocks = L // self.sparsity_config.block
return self.master_layout[..., :num_blocks, :num_blocks].cpu() # layout needs to be a CPU tensor
# add to cache
def get_ops(self, H, L):
from deepspeed.ops.sparse_attention.matmul import MatMul
from deepspeed.ops.sparse_attention.softmax import Softmax
if L not in SparseSelfAttention.ops:
sparsity_layout = self.get_layout(L)
sparse_dot_sdd_nt = MatMul(sparsity_layout, self.sparsity_config.block, 'sdd', trans_a=False, trans_b=True)
sparse_dot_dsd_nn = MatMul(sparsity_layout,
self.sparsity_config.block,
'dsd',
trans_a=False,
trans_b=False)
sparse_softmax = Softmax(sparsity_layout, self.sparsity_config.block)
SparseSelfAttention.ops[L] = (sparse_dot_sdd_nt, sparse_dot_dsd_nn, sparse_softmax)
return SparseSelfAttention.ops[L]
def transpose_key_for_scores(self, x, L):
bsz, num_heads, seq_len, head_dim = x.size()
if seq_len != L:
return x.permute(0, 1, 3, 2)
return x
def transpose_mask_for_sparse(self, qtype, x, is_key_padding_mask=False):
x = x.type(qtype)
if is_key_padding_mask:
xdim = x.dim()
for d in range(xdim - 1, 0, -1):
x = x.squeeze(dim=d)
return x
return x.squeeze()
# forward pass
def forward(self, query, key, value, rpe=None, key_padding_mask=None, attn_mask=None):
"""Applies forward phase of sparse self attention
Arguments:
query: required: query tensor
key: required: key tensor
value: required: value tensor
rpe: optional: a tensor same dimension as x that is used as relative position embedding
key_padding_mask: optional: a mask tensor of size (BatchSize X SequenceLength)
attn_mask: optional: a mask tensor of size (SequenceLength X SequenceLength); currently only 2D is supported
key_padding_mask_mode: optional: a boolean determining if key_padding_mask needs to be added or multiplied
attn_mask_mode: optional: a boolean determining if attn_mask needs to be added or multiplied
Return:
attn_output: a dense tensor containing attention context
"""
assert query.dtype == torch.half, "sparse attention only supports training in fp16 currently, please file a github issue if you need fp32 support"
bsz, num_heads, tgt_len, head_dim = query.size()
# transpose back key if it is already transposed
key = self.transpose_key_for_scores(key, tgt_len)
# check that operation is supported
if query.shape != key.shape or key.shape != value.shape:
raise NotImplementedError('only self-attention is supported for now')
# squeeze key_padding_mask if it is given
if key_padding_mask is not None:
key_padding_mask = self.transpose_mask_for_sparse(query.dtype, key_padding_mask, is_key_padding_mask=True)
# squeeze attn_mask if it is given
if attn_mask is not None:
attn_mask = self.transpose_mask_for_sparse(query.dtype, attn_mask)
# cache look-up table computations etc
sparse_dot_sdd_nt, sparse_dot_dsd_nn, sparse_softmax = self.get_ops(num_heads, tgt_len)
scaling = float(head_dim)**-0.5
# attention scores
attn_output_weights = sparse_dot_sdd_nt(query, key)
attn_output_weights = sparse_softmax(attn_output_weights,
scale=scaling,
rpe=rpe,
key_padding_mask=key_padding_mask,
attn_mask=attn_mask,
key_padding_mask_mode=self.key_padding_mask_mode,
attn_mask_mode=self.attn_mask_mode)
# outputs
attn_output = sparse_dot_dsd_nn(attn_output_weights, value)
return attn_output