diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 6fcf9e5babe8..3f933940c69d 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -503,18 +503,10 @@ def _update_model_kwargs_for_generation( return model_kwargs - @staticmethod - def _reorder_cache(past: Tuple[torch.Tensor], beam_idx: torch.Tensor) -> Tuple[torch.Tensor]: - """ - This function is used to re-order the :obj:`past_key_values` or :obj:`mems` cache if - :meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is - called. This is required to match :obj:`past_key_values` or :obj:`mems` with the correct beam_idx at every - generation step. - - For custom re-ordering of :obj:`past_key_values` or :obj:`mems`, the function should be implemented in - subclasses of :class:`~transformers.PreTrainedModel`. - """ - return tuple(layer_past.index_select(1, beam_idx.to(layer_past.device)) for layer_past in past) + def _reorder_cache(self, past, beam_idx): + raise NotImplementedError( + f"Make sure that a `_reorder_cache` function is correctly implemented in {self.__class__.__module__} to enable beam search for {self.__class__}" + ) def _get_logits_warper( self, top_k: int = None, top_p: float = None, temperature: float = None, num_beams: int = None diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 3c8b56e69e83..84dc51d97f75 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -721,7 +721,7 @@ def forward( if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None) else: - if getattr(self.config, "gradient_checkpointing", False): + if getattr(self.config, "gradient_checkpointing", False) and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -913,11 +913,13 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False): + if getattr(self.config, "gradient_checkpointing", False) and self.training: + if use_cache: - raise ValueError( - "When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`." + logger.warn( + "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." ) + use_cache = False def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 590363a124d0..72c795e62195 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -539,7 +539,14 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False): + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + if use_cache: + logger.warn( + "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." + ) + use_cache = False def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 4a79aa86a4ba..5d393d493421 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -680,7 +680,7 @@ def forward( if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None) else: - if getattr(self.config, "gradient_checkpointing", False): + if getattr(self.config, "gradient_checkpointing", False) and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -875,11 +875,13 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False): + if getattr(self.config, "gradient_checkpointing", False) and self.training: + if use_cache: - raise ValueError( - "When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`." + logger.warn( + "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." ) + use_cache = False def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 36cff4c41713..60cb4835707f 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -682,7 +682,7 @@ def forward( if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None) else: - if getattr(self.config, "gradient_checkpointing", False): + if getattr(self.config, "gradient_checkpointing", False) and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -877,11 +877,13 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False): + if getattr(self.config, "gradient_checkpointing", False) and self.training: + if use_cache: - raise ValueError( - "When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`." + logger.warn( + "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." ) + use_cache = False def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/transformers/models/ctrl/modeling_ctrl.py b/src/transformers/models/ctrl/modeling_ctrl.py index 76f0402d77e8..da5009540885 100644 --- a/src/transformers/models/ctrl/modeling_ctrl.py +++ b/src/transformers/models/ctrl/modeling_ctrl.py @@ -15,6 +15,8 @@ # limitations under the License. """ PyTorch CTRL model.""" +from typing import Tuple + import numpy as np import torch import torch.nn as nn @@ -262,7 +264,7 @@ def _init_weights(self, module): details. `What are input IDs? <../glossary.html#input-ids>`__ - past_key_values (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`): + past_key_values (:obj:`Tuple[Tuple[torch.FloatTensor]]` of length :obj:`config.n_layers`): Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see :obj:`past_key_values` output below). Can be used to speed up sequential decoding. The ``input_ids`` which have their past given to this model should not be passed as input ids as they have already been computed. @@ -389,7 +391,7 @@ def forward( if past_key_values is None: past_length = 0 - past_key_values = [None] * len(self.h) + past_key_values = tuple([None] * len(self.h)) else: past_length = past_key_values[0][0].size(-2) if position_ids is None: @@ -575,6 +577,18 @@ def forward( attentions=transformer_outputs.attentions, ) + @staticmethod + def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the :obj:`past_key_values` cache if + :meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is + called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past + ) + @add_start_docstrings( """ diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index fcab923cd918..e91220f29b1f 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -535,7 +535,14 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False): + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + if use_cache: + logger.warn( + "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." + ) + use_cache = False def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index cc2f55709fc0..9bb40c0a8764 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -17,7 +17,7 @@ import os from dataclasses import dataclass -from typing import List, Optional, Tuple +from typing import Optional, Tuple import torch import torch.nn as nn @@ -232,7 +232,7 @@ def forward( value = torch.cat((past_value, value), dim=-2) if use_cache is True: - present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking + present = (key.transpose(-2, -1), value) # transpose to have same shapes else: present = None @@ -369,9 +369,9 @@ class GPT2DoubleHeadsModelOutput(ModelOutput): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). mc_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`): Prediction scores of the multiple choice classification head (scores for each choice before SoftMax). - past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): - List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2, - batch_size, num_heads, sequence_length, embed_size_per_head)`). + past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): + Tuple of length :obj:`config.n_layers`, containing tuples of tensors of shape :obj:`(batch_size, num_heads, + sequence_length, embed_size_per_head)`). Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see :obj:`past_key_values` input) to speed up sequential decoding. @@ -392,7 +392,7 @@ class GPT2DoubleHeadsModelOutput(ModelOutput): mc_loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None mc_logits: torch.FloatTensor = None - past_key_values: Optional[List[torch.FloatTensor]] = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None @@ -418,7 +418,7 @@ class GPT2DoubleHeadsModelOutput(ModelOutput): Args: input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`): :obj:`input_ids_length` = ``sequence_length`` if :obj:`past_key_values` is ``None`` else - ``past_key_values[0].shape[-2]`` (``sequence_length`` of input past key value states). Indices of input + ``past_key_values[0][0].shape[-2]`` (``sequence_length`` of input past key value states). Indices of input sequence tokens in the vocabulary. If :obj:`past_key_values` is used, only ``input_ids`` that do not have their past calculated should be @@ -429,7 +429,7 @@ class GPT2DoubleHeadsModelOutput(ModelOutput): details. `What are input IDs? <../glossary.html#input-ids>`__ - past_key_values (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`): + past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers`): Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see :obj:`past_key_values` output below). Can be used to speed up sequential decoding. The ``input_ids`` which have their past given to this model should not be passed as ``input_ids`` as they have already been @@ -639,7 +639,7 @@ def forward( if past_key_values is None: past_length = 0 - past_key_values = [None] * len(self.h) + past_key_values = tuple([None] * len(self.h)) else: past_length = past_key_values[0][0].size(-2) if position_ids is None: @@ -707,7 +707,7 @@ def forward( torch.cuda.set_device(hidden_states.device) # Ensure layer_past is on same device as hidden_states (might not be correct) if layer_past is not None: - layer_past = layer_past.to(hidden_states.device) + layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) # Ensure that attention_mask is always on the same device as hidden_states if attention_mask is not None: attention_mask = attention_mask.to(hidden_states.device) @@ -716,19 +716,25 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if getattr(self.config, "gradient_checkpointing", False): + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + if use_cache: + logger.warn( + "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." + ) + use_cache = False def create_custom_forward(module): def custom_forward(*inputs): - # checkpointing only works with tuple returns, not with lists - return tuple(output for output in module(*inputs, use_cache, output_attentions)) + # None for past_key_value + return module(*inputs, use_cache, output_attentions) return custom_forward outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, - layer_past, + None, attention_mask, head_mask[i], encoder_hidden_states, @@ -931,6 +937,18 @@ def forward( cross_attentions=transformer_outputs.cross_attentions, ) + @staticmethod + def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the :obj:`past_key_values` cache if + :meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is + called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past + ) + @add_start_docstrings( """ @@ -1094,6 +1112,18 @@ def forward( attentions=transformer_outputs.attentions, ) + @staticmethod + def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the :obj:`past_key_values` cache if + :meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is + called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past + ) + @add_start_docstrings( """ diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py index d611336fbfce..ef12bc6fa7c7 100644 --- a/src/transformers/models/layoutlm/modeling_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_layoutlm.py @@ -465,7 +465,14 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False): + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + if use_cache: + logger.warn( + "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." + ) + use_cache = False def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index 098a54818cd3..05961787044b 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -1694,7 +1694,7 @@ def forward( if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None, None) else: - if getattr(self.config, "gradient_checkpointing", False): + if getattr(self.config, "gradient_checkpointing", False) and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -1919,11 +1919,13 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False): + if getattr(self.config, "gradient_checkpointing", False) and self.training: + if use_cache: - raise ValueError( - "When using `gradient_checkpointing`, make sure that `use_cache=False` and `config.use_cache=False`." + logger.warn( + "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." ) + use_cache = False def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/transformers/models/longformer/modeling_longformer.py b/src/transformers/models/longformer/modeling_longformer.py index 8ffaab5d794b..297b9a29a6b1 100755 --- a/src/transformers/models/longformer/modeling_longformer.py +++ b/src/transformers/models/longformer/modeling_longformer.py @@ -1225,7 +1225,7 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if getattr(self.config, "gradient_checkpointing", False): + if getattr(self.config, "gradient_checkpointing", False) and self.training: def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 6db24068f88b..dd630c2e39a4 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -689,7 +689,7 @@ def forward( if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None) else: - if getattr(self.config, "gradient_checkpointing", False): + if getattr(self.config, "gradient_checkpointing", False) and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -878,11 +878,13 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False): + if getattr(self.config, "gradient_checkpointing", False) and self.training: + if use_cache: - raise ValueError( - "When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`." + logger.warn( + "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." ) + use_cache = False def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 147fd5d7c077..3c744106befb 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -727,7 +727,7 @@ def forward( if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None) else: - if getattr(self.config, "gradient_checkpointing", False): + if getattr(self.config, "gradient_checkpointing", False) and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -922,11 +922,13 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False): + if getattr(self.config, "gradient_checkpointing", False) and self.training: + if use_cache: - raise ValueError( - "When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`." + logger.warn( + "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." ) + use_cache = False def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 583cec288e48..78c3d639c49c 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -693,7 +693,7 @@ def forward( if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None) else: - if getattr(self.config, "gradient_checkpointing", False): + if getattr(self.config, "gradient_checkpointing", False) and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -886,11 +886,13 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False): + if getattr(self.config, "gradient_checkpointing", False) and self.training: + if use_cache: - raise ValueError( - "When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`." + logger.warn( + "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." ) + use_cache = False def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index 2f648d3dc081..f7a9c97964ab 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -478,7 +478,14 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False): + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + if use_cache: + logger.warn( + "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." + ) + use_cache = False def create_custom_forward(module): def custom_forward(*inputs): diff --git a/src/transformers/models/transfo_xl/modeling_transfo_xl.py b/src/transformers/models/transfo_xl/modeling_transfo_xl.py index 3e579566f6bd..01d5cf0454fe 100644 --- a/src/transformers/models/transfo_xl/modeling_transfo_xl.py +++ b/src/transformers/models/transfo_xl/modeling_transfo_xl.py @@ -1137,6 +1137,15 @@ def _resize_cutoffs(self, new_num_tokens, new_emb_size, new_embedding_shapes, la self.crit.cutoff_ends = [0] + new_cutoffs self.crit.n_token = new_num_tokens + @staticmethod + def _reorder_cache(mems: List[torch.Tensor], beam_idx: torch.Tensor) -> List[torch.Tensor]: + """ + This function is used to re-order the :obj:`mems` cache if :meth:`~transformers.PretrainedModel.beam_search` or + :meth:`~transformers.PretrainedModel.beam_sample` is called. This is required to match :obj:`mems` with the + correct beam_idx at every generation step. + """ + return [layer_past.index_select(1, beam_idx.to(layer_past.device)) for layer_past in mems] + @add_start_docstrings( """ diff --git a/src/transformers/models/xlnet/modeling_xlnet.py b/src/transformers/models/xlnet/modeling_xlnet.py index cf8d67695cb2..e873996d36ab 100755 --- a/src/transformers/models/xlnet/modeling_xlnet.py +++ b/src/transformers/models/xlnet/modeling_xlnet.py @@ -1462,6 +1462,15 @@ def forward( attentions=transformer_outputs.attentions, ) + @staticmethod + def _reorder_cache(mems: List[torch.Tensor], beam_idx: torch.Tensor) -> List[torch.Tensor]: + """ + This function is used to re-order the :obj:`mems` cache if :meth:`~transformers.PretrainedModel.beam_search` or + :meth:`~transformers.PretrainedModel.beam_sample` is called. This is required to match :obj:`mems` with the + correct beam_idx at every generation step. + """ + return [layer_past.index_select(1, beam_idx.to(layer_past.device)) for layer_past in mems] + @add_start_docstrings( """ diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py index 5a12c9b79588..9148157b7d83 100755 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py @@ -526,7 +526,7 @@ def forward( layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False): + if getattr(self.config, "gradient_checkpointing", False) and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -2182,7 +2182,7 @@ def forward( if self.training and (dropout_probability < self.layerdrop): # skip the layer layer_outputs = (None, None) else: - if getattr(self.config, "gradient_checkpointing", False): + if getattr(self.config, "gradient_checkpointing", False) and self.training: def create_custom_forward(module): def custom_forward(*inputs): @@ -2374,11 +2374,11 @@ def forward( past_key_value = past_key_values[idx] if past_key_values is not None else None - if getattr(self.config, "gradient_checkpointing", False): + if getattr(self.config, "gradient_checkpointing", False) and self.training: + if use_cache: - raise ValueError( - "When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`." - ) + logger.warn("`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`...") + use_cache = False def create_custom_forward(module): def custom_forward(*inputs): diff --git a/tests/test_modeling_gpt2.py b/tests/test_modeling_gpt2.py index f60fa45bcbba..bf7049296721 100644 --- a/tests/test_modeling_gpt2.py +++ b/tests/test_modeling_gpt2.py @@ -131,6 +131,7 @@ def prepare_config_and_inputs(self, gradient_checkpointing=False): n_ctx=self.max_position_embeddings, # type_vocab_size=self.type_vocab_size, # initializer_range=self.initializer_range, + use_cache=not gradient_checkpointing, bos_token_id=self.bos_token_id, eos_token_id=self.eos_token_id, pad_token_id=self.pad_token_id,