Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Update simple_seq2seq.py #90

Merged
merged 27 commits into from
Aug 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Made the SST reader a little more strict in the kinds of input it accepts.



## [v1.1.0rc2](https://github.com/allenai/allennlp-models/releases/tag/v1.1.0rc2) - 2020-07-31

### Changed
Expand All @@ -37,10 +38,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `ModelCard` and related classes. Added model cards for all the pretrained models.
- Added a field `registered_predictor_name` to `ModelCard`.
- Added a method `load_predictor` to `allennlp_models.pretrained`.
- Added support to multi-layer decoder in simple seq2seq model.


## [v1.1.0rc1](https://github.com/allenai/allennlp-models/releases/tag/v1.1.0rc1) - 2020-07-14


### Fixed

- Updated the BERT SRL model to be compatible with the new huggingface tokenizers.
Expand Down
150 changes: 93 additions & 57 deletions allennlp_models/generation/models/simple_seq2seq.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Dict, List, Tuple, Iterable
from typing import Dict, List, Tuple, Iterable, Any

import numpy
from overrides import overrides
import torch
import torch.nn.functional as F
from torch.nn.modules.linear import Linear
from torch.nn.modules.rnn import LSTMCell
from torch.nn.modules.rnn import LSTMCell, LSTM

from allennlp.common.util import START_SYMBOL, END_SYMBOL
from allennlp.data import TextFieldTensors, Vocabulary
Expand All @@ -25,9 +25,7 @@ class SimpleSeq2Seq(Model):
a neural machine translation system, an abstractive summarization system, or any other common
seq2seq problem. The model here is simple, but should be a decent starting place for
implementing recent models for these tasks.

# Parameters

vocab : `Vocabulary`, required
Vocabulary containing source and target vocabularies. They may be under the same namespace
(`tokens`) or the target tokens can have a different namespace, in which case it needs to
Expand All @@ -45,6 +43,10 @@ class SimpleSeq2Seq(Model):
target_embedding_dim : `int`, optional (default = `'source_embedding_dim'`)
You can specify an embedding dimensionality for the target side. If not, we'll use the same
value as the source embedder's.
target_pretrain_file : `str`, optional (default = `None`)
Path to target pretrain embedding files
target_decoder_layers : `int`, optional (default = `1`)
Nums of layer for decoder
attention : `Attention`, optional (default = `None`)
If you want to use attention to get a dynamic summary of the encoder outputs at each step
of decoding, this is the function used to compute similarity between the decoder hidden
Expand Down Expand Up @@ -78,9 +80,12 @@ def __init__(
scheduled_sampling_ratio: float = 0.0,
use_bleu: bool = True,
bleu_ngram_weights: Iterable[float] = (0.25, 0.25, 0.25, 0.25),
target_pretrain_file: str = None,
target_decoder_layers: int = 1,
) -> None:
super().__init__(vocab)
self._target_namespace = target_namespace
self._target_decoder_layers = target_decoder_layers
self._scheduled_sampling_ratio = scheduled_sampling_ratio

