Skip to content

Commit

Permalink
Support single forward data input for BackTranslationAug #146
Browse files Browse the repository at this point in the history
  • Loading branch information
makcedward committed Sep 19, 2020
1 parent 13fab78 commit 818ae70
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 20 deletions.
29 changes: 19 additions & 10 deletions nlpaug/augmenter/word/back_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,21 @@


def init_back_translatoin_model(from_model_name, from_model_checkpt, to_model_name, to_model_checkpt,
tokenzier_name, bpe_name, device, force_reload=False):
tokenzier_name, bpe_name, is_load_from_github, device, force_reload=False):
global BACK_TRANSLATION_MODELS

model_name = '_'.join([from_model_name, to_model_name])
if model_name in BACK_TRANSLATION_MODELS and not force_reload:
BACK_TRANSLATION_MODELS[model_name].tokenzier_name = tokenzier_name
BACK_TRANSLATION_MODELS[model_name].bpe_name = bpe_name
BACK_TRANSLATION_MODELS[model_name].is_load_from_github = is_load_from_github
BACK_TRANSLATION_MODELS[model_name].device = device

return BACK_TRANSLATION_MODELS[model_name]
model = nml.Fairseq(from_model_name=from_model_name, from_model_checkpt=from_model_checkpt,
to_model_name=to_model_name, to_model_checkpt=to_model_checkpt,
tokenzier_name=tokenzier_name, bpe_name=bpe_name, device=device)
tokenzier_name=tokenzier_name, bpe_name=bpe_name, is_load_from_github=is_load_from_github,
device=device)

