diff --git a/README.md b/README.md index c10b3cad98..ec6cb67fc9 100644 --- a/README.md +++ b/README.md @@ -70,7 +70,7 @@ We also provide [pre-trained models for translation and language modeling](#pre- with a convenient `torch.hub` interface: ```python en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de.single_model') -en2de.translate('Hello world', beam=5) +en2de.translate(['Hello world'], beam=5) # 'Hallo Welt' ``` See the PyTorch Hub tutorials for [translation](https://pytorch.org/hub/pytorch_fairseq_translation/) diff --git a/examples/backtranslation/README.md b/examples/backtranslation/README.md index bc32675de7..79b05b3d3e 100644 --- a/examples/backtranslation/README.md +++ b/examples/backtranslation/README.md @@ -33,7 +33,7 @@ len(en2de_ensemble.models) # 5 # Translate -en2de_ensemble.translate('Hello world!') +en2de_ensemble.translate(['Hello world!']) # 'Hallo Welt!' ``` diff --git a/examples/language_model/README.md b/examples/language_model/README.md index 3992e2ca1b..bfa69b2996 100644 --- a/examples/language_model/README.md +++ b/examples/language_model/README.md @@ -28,18 +28,18 @@ torch.hub.list('pytorch/fairseq') # [..., 'transformer_lm.wmt19.en', ...] en_lm = torch.hub.load('pytorch/fairseq', 'transformer_lm.wmt19.en', tokenizer='moses', bpe='fastbpe') # Sample from the language model -en_lm.sample('Barack Obama', beam=1, sampling=True, sampling_topk=10, temperature=0.8) -# "Barack Obama is coming to Sydney and New Zealand (...)" +en_lm.sample(['Barack Obama'], beam=1, sampling=True, sampling_topk=10, temperature=0.8) +# ["Barack Obama is coming to Sydney and New Zealand (...)"] # Compute perplexity for a sequence -en_lm.score('Barack Obama is coming to Sydney and New Zealand')['positional_scores'].mean().neg().exp() +en_lm.score(['Barack Obama is coming to Sydney and New Zealand'])[0]['positional_scores'].mean().neg().exp() # tensor(15.1474) # The same interface can be used with custom models as well from fairseq.models.transformer_lm import TransformerLanguageModel custom_lm = TransformerLanguageModel.from_pretrained('/path/to/model/dir', 'checkpoint100.pt', tokenizer='moses', bpe='fastbpe') -custom_lm.sample('Barack Obama', beam=5) -# "Barack Obama (...)" +custom_lm.sample(['Barack Obama'], beam=5) +# ["Barack Obama (...)"] ``` ## Training a transformer language model with the CLI tools diff --git a/examples/pay_less_attention_paper/README.md b/examples/pay_less_attention_paper/README.md index 3fb93b23d1..21d5dfb7bb 100644 --- a/examples/pay_less_attention_paper/README.md +++ b/examples/pay_less_attention_paper/README.md @@ -70,7 +70,7 @@ zh2en = torch.hub.load('pytorch/fairseq', 'lightconv.glu.wmt17.zh-en', tokenizer assert isinstance(zh2en.models[0], fairseq.models.lightconv.LightConvModel) # Translate a sentence -zh2en.translate('你好 世界') +zh2en.translate(['你好 世界']) # 'Hello World' ``` @@ -84,7 +84,7 @@ en2fr = LightConvModel.from_pretrained( bpe='subword_nmt', bpe_codes='data-bin/wmt14_en_fr/en.code' ) -en2fr.translate('Hello world!') +en2fr.translate(['Hello world!']) # 'Bonjour le monde' ``` diff --git a/examples/translation/README.md b/examples/translation/README.md index db1844df55..ef0283cb58 100644 --- a/examples/translation/README.md +++ b/examples/translation/README.md @@ -39,7 +39,7 @@ en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt16.en-de', tokenizer=' assert isinstance(en2de.models[0], fairseq.models.transformer.TransformerModel) # Translate a sentence -en2de.translate('Hello world!') +en2de.translate(['Hello world!']) # 'Hallo Welt!' ``` diff --git a/examples/wmt19/README.md b/examples/wmt19/README.md index 3c59851264..de8505d3c7 100644 --- a/examples/wmt19/README.md +++ b/examples/wmt19/README.md @@ -31,22 +31,22 @@ import torch # English to German translation en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de', checkpoint_file='model1.pt:model2.pt:model3.pt:model4.pt', tokenizer='moses', bpe='fastbpe') -en2de.translate("Machine learning is great!") # 'Maschinelles Lernen ist großartig!' +en2de.translate(["Machine learning is great!"]) # ['Maschinelles Lernen ist großartig!'] # German to English translation de2en = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.de-en', checkpoint_file='model1.pt:model2.pt:model3.pt:model4.pt', tokenizer='moses', bpe='fastbpe') -de2en.translate("Maschinelles Lernen ist großartig!") # 'Machine learning is great!' +de2en.translate(["Maschinelles Lernen ist großartig!"]) # ['Machine learning is great!'] # English to Russian translation en2ru = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-ru', checkpoint_file='model1.pt:model2.pt:model3.pt:model4.pt', tokenizer='moses', bpe='fastbpe') -en2ru.translate("Machine learning is great!") # 'Машинное обучение - это здорово!' +en2ru.translate(["Machine learning is great!"]) # ['Машинное обучение - это здорово!'] # Russian to English translation ru2en = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.ru-en', checkpoint_file='model1.pt:model2.pt:model3.pt:model4.pt', tokenizer='moses', bpe='fastbpe') -ru2en.translate("Машинное обучение - это здорово!") # 'Machine learning is great!' +ru2en.translate(["Машинное обучение - это здорово!"]) # ['Machine learning is great!'] ``` #### Language Modeling @@ -54,15 +54,15 @@ ru2en.translate("Машинное обучение - это здорово!") # ```python # Sample from the English LM en_lm = torch.hub.load('pytorch/fairseq', 'transformer_lm.wmt19.en', tokenizer='moses', bpe='fastbpe') -en_lm.sample("Machine learning is") # 'Machine learning is the future of computing, says Microsoft boss Satya Nadella ...' +en_lm.sample(["Machine learning is"]) # ['Machine learning is the future of computing, says Microsoft boss Satya Nadella ...'] # Sample from the German LM de_lm = torch.hub.load('pytorch/fairseq', 'transformer_lm.wmt19.de', tokenizer='moses', bpe='fastbpe') -de_lm.sample("Maschinelles lernen ist") # 'Maschinelles lernen ist das A und O (neues-deutschland.de) Die Arbeitsbedingungen für Lehrerinnen und Lehrer sind seit Jahren verbesserungswürdig ...' +de_lm.sample(["Maschinelles lernen ist"]) # ['Maschinelles lernen ist das A und O (neues-deutschland.de) Die Arbeitsbedingungen für Lehrerinnen und Lehrer sind seit Jahren verbesserungswürdig ...'] # Sample from the Russian LM ru_lm = torch.hub.load('pytorch/fairseq', 'transformer_lm.wmt19.ru', tokenizer='moses', bpe='fastbpe') -ru_lm.sample("машинное обучение это") # 'машинное обучение это то, что мы называем "искусственным интеллектом".' +ru_lm.sample(["машинное обучение это"]) # ['машинное обучение это то, что мы называем "искусственным интеллектом".'] ``` ## Citation diff --git a/fairseq/hub_utils.py b/fairseq/hub_utils.py index 7fa5a4fffb..59c5e4ff9f 100644 --- a/fairseq/hub_utils.py +++ b/fairseq/hub_utils.py @@ -7,6 +7,7 @@ import argparse import copy import os +from typing import List, Dict, Iterator, Tuple, Any import torch from torch import nn @@ -106,6 +107,10 @@ def __init__(self, args, task, models): self.tokenizer = encoders.build_tokenizer(args) self.bpe = encoders.build_bpe(args) + self.max_positions = utils.resolve_max_positions( + self.task.max_positions(), *[model.max_positions() for model in models] + ) + # this is useful for determining the device self.register_buffer('_float_tensor', torch.tensor([0], dtype=torch.float)) @@ -113,22 +118,20 @@ def __init__(self, args, task, models): def device(self): return self._float_tensor.device - def translate(self, sentence: str, beam: int = 5, verbose: bool = False, **kwargs) -> str: - return self.sample(sentence, beam, verbose, **kwargs) + def translate(self, sentences: List[str], beam: int = 5, verbose: bool = False, **kwargs) -> List[str]: + return self.sample(sentences, beam, verbose, **kwargs) - def sample(self, sentence: str, beam: int = 1, verbose: bool = False, **kwargs) -> str: - input = self.encode(sentence) - hypo = self.generate(input, beam, verbose, **kwargs)[0]['tokens'] - return self.decode(hypo) + def sample(self, sentences: List[str], beam: int = 1, verbose: bool = False, **kwargs) -> List[str]: + tokenized_sentences = [self.encode(sentence) for sentence in sentences] + batched_hypos = self.generate(tokenized_sentences, beam, verbose, **kwargs) + return [self.decode(hypos[0]['tokens']) for hypos in batched_hypos] - def score(self, sentence: str, **kwargs): + def score(self, sentences: List[str], **kwargs): # NOTE: this doesn't support translation tasks currently - input = self.encode(sentence) - return self.generate(input, score_reference=True, **kwargs)[0] - - def generate(self, tokens: torch.LongTensor, beam: int = 5, verbose: bool = False, **kwargs) -> torch.LongTensor: - sample = self._build_sample(tokens) + tokenized_sentences = [self.encode(sentence) for sentence in sentences] + return [hypos[0] for hypos in self.generate([input], score_reference=True, **kwargs)] + def generate(self, tokenized_sentences: List[torch.LongTensor], beam: int = 5, verbose: bool = False, **kwargs) -> List[List[Dict[str, torch.Tensor]]]: # build generator using current args as well as any kwargs gen_args = copy.copy(self.args) gen_args.beam = beam @@ -136,30 +139,36 @@ def generate(self, tokens: torch.LongTensor, beam: int = 5, verbose: bool = Fals setattr(gen_args, k, v) generator = self.task.build_generator(gen_args) - translations = self.task.inference_step(generator, self.models, sample) - - if verbose: - src_str_with_unk = self.string(tokens) - print('S\t{}'.format(src_str_with_unk)) - - def getarg(name, default): - return getattr(gen_args, name, getattr(self.args, name, default)) + results = [] + for batch in self._build_batches(tokenized_sentences): + for k, input_tensor in batch["net_input"].items(): + batch["net_input"][k] = input_tensor.to(self.device) + translations = self.task.inference_step( + generator, self.models, batch + ) + for (iden, hypos) in zip(batch["id"].tolist(), translations): + results.append((iden, hypos)) + + # sort output to match input order + outputs = [hypos for (_, hypos) in sorted(results, key=lambda x: x[0])] - # Process top predictions - hypos = translations[0] if verbose: - for hypo in hypos: - hypo_str = self.decode(hypo['tokens']) - print('H\t{}\t{}'.format(hypo['score'], hypo_str)) - print('P\t{}'.format( - ' '.join(map(lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist())) - )) - if hypo['alignment'] is not None and getarg('print_alignment', False): - print('A\t{}'.format( - ' '.join(map(lambda x: str(utils.item(x)), hypo['alignment'].int().cpu())) + def getarg(name, default): + return getattr(gen_args, name, getattr(self.args, name, default)) + for source_tokens, target_hypotheses in zip(tokenized_sentences, outputs): + src_str_with_unk = self.string(source_tokens) + print('S\t{}'.format(src_str_with_unk)) + for hypo in target_hypotheses: + hypo_str = self.decode(hypo['tokens']) + print('H\t{}\t{}'.format(hypo['score'], hypo_str)) + print('P\t{}'.format( + ' '.join(map(lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist())) )) - - return hypos + if hypo['alignment'] is not None and getarg('print_alignment', False): + print('A\t{}'.format( + ' '.join(map(lambda x: str(utils.item(x)), hypo['alignment'].int().cpu())) + )) + return outputs def encode(self, sentence: str) -> torch.LongTensor: sentence = self.tokenize(sentence) @@ -196,16 +205,16 @@ def binarize(self, sentence: str) -> torch.LongTensor: def string(self, tokens: torch.LongTensor) -> str: return self.tgt_dict.string(tokens) - - def _build_sample(self, src_tokens: torch.LongTensor): - assert torch.is_tensor(src_tokens) - dataset = self.task.build_dataset_for_inference([src_tokens], [src_tokens.numel()]) - sample = dataset.collater([dataset[0]]) - sample = utils.apply_to_sample( - lambda tensor: tensor.to(self.device), - sample - ) - return sample + + def _build_batches(self, tokens: List[List[int]]) -> Iterator[Dict[str, Any]]: + lengths = torch.LongTensor([t.numel() for t in tokens]) + batch_iterator = self.task.get_batch_iterator( + dataset=self.task.build_dataset_for_inference(tokens, lengths), + max_tokens=self.args.max_tokens, + max_sentences=self.args.max_sentences, + max_positions=self.max_positions, + ).next_epoch_itr(shuffle=False) + return batch_iterator class BPEHubInterface(object):