Skip to content

Commit

Permalink
[ssl/w2vbert] support part of w2vbert training
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct committed Oct 9, 2023
1 parent 3790509 commit 08c6df0
Show file tree
Hide file tree
Showing 3 changed files with 332 additions and 10 deletions.
318 changes: 318 additions & 0 deletions wenet/ssl/w2vbert/w2vbert_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,318 @@
import math
from typing import Optional, Tuple, Union
import torch

from wenet.ssl.bestrq.mask import compute_mask_indices_v2
from wenet.ssl.wav2vec2.quantizer import Wav2vecGumbelVectorQuantizer
from wenet.ssl.wav2vec2.wav2vec2_model import (_compute_contrastive_loss,
_sample_negative_indices)
from wenet.transformer.attention import RelPositionMultiHeadedAttention

from wenet.transformer.encoder import ConformerEncoder, TransformerEncoder
from wenet.transformer.encoder_layer import ConformerEncoderLayer
from wenet.utils.mask import make_non_pad_mask


class W2VBERTModel(torch.nn.Module):

def __init__(
self,
encoder: Union[ConformerEncoder, TransformerEncoder],
embedding_dim: int = 256,
num_embeddings: int = 320,
num_codebooks: int = 1,
mask_prob: float = 0.065,
mask_length: int = 10,
min_masks: int = 2,
num_negatives: int = 100,
features_regularization_weight: float = 0.01,
max_gumbel_temperature: float = 2.0,
min_gumbel_temperature: float = 0.1,
gumbel_temperature_decay: float = 0.999995,
contrastive_logits_temperature: float = 0.1,
diversity_weight: float = 0.0,
bias: bool = True,
contrastive_blocks: int = 6,
masked_blocks: int = 6,
contrastive_weight: float = 1.0,
mlm_weight: float = 1.0,
) -> None:
""" Wrap encoder to train using W2V-BERT's style
Described in:
https://arxiv.org/pdf/2108.06209v2.pdf
Args:
encoder: wenet's encoder,
only support conformer and transformer now
embedding_dim: codebooks embedding dim
num_embeddings: numbers of each codebook
num_codebooks: numbers of codebooks i.e groups of codebook
mask_prob: probs of mask
mask_length: spans of masks
min_maks: min masks for each audio
num_negatives: numbers of negatives of each masks
features_regularization_weight: l2 regularization weight
max_gumbel_temperature: maximum temperature for gumbel softmax
min_gumbel_temperature: minimum temperature for gumbel softmax
gumbel_temperature_decay:
decay of gumbel temperature during training
contrastive_logits_temperature:
the temperature in the contrastive loss.
"""
super().__init__()
assert mask_prob > 0.0
assert (contrastive_blocks > 0 and masked_blocks > 0 and
contrastive_blocks + masked_blocks == len(encoder.encoders))
self.contrastive_blocks = contrastive_blocks
self.masked_blocks = masked_blocks

self.mask_prob = mask_prob
self.mask_length = mask_length
self.min_masks = min_masks
self.num_negatives = num_negatives

self.features_regularization_weight = features_regularization_weight
self.diversity_weight = diversity_weight

self.contrastive_weight = contrastive_weight
self.mlm_weight = mlm_weight
# encoder
self.encoder = encoder

# quantizer
self.num_codebooks = num_codebooks
self.quantizer = Wav2vecGumbelVectorQuantizer(
self.encoder.output_size(),
num_codebooks=num_codebooks,
num_embeddings=num_embeddings,
embedding_dim=embedding_dim,
hard=False,
)
self.max_gumbel_temp = max_gumbel_temperature
self.min_gumbel_temp = min_gumbel_temperature
self.gumbel_temp_decay = gumbel_temperature_decay

self.num_codevectors_per_group = num_embeddings
self.num_codevector_groups = num_codebooks

self.contrastive_logits_temp = contrastive_logits_temperature

# NOET(Mddct): mask_em is replaced by random value in Wav-BERT
# self.mask_emb = torch.nn.parameter.Parameter(
# torch.empty(self.encoder.output_size()).uniform_(),
# requires_grad=True,
# )
# TODO(Mddct): support causal or lookahead mask or keep consistent with
# wenet dynamic chunk training

# # n softmax
self.encoder_top_n_out = torch.nn.parameter.Parameter(
torch.empty(num_codebooks, self.encoder.output_size(),
num_embeddings))
torch.nn.init.trunc_normal_(self.encoder_top_n_out, std=0.02)
self.bias = bias
if bias:
self.encoder_top_n_out_bias = torch.nn.parameter.Parameter(
torch.empty(num_codebooks, num_embeddings))
torch.nn.init.zeros_(self.encoder_top_n_out_bias)

# reset parameter
self.reset_encoder_parameter()

def reset_encoder_parameter(self):

