From 27568a7ebed1a35f08ac0390f35b3de9b8dad0dd Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Wed, 13 Nov 2019 09:10:52 -0800 Subject: [PATCH] Merge TracingCompliantTransformer and regular Transformer, fix NAT tests Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/899 Differential Revision: D18373060 Pulled By: myleott fbshipit-source-id: bb5510ec15799a0a10a7c0669e76d8200e1ba479 --- fairseq/criterions/nat_loss.py | 2 +- fairseq/iterative_refinement_generator.py | 56 +- fairseq/models/cmlm_transformer.py | 16 +- fairseq/models/insertion_transformer.py | 19 +- ...iterative_nonautoregressive_transformer.py | 1 + fairseq/models/levenshtein_transformer.py | 697 +++++++++--------- fairseq/models/model_utils.py | 131 +--- fairseq/models/nonautoregressive_ensembles.py | 13 +- .../models/nonautoregressive_transformer.py | 50 +- fairseq/models/roberta/model.py | 2 +- .../models/tracing_compliant_transformer.py | 625 ---------------- fairseq/models/transformer.py | 64 +- fairseq/modules/mean_pool_gating_network.py | 11 +- tests/test_binaries.py | 55 +- 14 files changed, 551 insertions(+), 1191 deletions(-) delete mode 100644 fairseq/models/tracing_compliant_transformer.py diff --git a/fairseq/criterions/nat_loss.py b/fairseq/criterions/nat_loss.py index 174b1203cc..7e7dd07351 100644 --- a/fairseq/criterions/nat_loss.py +++ b/fairseq/criterions/nat_loss.py @@ -48,7 +48,7 @@ def mean_ds(x: Tensor, dim=None) -> Tensor: if masks is not None: outputs, targets = outputs[masks], targets[masks] - if not masks.any(): + if masks is not None and not masks.any(): nll_loss = torch.tensor(0) loss = nll_loss else: diff --git a/fairseq/iterative_refinement_generator.py b/fairseq/iterative_refinement_generator.py index 885e7c81b4..3af79fd2be 100644 --- a/fairseq/iterative_refinement_generator.py +++ b/fairseq/iterative_refinement_generator.py @@ -3,11 +3,20 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from collections import namedtuple + import torch + from fairseq import utils -from fairseq.models.levenshtein_transformer import LevenshteinTransformerModel -from fairseq.models.model_utils import script_skip_tensor_list, skip_tensors as _skip -from fairseq.models.nonautoregressive_ensembles import EnsembleLevT + + +DecoderOut = namedtuple('IterativeRefinementDecoderOut', [ + 'output_tokens', + 'output_scores', + 'attn', + 'step', + 'max_step', +]) class IterativeRefinementGenerator(object): @@ -88,6 +97,8 @@ def generate_batched_itr( @torch.no_grad() def generate(self, models, sample, prefix_tokens=None): + from fairseq.models.levenshtein_transformer import LevenshteinTransformerModel + from fairseq.models.nonautoregressive_ensembles import EnsembleLevT if len(models) == 1: # Keep this for other NAT models for which we have yet to implement ensemble wrappers. Later delete this. @@ -110,7 +121,7 @@ def generate(self, models, sample, prefix_tokens=None): # initialize buffers (very model specific, with length prediction or not) prev_decoder_out = model.initialize_output_tokens(encoder_out, src_tokens) - prev_output_tokens = prev_decoder_out[0].clone() + prev_output_tokens = prev_decoder_out.output_tokens.clone() finalized = [[] for _ in range(bsz)] @@ -150,8 +161,10 @@ def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn): "max_ratio": self.max_ratio, "decoding_format": self.decoding_format, } - prev_decoder_out[3] = step - prev_decoder_out[4] = self.max_iter + 1 + prev_decoder_out = prev_decoder_out._replace( + step=step, + max_step=self.max_iter + 1, + ) decoder_out = model.forward_decoder( prev_decoder_out, encoder_out, **decoder_options @@ -160,24 +173,26 @@ def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn): if self.adaptive: # terminate if there is a loop terminated, out_tokens, out_scores, out_attn = is_a_loop( - prev_output_tokens, decoder_out[0], decoder_out[1], decoder_out[2] + prev_output_tokens, decoder_out.output_tokens, decoder_out.output_scores, decoder_out.attn + ) + decoder_out = decoder_out._replace( + output_tokens=out_tokens, + output_scores=out_scores, + attn=out_attn, ) - decoder_out[0] = out_tokens - decoder_out[1] = out_scores - decoder_out[2] = out_attn else: - terminated = decoder_out[0].new_zeros(decoder_out[0].size(0)).bool() + terminated = decoder_out.output_tokens.new_zeros(decoder_out.output_tokens.size(0)).bool() if step == self.max_iter: # reach last iteration, terminate terminated.fill_(1) # collect finalized sentences finalized_idxs = sent_idxs[terminated] - finalized_tokens = decoder_out[0][terminated] - finalized_scores = decoder_out[1][terminated] + finalized_tokens = decoder_out.output_tokens[terminated] + finalized_scores = decoder_out.output_scores[terminated] finalized_attn = ( - None if decoder_out[2] is None else decoder_out[2][terminated] + None if decoder_out.attn is None else decoder_out.attn[terminated] ) for i in range(finalized_idxs.size(0)): @@ -194,10 +209,15 @@ def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn): break # for next step - prev_decoder_out = _skip(decoder_out, ~terminated) - encoder_out = script_skip_tensor_list(encoder_out, ~terminated) - sent_idxs = _skip(sent_idxs, ~terminated) + not_terminated = ~terminated + prev_decoder_out = decoder_out._replace( + output_tokens=decoder_out.output_tokens[not_terminated], + output_scores=decoder_out.output_scores[not_terminated], + attn=decoder_out.attn[not_terminated] if decoder_out.attn is not None else None, + ) + encoder_out = model.encoder.reorder_encoder_out(encoder_out, not_terminated.nonzero().squeeze()) + sent_idxs = sent_idxs[not_terminated] - prev_output_tokens = prev_decoder_out[0].clone() + prev_output_tokens = prev_decoder_out.output_tokens.clone() return finalized diff --git a/fairseq/models/cmlm_transformer.py b/fairseq/models/cmlm_transformer.py index 91a5a48a66..f3ba15dd98 100644 --- a/fairseq/models/cmlm_transformer.py +++ b/fairseq/models/cmlm_transformer.py @@ -10,9 +10,9 @@ arXiv preprint arXiv:1904.09324 (2019). """ -from fairseq.utils import new_arange from fairseq.models import register_model, register_model_architecture from fairseq.models.nonautoregressive_transformer import NATransformerModel +from fairseq.utils import new_arange def _skeptical_unmasking(output_scores, output_masks, p): @@ -55,11 +55,11 @@ def forward( def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwargs): - step = decoder_out["step"] - max_step = decoder_out["max_step"] + step = decoder_out.step + max_step = decoder_out.max_step - output_tokens = decoder_out["output_tokens"] - output_scores = decoder_out["output_scores"] + output_tokens = decoder_out.output_tokens + output_scores = decoder_out.output_scores # execute the decoder output_masks = output_tokens.eq(self.unk) @@ -78,7 +78,11 @@ def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwar output_tokens.masked_fill_(skeptical_mask, self.unk) output_scores.masked_fill_(skeptical_mask, 0.0) - return {"output_tokens": output_tokens, "output_scores": output_scores} + return decoder_out._replace( + output_tokens=output_tokens, + output_scores=output_scores, + attn=None, + ) @register_model_architecture("cmlm_transformer", "cmlm_transformer") diff --git a/fairseq/models/insertion_transformer.py b/fairseq/models/insertion_transformer.py index 1657bd0b1b..6a1e077fa8 100644 --- a/fairseq/models/insertion_transformer.py +++ b/fairseq/models/insertion_transformer.py @@ -6,7 +6,7 @@ import numpy as np import torch import torch.nn.functional as F -from fairseq.utils import new_arange + from fairseq.models import register_model, register_model_architecture from fairseq.models.levenshtein_transformer import ( LevenshteinTransformerDecoder, @@ -14,6 +14,7 @@ ) from fairseq.models.transformer import Linear, TransformerModel from fairseq.modules.transformer_sentence_encoder import init_bert_params +from fairseq.utils import new_arange class NegativeDistanceScore(object): @@ -116,8 +117,8 @@ def _apply_ins_words(in_tokens, in_scores, word_ins_pred, word_ins_scores, paddi @register_model("insertion_transformer") class InsertionTransformerModel(LevenshteinTransformerModel): - def __init__(self, encoder, decoder): - super().__init__(encoder, decoder) + def __init__(self, args, encoder, decoder): + super().__init__(args, encoder, decoder) @staticmethod def add_args(parser): @@ -169,8 +170,8 @@ def forward_decoder( self, decoder_out, encoder_out, eos_penalty=0.0, max_ratio=None, **kwargs ): - output_tokens = decoder_out["output_tokens"] - output_scores = decoder_out["output_scores"] + output_tokens = decoder_out.output_tokens + output_scores = decoder_out.output_scores # TODO: decoding for InsertionTransformer word_ins_out = self.decoder.forward_word_ins( output_tokens, encoder_out=encoder_out @@ -187,7 +188,11 @@ def forward_decoder( cut_off = output_tokens.ne(self.pad).sum(1).max() output_tokens = output_tokens[:, :cut_off] output_scores = output_scores[:, :cut_off] - return {"output_tokens": output_tokens, "output_scores": output_scores, "attn": None} + return decoder_out._replace( + output_tokens=output_tokens, + output_scores=output_scores, + attn=None, + ) class InsertionTransformerDecoder(LevenshteinTransformerDecoder): @@ -206,7 +211,7 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): self.label_tau = getattr(args, "label_tau", None) def forward_word_ins(self, prev_output_tokens, encoder_out=None): - features, _ = self.extract_features(prev_output_tokens, encoder_out=encoder_out) + features = self.extract_features(prev_output_tokens, encoder_out=encoder_out)[0] features = self.pool_out( torch.cat([features[:, :-1, :], features[:, 1:, :]], 2) ) diff --git a/fairseq/models/iterative_nonautoregressive_transformer.py b/fairseq/models/iterative_nonautoregressive_transformer.py index 73585db354..b104a7f8b4 100644 --- a/fairseq/models/iterative_nonautoregressive_transformer.py +++ b/fairseq/models/iterative_nonautoregressive_transformer.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import torch + from fairseq.models import register_model, register_model_architecture from fairseq.models.nonautoregressive_transformer import NATransformerModel diff --git a/fairseq/models/levenshtein_transformer.py b/fairseq/models/levenshtein_transformer.py index 4ed30ead14..42adc60a53 100644 --- a/fairseq/models/levenshtein_transformer.py +++ b/fairseq/models/levenshtein_transformer.py @@ -1,52 +1,109 @@ -#!/usr/bin/env python3 - -# Copyright (c) 2017-present, Facebook, Inc. -# All rights reserved. +# Copyright (c) Facebook, Inc. and its affiliates. # -# This source code is licensed under the license found in the LICENSE file in -# the root directory of this source tree. An additional grant of patent rights -# can be found in the PATENTS file in the same directory. - - -from __future__ import absolute_import, division, print_function, unicode_literals - -from typing import Optional +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. import torch import torch.nn as nn import torch.nn.functional as F + +from fairseq.iterative_refinement_generator import DecoderOut from fairseq.models import register_model, register_model_architecture -from fairseq.models.tracing_compliant_transformer import ( - TracingTransformerDecoder, - TracingTransformerEncoder, - TracingTransformerModel, - TransformerDecoderLayer, -) -from fairseq.models.model_utils import ( - fill_tensors as _fill, - script_skip_tensor, - script_skip_tensor_list, +from fairseq.models.transformer import ( + Embedding, + TransformerDecoder, + TransformerEncoder, + TransformerModel, + TransformerDecoderLayer ) -from fairseq.models.transformer import Embedding from fairseq.modules.transformer_sentence_encoder import init_bert_params -from torch import Tensor +from fairseq.utils import new_arange + +# -------------- Helper Functions --------------------------------------------------- # +def _skip(x, mask): + """ + Getting sliced (dim=0) tensor by mask. Supporting tensor and list/dict of tensors. + """ + if isinstance(x, int): + return x + + if x is None: + return None + + if isinstance(x, torch.Tensor): + if x.size(0) == mask.size(0): + return x[mask] + elif x.size(1) == mask.size(0): + return x[:, mask] + + if isinstance(x, list): + return [_skip(x_i, mask) for x_i in x] + + if isinstance(x, dict): + return {k: _skip(v, mask) for k, v in x.items()} + + raise NotImplementedError + + +def _skip_encoder_out(encoder, encoder_out, mask): + if not mask.any(): + return encoder_out + else: + return encoder.reorder_encoder_out(encoder_out, mask.nonzero().squeeze()) + + +def _fill(x, mask, y, padding_idx): + """ + Filling tensor x with y at masked positions (dim=0). + """ + if x is None: + return y + assert x.dim() == y.dim() and mask.size(0) == x.size(0) + assert x.dim() == 2 or (x.dim() == 3 and x.size(2) == y.size(2)) + n_selected = mask.sum() + assert n_selected == y.size(0) + + if n_selected == x.size(0): + return y + + if x.size(1) < y.size(1): + dims = [x.size(0), y.size(1) - x.size(1)] + if x.dim() == 3: + dims.append(x.size(2)) + x = torch.cat([x, x.new_zeros(*dims).fill_(padding_idx)], 1) + x[mask] = y + elif x.size(1) > y.size(1): + x[mask] = padding_idx + if x.dim() == 2: + x[mask, :y.size(1)] = y + else: + x[mask, :y.size(1), :] = y + else: + x[mask] = y + return x -def _get_ins_targets(in_tokens, out_tokens, padding_idx, unk_idx): +def load_libnat(): try: from fairseq import libnat except ImportError as e: import sys - sys.stderr.write("ERROR: missing libnat. run `pip install --editable .`\n") raise e + return libnat + + +def _get_ins_targets(in_tokens, out_tokens, padding_idx, unk_idx): + libnat = load_libnat() + in_seq_len, out_seq_len = in_tokens.size(1), out_tokens.size(1) in_tokens_list = [ [t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist()) ] out_tokens_list = [ - [t for t in s if t != padding_idx] for i, s in enumerate(out_tokens.tolist()) + [t for t in s if t != padding_idx] + for i, s in enumerate(out_tokens.tolist()) ] full_labels = libnat.suggested_ed2_path( @@ -71,28 +128,27 @@ def _get_ins_targets(in_tokens, out_tokens, padding_idx, unk_idx): ] # transform to tensor - masked_tgt_masks = torch.tensor(masked_tgt_masks, device=out_tokens.device).bool() + masked_tgt_masks = torch.tensor( + masked_tgt_masks, device=out_tokens.device + ).bool() mask_ins_targets = torch.tensor(mask_ins_targets, device=in_tokens.device) masked_tgt_tokens = out_tokens.masked_fill(masked_tgt_masks, unk_idx) return masked_tgt_masks, masked_tgt_tokens, mask_ins_targets def _get_del_targets(in_tokens, out_tokens, padding_idx): - try: - from fairseq import libnat - except ImportError as e: - import sys + libnat = load_libnat() - sys.stderr.write("ERROR: missing libnat. run `pip install --editable .`\n") - raise e out_seq_len = out_tokens.size(1) - in_tokens_list = [ - [t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist()) - ] - out_tokens_list = [ - [t for t in s if t != padding_idx] for i, s in enumerate(out_tokens.tolist()) - ] + with torch.cuda.device_of(in_tokens): + in_tokens_list = [ + [t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist()) + ] + out_tokens_list = [ + [t for t in s if t != padding_idx] + for i, s in enumerate(out_tokens.tolist()) + ] full_labels = libnat.suggested_ed2_path( in_tokens_list, out_tokens_list, padding_idx @@ -104,26 +160,23 @@ def _get_del_targets(in_tokens, out_tokens, padding_idx): ] # transform to tensor - word_del_targets = torch.tensor(word_del_targets) + word_del_targets = torch.tensor(word_del_targets, device=out_tokens.device) return word_del_targets def _get_del_ins_targets(in_tokens, out_tokens, padding_idx): - try: - from fairseq import libnat - except ImportError as e: - import sys + libnat = load_libnat() - sys.stderr.write("ERROR: missing libnat. run `pip install --editable .`\n") - raise e in_seq_len, out_seq_len = in_tokens.size(1), out_tokens.size(1) - in_tokens_list = [ - [t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist()) - ] - out_tokens_list = [ - [t for t in s if t != padding_idx] for i, s in enumerate(out_tokens.tolist()) - ] + with torch.cuda.device_of(in_tokens): + in_tokens_list = [ + [t for t in s if t != padding_idx] for i, s in enumerate(in_tokens.tolist()) + ] + out_tokens_list = [ + [t for t in s if t != padding_idx] + for i, s in enumerate(out_tokens.tolist()) + ] full_labels = libnat.suggested_ed2_path( in_tokens_list, out_tokens_list, padding_idx @@ -144,15 +197,101 @@ def _get_del_ins_targets(in_tokens, out_tokens, padding_idx): ] # transform to tensor - mask_ins_targets = torch.tensor(mask_ins_targets) - word_del_targets = torch.tensor(word_del_targets) + mask_ins_targets = torch.tensor(mask_ins_targets, device=in_tokens.device) + word_del_targets = torch.tensor(word_del_targets, device=out_tokens.device) return word_del_targets, mask_ins_targets +def _apply_ins_masks( + in_tokens, in_scores, mask_ins_pred, padding_idx, unk_idx, eos_idx +): + + in_masks = in_tokens.ne(padding_idx) + in_lengths = in_masks.sum(1) + + # HACK: hacky way to shift all the paddings to eos first. + in_tokens.masked_fill_(~in_masks, eos_idx) + mask_ins_pred.masked_fill_(~in_masks[:, 1:], 0) + + out_lengths = in_lengths + mask_ins_pred.sum(1) + out_max_len = out_lengths.max() + out_masks = ( + new_arange(out_lengths, out_max_len)[None, :] + < out_lengths[:, None] + ) + + reordering = (mask_ins_pred + in_masks[:, 1:].long()).cumsum(1) + out_tokens = ( + in_tokens.new_zeros(in_tokens.size(0), out_max_len) + .fill_(padding_idx) + .masked_fill_(out_masks, unk_idx) + ) + out_tokens[:, 0] = in_tokens[:, 0] + out_tokens.scatter_(1, reordering, in_tokens[:, 1:]) + + out_scores = None + if in_scores is not None: + in_scores.masked_fill_(~in_masks, 0) + out_scores = in_scores.new_zeros(*out_tokens.size()) + out_scores[:, 0] = in_scores[:, 0] + out_scores.scatter_(1, reordering, in_scores[:, 1:]) + + return out_tokens, out_scores + + +def _apply_ins_words( + in_tokens, in_scores, word_ins_pred, word_ins_scores, unk_idx +): + word_ins_masks = in_tokens.eq(unk_idx) + out_tokens = in_tokens.masked_scatter(word_ins_masks, word_ins_pred[word_ins_masks]) + + if in_scores is not None: + out_scores = in_scores.masked_scatter( + word_ins_masks, word_ins_scores[word_ins_masks] + ) + else: + out_scores = None + + return out_tokens, out_scores + + +def _apply_del_words( + in_tokens, in_scores, in_attn, word_del_pred, padding_idx, bos_idx, eos_idx +): + # apply deletion to a tensor + in_masks = in_tokens.ne(padding_idx) + bos_eos_masks = in_tokens.eq(bos_idx) | in_tokens.eq(eos_idx) + + max_len = in_tokens.size(1) + word_del_pred.masked_fill_(~in_masks, 1) + word_del_pred.masked_fill_(bos_eos_masks, 0) + + reordering = ( + new_arange(in_tokens) + .masked_fill_(word_del_pred, max_len) + .sort(1)[1] + ) + + out_tokens = in_tokens.masked_fill(word_del_pred, padding_idx).gather(1, reordering) + + out_scores = None + if in_scores is not None: + out_scores = in_scores.masked_fill(word_del_pred, 0).gather(1, reordering) + + out_attn = None + if in_attn is not None: + _mask = word_del_pred[:, :, None].expand_as(in_attn) + _reordering = reordering[:, :, None].expand_as(in_attn) + out_attn = in_attn.masked_fill(_mask, 0.).gather(1, _reordering) + + return out_tokens, out_scores, out_attn + +# ------------------------------------------------------------------------------------- # + @register_model("levenshtein_transformer") -class LevenshteinTransformerModel(TracingTransformerModel): - def __init__(self, encoder, decoder): - super().__init__(encoder, decoder) +class LevenshteinTransformerModel(TransformerModel): + def __init__(self, args, encoder, decoder): + super().__init__(args, encoder, decoder) self.tgt_dict = decoder.dictionary self.bos = decoder.dictionary.bos() self.eos = decoder.dictionary.eos() @@ -161,7 +300,7 @@ def __init__(self, encoder, decoder): @staticmethod def add_args(parser): - TracingTransformerModel.add_args(parser) + TransformerModel.add_args(parser) parser.add_argument( "--apply-bert-init", action="store_true", @@ -171,31 +310,27 @@ def add_args(parser): "--early-exit", default="6,6,6", type=str, - help="number of decoder layers for del_word, ins_mask, ins_word", + help="number of decoder layers before word_del, mask_ins, word_ins", ) parser.add_argument( "--no-share-discriminator", action="store_true", - help="addtional decoder-layers to learn deletion", + help="separate parameters for discriminator", ) parser.add_argument( "--no-share-maskpredictor", action="store_true", - help="addtional decoder-layers to learn predicting masks", + help="separate parameters for mask-predictor", ) parser.add_argument( - "--sampling-for-deletion", + "--share-discriminator-maskpredictor", action="store_true", - help="instead of argmax, use sampling to predict the tokens", + help="share the parameters for both mask-predictor and discriminator", ) - # Added for compatibility parser.add_argument( - "--decoder-out-embed-dim", - default=None, - type=int, - metavar="N", - help="decoder output embedding dimension (bottleneck layer before" - "output layer if specified.)", + "--sampling-for-deletion", + action='store_true', + help='instead of argmax, use sampling to predict the tokens' ) @classmethod @@ -207,7 +342,7 @@ def build_decoder(cls, args, tgt_dict, embed_tokens): @classmethod def build_encoder(cls, args, src_dict, embed_tokens): - encoder = TracingTransformerEncoder(args, src_dict, embed_tokens) + encoder = TransformerEncoder(args, src_dict, embed_tokens) if getattr(args, "apply_bert_init", False): encoder.apply(init_bert_params) return encoder @@ -238,8 +373,8 @@ def forward( # make online prediction if self.decoder.sampling_for_deletion: word_predictions = torch.multinomial( - F.softmax(word_ins_out, -1).view(-1, word_ins_out.size(-1)), 1 - ).view(word_ins_out.size(0), -1) + F.softmax(word_ins_out, -1).view(-1, word_ins_out.size(-1)), 1).view( + word_ins_out.size(0), -1) else: word_predictions = F.log_softmax(word_ins_out, dim=-1).max(2)[1] @@ -249,7 +384,10 @@ def forward( # generate training labels for deletion word_del_targets = _get_del_targets(word_predictions, tgt_tokens, self.pad) - word_del_out, _ = self.decoder.forward_word_del(word_predictions, encoder_out) + word_del_out, _ = self.decoder.forward_word_del( + word_predictions, encoder_out) + word_del_masks = word_predictions.ne(self.pad) + return { "mask_ins_out": mask_ins_out, "mask_ins_tgt": mask_ins_targets, @@ -259,7 +397,7 @@ def forward( "word_ins_mask": masked_tgt_masks, "word_del_out": word_del_out, "word_del_tgt": word_del_targets, - "word_del_mask": word_predictions.ne(self.pad), + "word_del_mask": word_del_masks, } def forward_encoder(self, encoder_inputs): @@ -269,248 +407,123 @@ def forward_decoder( self, decoder_out, encoder_out, eos_penalty=0.0, max_ratio=None, **kwargs ): - output_tokens = decoder_out[0] - output_scores = decoder_out[1] - attn = decoder_out[2] - - if max_ratio is not None and encoder_out[1] is not None: - max_lengths = ((~encoder_out[1]).sum(1) * max_ratio).clamp(min=10) + output_tokens = decoder_out.output_tokens + output_scores = decoder_out.output_scores + attn = decoder_out.attn + bsz = output_tokens.size(0) + if max_ratio is None: + max_lens = torch.zeros_like(output_tokens).fill_(255) else: - max_lengths = torch.zeros(output_tokens.size(0)).fill_(255) - - @torch.jit.script - def del_word( - output_tokens, - output_scores, - attn: Tensor, - word_del_attn: Optional[Tensor], - word_del_out, - can_del_word, - pad_idx: int, - bos_idx: int, - eos_idx: int, - ): - # delete words - # do not delete tokens if it is - if can_del_word.sum() != 0: # we cannot delete, skip - word_del_score = F.log_softmax(word_del_out, 2) - word_del_pred = torch.jit.Attribute(word_del_score.max(-1)[1], bool) - in_tokens = output_tokens[can_del_word] - in_scores = output_scores[can_del_word] - # apply deletion to a tensor - in_masks = in_tokens.ne(pad_idx) - bos_eos_masks = in_tokens.eq(bos_idx) | in_tokens.eq(eos_idx) - - max_len = in_tokens.size(1) - word_del_pred.masked_fill_(~in_masks, 1) - word_del_pred.masked_fill_(bos_eos_masks, 0) - - reordering = ( - torch.arange(max_len)[None, :] - .expand_as(in_tokens) - .contiguous() - .masked_fill(word_del_pred, max_len) - .sort(1)[1] - ) - - _tokens = in_tokens.masked_fill(word_del_pred, pad_idx).gather( - 1, reordering - ) - - _scores = in_scores.masked_fill(word_del_pred, 0).gather(1, reordering) - if word_del_attn is not None: - _mask = word_del_pred[:, :, None].expand_as(word_del_attn) - _reordering = reordering[:, :, None].expand_as(word_del_attn) - _attn = word_del_attn.masked_fill(_mask, 0.0).gather(1, _reordering) - attn = _fill(attn, can_del_word, _attn, 0) - - output_tokens = _fill(output_tokens, can_del_word, _tokens, pad_idx) - output_scores = _fill(output_scores, can_del_word, _scores, 0) - return output_tokens, output_scores, attn - - @torch.jit.script - def ins_placeholders( - output_tokens, - output_scores, - mask_ins_out, - can_ins_mask, - pad_idx: int, - unk_idx: int, - eos_idx: int, - max_ratio: float, - max_lengths, - ): - # insert placeholders - if can_ins_mask.sum() != 0: - mask_ins_score = F.log_softmax(mask_ins_out, 2) - if eos_penalty > 0.0: - mask_ins_score[:, :, 0] -= eos_penalty - mask_ins_pred = mask_ins_score.max(-1)[1] - if max_ratio is not None and encoder_out[1] is not None: - mask_ins_pred = torch.min( - mask_ins_pred, max_lengths[can_ins_mask, None].expand_as(mask_ins_pred) - ) - in_tokens = output_tokens[can_ins_mask] - in_scores = output_scores[can_ins_mask] - in_masks = in_tokens.ne(pad_idx) - in_lengths = in_masks.sum(1) - - # HACK: hacky way to shift all the paddings to eos first. - in_tokens.masked_fill_(~in_masks, eos_idx) - mask_ins_pred.masked_fill_(~in_masks[:, 1:], 0) - - out_lengths = in_lengths + mask_ins_pred.sum(1) - out_max_len = out_lengths.max() - out_masks = ( - torch.arange(out_max_len)[None, :].long() < out_lengths[:, None] - ) - - reordering = (mask_ins_pred + in_masks[:, 1:].long()).cumsum(1) - out_tokens = ( - torch.zeros(in_tokens.size()[0], out_max_len) - .fill_(pad_idx) - .masked_fill_(out_masks, unk_idx) - ) - out_tokens = torch.cat([in_tokens[:, :1], out_tokens[:, 1:]], 1) - out_tokens.scatter_(1, reordering, in_tokens[:, 1:].float()) - - if in_scores is not None: - in_scores.masked_fill_(~in_masks, 0) - out_scores = torch.zeros_like(out_tokens).to(in_scores) - out_tokens = torch.cat([in_tokens[:, :1], out_tokens[:, 1:]], 1) - out_scores.scatter_(1, reordering, in_scores[:, 1:]) - else: - out_scores = None - output_tokens = _fill(output_tokens, can_ins_mask, out_tokens, pad_idx) - output_scores = _fill(output_scores, can_ins_mask, out_scores, 0) - return output_tokens, output_scores - - @torch.jit.script - def ins_words( - output_tokens, - output_scores, - attn: Tensor, - word_ins_attn, - word_ins_out, - can_ins_word, - pad_idx: int, - unk_idx: int, - ): - # insert words - if can_ins_word.sum() != 0: - word_ins_scores = F.log_softmax(word_ins_out, 2) - word_ins_pred = word_ins_scores.max(-1)[1] - in_tokens = output_tokens[can_ins_word] - in_scores = output_scores[can_ins_word] - word_ins_masks = in_tokens.eq(unk_idx) - out_tokens = in_tokens.masked_scatter( - word_ins_masks, word_ins_pred[word_ins_masks].float() - ) - - if in_scores is not None: - out_scores = in_scores.masked_scatter( - word_ins_masks, word_ins_scores[word_ins_masks] - ) - else: - out_scores = None - output_tokens = _fill(output_tokens, can_ins_word, out_tokens, pad_idx) - output_scores = _fill(output_scores, can_ins_word, out_scores, 0) - attn = _fill(attn, can_ins_word, word_ins_attn, 0) - return output_tokens, output_scores, attn - + if encoder_out.encoder_padding_mask is None: + max_src_len = encoder_out.encoder_out.size(1) + src_lens = encoder_out.encoder_out.new(bsz).fill_(max_src_len) + else: + src_lens = (~encoder_out.encoder_padding_mask).sum(1) + max_lens = (src_lens * max_ratio).clamp(min=10).long() + + # delete words + # do not delete tokens if it is can_del_word = output_tokens.ne(self.pad).sum(1) > 2 - word_del_out, word_del_attn = self.decoder.forward_word_del( - script_skip_tensor(output_tokens, can_del_word), - script_skip_tensor_list(list(encoder_out), can_del_word), - ) - - output_tokens, output_scores, attn = del_word( - output_tokens, - output_scores, - attn, - word_del_attn, - word_del_out, - can_del_word, - self.pad, - self.bos, - self.eos, - ) + if can_del_word.sum() != 0: # we cannot delete, skip + word_del_out, word_del_attn = self.decoder.forward_word_del( + _skip(output_tokens, can_del_word), + _skip_encoder_out(self.encoder, encoder_out, can_del_word) + ) + word_del_score = F.log_softmax(word_del_out, 2) + word_del_pred = word_del_score.max(-1)[1].bool() + + _tokens, _scores, _attn = _apply_del_words( + output_tokens[can_del_word], + output_scores[can_del_word], + word_del_attn, + word_del_pred, + self.pad, + self.bos, + self.eos, + ) + output_tokens = _fill(output_tokens, can_del_word, _tokens, self.pad) + output_scores = _fill(output_scores, can_del_word, _scores, 0) + attn = _fill(attn, can_del_word, _attn, 0.) + + # insert placeholders + can_ins_mask = output_tokens.ne(self.pad).sum(1) < max_lens + if can_ins_mask.sum() != 0: + mask_ins_out, _ = self.decoder.forward_mask_ins( + _skip(output_tokens, can_ins_mask), + _skip_encoder_out(self.encoder, encoder_out, can_ins_mask) + ) + mask_ins_score = F.log_softmax(mask_ins_out, 2) + if eos_penalty > 0.0: + mask_ins_score[:, :, 0] = mask_ins_score[:, :, 0] - eos_penalty + mask_ins_pred = mask_ins_score.max(-1)[1] + mask_ins_pred = torch.min( + mask_ins_pred, max_lens[can_ins_mask, None].expand_as(mask_ins_pred) + ) - can_ins_mask = output_tokens.ne(self.pad).sum(1) < max_lengths - mask_ins_out, _ = self.decoder.forward_mask_ins( - script_skip_tensor(output_tokens, can_ins_mask), - script_skip_tensor_list(encoder_out, can_ins_mask), - ) - output_tokens, output_scores = ins_placeholders( - output_tokens, - output_scores, - mask_ins_out, - can_ins_mask, - self.pad, - self.unk, - self.eos, - max_ratio, - max_lengths, - ) + _tokens, _scores = _apply_ins_masks( + output_tokens[can_ins_mask], + output_scores[can_ins_mask], + mask_ins_pred, + self.pad, + self.unk, + self.eos, + ) + output_tokens = _fill(output_tokens, can_ins_mask, _tokens, self.pad) + output_scores = _fill(output_scores, can_ins_mask, _scores, 0) + # insert words can_ins_word = output_tokens.eq(self.unk).sum(1) > 0 - word_ins_out, word_ins_attn = self.decoder.forward_word_ins( - script_skip_tensor(output_tokens, can_ins_word), - script_skip_tensor_list(encoder_out, can_ins_word), - ) - + if can_ins_word.sum() != 0: + word_ins_out, word_ins_attn = self.decoder.forward_word_ins( + _skip(output_tokens, can_ins_word), + _skip_encoder_out(self.encoder, encoder_out, can_ins_word) + ) + word_ins_score, word_ins_pred = F.log_softmax(word_ins_out, 2).max(-1) + word_ins_pred = word_ins_score.max(-1)[1] + + _tokens, _scores = _apply_ins_words( + output_tokens[can_ins_word], + output_scores[can_ins_word], + word_ins_pred, + word_ins_score, + self.unk, + ) - output_tokens, output_scores, attn = ins_words( - output_tokens, - output_scores, - attn, - word_ins_attn, - word_ins_out, - can_ins_word, - self.pad, - self.unk, - ) + output_tokens = _fill(output_tokens, can_ins_word, _tokens, self.pad) + output_scores = _fill(output_scores, can_ins_word, _scores, 0) + attn = _fill(attn, can_ins_word, word_ins_attn, 0.) # delete some unnecessary paddings cut_off = output_tokens.ne(self.pad).sum(1).max() - - @torch.jit.script - def slice_wrap(x, l): - return x[:, :l] - - @torch.jit.script - def slice_wrap_attn(x, l): - return x if x.size()[0] == 0 else x[:, :l, :] - - output_tokens = slice_wrap(output_tokens, cut_off) - output_scores = slice_wrap(output_scores, cut_off) - attn = slice_wrap(attn, cut_off) - return [output_tokens, output_scores, attn, 0, 0] - - def initialize_output_tokens(self, encoder_out, src_tokens): - initial_output_tokens = torch.cat( - [ - torch.zeros(src_tokens.size(0), 1).fill_(self.bos), - torch.zeros(src_tokens.size(0), 1).fill_(self.eos), - ], - 1, + output_tokens = output_tokens[:, :cut_off] + output_scores = output_scores[:, :cut_off] + attn = None if attn is None else attn[:, :cut_off, :] + + return decoder_out._replace( + output_tokens=output_tokens, + output_scores=output_scores, + attn=attn, ) - initial_output_scores = torch.zeros_like(initial_output_tokens).to( - encoder_out[0] + def initialize_output_tokens(self, encoder_out, src_tokens): + initial_output_tokens = src_tokens.new_zeros(src_tokens.size(0), 2) + initial_output_tokens[:, 0] = self.bos + initial_output_tokens[:, 1] = self.eos + + initial_output_scores = initial_output_tokens.new_zeros( + *initial_output_tokens.size() + ).type_as(encoder_out.encoder_out) + return DecoderOut( + output_tokens=initial_output_tokens, + output_scores=initial_output_scores, + attn=None, + step=0, + max_step=0, ) - initial_attn = torch.empty([0]) - if getattr(self.decoder.layers[-1], "need_attn", True): - initial_attn = torch.zeros([src_tokens.size(0), 2, src_tokens.size(1)]).to( - initial_output_tokens - ) - - return [initial_output_tokens, initial_output_scores, initial_attn, 0, 0] - -class LevenshteinTransformerDecoder(TracingTransformerDecoder): +class LevenshteinTransformerDecoder(TransformerDecoder): def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): super().__init__( args, dictionary, embed_tokens, no_encoder_attn=no_encoder_attn @@ -524,38 +537,32 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): self.embed_word_del = Embedding(2, self.output_embed_dim, None) # del_word, ins_mask, ins_word - self.early_exit = [int(i) for i in args.early_exit.split(",")] + self.early_exit = [int(i) for i in args.early_exit.split(',')] assert len(self.early_exit) == 3 # copy layers for mask-predict/deletion self.layers_msk = None if getattr(args, "no_share_maskpredictor", False): - self.layers_msk = nn.ModuleList( - [ - TransformerDecoderLayer(args, no_encoder_attn) - for _ in range(self.early_exit[1]) - ] - ) + self.layers_msk = nn.ModuleList([ + TransformerDecoderLayer(args, no_encoder_attn) + for _ in range(self.early_exit[1]) + ]) self.layers_del = None if getattr(args, "no_share_discriminator", False): - self.layers_del = nn.ModuleList( - [ - TransformerDecoderLayer(args, no_encoder_attn) - for _ in range(self.early_exit[0]) - ] - ) + self.layers_del = nn.ModuleList([ + TransformerDecoderLayer(args, no_encoder_attn) + for _ in range(self.early_exit[0]) + ]) + + if getattr(args, "share_discriminator_maskpredictor", False): + assert getattr(args, "no_share_discriminator", False), "must set saperate discriminator" + self.layers_msk = self.layers_del def extract_features( - self, - prev_output_tokens, - encoder_out=None, - early_exit=None, - layers=None, - **unused + self, prev_output_tokens, encoder_out=None, early_exit=None, layers=None, **unused ): """ Similar to *forward* but only return features. - Inputs: prev_output_tokens: Tensor(B, T) encoder_out: a dictionary of hidden states and masks @@ -574,7 +581,7 @@ def extract_features( ) # embed tokens and positions - x = self.embed_scale * self.embed_tokens(prev_output_tokens.long()) + x = self.embed_scale * self.embed_tokens(prev_output_tokens) if self.project_in_dim is not None: x = self.project_in_dim(x) @@ -591,11 +598,11 @@ def extract_features( decoder_padding_mask = prev_output_tokens.eq(self.padding_idx) layers = self.layers if layers is None else layers early_exit = len(layers) if early_exit is None else early_exit - for _, layer in enumerate(layers[:early_exit]): + for _, layer in enumerate(layers[: early_exit]): x, attn = layer( x, - encoder_out[0] if encoder_out is not None else None, - encoder_out[1] if encoder_out is not None else None, + encoder_out.encoder_out if encoder_out is not None else None, + encoder_out.encoder_padding_mask if encoder_out is not None else None, self_attn_mask=None, self_attn_padding_mask=decoder_padding_mask, ) @@ -610,38 +617,26 @@ def extract_features( if self.project_out_dim is not None: x = self.project_out_dim(x) - return x, attn, inner_states + return x, {"attn": attn, "inner_states": inner_states} def forward_mask_ins(self, prev_output_tokens, encoder_out=None, **unused): - features, attn, _ = self.extract_features( - prev_output_tokens, - encoder_out=encoder_out, - early_exit=self.early_exit[1], - layers=self.layers_msk, - **unused + features, extra = self.extract_features( + prev_output_tokens, encoder_out=encoder_out, early_exit=self.early_exit[1], layers=self.layers_msk, **unused ) features_cat = torch.cat([features[:, :-1, :], features[:, 1:, :]], 2) - return F.linear(features_cat, self.embed_mask_ins.weight), attn + return F.linear(features_cat, self.embed_mask_ins.weight), extra['attn'] def forward_word_ins(self, prev_output_tokens, encoder_out=None, **unused): - features, attn, _ = self.extract_features( - prev_output_tokens, - encoder_out=encoder_out, - early_exit=self.early_exit[2], - layers=self.layers, - **unused + features, extra = self.extract_features( + prev_output_tokens, encoder_out=encoder_out, early_exit=self.early_exit[2], layers=self.layers, **unused ) - return self.output_layer(features), attn + return self.output_layer(features), extra['attn'] def forward_word_del(self, prev_output_tokens, encoder_out=None, **unused): - features, attn, _ = self.extract_features( - prev_output_tokens, - encoder_out=encoder_out, - early_exit=self.early_exit[0], - layers=self.layers_del, - **unused + features, extra = self.extract_features( + prev_output_tokens, encoder_out=encoder_out, early_exit=self.early_exit[0], layers=self.layers_del, **unused ) - return F.linear(features, self.embed_word_del.weight), attn + return F.linear(features, self.embed_word_del.weight), extra['attn'] @register_model_architecture("levenshtein_transformer", "levenshtein_transformer") @@ -671,7 +666,7 @@ def base_architecture(args): args.share_decoder_input_output_embed = getattr( args, "share_decoder_input_output_embed", False ) - args.share_all_embeddings = getattr(args, "share_all_embeddings", True) + args.share_all_embeddings = getattr(args, "share_all_embeddings", False) args.no_token_positional_embeddings = getattr( args, "no_token_positional_embeddings", False ) @@ -686,6 +681,8 @@ def base_architecture(args): args.early_exit = getattr(args, "early_exit", "6,6,6") args.no_share_discriminator = getattr(args, "no_share_discriminator", False) args.no_share_maskpredictor = getattr(args, "no_share_maskpredictor", False) + args.share_discriminator_maskpredictor = getattr(args, "share_discriminator_maskpredictor", False) + args.no_share_last_layer = getattr(args, "no_share_last_layer", False) @register_model_architecture( diff --git a/fairseq/models/model_utils.py b/fairseq/models/model_utils.py index 432f81ea3d..46ec62f772 100644 --- a/fairseq/models/model_utils.py +++ b/fairseq/models/model_utils.py @@ -3,7 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import Dict, List +from typing import List, Optional import torch from torch import Tensor @@ -33,39 +33,6 @@ def script_skip_tensor(x: Tensor, mask): return res -@torch.jit.script -def script_skip_tensor_dict(x: Dict[str, Tensor], mask): - outputs = {} - for s, t in x.items(): - outputs[s] = t[mask] if t.size(0) == mask.size(0) else t[:, mask] - return outputs - - -def skip_tensors(x, mask): - """ - Getting sliced (dim=0) tensor by mask. Supporting tensor and list/dict of tensors. - """ - if isinstance(x, int): - return x - - if x is None: - return None - - if isinstance(x, torch.Tensor): - if x.size(0) == mask.size(0): - return x[mask] - elif x.size(1) == mask.size(0): - return x[:, mask] - - if isinstance(x, list): - return [skip_tensors(x_i, mask) for x_i in x] - - if isinstance(x, dict): - return {k: skip_tensors(v, mask) for k, v in x.items()} - - raise NotImplementedError - - @torch.jit.script def expand_2d_or_3d_tensor(x, trg_dim: int, padding_idx: int): """ @@ -88,12 +55,17 @@ def expand_2d_or_3d_tensor(x, trg_dim: int, padding_idx: int): @torch.jit.script -def fill_tensors(x, mask, y, padding_idx: int): +def coalesce(x: Optional[Tensor], y: Tensor) -> Tensor: + return x if x is not None else y + + +@torch.jit.script +def fill_tensors(x: Optional[Tensor], mask, y: Optional[Tensor], padding_idx: int) -> Optional[Tensor]: """ Filling tensor x with y at masked positions (dim=0). """ - if x is None or x.size()[0] == 0: - return torch.empty([0]) + if x is None or x.size()[0] == 0 or y is None: + return x assert x.dim() == y.dim() and mask.size(0) == x.size(0) assert x.dim() == 2 or (x.dim() == 3 and x.size(2) == y.size(2)) @@ -116,88 +88,3 @@ def fill_tensors(x, mask, y, padding_idx: int): else: x[mask] = y return x - - -def _apply_ins_masks( - in_tokens, in_scores, mask_ins_pred, padding_idx, unk_idx, eos_idx -): - - in_masks = in_tokens.ne(padding_idx) - in_lengths = in_masks.sum(1) - - # HACK: hacky way to shift all the paddings to eos first. - in_tokens.masked_fill_(~in_masks, eos_idx) - mask_ins_pred.masked_fill_(~in_masks[:, 1:], 0) - - out_lengths = in_lengths + mask_ins_pred.sum(1) - out_max_len = out_lengths.max() - out_masks = ( - torch.arange(out_max_len, device=out_lengths.device)[None, :] - < out_lengths[:, None] - ) - - reordering = (mask_ins_pred + in_masks[:, 1:].long()).cumsum(1) - out_tokens = ( - in_tokens.new_zeros(in_tokens.size(0), out_max_len) - .fill_(padding_idx) - .masked_fill_(out_masks, unk_idx) - ) - out_tokens[:, 0] = in_tokens[:, 0] - out_tokens.scatter_(1, reordering, in_tokens[:, 1:]) - - out_scores = None - if in_scores is not None: - in_scores.masked_fill_(~in_masks, 0) - out_scores = in_scores.new_zeros(*out_tokens.size()) - out_scores[:, 0] = in_scores[:, 0] - out_scores.scatter_(1, reordering, in_scores[:, 1:]) - - return out_tokens, out_scores - - -def _apply_ins_words(in_tokens, in_scores, word_ins_pred, word_ins_scores, unk_idx): - word_ins_masks = in_tokens.eq(unk_idx) - out_tokens = in_tokens.masked_scatter(word_ins_masks, word_ins_pred[word_ins_masks]) - - if in_scores is not None: - out_scores = in_scores.masked_scatter( - word_ins_masks, word_ins_scores[word_ins_masks] - ) - else: - out_scores = None - - return out_tokens, out_scores - - -def _apply_del_words( - in_tokens, in_scores, in_attn, word_del_pred, padding_idx, bos_idx, eos_idx -): - # apply deletion to a tensor - in_masks = in_tokens.ne(padding_idx) - bos_eos_masks = in_tokens.eq(bos_idx) | in_tokens.eq(eos_idx) - - max_len = in_tokens.size(1) - word_del_pred.masked_fill_(~in_masks, 1) - word_del_pred.masked_fill_(bos_eos_masks, 0) - - reordering = ( - torch.arange(max_len, device=in_tokens.device)[None, :] - .expand_as(in_tokens) - .contiguous() - .masked_fill_(word_del_pred, max_len) - .sort(1)[1] - ) - - out_tokens = in_tokens.masked_fill(word_del_pred, padding_idx).gather(1, reordering) - - out_scores = None - if in_scores is not None: - out_scores = in_scores.masked_fill(word_del_pred, 0).gather(1, reordering) - - out_attn = None - if in_attn is not None: - _mask = word_del_pred[:, :, None].expand_as(in_attn) - _reordering = reordering[:, :, None].expand_as(in_attn) - out_attn = in_attn.masked_fill(_mask, 0.).gather(1, _reordering) - - return out_tokens, out_scores, out_attn diff --git a/fairseq/models/nonautoregressive_ensembles.py b/fairseq/models/nonautoregressive_ensembles.py index 01680b86cd..dab227cc7c 100644 --- a/fairseq/models/nonautoregressive_ensembles.py +++ b/fairseq/models/nonautoregressive_ensembles.py @@ -3,11 +3,18 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import math + import torch import torch.nn.functional as F -import math -from fairseq.models.model_utils import fill_tensors as _fill, skip_tensors as _skip -from fairseq.models.model_utils import _apply_del_words, _apply_ins_masks, _apply_ins_words + +from fairseq.models.levenshtein_transformer import ( + _skip, + _apply_ins_masks, + _apply_ins_words, + _apply_del_words, +) +from fairseq.models.model_utils import fill_tensors as _fill class BasicEnsembleModel(torch.nn.Module): diff --git a/fairseq/models/nonautoregressive_transformer.py b/fairseq/models/nonautoregressive_transformer.py index 1add0bd480..ddb2ea12e8 100644 --- a/fairseq/models/nonautoregressive_transformer.py +++ b/fairseq/models/nonautoregressive_transformer.py @@ -5,7 +5,9 @@ import torch import torch.nn.functional as F + from fairseq import utils +from fairseq.iterative_refinement_generator import DecoderOut from fairseq.models import register_model, register_model_architecture from fairseq.models.transformer import ( Embedding, @@ -45,8 +47,8 @@ def _uniform_assignment(src_lens, trg_lens): @register_model("nonautoregressive_transformer") class NATransformerModel(TransformerModel): - def __init__(self, encoder, decoder): - super().__init__(encoder, decoder) + def __init__(self, args, encoder, decoder): + super().__init__(args, encoder, decoder) self.tgt_dict = decoder.dictionary self.bos = decoder.dictionary.bos() self.eos = decoder.dictionary.eos() @@ -112,9 +114,9 @@ def forward_encoder(self, encoder_inputs): return self.encoder(*encoder_inputs) def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwargs): - step = decoder_out["step"] - output_tokens = decoder_out["output_tokens"] - output_scores = decoder_out["output_scores"] + step = decoder_out.step + output_tokens = decoder_out.output_tokens + output_scores = decoder_out.output_scores # execute the decoder output_masks = output_tokens.ne(self.pad) @@ -127,12 +129,16 @@ def forward_decoder(self, decoder_out, encoder_out, decoding_format=None, **kwar output_tokens.masked_scatter_(output_masks, _tokens[output_masks]) output_scores.masked_scatter_(output_masks, _scores[output_masks]) - return {"output_tokens": output_tokens, "output_scores": output_scores, "attn": None} + return decoder_out._replace( + output_tokens=output_tokens, + output_scores=output_scores, + attn=None, + ) def initialize_output_tokens(self, encoder_out, src_tokens): # length prediction _, length_tgt = self.decoder.forward_length_prediction(encoder_out) - max_length = length_tgt.max() + max_length = length_tgt.clamp_(min=2).max() idx_length = utils.new_arange(src_tokens, max_length) initial_output_tokens = src_tokens.new_zeros( @@ -146,13 +152,15 @@ def initialize_output_tokens(self, encoder_out, src_tokens): initial_output_scores = initial_output_tokens.new_zeros( *initial_output_tokens.size() - ).type_as(encoder_out["encoder_out"]) - - return { - "output_tokens": initial_output_tokens, - "output_scores": initial_output_scores, - "attn": None - } + ).type_as(encoder_out.encoder_out) + + return DecoderOut( + output_tokens=initial_output_tokens, + output_scores=initial_output_scores, + attn=None, + step=0, + max_step=0, + ) class NATransformerDecoder(TransformerDecoder): @@ -220,8 +228,8 @@ def extract_features( """ # embedding if embedding_copy: - src_embd = encoder_out["encoder_embedding"] - src_mask = encoder_out["encoder_padding_mask"] + src_embd = encoder_out.encoder_embedding + src_mask = encoder_out.encoder_padding_mask src_mask = ( ~src_mask if src_mask is not None @@ -253,10 +261,8 @@ def extract_features( x, attn = layer( x, - encoder_out["encoder_out"] if encoder_out is not None else None, - encoder_out["encoder_padding_mask"] - if encoder_out is not None - else None, + encoder_out.encoder_out if encoder_out is not None else None, + encoder_out.encoder_padding_mask if encoder_out is not None else None, self_attn_mask=None, self_attn_padding_mask=decoder_padding_mask, ) @@ -311,8 +317,8 @@ def forward_copying_source(self, src_embeds, src_masks, tgt_masks): return copied_embedding def forward_length_prediction(self, encoder_out, tgt_tokens=None): - enc_feats = encoder_out["encoder_out"] # T x B x C - src_masks = encoder_out["encoder_padding_mask"] # B x T or None + enc_feats = encoder_out.encoder_out # T x B x C + src_masks = encoder_out.encoder_padding_mask # B x T or None if self.pred_length_offset: if src_masks is None: diff --git a/fairseq/models/roberta/model.py b/fairseq/models/roberta/model.py index dd92016af7..cffa1c118a 100644 --- a/fairseq/models/roberta/model.py +++ b/fairseq/models/roberta/model.py @@ -348,7 +348,7 @@ def forward(self, src_tokens, features_only=False, return_all_hiddens=False, mas - a dictionary of additional data, where 'inner_states' is a list of hidden states. """ - x, extra = self.extract_features(src_tokens, return_all_hiddens) + x, extra = self.extract_features(src_tokens, return_all_hiddens=return_all_hiddens) if not features_only: x = self.output_layer(x, masked_tokens=masked_tokens) return x, extra diff --git a/fairseq/models/tracing_compliant_transformer.py b/fairseq/models/tracing_compliant_transformer.py deleted file mode 100644 index ca3d807bed..0000000000 --- a/fairseq/models/tracing_compliant_transformer.py +++ /dev/null @@ -1,625 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import math - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from fairseq import options, utils -from fairseq.models import ( - FairseqEncoder, - FairseqIncrementalDecoder, - FairseqEncoderDecoderModel, - register_model, - register_model_architecture, -) -from fairseq.models.transformer import Embedding, Linear, base_architecture -from fairseq.modules import ( - AdaptiveSoftmax, - LayerNorm, - PositionalEmbedding, - SinusoidalPositionalEmbedding, - TransformerDecoderLayer, - TransformerEncoderLayer, -) - -DEFAULT_MAX_SOURCE_POSITIONS = 1024 -DEFAULT_MAX_TARGET_POSITIONS = 1024 - - -@register_model('tracing_transformer') -class TracingTransformerModel(FairseqEncoderDecoderModel): - """ - Transformer model from `"Attention Is All You Need" (Vaswani, et al, 2017) - `_. - - Args: - encoder (TransformerEncoder): the encoder - decoder (TransformerDecoder): the decoder - - The Transformer model provides the following named architectures and - command-line arguments: - - .. argparse:: - :ref: fairseq.models.transformer_parser - :prog: - """ - - @classmethod - def hub_models(cls): - # fmt: off - return { - 'transformer.wmt14.en-fr': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-fr.joined-dict.transformer.tar.bz2', - 'transformer.wmt16.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt16.en-de.joined-dict.transformer.tar.bz2', - 'transformer.wmt18.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.gz', - 'transformer.wmt19.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.joined-dict.ensemble.tar.gz', - 'transformer.wmt19.en-ru': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.ensemble.tar.gz', - 'transformer.wmt19.de-en': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.ensemble.tar.gz', - 'transformer.wmt19.ru-en': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.ensemble.tar.gz', - 'transformer.wmt19.en-de.single_model': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.joined-dict.single_model.tar.gz', - 'transformer.wmt19.en-ru.single_model': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.single_model.tar.gz', - 'transformer.wmt19.de-en.single_model': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.single_model.tar.gz', - 'transformer.wmt19.ru-en.single_model': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.single_model.tar.gz', - } - # fmt: on - - def __init__(self, encoder, decoder): - super().__init__(encoder, decoder) - self.supports_align_args = True - - @staticmethod - def add_args(parser): - """Add model-specific arguments to the parser.""" - # fmt: off - parser.add_argument('--activation-fn', - choices=utils.get_available_activation_fns(), - help='activation function to use') - parser.add_argument('--dropout', type=float, metavar='D', - help='dropout probability') - parser.add_argument('--attention-dropout', type=float, metavar='D', - help='dropout probability for attention weights') - parser.add_argument('--activation-dropout', '--relu-dropout', type=float, metavar='D', - help='dropout probability after activation in FFN.') - parser.add_argument('--encoder-embed-path', type=str, metavar='STR', - help='path to pre-trained encoder embedding') - parser.add_argument('--encoder-embed-dim', type=int, metavar='N', - help='encoder embedding dimension') - parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N', - help='encoder embedding dimension for FFN') - parser.add_argument('--encoder-layers', type=int, metavar='N', - help='num encoder layers') - parser.add_argument('--encoder-attention-heads', type=int, metavar='N', - help='num encoder attention heads') - parser.add_argument('--encoder-normalize-before', action='store_true', - help='apply layernorm before each encoder block') - parser.add_argument('--encoder-learned-pos', action='store_true', - help='use learned positional embeddings in the encoder') - parser.add_argument('--decoder-embed-path', type=str, metavar='STR', - help='path to pre-trained decoder embedding') - parser.add_argument('--decoder-embed-dim', type=int, metavar='N', - help='decoder embedding dimension') - parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N', - help='decoder embedding dimension for FFN') - parser.add_argument('--decoder-layers', type=int, metavar='N', - help='num decoder layers') - parser.add_argument('--decoder-attention-heads', type=int, metavar='N', - help='num decoder attention heads') - parser.add_argument('--decoder-learned-pos', action='store_true', - help='use learned positional embeddings in the decoder') - parser.add_argument('--decoder-normalize-before', action='store_true', - help='apply layernorm before each decoder block') - parser.add_argument('--share-decoder-input-output-embed', action='store_true', - help='share decoder input and output embeddings') - parser.add_argument('--share-all-embeddings', action='store_true', - help='share encoder, decoder and output embeddings' - ' (requires shared dictionary and embed dim)') - parser.add_argument('--no-token-positional-embeddings', default=False, action='store_true', - help='if set, disables positional embeddings (outside self attention)') - parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR', - help='comma separated list of adaptive softmax cutoff points. ' - 'Must be used with adaptive_loss criterion'), - parser.add_argument('--adaptive-softmax-dropout', type=float, metavar='D', - help='sets adaptive softmax dropout for the tail projections') - # args for "Cross+Self-Attention for Transformer Models" (Peitz et al., 2019) - parser.add_argument('--no-cross-attention', default=False, action='store_true', - help='do not perform cross-attention') - parser.add_argument('--cross-self-attention', default=False, action='store_true', - help='perform cross+self-attention') - parser.add_argument('--layer-wise-attention', default=False, action='store_true', - help='perform layer-wise attention (cross-attention or cross+self-attention)') - # fmt: on - - @classmethod - def build_model(cls, args, task): - """Build a new model instance.""" - - # make sure all arguments are present in older models - base_architecture(args) - - if not hasattr(args, 'max_source_positions'): - args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS - if not hasattr(args, 'max_target_positions'): - args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS - - src_dict, tgt_dict = task.source_dictionary, task.target_dictionary - - def build_embedding(dictionary, embed_dim, path=None): - num_embeddings = len(dictionary) - padding_idx = dictionary.pad() - emb = Embedding(num_embeddings, embed_dim, padding_idx) - # if provided, load from preloaded dictionaries - if path: - embed_dict = utils.parse_embedding(path) - utils.load_embedding(embed_dict, dictionary, emb) - return emb - - if args.share_all_embeddings: - if src_dict != tgt_dict: - raise ValueError('--share-all-embeddings requires a joined dictionary') - if args.encoder_embed_dim != args.decoder_embed_dim: - raise ValueError( - '--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim') - if args.decoder_embed_path and ( - args.decoder_embed_path != args.encoder_embed_path): - raise ValueError('--share-all-embeddings not compatible with --decoder-embed-path') - encoder_embed_tokens = build_embedding( - src_dict, args.encoder_embed_dim, args.encoder_embed_path - ) - decoder_embed_tokens = encoder_embed_tokens - args.share_decoder_input_output_embed = True - else: - encoder_embed_tokens = build_embedding( - src_dict, args.encoder_embed_dim, args.encoder_embed_path - ) - decoder_embed_tokens = build_embedding( - tgt_dict, args.decoder_embed_dim, args.decoder_embed_path - ) - - encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens) - decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens) - return cls(encoder, decoder) - - @classmethod - def build_encoder(cls, args, src_dict, embed_tokens): - return TracingTransformerEncoder(args, src_dict, embed_tokens) - - @classmethod - def build_decoder(cls, args, tgt_dict, embed_tokens): - return TracingTransformerDecoder( - args, - tgt_dict, - embed_tokens, - no_encoder_attn=getattr(args, 'no_cross_attention', False), - ) - - -class TracingTransformerEncoder(FairseqEncoder): - """ - Transformer encoder consisting of *args.encoder_layers* layers. Each layer - is a :class:`TransformerEncoderLayer`. - - Args: - args (argparse.Namespace): parsed command-line arguments - dictionary (~fairseq.data.Dictionary): encoding dictionary - embed_tokens (torch.nn.Embedding): input embedding - """ - - def __init__(self, args, dictionary, embed_tokens): - super().__init__(dictionary) - self.register_buffer('version', torch.Tensor([3])) - - self.dropout = args.dropout - - embed_dim = embed_tokens.embedding_dim - self.padding_idx = embed_tokens.padding_idx - self.max_source_positions = args.max_source_positions - - self.embed_tokens = embed_tokens - self.embed_scale = math.sqrt(embed_dim) - self.embed_positions = PositionalEmbedding( - args.max_source_positions, embed_dim, self.padding_idx, - learned=args.encoder_learned_pos, - ) if not args.no_token_positional_embeddings else None - - self.layer_wise_attention = getattr(args, 'layer_wise_attention', False) - - self.layers = nn.ModuleList([]) - self.layers.extend([ - TransformerEncoderLayer(args) - for i in range(args.encoder_layers) - ]) - - if args.encoder_normalize_before: - self.layer_norm = LayerNorm(embed_dim) - else: - self.layer_norm = None - - def forward_embedding(self, src_tokens): - # embed tokens and positions - embed = self.embed_scale * self.embed_tokens(src_tokens) - if self.embed_positions is not None: - x = embed + self.embed_positions(src_tokens) - x = F.dropout(x, p=self.dropout, training=self.training) - return x, embed - - def forward(self, src_tokens, src_lengths, cls_input=None, return_all_hiddens=False): - """ - Args: - src_tokens (LongTensor): tokens in the source language of shape - `(batch, src_len)` - src_lengths (torch.LongTensor): lengths of each source sentence of - shape `(batch)` - return_all_hiddens (bool, optional): also return all of the - intermediate hidden states (default: False). - - Returns: - dict: - - **encoder_out** (Tensor): the last encoder layer's output of - shape `(src_len, batch, embed_dim)` - - **encoder_padding_mask** (ByteTensor): the positions of - padding elements of shape `(batch, src_len)` - - **encoder_states** (List[Tensor]): all intermediate - hidden states of shape `(src_len, batch, embed_dim)`. - Only populated if *return_all_hiddens* is True. - """ - if self.layer_wise_attention: - return_all_hiddens = True - - x, encoder_embedding = self.forward_embedding(src_tokens) - - # B x T x C -> T x B x C - x = x.transpose(0, 1) - - # compute padding mask - encoder_padding_mask = src_tokens.eq(self.padding_idx) - - encoder_states = [] if return_all_hiddens else None - - # encoder layers - for layer in self.layers: - x = layer(x, encoder_padding_mask) - if return_all_hiddens: - encoder_states.append(x) - - if self.layer_norm: - x = self.layer_norm(x) - if return_all_hiddens: - encoder_states[-1] = x - if encoder_states is not None: - return x, encoder_padding_mask, encoder_embedding, encoder_states - else: - return x, encoder_padding_mask, encoder_embedding - - def reorder_encoder_out(self, encoder_out, new_order): - """ - Reorder encoder output according to *new_order*. - - Args: - encoder_out: output from the ``forward()`` method - new_order (LongTensor): desired order - - Returns: - *encoder_out* rearranged according to *new_order* - """ - # 0: encoder_out - # 1: encoder_padding_mask - # 2: encoder_states - if encoder_out[0] is not None: - encoder_out[0] = \ - encoder_out[0].index_select(1, new_order) - if encoder_out[1] is not None: - encoder_out[1] = \ - encoder_out[1].index_select(0, new_order) - if len(encoder_out) == 3 and encoder_out[2] is not None: - for idx, state in enumerate(encoder_out[2]): - encoder_out[2][idx] = state.index_select(1, new_order) - return encoder_out - - def max_positions(self): - """Maximum input length supported by the encoder.""" - if self.embed_positions is None: - return self.max_source_positions - return min(self.max_source_positions, self.embed_positions.max_positions()) - - def buffered_future_mask(self, tensor): - dim = tensor.size(0) - if not hasattr(self, '_future_mask') or self._future_mask is None or self._future_mask.device != tensor.device: - self._future_mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1) - if self._future_mask.size(0) < dim: - self._future_mask = torch.triu(utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1) - return self._future_mask[:dim, :dim] - - def upgrade_state_dict_named(self, state_dict, name): - """Upgrade a (possibly old) state dict for new versions of fairseq.""" - if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): - weights_key = '{}.embed_positions.weights'.format(name) - if weights_key in state_dict: - del state_dict[weights_key] - state_dict['{}.embed_positions._float_tensor'.format(name)] = torch.FloatTensor(1) - for i in range(len(self.layers)): - # update layer norms - self.layers[i].upgrade_state_dict_named(state_dict, "{}.layers.{}".format(name, i)) - - version_key = '{}.version'.format(name) - if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2: - # earlier checkpoints did not normalize after the stack of layers - self.layer_norm = None - self.normalize = False - state_dict[version_key] = torch.Tensor([1]) - return state_dict - - -class TracingTransformerDecoder(FairseqIncrementalDecoder): - """ - Transformer decoder consisting of *args.decoder_layers* layers. Each layer - is a :class:`TransformerDecoderLayer`. - - Args: - args (argparse.Namespace): parsed command-line arguments - dictionary (~fairseq.data.Dictionary): decoding dictionary - embed_tokens (torch.nn.Embedding): output embedding - no_encoder_attn (bool, optional): whether to attend to encoder outputs - (default: False). - """ - - def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): - super().__init__(dictionary) - self.register_buffer('version', torch.Tensor([3])) - - self.dropout = args.dropout - self.share_input_output_embed = args.share_decoder_input_output_embed - - input_embed_dim = embed_tokens.embedding_dim - embed_dim = args.decoder_embed_dim - self.output_embed_dim = args.decoder_output_dim - - self.padding_idx = embed_tokens.padding_idx - self.max_target_positions = args.max_target_positions - - self.embed_tokens = embed_tokens - self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim - - self.project_in_dim = Linear(input_embed_dim, embed_dim, bias=False) if embed_dim != input_embed_dim else None - - self.embed_positions = PositionalEmbedding( - args.max_target_positions, embed_dim, self.padding_idx, - learned=args.decoder_learned_pos, - ) if not args.no_token_positional_embeddings else None - - self.cross_self_attention = getattr(args, 'cross_self_attention', False) - self.layer_wise_attention = getattr(args, 'layer_wise_attention', False) - - self.layers = nn.ModuleList([]) - self.layers.extend([ - TransformerDecoderLayer(args, no_encoder_attn) - for _ in range(args.decoder_layers) - ]) - - self.adaptive_softmax = None - - self.project_out_dim = Linear(embed_dim, self.output_embed_dim, bias=False) \ - if embed_dim != self.output_embed_dim and not args.tie_adaptive_weights else None - - if args.adaptive_softmax_cutoff is not None: - self.adaptive_softmax = AdaptiveSoftmax( - len(dictionary), - self.output_embed_dim, - options.eval_str_list(args.adaptive_softmax_cutoff, type=int), - dropout=args.adaptive_softmax_dropout, - adaptive_inputs=embed_tokens if args.tie_adaptive_weights else None, - factor=args.adaptive_softmax_factor, - tie_proj=args.tie_adaptive_proj, - ) - elif not self.share_input_output_embed: - self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), self.output_embed_dim)) - nn.init.normal_(self.embed_out, mean=0, std=self.output_embed_dim ** -0.5) - - if args.decoder_normalize_before and not getattr(args, 'no_decoder_final_norm', False): - self.layer_norm = LayerNorm(embed_dim) - else: - self.layer_norm = None - - def forward( - self, - prev_output_tokens, - encoder_out=None, - incremental_state=None, - features_only=False, - **extra_args, - ): - """ - Args: - prev_output_tokens (LongTensor): previous decoder outputs of shape - `(batch, tgt_len)`, for teacher forcing - encoder_out (Tensor, optional): output from the encoder, used for - encoder-side attention - incremental_state (dict): dictionary used for storing state during - :ref:`Incremental decoding` - features_only (bool, optional): only return features without - applying output layer (default: False). - - Returns: - tuple: - - the decoder's output of shape `(batch, tgt_len, vocab)` - - a dictionary with any model-specific outputs - """ - x, extra = self.extract_features( - prev_output_tokens, encoder_out, incremental_state, **extra_args, - ) - if not features_only: - x = self.output_layer(x) - return x, extra - - def extract_features( - self, - prev_output_tokens, - encoder_out=None, - incremental_state=None, - full_context_alignment=False, - alignment_layer=None, - alignment_heads=None, - **unused, - ): - """ - Similar to *forward* but only return features. - - Includes several features from "Jointly Learning to Align and - Translate with Transformer Models" (Garg et al., EMNLP 2019). - - Args: - full_context_alignment (bool, optional): don't apply - auto-regressive mask to self-attention (default: False). - alignment_layer (int, optional): return mean alignment over - heads at this layer (default: last layer). - alignment_heads (int, optional): only average alignment over - this many heads (default: all heads). - - Returns: - tuple: - - the decoder's features of shape `(batch, tgt_len, embed_dim)` - - a dictionary with any model-specific outputs - """ - if alignment_layer is None: - alignment_layer = len(self.layers) - 1 - - # embed positions - positions = self.embed_positions( - prev_output_tokens, - incremental_state=incremental_state, - ) if self.embed_positions is not None else None - - if incremental_state is not None: - prev_output_tokens = prev_output_tokens[:, -1:] - if positions is not None: - positions = positions[:, -1:] - - # embed tokens and positions - x = self.embed_scale * self.embed_tokens(prev_output_tokens) - - if self.project_in_dim is not None: - x = self.project_in_dim(x) - - if positions is not None: - x += positions - x = F.dropout(x, p=self.dropout, training=self.training) - - # B x T x C -> T x B x C - x = x.transpose(0, 1) - - self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx) - if not self_attn_padding_mask.any() and not self.cross_self_attention: - self_attn_padding_mask = None - - # decoder layers - attn = None - inner_states = [x] - for idx, layer in enumerate(self.layers): - encoder_state = None - if encoder_out is not None: - if self.layer_wise_attention: - encoder_state = encoder_out[3][idx] - else: - encoder_state = encoder_out[0] - - if incremental_state is None and not full_context_alignment: - self_attn_mask = self.buffered_future_mask(x) - else: - self_attn_mask = None - - x, layer_attn = layer( - x, - encoder_state - if encoder_state is not None else None, - encoder_out[1] - if encoder_out is not None else None, - incremental_state, - self_attn_mask=self_attn_mask, - self_attn_padding_mask=self_attn_padding_mask, - need_attn=(idx == alignment_layer), - need_head_weights=(idx == alignment_layer), - ) - - inner_states.append(x) - if layer_attn is not None and idx == alignment_layer: - attn = layer_attn.float() - - if attn is not None: - if alignment_heads is not None: - attn = attn[:alignment_heads] - - # average probabilities over heads - attn = attn.mean(dim=0) - - if self.layer_norm: - x = self.layer_norm(x) - - # T x B x C -> B x T x C - x = x.transpose(0, 1) - - if self.project_out_dim is not None: - x = self.project_out_dim(x) - - return x, {'attn': attn, 'inner_states': inner_states} - - def output_layer(self, features, **kwargs): - """Project features to the vocabulary size.""" - if self.adaptive_softmax is None: - # project back to size of vocabulary - if self.share_input_output_embed: - return F.linear(features, self.embed_tokens.weight) - else: - return F.linear(features, self.embed_out) - else: - return features - - def max_positions(self): - """Maximum output length supported by the decoder.""" - if self.embed_positions is None: - return self.max_target_positions - return min(self.max_target_positions, self.embed_positions.max_positions()) - - def buffered_future_mask(self, tensor): - dim = tensor.size(0) - if ( - not hasattr(self, '_future_mask') - or self._future_mask is None - or self._future_mask.device != tensor.device - or self._future_mask.size(0) < dim - ): - self._future_mask = torch.triu(utils.fill_with_neg_inf(tensor.new(dim, dim)), 1) - return self._future_mask[:dim, :dim] - - def upgrade_state_dict_named(self, state_dict, name): - """Upgrade a (possibly old) state dict for new versions of fairseq.""" - if isinstance(self.embed_positions, SinusoidalPositionalEmbedding): - weights_key = '{}.embed_positions.weights'.format(name) - if weights_key in state_dict: - del state_dict[weights_key] - state_dict['{}.embed_positions._float_tensor'.format(name)] = torch.FloatTensor(1) - - for i in range(len(self.layers)): - # update layer norms - layer_norm_map = { - '0': 'self_attn_layer_norm', - '1': 'encoder_attn_layer_norm', - '2': 'final_layer_norm' - } - for old, new in layer_norm_map.items(): - for m in ('weight', 'bias'): - k = '{}.layers.{}.layer_norms.{}.{}'.format(name, i, old, m) - if k in state_dict: - state_dict['{}.layers.{}.{}.{}'.format(name, i, new, m)] = state_dict[k] - del state_dict[k] - - version_key = '{}.version'.format(name) - if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2: - # earlier checkpoints did not normalize after the stack of layers - self.layer_norm = None - self.normalize = False - state_dict[version_key] = torch.Tensor([1]) - - return state_dict diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index 3c2b607a5e..380419aa40 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from collections import namedtuple import math import torch @@ -279,6 +280,14 @@ def forward_decoder( return decoder_out +EncoderOut = namedtuple('TransformerEncoderOut', [ + 'encoder_out', # T x B x C + 'encoder_padding_mask', # B x T + 'encoder_embedding', # B x T x C + 'encoder_states', # List[T x B x C] +]) + + class TransformerEncoder(FairseqEncoder): """ Transformer encoder consisting of *args.encoder_layers* layers. Each layer @@ -348,11 +357,13 @@ def forward(self, src_tokens, src_lengths, cls_input=None, return_all_hiddens=Fa intermediate hidden states (default: False). Returns: - dict: + namedtuple: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of padding elements of shape `(batch, src_len)` + - **encoder_embedding** (Tensor): the (scaled) embedding lookup + of shape `(batch, src_len, embed_dim)` - **encoder_states** (List[Tensor]): all intermediate hidden states of shape `(src_len, batch, embed_dim)`. Only populated if *return_all_hiddens* is True. @@ -386,12 +397,12 @@ def forward(self, src_tokens, src_lengths, cls_input=None, return_all_hiddens=Fa if return_all_hiddens: encoder_states[-1] = x - return { - 'encoder_out': x, # T x B x C - 'encoder_padding_mask': encoder_padding_mask, # B x T - 'encoder_embedding': encoder_embedding, # B x T x C - 'encoder_states': encoder_states, # List[T x B x C] - } + return EncoderOut( + encoder_out=x, # T x B x C + encoder_padding_mask=encoder_padding_mask, # B x T + encoder_embedding=encoder_embedding, # B x T x C + encoder_states=encoder_states, # List[T x B x C] + ) def reorder_encoder_out(self, encoder_out, new_order): """ @@ -404,15 +415,21 @@ def reorder_encoder_out(self, encoder_out, new_order): Returns: *encoder_out* rearranged according to *new_order* """ - if encoder_out['encoder_out'] is not None: - encoder_out['encoder_out'] = \ - encoder_out['encoder_out'].index_select(1, new_order) - if encoder_out['encoder_padding_mask'] is not None: - encoder_out['encoder_padding_mask'] = \ - encoder_out['encoder_padding_mask'].index_select(0, new_order) - if encoder_out.get('encoder_states', None) is not None: - for idx, state in enumerate(encoder_out['encoder_states']): - encoder_out['encoder_states'][idx] = state.index_select(1, new_order) + if encoder_out.encoder_out is not None: + encoder_out = encoder_out._replace( + encoder_out=encoder_out.encoder_out.index_select(1, new_order) + ) + if encoder_out.encoder_padding_mask is not None: + encoder_out = encoder_out._replace( + encoder_padding_mask=encoder_out.encoder_padding_mask.index_select(0, new_order) + ) + if encoder_out.encoder_embedding is not None: + encoder_out = encoder_out._replace( + encoder_embedding=encoder_out.encoder_embedding.index_select(0, new_order) + ) + if encoder_out.encoder_states is not None: + for idx, state in enumerate(encoder_out.encoder_states): + encoder_out.encoder_states[idx] = state.index_select(1, new_order) return encoder_out def max_positions(self): @@ -532,13 +549,13 @@ def forward( encoder_out=None, incremental_state=None, features_only=False, - **extra_args, + **extra_args ): """ Args: prev_output_tokens (LongTensor): previous decoder outputs of shape `(batch, tgt_len)`, for teacher forcing - encoder_out (Tensor, optional): output from the encoder, used for + encoder_out (optional): output from the encoder, used for encoder-side attention incremental_state (dict): dictionary used for storing state during :ref:`Incremental decoding` @@ -551,7 +568,10 @@ def forward( - a dictionary with any model-specific outputs """ x, extra = self.extract_features( - prev_output_tokens, encoder_out, incremental_state, **extra_args, + prev_output_tokens, + encoder_out=encoder_out, + incremental_state=incremental_state, + **extra_args ) if not features_only: x = self.output_layer(x) @@ -628,9 +648,9 @@ def extract_features( encoder_state = None if encoder_out is not None: if self.layer_wise_attention: - encoder_state = encoder_out['encoder_states'][idx] + encoder_state = encoder_out.encoder_states[idx] else: - encoder_state = encoder_out['encoder_out'] + encoder_state = encoder_out.encoder_out if incremental_state is None and not full_context_alignment: self_attn_mask = self.buffered_future_mask(x) @@ -643,7 +663,7 @@ def extract_features( x, layer_attn = layer( x, encoder_state, - encoder_out['encoder_padding_mask'] if encoder_out is not None else None, + encoder_out.encoder_padding_mask if encoder_out is not None else None, incremental_state, self_attn_mask=self_attn_mask, self_attn_padding_mask=self_attn_padding_mask, diff --git a/fairseq/modules/mean_pool_gating_network.py b/fairseq/modules/mean_pool_gating_network.py index a22d9bd6e4..25743b4e98 100644 --- a/fairseq/modules/mean_pool_gating_network.py +++ b/fairseq/modules/mean_pool_gating_network.py @@ -26,16 +26,15 @@ def __init__(self, embed_dim, num_experts, dropout=None): def forward(self, encoder_out): if not ( - isinstance(encoder_out, dict) - and 'encoder_out' in encoder_out - and 'encoder_padding_mask' in encoder_out - and encoder_out['encoder_out'].size(2) == self.embed_dim + hasattr(encoder_out, 'encoder_out') + and hasattr(encoder_out, 'encoder_padding_mask') + and encoder_out.encoder_out.size(2) == self.embed_dim ): raise ValueError('Unexpected format for encoder_out') # mean pooling over time - encoder_padding_mask = encoder_out['encoder_padding_mask'] # B x T - encoder_out = encoder_out['encoder_out'].transpose(0, 1) # B x T x C + encoder_padding_mask = encoder_out.encoder_padding_mask # B x T + encoder_out = encoder_out.encoder_out.transpose(0, 1) # B x T x C if encoder_padding_mask is not None: encoder_out = encoder_out.clone() # required because of transpose above encoder_out[encoder_padding_mask] = 0 diff --git a/tests/test_binaries.py b/tests/test_binaries.py index 113901ab06..96a6611c5b 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -197,51 +197,90 @@ def test_dynamicconv(self): ]) generate_main(data_dir) + def test_cmlm_transformer(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory('test_cmlm_transformer') as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data(data_dir, ['--joined-dictionary']) + train_translation_model(data_dir, 'cmlm_transformer', [ + '--apply-bert-init', + '--criterion', 'nat_loss', + '--noise', 'full_mask', + '--pred-length-offset', + '--length-loss-factor', '0.1' + ], task='translation_lev') + generate_main(data_dir, [ + '--task', 'translation_lev', + '--iter-decode-max-iter', '9', + '--iter-decode-eos-penalty', '0', + '--print-step', + ]) + def test_levenshtein_transformer(self): with contextlib.redirect_stdout(StringIO()): with tempfile.TemporaryDirectory('test_levenshtein_transformer') as data_dir: create_dummy_data(data_dir) - preprocess_translation_data(data_dir) + preprocess_translation_data(data_dir, ['--joined-dictionary']) train_translation_model(data_dir, 'levenshtein_transformer', [ '--apply-bert-init', '--early-exit', '6,6,6', '--criterion', 'nat_loss' ], task='translation_lev') - generate_main(data_dir) + generate_main(data_dir, [ + '--task', 'translation_lev', + '--iter-decode-max-iter', '9', + '--iter-decode-eos-penalty', '0', + '--print-step', + ]) def test_nonautoregressive_transformer(self): with contextlib.redirect_stdout(StringIO()): with tempfile.TemporaryDirectory('test_nonautoregressive_transformer') as data_dir: create_dummy_data(data_dir) - preprocess_translation_data(data_dir) + preprocess_translation_data(data_dir, ['--joined-dictionary']) train_translation_model(data_dir, 'nonautoregressive_transformer', [ '--apply-bert-init', '--src-embedding-copy', '--criterion', 'nat_loss', '--noise', 'full_mask', '--pred-length-offset', '--length-loss-factor', '0.1' ], task='translation_lev') - generate_main(data_dir) + generate_main(data_dir, [ + '--task', 'translation_lev', + '--iter-decode-max-iter', '9', + '--iter-decode-eos-penalty', '0', + '--print-step', + ]) def test_iterative_nonautoregressive_transformer(self): with contextlib.redirect_stdout(StringIO()): with tempfile.TemporaryDirectory('test_iterative_nonautoregressive_transformer') as data_dir: create_dummy_data(data_dir) - preprocess_translation_data(data_dir) + preprocess_translation_data(data_dir, ['--joined-dictionary']) train_translation_model(data_dir, 'iterative_nonautoregressive_transformer', [ '--apply-bert-init', '--src-embedding-copy', '--criterion', 'nat_loss', '--noise', 'full_mask', '--stochastic-approx', '--dae-ratio', '0.5', '--train-step', '3' ], task='translation_lev') - generate_main(data_dir) + generate_main(data_dir, [ + '--task', 'translation_lev', + '--iter-decode-max-iter', '9', + '--iter-decode-eos-penalty', '0', + '--print-step', + ]) def test_insertion_transformer(self): with contextlib.redirect_stdout(StringIO()): with tempfile.TemporaryDirectory('test_insertion_transformer') as data_dir: create_dummy_data(data_dir) - preprocess_translation_data(data_dir) + preprocess_translation_data(data_dir, ['--joined-dictionary']) train_translation_model(data_dir, 'insertion_transformer', [ '--apply-bert-init', '--criterion', 'nat_loss', '--noise', 'random_mask' ], task='translation_lev') - generate_main(data_dir) + generate_main(data_dir, [ + '--task', 'translation_lev', + '--iter-decode-max-iter', '9', + '--iter-decode-eos-penalty', '0', + '--print-step', + ]) def test_mixture_of_experts(self): with contextlib.redirect_stdout(StringIO()):