Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[transformer] add cross attention #2388

Merged
merged 1 commit into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 102 additions & 28 deletions wenet/transformer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,33 +55,44 @@ def __init__(self,
self.use_sdpa = use_sdpa
self.dropout_rate = dropout_rate

def _forward_linearx(self, name: str, x: torch.Tensor) -> torch.Tensor:
assert x.ndim >= 3
if name == 'query':
x = self.linear_q(x)
elif name == 'key':
x = self.linear_k(x)
else:
assert name == 'value'
x = self.linear_v(x)
# split last dim
x_shape = x.size()
x_shape = x_shape[:-1] + torch.Size([self.h, self.d_k])
x = x.view(x_shape)
x = x.transpose(-3, -2) # (batch, ..., head, time, d_k)
return x

def forward_qkv(
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Transform query, key and value.

Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
query (torch.Tensor): Query tensor (#batch, ..., time1, size).
key (torch.Tensor): Key tensor (#batch, ..., time2, size).
value (torch.Tensor): Value tensor (#batch, ..., time2, size).

Returns:
torch.Tensor: Transformed query tensor, size
(#batch, n_head, time1, d_k).
(#batch, ..., n_head, time1, d_k).
torch.Tensor: Transformed key tensor, size
(#batch, n_head, time2, d_k).
(#batch, ..., n_head, time2, d_k).
torch.Tensor: Transformed value tensor, size
(#batch, n_head, time2, d_k).
(#batch, ..., n_head, time2, d_k).

"""
n_batch = query.size(0)
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
q = q.transpose(1, 2) # (batch, head, time1, d_k)
k = k.transpose(1, 2) # (batch, head, time2, d_k)
v = v.transpose(1, 2) # (batch, head, time2, d_k)

q = self._forward_linearx('query', query)
k = self._forward_linearx('key', key)
v = self._forward_linearx('value', value)
return q, k, v

def forward_attention(
Expand All @@ -94,42 +105,41 @@ def forward_attention(

Args:
value (torch.Tensor): Transformed value, size
(#batch, n_head, time2, d_k).
(#batch, ..., n_head, time2, d_k).
scores (torch.Tensor): Attention score, size
(#batch, n_head, time1, time2).
(#batch, ..., n_head, time1, time2).
mask (torch.Tensor): Mask, size (#batch, 1, time2) or
(#batch, time1, time2), (0, 0, 0) means fake mask.
(#batch, ..., time1, time2), (0, ..., 0, 0) means fake mask.

Returns:
torch.Tensor: Transformed value (#batch, time1, d_model)
weighted by the attention score (#batch, time1, time2).

"""
n_batch = value.size(0)
# NOTE(xcsong): When will `if mask.size(2) > 0` be True?
# 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
# 1st chunk to ease the onnx export.]
# 2. pytorch training
if mask.size(2) > 0: # time2 > 0
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
if mask.size(-1) > 0: # time2 > 0
mask = mask.unsqueeze(-3).eq(0) # (batch, .., 1, *, time2)
# For last chunk, time2 might be larger than scores.size(-1)
mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
mask = mask[..., :scores.size(-1)] # (batch, 1, *, time2)
scores = scores.masked_fill(mask, -float('inf'))
attn = torch.softmax(scores, dim=-1).masked_fill(
mask, 0.0) # (batch, head, time1, time2)
# NOTE(xcsong): When will `if mask.size(2) > 0` be False?
# 1. onnx(16/-1, -1/-1, 16/0)
# 2. jit (16/-1, -1/-1, 16/0, 16/4)
else:
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
attn = torch.softmax(scores,
dim=-1) # (batch, ..., head, time1, time2)

p_attn = self.dropout(attn)
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
self.h * self.d_k)
) # (batch, time1, d_model)

return self.linear_out(x) # (batch, time1, d_model)
x = torch.matmul(p_attn, value) # (batch, ..., head, time1, d_k)
x = x.transpose(-3, -2).contiguous() # [batch, ..., time1, head, d_k]
x_shape = x.size()[:-2] + torch.Size([self.h * self.d_k])
x = x.view(x_shape) # (batch, ..., time1, d_model)
return self.linear_out(x) # (batch, ..., time1, d_model)

def forward(
self,
Expand Down Expand Up @@ -369,3 +379,67 @@ def forward(
query.size(0), -1,
self.h * self.d_k)) # (batch, time1, d_model)
return self.linear_out(output), new_cache


class MultiHeadedCrossAttention(MultiHeadedAttention):

def __init__(self,
n_head: int,
n_feat: int,
dropout_rate: float,
key_bias: bool = True,
use_sdpa: bool = False):
super().__init__(n_head, n_feat, dropout_rate, key_bias, use_sdpa)

def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
pos_emb: torch.Tensor = torch.empty(0),
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
) -> Tuple[torch.Tensor, torch.Tensor]:
del pos_emb
if cache.size(0) > 0:
assert not self.training
q = self._forward_linearx('query', query)
k, v = torch.split(cache, cache.size(-1) // 2, dim=-1)

else:
q, k, v = self.forward_qkv(query, key, value)
new_cache = torch.cat((k, v), dim=-1)

B = query.size(0)
Beams = 1
if B != k.size(0):
assert not self.training
Beams = B // k.size(0)
B = k.size(0)
q = q.view(B, Beams, q.size(-3), q.size(-2), q.size(-1))
k = k.unsqueeze(1)
v = v.unsqueeze(1)
mask = mask.unsqueeze(1)

if not self.use_sdpa:
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
output = self.forward_attention(v, scores, mask)
else:
output = torch.nn.functional.scaled_dot_product_attention(
q,
k,
v,
attn_mask=mask.unsqueeze(1),
dropout_p=self.dropout_rate,
scale=1 / math.sqrt(self.d_k),
)
output = output.transpose(-2, -3).contiguous()
output_shape = output.size()[:-2] + torch.Size([self.h * self.d_k])
output = output.view(output_shape) # (batch, ..., time1, d_model)
output = self.linear_out(output)

if query.size(0) != B:
assert not self.training
output_shape = torch.Size([B * Beams]) + output.size()[2:]
output = output.view(output_shape)
return output, new_cache
36 changes: 25 additions & 11 deletions wenet/transformer/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
"""Decoder definition."""
from typing import Tuple, List, Optional
from typing import Dict, Tuple, List, Optional

import torch
import torch.utils.checkpoint as ckpt
Expand Down Expand Up @@ -101,7 +101,7 @@ def __init__(
WENET_ATTENTION_CLASSES["selfattn"](
attention_heads, attention_dim,
self_attention_dropout_rate, key_bias, use_sdpa),
WENET_ATTENTION_CLASSES["selfattn"](
WENET_ATTENTION_CLASSES["crossattn"](
attention_heads, attention_dim, src_attention_dropout_rate,
key_bias, use_sdpa) if src_attention else None,
PositionwiseFeedForward(attention_dim, linear_units,
Expand Down Expand Up @@ -196,8 +196,8 @@ def forward_one_step(
memory_mask: torch.Tensor,
tgt: torch.Tensor,
tgt_mask: torch.Tensor,
cache: Optional[List[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
cache: Dict[str, Dict[str, torch.Tensor]],
) -> torch.Tensor:
"""Forward one step.
This is only used for decoding.
Args:
Expand All @@ -213,25 +213,39 @@ def forward_one_step(
y.shape` is (batch, maxlen_out, token)
"""
x, _ = self.embed(tgt)
new_cache = []
update_cross_att_cache = True
if len(cache['cross_att_cache']) != 0:
assert len(cache['cross_att_cache']) == self.num_blocks
update_cross_att_cache = False
for i, decoder in enumerate(self.decoders):
if cache is None:
c = None
else:
c = cache[i]
layer_i = 'layer_{}'.format(i)
self_att_cache = cache['self_att_cache'].get(layer_i, None)
cross_att_cache = cache['cross_att_cache'].get(layer_i, None)
c = {
'self_att_cache': self_att_cache,
'cross_att_cache': cross_att_cache,
}

x, tgt_mask, memory, memory_mask = decoder(x,
tgt_mask,
memory,
memory_mask,
cache=c)
new_cache.append(x)

# update cache dict
assert c['self_att_cache'] is not None
assert c['cross_att_cache'] is not None
cache['self_att_cache'][layer_i] = c['self_att_cache']
if update_cross_att_cache:
cache['cross_att_cache'][layer_i] = c['cross_att_cache']

if self.normalize_before:
y = self.after_norm(x[:, -1])
else:
y = x[:, -1]
if self.use_output_layer:
y = torch.log_softmax(self.output_layer(y), dim=-1)
return y, new_cache
return y

def tie_or_clone_weights(self, jit_mode: bool = True):
"""Tie or clone module weights (between word_emb and output_layer)
Expand Down
46 changes: 30 additions & 16 deletions wenet/transformer/decoder_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Decoder self-attention layer definition."""
from typing import Optional, Tuple
from typing import Dict, Optional, Tuple

import torch
from torch import nn
Expand Down Expand Up @@ -65,7 +65,7 @@ def forward(
tgt_mask: torch.Tensor,
memory: torch.Tensor,
memory_mask: torch.Tensor,
cache: Optional[torch.Tensor] = None
cache: Optional[Dict[str, Optional[torch.Tensor]]] = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute decoded features.

Expand All @@ -87,35 +87,52 @@ def forward(
torch.Tensor: Encoded memory mask (#batch, maxlen_in).

"""
if cache is not None:
att_cache = cache['self_att_cache']
cross_att_cache = cache['cross_att_cache']
else:
att_cache, cross_att_cache = None, None

residual = tgt
if self.normalize_before:
tgt = self.norm1(tgt)

if cache is None:
if att_cache is None:
tgt_q = tgt
tgt_q_mask = tgt_mask
att_cache = torch.empty(0, 0, 0, 0)
else:
# compute only the last frame query keeping dim: max_time_out -> 1
assert cache.shape == (
tgt.shape[0],
tgt.shape[1] - 1,
self.size,
), "{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
tgt_q = tgt[:, -1:, :]
residual = residual[:, -1:, :]
tgt_q_mask = tgt_mask[:, -1:, :]

x = residual + self.dropout(
self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0])
x, new_att_cache = self.self_attn(
tgt_q,
tgt_q,
tgt_q,
tgt_q_mask,
cache=att_cache,
)
if cache is not None:
cache['self_att_cache'] = new_att_cache
x = residual + self.dropout(x)
if not self.normalize_before:
x = self.norm1(x)

if self.src_attn is not None:
residual = x
if self.normalize_before:
x = self.norm2(x)
x = residual + self.dropout(
self.src_attn(x, memory, memory, memory_mask)[0])
if cross_att_cache is None:
cross_att_cache = torch.empty(0, 0, 0, 0)
x, new_cross_cache = self.src_attn(x,
memory,
memory,
memory_mask,
cache=cross_att_cache)
if cache is not None:
cache['cross_att_cache'] = new_cross_cache
x = residual + self.dropout(x)
if not self.normalize_before:
x = self.norm2(x)

Expand All @@ -126,7 +143,4 @@ def forward(
if not self.normalize_before:
x = self.norm3(x)

if cache is not None:
x = torch.cat([cache, x], dim=1)

return x, tgt_mask, memory, memory_mask
Loading
Loading