# We need the start symbol to provide as the input at the first timestep of decoding, and
Expand All @@ -93,7 +98,7 @@ def __init__(
self.vocab._padding_token, self._target_namespace
)
self._bleu = BLEU(
bleu_ngram_weights, exclude_indices={pad_index, self._end_index, self._start_index}
bleu_ngram_weights, exclude_indices={pad_index, self._end_index, self._start_index},
)
else:
self._bleu = None
Expand All @@ -118,9 +123,17 @@ def __init__(

# Dense embedding of vocab words in the target space.
target_embedding_dim = target_embedding_dim or source_embedder.get_output_dim()
self._target_embedder = Embedding(
num_embeddings=num_classes, embedding_dim=target_embedding_dim
)
if not target_pretrain_file:
self._target_embedder = Embedding(
num_embeddings=num_classes, embedding_dim=target_embedding_dim
)
else:
self._target_embedder = Embedding(
embedding_dim=target_embedding_dim,
pretrained_file=target_pretrain_file,
vocab_namespace=self._target_namespace,
vocab=self.vocab,
)

# Decoder output dim needs to be the same as the encoder output dim since we initialize the
# hidden state of the decoder with the final hidden state of the encoder.
Expand All @@ -139,7 +152,12 @@ def __init__(
# We'll use an LSTM cell as the recurrent cell that produces a hidden state
# for the decoder at each time step.
# TODO (pradeep): Do not hardcode decoder cell type.
self._decoder_cell = LSTMCell(self._decoder_input_dim, self._decoder_output_dim)
if self._target_decoder_layers > 1:
self._decoder_cell = LSTM(
self._decoder_input_dim, self._decoder_output_dim, self._target_decoder_layers,
)
else:
self._decoder_cell = LSTMCell(self._decoder_input_dim, self._decoder_output_dim)

# We project the hidden state from the decoder into the output vocabulary space
# in order to get log probabilities of each target token, at each time step.
Expand All @@ -150,9 +168,7 @@ def take_step(
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Take a decoding step. This is called by the beam search class.

# Parameters

last_predictions : `torch.Tensor`
A tensor of shape `(group_size,)`, which gives the indices of the predictions
during the last time step.
Expand All @@ -166,14 +182,12 @@ def take_step(
The time step in beam search decoding.

# Returns

Tuple[torch.Tensor, Dict[str, torch.Tensor]]
A tuple of `(log_probabilities, updated_state)`, where `log_probabilities`
is a tensor of shape `(group_size, num_classes)` containing the predicted
log probability of each class for the next step, for each item in the group,
while `updated_state` is a dictionary of tensors containing the encoder outputs,
source mask, and updated decoder hidden state and context.

Notes
-----
We treat the inputs as a batch, even though `group_size` is not necessarily
Expand All @@ -197,19 +211,17 @@ def forward(

"""
Make foward pass with decoder logic for producing the entire target sequence.

# Parameters

source_tokens : `TextFieldTensors`
The output of `TextField.as_array()` applied on the source `TextField`. This will be
passed through a `TextFieldEmbedder` and then through an encoder.
target_tokens : `TextFieldTensors`, optional (default = `None`)
Output of `Textfield.as_array()` applied on target `TextField`. We assume that the
target tokens are also represented as a `TextField`.

# Returns

`Dict[str, torch.Tensor]`

"""
state = self._encode(source_tokens)

Expand All @@ -235,37 +247,38 @@ def forward(
return output_dict

@overrides
def make_output_human_readable(
self, output_dict: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:
def make_output_human_readable(self, output_dict: Dict[str, Any]) -> Dict[str, Any]:
"""
Finalize predictions.

This method overrides `Model.make_output_human_readable`, which gets called after `Model.forward`, at test
time, to finalize predictions. The logic for the decoder part of the encoder-decoder lives
within the `forward` method.

This method trims the output predictions to the first end symbol, replaces indices with
corresponding tokens, and adds a field called `predicted_tokens` to the `output_dict`.
"""
predicted_indices = output_dict["predictions"]
if not isinstance(predicted_indices, numpy.ndarray):
predicted_indices = predicted_indices.detach().cpu().numpy()
all_predicted_tokens = []
for indices in predicted_indices:
for top_k_predictions in predicted_indices:
# Beam search gives us the top k results for each source sentence in the batch
# but we just want the single best.
if len(indices.shape) > 1:
indices = indices[0]
indices = list(indices)
# Collect indices till the first end_symbol
if self._end_index in indices:
indices = indices[: indices.index(self._end_index)]
predicted_tokens = [
self.vocab.get_token_from_index(x, namespace=self._target_namespace)
for x in indices
]
all_predicted_tokens.append(predicted_tokens)
# we want top-k results.
if len(top_k_predictions.shape) == 1:
top_k_predictions = [top_k_predictions]

batch_predicted_tokens = []
for indices in top_k_predictions:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the extra loop necessary now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original code only return top1 result for beam_search. It's not convenient if we want eval top5 score (or we want choose result by some hand-craft algorithm). So I use the code segment in copy-net to get all results.

indices = list(indices)
# Collect indices till the first end_symbol
if self._end_index in indices:
indices = indices[: indices.index(self._end_index)]
predicted_tokens = [
self.vocab.get_token_from_index(x, namespace=self._target_namespace)
for x in indices
]
batch_predicted_tokens.append(predicted_tokens)

all_predicted_tokens.append(batch_predicted_tokens)
output_dict["predicted_tokens"] = all_predicted_tokens
return output_dict

Expand All @@ -282,7 +295,7 @@ def _init_decoder_state(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch
batch_size = state["source_mask"].size(0)
# shape: (batch_size, encoder_output_dim)
final_encoder_output = util.get_final_encoder_states(
state["encoder_outputs"], state["source_mask"], self._encoder.is_bidirectional()
state["encoder_outputs"], state["source_mask"], self._encoder.is_bidirectional(),
)
# Initialize the decoder hidden state with the final output of the encoder.
# shape: (batch_size, decoder_output_dim)
Expand All @@ -291,14 +304,24 @@ def _init_decoder_state(self, state: Dict[str, torch.Tensor]) -> Dict[str, torch
state["decoder_context"] = state["encoder_outputs"].new_zeros(
batch_size, self._decoder_output_dim
)
if self._target_decoder_layers > 1:
# shape: (num_layers, batch_size, decoder_output_dim)
state["decoder_hidden"] = (
state["decoder_hidden"].unsqueeze(0).repeat(self._target_decoder_layers, 1, 1)
)

# shape: (num_layers, batch_size, decoder_output_dim)
state["decoder_context"] = (
state["decoder_context"].unsqueeze(0).repeat(self._target_decoder_layers, 1, 1)
)

return state

def _forward_loop(
self, state: Dict[str, torch.Tensor], target_tokens: TextFieldTensors = None
) -> Dict[str, torch.Tensor]:
"""
Make forward pass during training or do greedy search during prediction.

Notes
-----
We really only use the predictions from the method to test that beam search
Expand Down Expand Up @@ -401,7 +424,6 @@ def _prepare_output_projections(
Decode current state and last prediction to produce produce projections
into the target space, which can then be used to get probabilities of
each target token for the next step.

Inputs are the same as for `take_step()`.
"""
# shape: (group_size, max_input_sequence_length, encoder_output_dim)
Expand All @@ -410,42 +432,59 @@ def _prepare_output_projections(
# shape: (group_size, max_input_sequence_length)
source_mask = state["source_mask"]

# shape: (group_size, decoder_output_dim)
# shape: (num_layers, group_size, decoder_output_dim)
decoder_hidden = state["decoder_hidden"]

# shape: (group_size, decoder_output_dim)
# shape: (num_layers, group_size, decoder_output_dim)
decoder_context = state["decoder_context"]

# shape: (group_size, target_embedding_dim)
embedded_input = self._target_embedder(last_predictions)

if self._attention:
# shape: (group_size, encoder_output_dim)
attended_input = self._prepare_attended_input(
decoder_hidden, encoder_outputs, source_mask
)

if self._target_decoder_layers > 1:
attended_input = self._prepare_attended_input(
decoder_hidden[0], encoder_outputs, source_mask
)
else:
attended_input = self._prepare_attended_input(
decoder_hidden, encoder_outputs, source_mask
)
# shape: (group_size, decoder_output_dim + target_embedding_dim)
decoder_input = torch.cat((attended_input, embedded_input), -1)
else:
# shape: (group_size, target_embedding_dim)
decoder_input = embedded_input

# shape (decoder_hidden): (batch_size, decoder_output_dim)
# shape (decoder_context): (batch_size, decoder_output_dim)

# TODO (epwalsh): remove the autocast(False) once torch's AMP is working for LSTMCells.
with torch.cuda.amp.autocast(False):
decoder_hidden, decoder_context = self._decoder_cell(
decoder_input.float(), (decoder_hidden.float(), decoder_context.float())
)
if self._target_decoder_layers > 1:
# shape: (1, batch_size, target_embedding_dim)
decoder_input = decoder_input.unsqueeze(0)

# shape (decoder_hidden): (num_layers, batch_size, decoder_output_dim)
# shape (decoder_context): (num_layers, batch_size, decoder_output_dim)
# TODO (epwalsh): remove the autocast(False) once torch's AMP is working for LSTMCells.
with torch.cuda.amp.autocast(False):
_, (decoder_hidden, decoder_context) = self._decoder_cell(
decoder_input.float(), (decoder_hidden.float(), decoder_context.float())
)
else:
# shape (decoder_hidden): (batch_size, decoder_output_dim)
# shape (decoder_context): (batch_size, decoder_output_dim)
# TODO (epwalsh): remove the autocast(False) once torch's AMP is working for LSTMCells.
with torch.cuda.amp.autocast(False):
decoder_hidden, decoder_context = self._decoder_cell(
decoder_input.float(), (decoder_hidden.float(), decoder_context.float())
)

state["decoder_hidden"] = decoder_hidden
state["decoder_context"] = decoder_context

# shape: (group_size, num_classes)
output_projections = self._output_projection_layer(decoder_hidden)

if self._target_decoder_layers > 1:
output_projections = self._output_projection_layer(decoder_hidden[-1])
else:
output_projections = self._output_projection_layer(decoder_hidden)
return output_projections, state

def _prepare_attended_input(
Expand All @@ -465,20 +504,17 @@ def _prepare_attended_input(

@staticmethod
def _get_loss(
logits: torch.LongTensor, targets: torch.LongTensor, target_mask: torch.BoolTensor
logits: torch.LongTensor, targets: torch.LongTensor, target_mask: torch.BoolTensor,
) -> torch.Tensor:
"""
Compute loss.

Takes logits (unnormalized outputs from the decoder) of size (batch_size,
num_decoding_steps, num_classes), target indices of size (batch_size, num_decoding_steps+1)
and corresponding masks of size (batch_size, num_decoding_steps+1) steps and computes cross
entropy loss while taking the mask into account.

The length of `targets` is expected to be greater than that of `logits` because the
decoder does not need to compute the output corresponding to the last timestep of
`targets`. This method aligns the inputs appropriately to compute the loss.

During training, we want the logit corresponding to timestep i to be similar to the target
token from timestep i + 1. That is, the targets should be shifted by one timestep for
appropriate comparison. Consider a single example where the target has 3 words, and
Expand Down
8 changes: 6 additions & 2 deletions tests/generation/models/simple_seq2seq_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ def test_bidirectional_model_can_train_save_and_load(self):
self.param_file, tolerance=1e-2, overrides=param_overrides
)

def test_multi_layer_decoder_model_can_train_save_and_load(self):
param_overrides = json.dumps({"model": {"target_decoder_layers": 2}})
self.ensure_model_can_train_save_and_load(
self.param_file, tolerance=1e-2, overrides=param_overrides
)

def test_no_attention_model_can_train_save_and_load(self):
param_overrides = json.dumps({"model": {"attention": None}})
self.ensure_model_can_train_save_and_load(
Expand All @@ -49,7 +55,6 @@ def test_greedy_model_can_train_save_and_load(self):
)

def test_loss_is_computed_correctly(self):

batch_size = 5
num_decoding_steps = 5
num_classes = 10
Expand Down Expand Up @@ -85,7 +90,6 @@ def test_decode_runs_correctly(self):
assert "predicted_tokens" in decode_output_dict

def test_greedy_decode_matches_beam_search(self):

beam_search = BeamSearch(
self.model._end_index, max_steps=self.model._max_decoding_steps, beam_size=1
)
Expand Down
3 changes: 2 additions & 1 deletion tests/generation/predictors/seq2seq_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ def test_uses_named_inputs_with_simple_seq2seq(self):
predicted_tokens = result.get("predicted_tokens")
assert predicted_tokens is not None
assert isinstance(predicted_tokens, list)
assert all(isinstance(x, str) for x in predicted_tokens)
for predicted_token in predicted_tokens:
assert all(isinstance(x, str) for x in predicted_token)
Comment on lines +22 to +23
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

predicted_tokens is now a list of lists?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it contains top n sequences, could see in here


def test_uses_named_inputs_with_composed_seq2seq(self):
inputs = {"source": "What kind of test succeeded on its first attempt?"}
Expand Down