def _reset_parameter(module: torch.nn.Module):
if isinstance(module, torch.nn.Linear):
torch.nn.init.trunc_normal_(module.weight.data,
mean=0.0,
std=0.02)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, torch.nn.Conv1d):
torch.nn.init.kaiming_normal_(module.weight)
if module.bias is not None:
k = math.sqrt(module.groups /
(module.in_channels * module.kernel_size[0]))
torch.nn.init.uniform_(module.bias, a=-k, b=k)
elif isinstance(module, torch.Tensor):
torch.nn.init.trunc_normal_(module)
else:
raise NotImplementedError("other module not support now")

encoders = self.encoder.encoders
for _, layer in enumerate(encoders):
self_attn = layer.self_attn
_reset_parameter(self_attn.linear_q)
_reset_parameter(self_attn.linear_k)
_reset_parameter(self_attn.linear_v)
_reset_parameter(self_attn.linear_out)
if isinstance(self_attn, RelPositionMultiHeadedAttention):
_reset_parameter(self_attn.pos_bias_u)
_reset_parameter(self_attn.pos_bias_v)
if isinstance(layer, ConformerEncoderLayer):
conv1, conv2 = (layer.conv_module.pointwise_conv1,
layer.conv_module.depthwise_conv)
_reset_parameter(conv1)
_reset_parameter(conv2)

@torch.jit.ignore(drop=True)
def forward(
self,
xs: torch.Tensor,
xs_lens: torch.Tensor,
text: Optional[torch.Tensor] = None,
text_length: Optional[torch.Tensor] = None,
steps: Optional[int] = None,
):

assert xs.size(0) == xs_lens.size(0)
assert steps is not None

# 1 forward subsampling
# NOTE(Mddct): use subsampling as feature extraction
xs, pos_emb, masks = self._forward_subsampling(xs, xs_lens)
unmasked_xs = xs
# 2 mask features
masked_xs, masked_masks = self._apply_mask(xs, masks.squeeze(1))
# 3 forward encoder blocks
contrastive_vec, mlm_vec, out_mask = self._forward_encoder_blocks(
masked_xs, masks, pos_emb, masks)

# 4 constrastive branch
gumbel_temperature = max(
self.max_gumbel_temp * self.gumbel_temp_decay**steps,
self.min_gumbel_temp)

quantized_features, codevector_perplexity, targets_ids = self.quantizer(
unmasked_xs, masks.squeeze(1), gumbel_temperature)

sampled_negative_indices = _sample_negative_indices(
xs.size()[:-1], self.num_negatives, masked_masks.device,
masked_masks)

loss_contrastive = _compute_contrastive_loss(
quantized_features, contrastive_vec, sampled_negative_indices,
masked_masks, self.contrastive_logits_temp, self.num_negatives)
loss = loss_contrastive

# scale by sample size
# make sure that diversity loss is multiplied by `sample_size`
# since contrastive_loss is `sum`-reduced instead of averaged
sample_size = masked_masks.sum()
# higher codevector_perplexity leads to lower diversity loss
loss_diversity: Optional[torch.Tensor] = None
if self.diversity_weight != 0.0:
loss_diversity = (
self.num_codevector_groups * self.num_codevectors_per_group -
codevector_perplexity) / (self.num_codevectors_per_group *
self.num_codevector_groups)
loss_diversity = loss_diversity * sample_size
loss = loss + self.diversity_weight * loss_diversity
loss = loss / sample_size

features_pen: Optional[torch.Tensor] = None
if self.features_regularization_weight != 0.0:
features_pen = xs.pow(2).mean()
loss = loss + self.features_regularization_weight * features_pen

# 5 maked lm branch
out = mlm_vec.unsqueeze(1)
top_n_out = self.encoder_top_n_out.unsqueeze(
0) # [1, num_codebooks, dim, num_embeddings]
out = torch.matmul(out,
top_n_out) # [B, num_codebooks, T', num_embeddings]
if self.bias:
out = out + self.encoder_top_n_out_bias.unsqueeze(0).unsqueeze(2)
out = out + self.encoder_top_n_out_bias.unsqueeze(0).unsqueeze(2)

num_codes = masked_masks.sum() * self.num_codebooks
loss_mlm = self._compute_mlm_loss(out,
targets_ids,
mask=out_mask.squeeze(1) *
masked_masks)
ids_corr = out.argmax(dim=-1,
keepdim=False).transpose(1, 2) == targets_ids
codes_acc = (ids_corr * masked_masks.unsqueeze(2)).sum() / num_codes
# TODO(Mddct): support num codes used in batch, unique num codes
# used in batch like bestrq

# 6 final loss
loss = self.contrastive_weight * loss + self.mlm_weight * loss_mlm
return {
"code_ppl": codevector_perplexity.detach(),
"features_l2": features_pen,
"codes_acc": codes_acc.detach(),
"loss": loss,
"losss_constrastive": loss_contrastive / sample_size,
"loss_diversity": loss_diversity,
"loss_mlm": loss_mlm,
}

def _apply_mask(
self, xs: torch.Tensor,
xs_masks: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

masks = compute_mask_indices_v2(xs.size()[:-1],
~xs_masks,
self.mask_prob,
self.mask_length,
min_masks=self.min_masks,
device=xs.device)
masks_expand = masks.unsqueeze(-1) # [B, T, 1]

mask_emb = torch.normal(mean=0,
std=0.1,
size=xs.size(),
device=xs.device)
xs = torch.where(masks_expand, mask_emb, xs)

return xs, masks

def _compute_mlm_loss(self, input: torch.Tensor, target: torch.Tensor,
mask: torch.Tensor) -> torch.Tensor:
log_probs = torch.log_softmax(input, dim=-1).transpose(
1, 2) # [B, T', num_codebooks, num_embeddings]

per_example_n_loss = -log_probs.gather(3, target.unsqueeze(3)).squeeze(
3) # [B, T', num_codebooks]

numerator = torch.sum(per_example_n_loss * mask.unsqueeze(2))
denominator = torch.sum(mask) + 1e-5
loss = numerator / (denominator * self.num_codebooks)
return loss

def _forward_subsampling(
self, xs: torch.Tensor, xs_lens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

masks = make_non_pad_mask(xs_lens).unsqueeze(1) # (B, 1, T)
if self.encoder.global_cmvn is not None:
xs = self.encoder.global_cmvn(xs)
xs, pos_emb, masks = self.encoder.embed(xs, masks)
return xs, pos_emb, masks

def _forward_encoder_blocks(
self, xs: torch.Tensor, xs_masks: torch.Tensor, pos_emb: torch.Tensor,
mask_pad: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

masks = xs_masks

xs: torch.Tensor
# forward contrastive layers get context vector for Contrastive Loss
for layer in self.encoder.encoders[:self.contrastive_blocks]:
xs, masks, _, _ = layer(xs, xs_masks, pos_emb, mask_pad)
contrastive_vec = xs

for layer in self.encoder.encoders[self.contrastive_blocks:]:
xs, masks, _, _ = layer(xs, xs_masks, pos_emb, mask_pad)
masked_vec = xs

if self.encoder.normalize_before:
xs = self.encoder.after_norm(xs)
masked_vec = xs
# Here we assume the mask is not changed in encoder layers, so just
# return the masks before encoder layers, and the masks will be used
# for cross attention with decoder later
return contrastive_vec, masked_vec, masks
14 changes: 9 additions & 5 deletions wenet/ssl/wav2vec2/quantizer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Tuple
import torch


Expand Down Expand Up @@ -66,10 +67,12 @@ def _compute_perplexity(probs, mask=None):
marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum()
return perplexity

def forward(self,
input: torch.Tensor,
input_mask: torch.Tensor,
temperature: float = 1.):
def forward(
self,
input: torch.Tensor,
input_mask: torch.Tensor,
temperature: float = 1.
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

b, t, _ = input.size()

Expand Down Expand Up @@ -98,6 +101,7 @@ def forward(self,
b * t, self.num_groups, -1)
perplexity = self._compute_perplexity(codevector_probs, input_mask)

targets_idx = codevector_probs.argmax(-1).reshape(b, t, -1)
codevector_probs = codevector_probs.reshape(b * t, -1)
# use probs to retrieve codevectors
codevectors_per_group = codevector_probs.unsqueeze(
Expand All @@ -106,4 +110,4 @@ def forward(self,
b * t, self.num_groups, self.num_codevectors_per_group, -1)

codevectors = codevectors.sum(-2).reshape(b, t, -1)
return codevectors, perplexity
return codevectors, perplexity, targets_idx
10 changes: 5 additions & 5 deletions wenet/ssl/wav2vec2/wav2vec2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,17 @@ def _sample_negative_indices(features_shape: Tuple,
"""
batch_size, sequence_length = features_shape

sequence_length_range = torch.arange(sequence_length)
sequence_length_range = torch.arange(sequence_length, device=device)

# get `num_negatives` random vector indices from the same utterance
sampled_negative_indices = torch.zeros(
(batch_size, sequence_length, num_negatives),
dtype=sequence_length_range.dtype,
device=device)

mask_time_indices = (mask_time_indices.bool() if mask_time_indices
is not None else torch.ones(features_shape,
dtype=torch.bool))
mask_time_indices = (mask_time_indices.bool()
if mask_time_indices is not None else torch.ones(
features_shape, dtype=torch.bool, device=device))

for batch_idx in range(batch_size):
high = mask_time_indices[batch_idx].sum() - 1
Expand Down Expand Up @@ -243,7 +243,7 @@ def forward(
self.max_gumbel_temp * self.gumbel_temp_decay**steps,
self.min_gumbel_temp)

quantized_features, codevector_perplexity = self.quantizer(
quantized_features, codevector_perplexity, _ = self.quantizer(
unmasked_xs, masks.squeeze(1), gumbel_temperature)

sampled_negative_indices = _sample_negative_indices(
Expand Down

0 comments on commit 08c6df0

Please sign in to comment.