Skip to content

Commit

Permalink
[transformer] add cross attention
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Mar 5, 2024
1 parent 3320377 commit e8a6e6d
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 66 deletions.
130 changes: 102 additions & 28 deletions wenet/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,33 +55,44 @@ def __init__(self,
self.use_sdpa = use_sdpa
self.dropout_rate = dropout_rate

def _forward_linearx(self, name: str, x: torch.Tensor) -> torch.Tensor:
assert x.ndim >= 3
if name == 'query':
x = self.linear_q(x)
elif name == 'key':
x = self.linear_k(x)
else:
assert name == 'value'
x = self.linear_v(x)
# split last dim
x_shape = x.size()
x_shape = x_shape[:-1] + torch.Size([self.h, self.d_k])
x = x.view(x_shape)
x = x.transpose(-3, -2) # (batch, ..., head, time, d_k)
return x

def forward_qkv(
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Transform query, key and value.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
query (torch.Tensor): Query tensor (#batch, ..., time1, size).
key (torch.Tensor): Key tensor (#batch, ..., time2, size).
value (torch.Tensor): Value tensor (#batch, ..., time2, size).
Returns:
torch.Tensor: Transformed query tensor, size
(#batch, n_head, time1, d_k).
(#batch, ..., n_head, time1, d_k).
torch.Tensor: Transformed key tensor, size
(#batch, n_head, time2, d_k).
(#batch, ..., n_head, time2, d_k).
torch.Tensor: Transformed value tensor, size
(#batch, n_head, time2, d_k).
(#batch, ..., n_head, time2, d_k).
"""
n_batch = query.size(0)
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
q = q.transpose(1, 2) # (batch, head, time1, d_k)
k = k.transpose(1, 2) # (batch, head, time2, d_k)
v = v.transpose(1, 2) # (batch, head, time2, d_k)

q = self._forward_linearx('query', query)
k = self._forward_linearx('key', key)
v = self._forward_linearx('value', value)
return q, k, v

def forward_attention(
Expand All @@ -94,42 +105,41 @@ def forward_attention(
Args:
value (torch.Tensor): Transformed value, size
(#batch, n_head, time2, d_k).
(#batch, ..., n_head, time2, d_k).
scores (torch.Tensor): Attention score, size
(#batch, n_head, time1, time2).
(#batch, ..., n_head, time1, time2).
mask (torch.Tensor): Mask, size (#batch, 1, time2) or
(#batch, time1, time2), (0, 0, 0) means fake mask.
(#batch, ..., time1, time2), (0, ..., 0, 0) means fake mask.
Returns:
torch.Tensor: Transformed value (#batch, time1, d_model)
weighted by the attention score (#batch, time1, time2).
"""
n_batch = value.size(0)
# NOTE(xcsong): When will `if mask.size(2) > 0` be True?
# 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
# 1st chunk to ease the onnx export.]
# 2. pytorch training
if mask.size(2) > 0: # time2 > 0
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
if mask.size(-1) > 0: # time2 > 0
mask = mask.unsqueeze(-3).eq(0) # (batch, .., 1, *, time2)
# For last chunk, time2 might be larger than scores.size(-1)
mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
mask = mask[..., :scores.size(-1)] # (batch, 1, *, time2)
scores = scores.masked_fill(mask, -float('inf'))
attn = torch.softmax(scores, dim=-1).masked_fill(
mask, 0.0) # (batch, head, time1, time2)
# NOTE(xcsong): When will `if mask.size(2) > 0` be False?
# 1. onnx(16/-1, -1/-1, 16/0)
# 2. jit (16/-1, -1/-1, 16/0, 16/4)
else:
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
attn = torch.softmax(scores,
dim=-1) # (batch, ..., head, time1, time2)

p_attn = self.dropout(attn)
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
self.h * self.d_k)
) # (batch, time1, d_model)

return self.linear_out(x) # (batch, time1, d_model)
x = torch.matmul(p_attn, value) # (batch, ..., head, time1, d_k)
x = x.transpose(-3, -2).contiguous() # [batch, ..., time1, head, d_k]
x_shape = x.size()[:-2] + torch.Size([self.h * self.d_k])
x = x.view(x_shape) # (batch, ..., time1, d_model)
return self.linear_out(x) # (batch, ..., time1, d_model)

def forward(
self,
Expand Down Expand Up @@ -369,3 +379,67 @@ def forward(
query.size(0), -1,
self.h * self.d_k)) # (batch, time1, d_model)
return self.linear_out(output), new_cache


class MultiHeadedCrossAttention(MultiHeadedAttention):

def __init__(self,
n_head: int,
n_feat: int,
dropout_rate: float,
key_bias: bool = True,
use_sdpa: bool = False):
super().__init__(n_head, n_feat, dropout_rate, key_bias, use_sdpa)

def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
pos_emb: torch.Tensor = torch.empty(0),
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
) -> Tuple[torch.Tensor, torch.Tensor]:
del pos_emb
if cache.size(0) > 0:
assert not self.training
q = self._forward_linearx('query', query)
k, v = torch.split(cache, cache.size(-1) // 2, dim=-1)

else:
q, k, v = self.forward_qkv(query, key, value)
new_cache = torch.cat((k, v), dim=-1)

B = query.size(0)
Beams = 1
if B != k.size(0):
assert not self.training
Beams = B // k.size(0)
B = k.size(0)
q = q.view(B, Beams, q.size(-3), q.size(-2), q.size(-1))
k = k.unsqueeze(1)
v = v.unsqueeze(1)
mask = mask.unsqueeze(1)

if not self.use_sdpa:
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
output = self.forward_attention(v, scores, mask)
else:
output = torch.nn.functional.scaled_dot_product_attention(
q,
k,
v,
attn_mask=mask.unsqueeze(1),
dropout_p=self.dropout_rate,
scale=1 / math.sqrt(self.d_k),
)
output = output.transpose(-2, -3).contiguous()
output_shape = output.size()[:-2] + torch.Size([self.h * self.d_k])
output = output.view(output_shape) # (batch, ..., time1, d_model)
output = self.linear_out(output)

if query.size(0) != B:
assert not self.training
output_shape = torch.Size([B * Beams]) + output.size()[2:]
output = output.view(output_shape)
return output, new_cache
36 changes: 25 additions & 11 deletions wenet/transformer/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
"""Decoder definition."""
from typing import Tuple, List, Optional
from typing import Dict, Tuple, List, Optional

import torch
import torch.utils.checkpoint as ckpt
Expand Down Expand Up @@ -101,7 +101,7 @@ def __init__(
WENET_ATTENTION_CLASSES["selfattn"](
attention_heads, attention_dim,
self_attention_dropout_rate, key_bias, use_sdpa),
WENET_ATTENTION_CLASSES["selfattn"](
WENET_ATTENTION_CLASSES["crossattn"](
attention_heads, attention_dim, src_attention_dropout_rate,
key_bias, use_sdpa) if src_attention else None,
PositionwiseFeedForward(attention_dim, linear_units,
Expand Down Expand Up @@ -196,8 +196,8 @@ def forward_one_step(
memory_mask: torch.Tensor,
tgt: torch.Tensor,
tgt_mask: torch.Tensor,
cache: Optional[List[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
cache: Dict[str, Dict[str, torch.Tensor]],
) -> torch.Tensor:
"""Forward one step.
This is only used for decoding.
Args:
Expand All @@ -213,25 +213,39 @@ def forward_one_step(
y.shape` is (batch, maxlen_out, token)
"""
x, _ = self.embed(tgt)
new_cache = []
update_cross_att_cache = True
if len(cache['cross_att_cache']) != 0:
assert len(cache['cross_att_cache']) == self.num_blocks
update_cross_att_cache = False
for i, decoder in enumerate(self.decoders):
if cache is None:
c = None
else:
c = cache[i]
layer_i = 'layer_{}'.format(i)
self_att_cache = cache['self_att_cache'].get(layer_i, None)
cross_att_cache = cache['cross_att_cache'].get(layer_i, None)
c = {
'self_att_cache': self_att_cache,
'cross_att_cache': cross_att_cache,
}

x, tgt_mask, memory, memory_mask = decoder(x,
tgt_mask,
memory,
memory_mask,
cache=c)
new_cache.append(x)

# update cache dict
assert c['self_att_cache'] is not None
assert c['cross_att_cache'] is not None
cache['self_att_cache'][layer_i] = c['self_att_cache']
if update_cross_att_cache:
cache['cross_att_cache'][layer_i] = c['cross_att_cache']

if self.normalize_before:
y = self.after_norm(x[:, -1])
else:
y = x[:, -1]
if self.use_output_layer:
y = torch.log_softmax(self.output_layer(y), dim=-1)
return y, new_cache
return y

def tie_or_clone_weights(self, jit_mode: bool = True):
"""Tie or clone module weights (between word_emb and output_layer)
Expand Down
46 changes: 30 additions & 16 deletions wenet/transformer/decoder_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Decoder self-attention layer definition."""
from typing import Optional, Tuple
from typing import Dict, Optional, Tuple

import torch
from torch import nn
Expand Down Expand Up @@ -65,7 +65,7 @@ def forward(
tgt_mask: torch.Tensor,
memory: torch.Tensor,
memory_mask: torch.Tensor,
cache: Optional[torch.Tensor] = None
cache: Optional[Dict[str, Optional[torch.Tensor]]] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute decoded features.
Expand All @@ -87,35 +87,52 @@ def forward(
torch.Tensor: Encoded memory mask (#batch, maxlen_in).
"""
if cache is not None:
att_cache = cache['self_att_cache']
cross_att_cache = cache['cross_att_cache']
else:
att_cache, cross_att_cache = None, None

residual = tgt
if self.normalize_before:
tgt = self.norm1(tgt)

if cache is None:
if att_cache is None:
tgt_q = tgt
tgt_q_mask = tgt_mask
att_cache = torch.empty(0, 0, 0, 0)
else:
# compute only the last frame query keeping dim: max_time_out -> 1
assert cache.shape == (
tgt.shape[0],
tgt.shape[1] - 1,
self.size,
), "{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
tgt_q = tgt[:, -1:, :]
residual = residual[:, -1:, :]
tgt_q_mask = tgt_mask[:, -1:, :]

x = residual + self.dropout(
self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0])
x, new_att_cache = self.self_attn(
tgt_q,
tgt_q,
tgt_q,
tgt_q_mask,
cache=att_cache,
)
if cache is not None:
cache['self_att_cache'] = new_att_cache
x = residual + self.dropout(x)
if not self.normalize_before:
x = self.norm1(x)

if self.src_attn is not None:
residual = x
if self.normalize_before:
x = self.norm2(x)
x = residual + self.dropout(
self.src_attn(x, memory, memory, memory_mask)[0])
if cross_att_cache is None:
cross_att_cache = torch.empty(0, 0, 0, 0)
x, new_cross_cache = self.src_attn(x,
memory,
memory,
memory_mask,
cache=cross_att_cache)
if cache is not None:
cache['cross_att_cache'] = new_cross_cache
x = residual + self.dropout(x)
if not self.normalize_before:
x = self.norm2(x)

Expand All @@ -126,7 +143,4 @@ def forward(
if not self.normalize_before:
x = self.norm3(x)

if cache is not None:
x = torch.cat([cache, x], dim=1)

return x, tgt_mask, memory, memory_mask
Loading

0 comments on commit e8a6e6d

Please sign in to comment.