Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update past_key_values in GPT-2 #9596

Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 4 additions & 12 deletions src/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,18 +503,10 @@ def _update_model_kwargs_for_generation(

return model_kwargs

@staticmethod
def _reorder_cache(past: Tuple[torch.Tensor], beam_idx: torch.Tensor) -> Tuple[torch.Tensor]:
"""
This function is used to re-order the :obj:`past_key_values` or :obj:`mems` cache if
:meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is
called. This is required to match :obj:`past_key_values` or :obj:`mems` with the correct beam_idx at every
generation step.

For custom re-ordering of :obj:`past_key_values` or :obj:`mems`, the function should be implemented in
subclasses of :class:`~transformers.PreTrainedModel`.
"""
return tuple(layer_past.index_select(1, beam_idx.to(layer_past.device)) for layer_past in past)
def _reorder_cache(self, past, beam_idx):
raise NotImplementedError(
f"Make sure that a `_reorder_cache` function is correctly implemented in {self.__class__.__module__} to enable beam search for {self.__class__}"
)

def _get_logits_warper(
self, top_k: int = None, top_p: float = None, temperature: float = None, num_beams: int = None
Expand Down
18 changes: 16 additions & 2 deletions src/transformers/models/ctrl/modeling_ctrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
# limitations under the License.
""" PyTorch CTRL model."""

from typing import Tuple

import numpy as np
import torch
import torch.nn as nn
Expand Down Expand Up @@ -262,7 +264,7 @@ def _init_weights(self, module):
details.

`What are input IDs? <../glossary.html#input-ids>`__
past_key_values (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
past_key_values (:obj:`Tuple[Tuple[torch.FloatTensor]]` of length :obj:`config.n_layers`):
Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see
:obj:`past_key_values` output below). Can be used to speed up sequential decoding. The ``input_ids`` which
have their past given to this model should not be passed as input ids as they have already been computed.
Expand Down Expand Up @@ -389,7 +391,7 @@ def forward(

if past_key_values is None:
past_length = 0
past_key_values = [None] * len(self.h)
past_key_values = tuple([None] * len(self.h))
else:
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
Expand Down Expand Up @@ -575,6 +577,18 @@ def forward(
attentions=transformer_outputs.attentions,
)

@staticmethod
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
"""
This function is used to re-order the :obj:`past_key_values` cache if
:meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
"""
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past
)


@add_start_docstrings(
"""
Expand Down
54 changes: 41 additions & 13 deletions src/transformers/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import os
from dataclasses import dataclass
from typing import List, Optional, Tuple
from typing import Optional, Tuple

import torch
import torch.nn as nn
Expand Down Expand Up @@ -232,7 +232,7 @@ def forward(
value = torch.cat((past_value, value), dim=-2)

if use_cache is True:
present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
present = (key.transpose(-2, -1), value) # transpose to have same shapes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the reason for the recent failure of the slow test:

RUN_SLOW=1 pytest tests/test_onnx.py::OnnxExportTestCase::test_export_pytorch

Can you fix the onnx part easily? @mfuntowicz @Narsil

else:
present = None

Expand Down Expand Up @@ -369,9 +369,9 @@ class GPT2DoubleHeadsModelOutput(ModelOutput):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
mc_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2,
batch_size, num_heads, sequence_length, embed_size_per_head)`).
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
Tuple of length :obj:`config.n_layers`, containing tuples of tensors of shape :obj:`(batch_size, num_heads,
sequence_length, embed_size_per_head)`).

Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
:obj:`past_key_values` input) to speed up sequential decoding.
Expand All @@ -392,7 +392,7 @@ class GPT2DoubleHeadsModelOutput(ModelOutput):
mc_loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
mc_logits: torch.FloatTensor = None
past_key_values: Optional[List[torch.FloatTensor]] = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None

Expand All @@ -418,7 +418,7 @@ class GPT2DoubleHeadsModelOutput(ModelOutput):
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`):
:obj:`input_ids_length` = ``sequence_length`` if :obj:`past_key_values` is ``None`` else
``past_key_values[0].shape[-2]`` (``sequence_length`` of input past key value states). Indices of input
``past_key_values[0][0].shape[-2]`` (``sequence_length`` of input past key value states). Indices of input
sequence tokens in the vocabulary.

If :obj:`past_key_values` is used, only ``input_ids`` that do not have their past calculated should be
Expand All @@ -429,7 +429,7 @@ class GPT2DoubleHeadsModelOutput(ModelOutput):
details.

`What are input IDs? <../glossary.html#input-ids>`__
past_key_values (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers`):
Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
:obj:`past_key_values` output below). Can be used to speed up sequential decoding. The ``input_ids`` which
have their past given to this model should not be passed as ``input_ids`` as they have already been
Expand Down Expand Up @@ -639,7 +639,7 @@ def forward(

if past_key_values is None:
past_length = 0
past_key_values = [None] * len(self.h)
past_key_values = tuple([None] * len(self.h))
else:
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
Expand Down Expand Up @@ -707,7 +707,7 @@ def forward(
torch.cuda.set_device(hidden_states.device)
# Ensure layer_past is on same device as hidden_states (might not be correct)
if layer_past is not None:
layer_past = layer_past.to(hidden_states.device)
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
# Ensure that attention_mask is always on the same device as hidden_states
if attention_mask is not None:
attention_mask = attention_mask.to(hidden_states.device)
Expand All @@ -717,18 +717,22 @@ def forward(
all_hidden_states = all_hidden_states + (hidden_states,)

if getattr(self.config, "gradient_checkpointing", False):
if use_cache:
raise ValueError(
"When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`."
)

def create_custom_forward(module):
def custom_forward(*inputs):
# checkpointing only works with tuple returns, not with lists
return tuple(output for output in module(*inputs, use_cache, output_attentions))
# None for past_key_value
return module(*inputs, use_cache, output_attentions)

return custom_forward

outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
layer_past,
None,
attention_mask,
head_mask[i],
encoder_hidden_states,
Expand Down Expand Up @@ -931,6 +935,18 @@ def forward(
cross_attentions=transformer_outputs.cross_attentions,
)

@staticmethod
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
"""
This function is used to re-order the :obj:`past_key_values` cache if
:meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
"""
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past
)


@add_start_docstrings(
"""
Expand Down Expand Up @@ -1094,6 +1110,18 @@ def forward(
attentions=transformer_outputs.attentions,
)

@staticmethod
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
"""
This function is used to re-order the :obj:`past_key_values` cache if
:meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
"""
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past
)


@add_start_docstrings(
"""
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/models/transfo_xl/modeling_transfo_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1137,6 +1137,15 @@ def _resize_cutoffs(self, new_num_tokens, new_emb_size, new_embedding_shapes, la
self.crit.cutoff_ends = [0] + new_cutoffs
self.crit.n_token = new_num_tokens

@staticmethod
def _reorder_cache(mems: List[torch.Tensor], beam_idx: torch.Tensor) -> List[torch.Tensor]:
"""
This function is used to re-order the :obj:`mems` cache if :meth:`~transformers.PretrainedModel.beam_search` or
:meth:`~transformers.PretrainedModel.beam_sample` is called. This is required to match :obj:`mems` with the
correct beam_idx at every generation step.
"""
return [layer_past.index_select(1, beam_idx.to(layer_past.device)) for layer_past in mems]


@add_start_docstrings(
"""
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/models/xlnet/modeling_xlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1462,6 +1462,15 @@ def forward(
attentions=transformer_outputs.attentions,
)

@staticmethod
def _reorder_cache(mems: List[torch.Tensor], beam_idx: torch.Tensor) -> List[torch.Tensor]:
"""
This function is used to re-order the :obj:`mems` cache if :meth:`~transformers.PretrainedModel.beam_search` or
:meth:`~transformers.PretrainedModel.beam_sample` is called. This is required to match :obj:`mems` with the
correct beam_idx at every generation step.
"""
return [layer_past.index_select(1, beam_idx.to(layer_past.device)) for layer_past in mems]


@add_start_docstrings(
"""
Expand Down
1 change: 1 addition & 0 deletions tests/test_modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def prepare_config_and_inputs(self, gradient_checkpointing=False):
n_ctx=self.max_position_embeddings,
# type_vocab_size=self.type_vocab_size,
# initializer_range=self.initializer_range,
use_cache=not gradient_checkpointing,
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
Expand Down