-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ssl/w2vbert] support part of w2vbert training
- Loading branch information
Showing
3 changed files
with
332 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,318 @@ | ||
import math | ||
from typing import Optional, Tuple, Union | ||
import torch | ||
|
||
from wenet.ssl.bestrq.mask import compute_mask_indices_v2 | ||
from wenet.ssl.wav2vec2.quantizer import Wav2vecGumbelVectorQuantizer | ||
from wenet.ssl.wav2vec2.wav2vec2_model import (_compute_contrastive_loss, | ||
_sample_negative_indices) | ||
from wenet.transformer.attention import RelPositionMultiHeadedAttention | ||
|
||
from wenet.transformer.encoder import ConformerEncoder, TransformerEncoder | ||
from wenet.transformer.encoder_layer import ConformerEncoderLayer | ||
from wenet.utils.mask import make_non_pad_mask | ||
|
||
|
||
class W2VBERTModel(torch.nn.Module): | ||
|
||
def __init__( | ||
self, | ||
encoder: Union[ConformerEncoder, TransformerEncoder], | ||
embedding_dim: int = 256, | ||
num_embeddings: int = 320, | ||
num_codebooks: int = 1, | ||
mask_prob: float = 0.065, | ||
mask_length: int = 10, | ||
min_masks: int = 2, | ||
num_negatives: int = 100, | ||
features_regularization_weight: float = 0.01, | ||
max_gumbel_temperature: float = 2.0, | ||
min_gumbel_temperature: float = 0.1, | ||
gumbel_temperature_decay: float = 0.999995, | ||
contrastive_logits_temperature: float = 0.1, | ||
diversity_weight: float = 0.0, | ||
bias: bool = True, | ||
contrastive_blocks: int = 6, | ||
masked_blocks: int = 6, | ||
contrastive_weight: float = 1.0, | ||
mlm_weight: float = 1.0, | ||
) -> None: | ||
""" Wrap encoder to train using W2V-BERT's style | ||
Described in: | ||
https://arxiv.org/pdf/2108.06209v2.pdf | ||
Args: | ||
encoder: wenet's encoder, | ||
only support conformer and transformer now | ||
embedding_dim: codebooks embedding dim | ||
num_embeddings: numbers of each codebook | ||
num_codebooks: numbers of codebooks i.e groups of codebook | ||
mask_prob: probs of mask | ||
mask_length: spans of masks | ||
min_maks: min masks for each audio | ||
num_negatives: numbers of negatives of each masks | ||
features_regularization_weight: l2 regularization weight | ||
max_gumbel_temperature: maximum temperature for gumbel softmax | ||
min_gumbel_temperature: minimum temperature for gumbel softmax | ||
gumbel_temperature_decay: | ||
decay of gumbel temperature during training | ||
contrastive_logits_temperature: | ||
the temperature in the contrastive loss. | ||
""" | ||
super().__init__() | ||
assert mask_prob > 0.0 | ||
assert (contrastive_blocks > 0 and masked_blocks > 0 and | ||
contrastive_blocks + masked_blocks == len(encoder.encoders)) | ||
self.contrastive_blocks = contrastive_blocks | ||
self.masked_blocks = masked_blocks | ||
|
||
self.mask_prob = mask_prob | ||
self.mask_length = mask_length | ||
self.min_masks = min_masks | ||
self.num_negatives = num_negatives | ||
|
||
self.features_regularization_weight = features_regularization_weight | ||
self.diversity_weight = diversity_weight | ||
|
||
self.contrastive_weight = contrastive_weight | ||
self.mlm_weight = mlm_weight | ||
# encoder | ||
self.encoder = encoder | ||
|
||
# quantizer | ||
self.num_codebooks = num_codebooks | ||
self.quantizer = Wav2vecGumbelVectorQuantizer( | ||
self.encoder.output_size(), | ||
num_codebooks=num_codebooks, | ||
num_embeddings=num_embeddings, | ||
embedding_dim=embedding_dim, | ||
hard=False, | ||
) | ||
self.max_gumbel_temp = max_gumbel_temperature | ||
self.min_gumbel_temp = min_gumbel_temperature | ||
self.gumbel_temp_decay = gumbel_temperature_decay | ||
|
||
self.num_codevectors_per_group = num_embeddings | ||
self.num_codevector_groups = num_codebooks | ||
|
||
self.contrastive_logits_temp = contrastive_logits_temperature | ||
|
||
# NOET(Mddct): mask_em is replaced by random value in Wav-BERT | ||
# self.mask_emb = torch.nn.parameter.Parameter( | ||
# torch.empty(self.encoder.output_size()).uniform_(), | ||
# requires_grad=True, | ||
# ) | ||
# TODO(Mddct): support causal or lookahead mask or keep consistent with | ||
# wenet dynamic chunk training | ||
|
||
# # n softmax | ||
self.encoder_top_n_out = torch.nn.parameter.Parameter( | ||
torch.empty(num_codebooks, self.encoder.output_size(), | ||
num_embeddings)) | ||
torch.nn.init.trunc_normal_(self.encoder_top_n_out, std=0.02) | ||
self.bias = bias | ||
if bias: | ||
self.encoder_top_n_out_bias = torch.nn.parameter.Parameter( | ||
torch.empty(num_codebooks, num_embeddings)) | ||
torch.nn.init.zeros_(self.encoder_top_n_out_bias) | ||
|
||
# reset parameter | ||
self.reset_encoder_parameter() | ||
|
||
def reset_encoder_parameter(self): | ||
|
||
def _reset_parameter(module: torch.nn.Module): | ||
if isinstance(module, torch.nn.Linear): | ||
torch.nn.init.trunc_normal_(module.weight.data, | ||
mean=0.0, | ||
std=0.02) | ||
if module.bias is not None: | ||
module.bias.data.zero_() | ||
elif isinstance(module, torch.nn.Conv1d): | ||
torch.nn.init.kaiming_normal_(module.weight) | ||
if module.bias is not None: | ||
k = math.sqrt(module.groups / | ||
(module.in_channels * module.kernel_size[0])) | ||
torch.nn.init.uniform_(module.bias, a=-k, b=k) | ||
elif isinstance(module, torch.Tensor): | ||
torch.nn.init.trunc_normal_(module) | ||
else: | ||
raise NotImplementedError("other module not support now") | ||
|
||
encoders = self.encoder.encoders | ||
for _, layer in enumerate(encoders): | ||
self_attn = layer.self_attn | ||
_reset_parameter(self_attn.linear_q) | ||
_reset_parameter(self_attn.linear_k) | ||
_reset_parameter(self_attn.linear_v) | ||
_reset_parameter(self_attn.linear_out) | ||
if isinstance(self_attn, RelPositionMultiHeadedAttention): | ||
_reset_parameter(self_attn.pos_bias_u) | ||
_reset_parameter(self_attn.pos_bias_v) | ||
if isinstance(layer, ConformerEncoderLayer): | ||
conv1, conv2 = (layer.conv_module.pointwise_conv1, | ||
layer.conv_module.depthwise_conv) | ||
_reset_parameter(conv1) | ||
_reset_parameter(conv2) | ||
|
||
@torch.jit.ignore(drop=True) | ||
def forward( | ||
self, | ||
xs: torch.Tensor, | ||
xs_lens: torch.Tensor, | ||
text: Optional[torch.Tensor] = None, | ||
text_length: Optional[torch.Tensor] = None, | ||
steps: Optional[int] = None, | ||
): | ||
|
||
assert xs.size(0) == xs_lens.size(0) | ||
assert steps is not None | ||
|
||
# 1 forward subsampling | ||
# NOTE(Mddct): use subsampling as feature extraction | ||
xs, pos_emb, masks = self._forward_subsampling(xs, xs_lens) | ||
unmasked_xs = xs | ||
# 2 mask features | ||
masked_xs, masked_masks = self._apply_mask(xs, masks.squeeze(1)) | ||
# 3 forward encoder blocks | ||
contrastive_vec, mlm_vec, out_mask = self._forward_encoder_blocks( | ||
masked_xs, masks, pos_emb, masks) | ||
|
||
# 4 constrastive branch | ||
gumbel_temperature = max( | ||
self.max_gumbel_temp * self.gumbel_temp_decay**steps, | ||
self.min_gumbel_temp) | ||
|
||
quantized_features, codevector_perplexity, targets_ids = self.quantizer( | ||
unmasked_xs, masks.squeeze(1), gumbel_temperature) | ||
|
||
sampled_negative_indices = _sample_negative_indices( | ||
xs.size()[:-1], self.num_negatives, masked_masks.device, | ||
masked_masks) | ||
|
||
loss_contrastive = _compute_contrastive_loss( | ||
quantized_features, contrastive_vec, sampled_negative_indices, | ||
masked_masks, self.contrastive_logits_temp, self.num_negatives) | ||
loss = loss_contrastive | ||
|
||
# scale by sample size | ||
# make sure that diversity loss is multiplied by `sample_size` | ||
# since contrastive_loss is `sum`-reduced instead of averaged | ||
sample_size = masked_masks.sum() | ||
# higher codevector_perplexity leads to lower diversity loss | ||
loss_diversity: Optional[torch.Tensor] = None | ||
if self.diversity_weight != 0.0: | ||
loss_diversity = ( | ||
self.num_codevector_groups * self.num_codevectors_per_group - | ||
codevector_perplexity) / (self.num_codevectors_per_group * | ||
self.num_codevector_groups) | ||
loss_diversity = loss_diversity * sample_size | ||
loss = loss + self.diversity_weight * loss_diversity | ||
loss = loss / sample_size | ||
|
||
features_pen: Optional[torch.Tensor] = None | ||
if self.features_regularization_weight != 0.0: | ||
features_pen = xs.pow(2).mean() | ||
loss = loss + self.features_regularization_weight * features_pen | ||
|
||
# 5 maked lm branch | ||
out = mlm_vec.unsqueeze(1) | ||
top_n_out = self.encoder_top_n_out.unsqueeze( | ||
0) # [1, num_codebooks, dim, num_embeddings] | ||
out = torch.matmul(out, | ||
top_n_out) # [B, num_codebooks, T', num_embeddings] | ||
if self.bias: | ||
out = out + self.encoder_top_n_out_bias.unsqueeze(0).unsqueeze(2) | ||
out = out + self.encoder_top_n_out_bias.unsqueeze(0).unsqueeze(2) | ||
|
||
num_codes = masked_masks.sum() * self.num_codebooks | ||
loss_mlm = self._compute_mlm_loss(out, | ||
targets_ids, | ||
mask=out_mask.squeeze(1) * | ||
masked_masks) | ||
ids_corr = out.argmax(dim=-1, | ||
keepdim=False).transpose(1, 2) == targets_ids | ||
codes_acc = (ids_corr * masked_masks.unsqueeze(2)).sum() / num_codes | ||
# TODO(Mddct): support num codes used in batch, unique num codes | ||
# used in batch like bestrq | ||
|
||
# 6 final loss | ||
loss = self.contrastive_weight * loss + self.mlm_weight * loss_mlm | ||
return { | ||
"code_ppl": codevector_perplexity.detach(), | ||
"features_l2": features_pen, | ||
"codes_acc": codes_acc.detach(), | ||
"loss": loss, | ||
"losss_constrastive": loss_contrastive / sample_size, | ||
"loss_diversity": loss_diversity, | ||
"loss_mlm": loss_mlm, | ||
} | ||
|
||
def _apply_mask( | ||
self, xs: torch.Tensor, | ||
xs_masks: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | ||
|
||
masks = compute_mask_indices_v2(xs.size()[:-1], | ||
~xs_masks, | ||
self.mask_prob, | ||
self.mask_length, | ||
min_masks=self.min_masks, | ||
device=xs.device) | ||
masks_expand = masks.unsqueeze(-1) # [B, T, 1] | ||
|
||
mask_emb = torch.normal(mean=0, | ||
std=0.1, | ||
size=xs.size(), | ||
device=xs.device) | ||
xs = torch.where(masks_expand, mask_emb, xs) | ||
|
||
return xs, masks | ||
|
||
def _compute_mlm_loss(self, input: torch.Tensor, target: torch.Tensor, | ||
mask: torch.Tensor) -> torch.Tensor: | ||
log_probs = torch.log_softmax(input, dim=-1).transpose( | ||
1, 2) # [B, T', num_codebooks, num_embeddings] | ||
|
||
per_example_n_loss = -log_probs.gather(3, target.unsqueeze(3)).squeeze( | ||
3) # [B, T', num_codebooks] | ||
|
||
numerator = torch.sum(per_example_n_loss * mask.unsqueeze(2)) | ||
denominator = torch.sum(mask) + 1e-5 | ||
loss = numerator / (denominator * self.num_codebooks) | ||
return loss | ||
|
||
def _forward_subsampling( | ||
self, xs: torch.Tensor, xs_lens: torch.Tensor | ||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
|
||
masks = make_non_pad_mask(xs_lens).unsqueeze(1) # (B, 1, T) | ||
if self.encoder.global_cmvn is not None: | ||
xs = self.encoder.global_cmvn(xs) | ||
xs, pos_emb, masks = self.encoder.embed(xs, masks) | ||
return xs, pos_emb, masks | ||
|
||
def _forward_encoder_blocks( | ||
self, xs: torch.Tensor, xs_masks: torch.Tensor, pos_emb: torch.Tensor, | ||
mask_pad: torch.Tensor | ||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
|
||
masks = xs_masks | ||
|
||
xs: torch.Tensor | ||
# forward contrastive layers get context vector for Contrastive Loss | ||
for layer in self.encoder.encoders[:self.contrastive_blocks]: | ||
xs, masks, _, _ = layer(xs, xs_masks, pos_emb, mask_pad) | ||
contrastive_vec = xs | ||
|
||
for layer in self.encoder.encoders[self.contrastive_blocks:]: | ||
xs, masks, _, _ = layer(xs, xs_masks, pos_emb, mask_pad) | ||
masked_vec = xs | ||
|
||
if self.encoder.normalize_before: | ||
xs = self.encoder.after_norm(xs) | ||
masked_vec = xs | ||
# Here we assume the mask is not changed in encoder layers, so just | ||
# return the masks before encoder layers, and the masks will be used | ||
# for cross attention with decoder later | ||
return contrastive_vec, masked_vec, masks |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters