-
Notifications
You must be signed in to change notification settings - Fork 6.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Pull Request resolved: fairinternal/fairseq-py#757 Differential Revision: D16418305 Pulled By: myleott fbshipit-source-id: 25f293a2792509f7a75c688e4bf8cff02e6bba2e
- Loading branch information
1 parent
51ba352
commit 654affc
Showing
11 changed files
with
569 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Copyright (c) 2017-present, Facebook, Inc. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the license found in the LICENSE file in | ||
# the root directory of this source tree. An additional grant of patent rights | ||
# can be found in the PATENTS file in the same directory. | ||
|
||
from torch.utils.data._utils.collate import default_collate | ||
|
||
from . import FairseqDataset | ||
|
||
|
||
class BaseWrapperDataset(FairseqDataset): | ||
|
||
def __init__(self, dataset): | ||
super().__init__() | ||
self.dataset = dataset | ||
|
||
def __getitem__(self, index): | ||
return self.dataset[index] | ||
|
||
def __len__(self): | ||
return len(self.dataset) | ||
|
||
def collater(self, samples): | ||
if hasattr(self.dataset, 'collater'): | ||
return self.dataset.collater(samples) | ||
else: | ||
return default_collate(samples) | ||
|
||
@property | ||
def sizes(self): | ||
return self.dataset.sizes | ||
|
||
def num_tokens(self, index): | ||
return self.dataset.num_tokens(index) | ||
|
||
def size(self, index): | ||
return self.dataset.size(index) | ||
|
||
def ordered_indices(self): | ||
return self.dataset.ordered_indices() | ||
|
||
@property | ||
def supports_prefetch(self): | ||
return getattr(self.dataset, 'supports_prefetch', False) | ||
|
||
def prefetch(self, indices): | ||
self.dataset.prefetch(indices) | ||
|
||
def set_epoch(self, epoch): | ||
super().set_epoch(epoch) | ||
self.dataset.set_epoch(epoch) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Copyright (c) 2017-present, Facebook, Inc. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the license found in the LICENSE file in | ||
# the root directory of this source tree. An additional grant of patent rights | ||
# can be found in the PATENTS file in the same directory. | ||
|
||
import torch | ||
|
||
from . import FairseqDataset | ||
|
||
|
||
class IdDataset(FairseqDataset): | ||
|
||
def __getitem__(self, index): | ||
return index | ||
|
||
def __len__(self): | ||
return 0 | ||
|
||
def collater(self, samples): | ||
return torch.tensor(samples) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# Copyright (c) 2017-present, Facebook, Inc. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the license found in the LICENSE file in | ||
# the root directory of this source tree. An additional grant of patent rights | ||
# can be found in the PATENTS file in the same directory. | ||
|
||
from functools import lru_cache | ||
|
||
from . import BaseWrapperDataset | ||
|
||
|
||
class LRUCacheDataset(BaseWrapperDataset): | ||
|
||
def __init__(self, dataset, token=None): | ||
super().__init__(dataset) | ||
|
||
@lru_cache(maxsize=8) | ||
def __getitem__(self, index): | ||
return self.dataset[index] | ||
|
||
@lru_cache(maxsize=8) | ||
def collater(self, samples): | ||
return self.dataset.collater(samples) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
# Copyright (c) 2017-present, Facebook, Inc. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the license found in the LICENSE file in | ||
# the root directory of this source tree. An additional grant of patent rights | ||
# can be found in the PATENTS file in the same directory. | ||
|
||
from functools import lru_cache | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from fairseq.data import data_utils, Dictionary | ||
|
||
from . import BaseWrapperDataset, LRUCacheDataset | ||
|
||
|
||
class MaskTokensDataset(BaseWrapperDataset): | ||
""" | ||
A wrapper Dataset for masked language modeling. | ||
Input items are masked according to the specified masking probability. | ||
Args: | ||
dataset: Dataset to wrap. | ||
sizes: Sentence lengths | ||
vocab: Dictionary with the vocabulary and special tokens. | ||
pad_idx: Id of pad token in vocab | ||
mask_idx: Id of mask token in vocab | ||
return_masked_tokens: controls whether to return the non-masked tokens | ||
(the default) or to return a tensor with the original masked token | ||
IDs (and *pad_idx* elsewhere). The latter is useful as targets for | ||
masked LM training. | ||
seed: Seed for random number generator for reproducibility. | ||
mask_prob: probability of replacing a token with *mask_idx*. | ||
leave_unmasked_prob: probability that a masked token is unmasked. | ||
random_token_prob: probability of replacing a masked token with a | ||
random token from the vocabulary. | ||
freq_weighted_replacement: sample random replacement words based on | ||
word frequencies in the vocab. | ||
mask_whole_words: only mask whole words. This should be a byte mask | ||
over vocab indices, indicating whether it is the beginning of a | ||
word. We will extend any mask to encompass the whole word. | ||
bpe: BPE to use for whole-word masking. | ||
""" | ||
|
||
@classmethod | ||
def apply_mask(cls, dataset: torch.utils.data.Dataset, *args, **kwargs): | ||
"""Return the source and target datasets for masked LM training.""" | ||
dataset = LRUCacheDataset(dataset) | ||
return ( | ||
LRUCacheDataset(cls(dataset, *args, **kwargs, return_masked_tokens=False)), | ||
LRUCacheDataset(cls(dataset, *args, **kwargs, return_masked_tokens=True)), | ||
) | ||
|
||
def __init__( | ||
self, | ||
dataset: torch.utils.data.Dataset, | ||
vocab: Dictionary, | ||
pad_idx: int, | ||
mask_idx: int, | ||
return_masked_tokens: bool = False, | ||
seed: int = 1, | ||
mask_prob: float = 0.15, | ||
leave_unmasked_prob: float = 0.1, | ||
random_token_prob: float = 0.1, | ||
freq_weighted_replacement: bool = False, | ||
mask_whole_words: torch.Tensor = None, | ||
): | ||
assert 0.0 < mask_prob < 1.0 | ||
assert 0.0 <= random_token_prob <= 1.0 | ||
assert 0.0 <= leave_unmasked_prob <= 1.0 | ||
assert random_token_prob + leave_unmasked_prob <= 1.0 | ||
|
||
self.dataset = dataset | ||
self.vocab = vocab | ||
self.pad_idx = pad_idx | ||
self.mask_idx = mask_idx | ||
self.return_masked_tokens = return_masked_tokens | ||
self.seed = seed | ||
self.mask_prob = mask_prob | ||
self.leave_unmasked_prob = leave_unmasked_prob | ||
self.random_token_prob = random_token_prob | ||
self.mask_whole_words = mask_whole_words | ||
|
||
if random_token_prob > 0.0: | ||
if freq_weighted_replacement: | ||
weights = np.array(self.vocab.count) | ||
else: | ||
weights = np.ones(len(self.vocab)) | ||
weights[:self.vocab.nspecial] = 0 | ||
self.weights = weights / weights.sum() | ||
|
||
self.epoch = 0 | ||
|
||
def set_epoch(self, epoch, **unused): | ||
self.epoch = epoch | ||
|
||
@lru_cache(maxsize=8) | ||
def __getitem__(self, index: int): | ||
with data_utils.numpy_seed(self.seed, self.epoch, index): | ||
item = self.dataset[index] | ||
sz = len(item) | ||
|
||
assert self.mask_idx not in item, \ | ||
'Dataset contains mask_idx (={}), this is not expected!'.format( | ||
self.mask_idx, | ||
) | ||
|
||
if self.mask_whole_words is not None: | ||
word_begins_mask = self.mask_whole_words.gather(0, item) | ||
word_begins_idx = word_begins_mask.nonzero().view(-1) | ||
sz = len(word_begins_idx) | ||
words = np.split(word_begins_mask, word_begins_idx)[1:] | ||
assert len(words) == sz | ||
word_lens = list(map(len, words)) | ||
|
||
# decide elements to mask | ||
mask = np.full(sz, False) | ||
num_mask = int( | ||
# add a random number for probabilistic rounding | ||
self.mask_prob * sz + np.random.rand() | ||
) | ||
mask[np.random.choice(sz, num_mask, replace=False)] = True | ||
|
||
if self.return_masked_tokens: | ||
# exit early if we're just returning the masked tokens | ||
# (i.e., the targets for masked LM training) | ||
if self.mask_whole_words is not None: | ||
mask = np.repeat(mask, word_lens) | ||
new_item = np.full(len(mask), self.pad_idx) | ||
new_item[mask] = item[torch.from_numpy(mask)] | ||
return torch.from_numpy(new_item) | ||
|
||
# decide unmasking and random replacement | ||
rand_or_unmask_prob = self.random_token_prob + self.leave_unmasked_prob | ||
if rand_or_unmask_prob > 0.0: | ||
rand_or_unmask = mask & (np.random.rand(sz) < rand_or_unmask_prob) | ||
if self.random_token_prob == 0.0: | ||
unmask = rand_or_unmask | ||
rand_mask = None | ||
elif self.leave_unmasked_prob == 0.0: | ||
unmask = None | ||
rand_mask = rand_or_unmask | ||
else: | ||
unmask_prob = self.leave_unmasked_prob / rand_or_unmask_prob | ||
decision = np.random.rand(sz) < unmask_prob | ||
unmask = rand_or_unmask & decision | ||
rand_mask = rand_or_unmask & (~decision) | ||
else: | ||
unmask = rand_mask = None | ||
|
||
if unmask is not None: | ||
mask = mask ^ unmask | ||
|
||
if self.mask_whole_words is not None: | ||
mask = np.repeat(mask, word_lens) | ||
|
||
new_item = np.copy(item) | ||
new_item[mask] = self.mask_idx | ||
if rand_mask is not None: | ||
num_rand = rand_mask.sum() | ||
if num_rand > 0: | ||
if self.mask_whole_words is not None: | ||
rand_mask = np.repeat(rand_mask, word_lens) | ||
num_rand = rand_mask.sum() | ||
|
||
new_item[rand_mask] = np.random.choice( | ||
len(self.vocab), | ||
num_rand, | ||
p=self.weights, | ||
) | ||
|
||
return torch.from_numpy(new_item) |
Oops, something went wrong.