diff --git a/wenet/ssl/bestrq/mask.py b/wenet/ssl/bestrq/mask.py index f5e62eea0..c8905953e 100644 --- a/wenet/ssl/bestrq/mask.py +++ b/wenet/ssl/bestrq/mask.py @@ -1,53 +1,54 @@ -import torch - - -def _sampler(pdf: torch.Tensor, num_samples: int, - device=torch.device('cpu')) -> torch.Tensor: - size = pdf.size() - z = -torch.log(torch.rand(size, device=device)) - _, indices = torch.topk(pdf + z, num_samples) - return indices - - -def compute_mask_indices( - size: torch.Size, - mask_prob: float, - mask_length: int, - min_masks: int = 0, - device=torch.device('cpu')) -> torch.Tensor: - - assert len(size) == 2 - batch_size, seq_length = size - - # compute number of masked span in batch - num_masked_spans = mask_prob * float(seq_length) / float( - mask_length) + torch.rand(1)[0] - num_masked_spans = int(num_masked_spans) - num_masked_spans = max(num_masked_spans, min_masks) - - # num_masked <= seq_length - if num_masked_spans * mask_length > seq_length: - num_masked_spans = seq_length // mask_length - - pdf = torch.ones(batch_size, seq_length - (mask_length - 1), device=device) - mask_idxs = _sampler(pdf, num_masked_spans, device=device) - - mask_idxs = mask_idxs.unsqueeze(-1).repeat(1, 1, mask_length).view( - batch_size, - num_masked_spans * mask_length) # [B,num_masked_spans*mask_length] - - offset = torch.arange(mask_length, device=device).view(1, 1, -1).repeat( - 1, num_masked_spans, 1) # [1,num_masked_spans,mask_length] - offset = offset.view(1, num_masked_spans * mask_length) - - mask_idxs = mask_idxs + offset # [B,num_masked_spans, mask_length] - - ones = torch.ones(batch_size, - seq_length, - dtype=torch.bool, - device=mask_idxs.device) - # masks to fill - full_mask = torch.zeros_like(ones, - dtype=torch.bool, - device=mask_idxs.device) - return torch.scatter(full_mask, dim=1, index=mask_idxs, src=ones) +import torch + + +def _sampler(pdf: torch.Tensor, num_samples: int, + device=torch.device('cpu')) -> torch.Tensor: + size = pdf.size() + z = -torch.log(torch.rand(size, device=device)) + _, indices = torch.topk(pdf + z, num_samples) + return indices + + +def compute_mask_indices( + size: torch.Size, + mask_prob: float, + mask_length: int, + min_masks: int = 0, + device=torch.device('cpu'), +) -> torch.Tensor: + + assert len(size) == 2 + batch_size, seq_length = size + + # compute number of masked span in batch + num_masked_spans = mask_prob * float(seq_length) / float( + mask_length) + torch.rand(1)[0] + num_masked_spans = int(num_masked_spans) + num_masked_spans = max(num_masked_spans, min_masks) + + # num_masked <= seq_length + if num_masked_spans * mask_length > seq_length: + num_masked_spans = seq_length // mask_length + + pdf = torch.ones(batch_size, seq_length - (mask_length - 1), device=device) + mask_idxs = _sampler(pdf, num_masked_spans, device=device) + + mask_idxs = mask_idxs.unsqueeze(-1).repeat(1, 1, mask_length).view( + batch_size, + num_masked_spans * mask_length) # [B,num_masked_spans*mask_length] + + offset = torch.arange(mask_length, device=device).view(1, 1, -1).repeat( + 1, num_masked_spans, 1) # [1,num_masked_spans,mask_length] + offset = offset.view(1, num_masked_spans * mask_length) + + mask_idxs = mask_idxs + offset # [B,num_masked_spans, mask_length] + + ones = torch.ones(batch_size, + seq_length, + dtype=torch.bool, + device=mask_idxs.device) + # masks to fill + full_mask = torch.zeros_like(ones, + dtype=torch.bool, + device=mask_idxs.device) + return torch.scatter(full_mask, dim=1, index=mask_idxs, src=ones)