BACK_TRANSLATION_MODELS[model_name] = model
return model
Expand All @@ -41,6 +47,8 @@ class BackTranslationAug(WordAugmenter):
:param str bpe: Default value is 'fastbpe'
:param str device: Use either cpu or gpu. Default value is None, it uses GPU if having. While possible values are
'cuda' and 'cpu'.
:param bool is_load_from_github: Default is True. If True, transaltion models will be loaded from fairseq's
github. Otherwise, providing model directory for both `from_model_name` and `to_model_name` parameters.
:param bool force_reload: Force reload the contextual word embeddings model to memory when initialize the class.
Default value is False and suggesting to keep it as False if performance is the consideration.
:param str name: Name of this augmenter
Expand All @@ -49,31 +57,32 @@ class BackTranslationAug(WordAugmenter):
>>> aug = naw.BackTranslationAug()
"""

def __init__(self, from_model_name, to_model_name, from_model_checkpt='model1.pt', to_model_checkpt='model1.pt',
tokenizer='moses', bpe='fastbpe', name='BackTranslationAug', device=None, force_reload=False, verbose=0):
def __init__(self, from_model_name='transformer.wmt19.en-de', to_model_name='transformer.wmt19.de-en',
from_model_checkpt='model1.pt', to_model_checkpt='model1.pt', tokenizer='moses', bpe='fastbpe',
is_load_from_github=True, name='BackTranslationAug', device=None, force_reload=False, verbose=0):
super().__init__(
# TODO: does not support include detail
action='substitute', name=name, aug_p=None, aug_min=None, aug_max=None, tokenizer=None,
device=device, verbose=verbose, include_detail=False)

device=device, verbose=verbose, include_detail=False, parallelable=True)

self.model = self.get_model(
from_model_name=from_model_name, from_model_checkpt=from_model_checkpt,
to_model_name=to_model_name, to_model_checkpt=to_model_checkpt,
tokenzier_name=tokenizer, bpe_name=bpe, device=device
tokenzier_name=tokenizer, bpe_name=bpe, device=device,
is_load_from_github=is_load_from_github
)
self.device = self.model.device
self.is_load_from_github = is_load_from_github

def substitute(self, data):
augmented_text = self.model.predict(data)
return augmented_text

@classmethod
def get_model(cls, from_model_name, from_model_checkpt, to_model_name, to_model_checkpt,
tokenzier_name, bpe_name, device='cuda', force_reload=False):
tokenzier_name, bpe_name, device='cuda', is_load_from_github=True, force_reload=False):
return init_back_translatoin_model(from_model_name, from_model_checkpt,
to_model_name, to_model_checkpt, tokenzier_name, bpe_name,
device, force_reload
is_load_from_github, device, force_reload
)

@classmethod
Expand Down
51 changes: 41 additions & 10 deletions nlpaug/model/lang_models/fairseq.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

try:
import torch
except ImportError:
Expand All @@ -9,31 +11,60 @@


class Fairseq(LanguageModels):
def __init__(self, from_model_name, from_model_checkpt, to_model_name, to_model_checkpt, tokenzier_name='moses', bpe_name='fastbpe',
def __init__(self, from_model_name, from_model_checkpt, to_model_name, to_model_checkpt,
is_load_from_github=True, tokenzier_name='moses', bpe_name='fastbpe',
device='cuda'):
super().__init__(device, temperature=None, top_k=None, top_p=None)

try:
import fairseq
from fairseq.models.transformer import TransformerModel
except ModuleNotFoundError:
raise ModuleNotFoundError('Missed fairseq library. Install fairseq by https://github.com/pytorch/fairseq')

self.from_model_name = from_model_name
self.from_model_checkpt = from_model_checkpt
self.to_model_name = to_model_name
self.to_model_checkpt = to_model_checkpt
self.is_load_from_github = is_load_from_github
self.tokenzier_name = tokenzier_name
self.bpe_name = bpe_name

# TODO: enahnce to support custom model. https://github.com/pytorch/fairseq/tree/master/examples/translation
self.from_model = torch.hub.load(
github='pytorch/fairseq', model=from_model_name,
checkpoint_file=from_model_checkpt,
tokenizer=tokenzier_name, bpe=bpe_name)
self.to_model = torch.hub.load(
github='pytorch/fairseq', model=to_model_name,
checkpoint_file=to_model_checkpt,
tokenizer=tokenzier_name, bpe=bpe_name)
if is_load_from_github:
self.from_model = torch.hub.load(
github='pytorch/fairseq', model=from_model_name,
checkpoint_file=from_model_checkpt,
tokenizer=tokenzier_name, bpe=bpe_name)
self.to_model = torch.hub.load(
github='pytorch/fairseq', model=to_model_name,
checkpoint_file=to_model_checkpt,
tokenizer=tokenzier_name, bpe=bpe_name)
else:
try:
self.from_model = TransformerModel.from_pretrained(
model_name_or_path=os.path.join(from_model_name, ''),
checkpoint_file=from_model_checkpt,
tokenizer=tokenzier_name, bpe=bpe_name)
except TypeError:
err_msg = 'Cannot load model from local path. You may check the following parameters are correct or not.'
err_msg += ' Model Directory: ' + from_model_name
err_msg += ', Checkpoint File Name: ' + from_model_checkpt
err_msg += ', Tokenizer Name: ' + tokenzier_name
err_msg += ', BPE Name: ' + bpe_name
raise ValueError(err_msg)

try:
self.to_model = TransformerModel.from_pretrained(
model_name_or_path=os.path.join(to_model_name, ''),
checkpoint_file=to_model_checkpt,
tokenizer=tokenzier_name, bpe=bpe_name)
except TypeError:
err_msg = 'Cannot load model from local path. You may check the following parameters are correct or not.'
err_msg += ' Model Directory: ' + to_model_name
err_msg += ', Checkpoint File Name: ' + to_model_checkpt
err_msg += ', Tokenizer Name: ' + tokenzier_name
err_msg += ', BPE Name: ' + bpe_name
raise ValueError(err_msg)

self.from_model.eval()
self.to_model.eval()
Expand Down

0 comments on commit 818ae70

Please sign in to comment.