From 5c96445ab8cc3abf65fe696c69b4f9d390362aac Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 21 Apr 2021 11:56:45 +0200 Subject: [PATCH] Removed `max_length` from being mandatory within `generate`. (#11314) * Removed `max_length` from being mandatory within `generate`. - Moving on to fully using `StoppingCriteria` for `greedy` and `sample` modes. - `max_length` still used for `beam_search` and `group_beam_search` (Follow up PR) - Fixes a bug with MaxLengthStoppingCriteria (we should stop as soon a we hit the max_length, the comparison needs to be or equal, that affects the tests). - Added options to use `logits_processor` and `stopping_criteria` directly within `generate` function (so some users can define their own `logits_processor` and `stopping_criteria`). - Modified the backward compat tests to make sure we issue a warning. * Fix `max_length` argument in `generate`. * Moving validate to being functional. - Renamed `smax_length` to `stoppping_max_length`. * Removing `logits_processor` and `stopping_criteria` from `generate` arguments. * Deepcopy. * Fix global variable name. --- .../generation_stopping_criteria.py | 41 ++++--- src/transformers/generation_utils.py | 108 ++++++++++++------ tests/test_generation_stopping_criteria.py | 11 +- tests/test_generation_utils.py | 44 +++---- 4 files changed, 123 insertions(+), 81 deletions(-) diff --git a/src/transformers/generation_stopping_criteria.py b/src/transformers/generation_stopping_criteria.py index f90a18a56ef1..ab853985240d 100644 --- a/src/transformers/generation_stopping_criteria.py +++ b/src/transformers/generation_stopping_criteria.py @@ -1,6 +1,7 @@ import time import warnings from abc import ABC +from copy import deepcopy from typing import Optional import torch @@ -8,7 +9,7 @@ from .file_utils import add_start_docstrings -LOGITS_PROCESSOR_INPUTS_DOCSTRING = r""" +STOPPING_CRITERIA_INPUTS_DOCSTRING = r""" Args: input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. @@ -33,7 +34,7 @@ class StoppingCriteria(ABC): """Abstract base class for all stopping criteria that can be applied during generation.""" - @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) + @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, score: torch.FloatTensor, **kwargs) -> bool: raise NotImplementedError("StoppingCriteria needs to be subclassed") @@ -51,9 +52,9 @@ class MaxLengthCriteria(StoppingCriteria): def __init__(self, max_length: int): self.max_length = max_length - @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) + @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: - return input_ids.shape[-1] > self.max_length + return input_ids.shape[-1] >= self.max_length class MaxTimeCriteria(StoppingCriteria): @@ -73,25 +74,29 @@ def __init__(self, max_time: float, initial_timestamp: Optional[float] = None): self.max_time = max_time self.initial_timestamp = time.time() if initial_timestamp is None else initial_timestamp - @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) + @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: return time.time() - self.initial_timestamp > self.max_time class StoppingCriteriaList(list): - @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) + @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: return any(criteria(input_ids, scores) for criteria in self) - -def validate_stopping_criteria(stopping_criteria: StoppingCriteriaList, max_length: int): - found = False - for stopping_criterium in stopping_criteria: - if isinstance(stopping_criterium, MaxLengthCriteria): - found = True - if stopping_criterium.max_length != max_length: - warnings.warn( - "You set different `max_length` for stopping criteria and `max_length` parameter", UserWarning - ) - if not found: - stopping_criteria.append(MaxLengthCriteria(max_length=max_length)) + @property + def max_length(self) -> Optional[int]: + for stopping_criterium in self: + if isinstance(stopping_criterium, MaxLengthCriteria): + return stopping_criterium.max_length + return None + + +def validate_stopping_criteria(stopping_criteria: StoppingCriteriaList, max_length: int) -> StoppingCriteriaList: + stopping_max_length = stopping_criteria.max_length + new_stopping_criteria = deepcopy(stopping_criteria) + if stopping_max_length is not None and stopping_max_length != max_length: + warnings.warn("You set different `max_length` for stopping criteria and `max_length` parameter", UserWarning) + elif stopping_max_length is None: + new_stopping_criteria.append(MaxLengthCriteria(max_length=max_length)) + return new_stopping_criteria diff --git a/src/transformers/generation_utils.py b/src/transformers/generation_utils.py index 09f00dd88720..165fc4aa1222 100644 --- a/src/transformers/generation_utils.py +++ b/src/transformers/generation_utils.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from dataclasses import dataclass from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union @@ -564,6 +565,7 @@ def _get_logits_processor( This class returns a :obj:`~transformers.LogitsProcessorList` list object that contains all relevant :obj:`~transformers.LogitsProcessor` instances used to modify the scores of the language model head. """ + processors = LogitsProcessorList() # init warp parameters repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty @@ -589,7 +591,6 @@ def _get_logits_processor( remove_invalid_values if remove_invalid_values is not None else self.config.remove_invalid_values ) # instantiate processors list - processors = LogitsProcessorList() # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files # all samplers can be found in `generation_utils_samplers.py` @@ -629,7 +630,6 @@ def _get_stopping_criteria( max_length: Optional[int], max_time: Optional[float], ) -> StoppingCriteriaList: - stopping_criteria = StoppingCriteriaList() if max_length is not None: stopping_criteria.append(MaxLengthCriteria(max_length=max_length)) @@ -859,9 +859,9 @@ def generate( """ # set init values + max_length = max_length if max_length is not None else self.config.max_length num_beams = num_beams if num_beams is not None else self.config.num_beams num_beam_groups = num_beam_groups if num_beam_groups is not None else self.config.num_beam_groups - max_length = max_length if max_length is not None else self.config.max_length do_sample = do_sample if do_sample is not None else self.config.do_sample num_return_sequences = ( num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences @@ -958,10 +958,13 @@ def generate( remove_invalid_values=remove_invalid_values, ) - stopping_criteria = self._get_stopping_criteria( - max_length=max_length, - max_time=max_time, - ) + stopping_criteria = self._get_stopping_criteria(max_length=max_length, max_time=max_time) + if max_length is not None: + warnings.warn( + "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", + UserWarning, + ) + stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) if is_greedy_gen_mode: if num_return_sequences > 1: @@ -974,7 +977,6 @@ def generate( input_ids, logits_processor=logits_processor, stopping_criteria=stopping_criteria, - max_length=max_length, pad_token_id=pad_token_id, eos_token_id=eos_token_id, output_scores=output_scores, @@ -1003,7 +1005,6 @@ def generate( logits_processor=logits_processor, logits_warper=logits_warper, stopping_criteria=stopping_criteria, - max_length=max_length, pad_token_id=pad_token_id, eos_token_id=eos_token_id, output_scores=output_scores, @@ -1021,9 +1022,12 @@ def generate( if num_return_sequences > num_beams: raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.") + if stopping_criteria.max_length is None: + raise ValueError("`max_length` needs to be a stopping_criteria for now.") + beam_scorer = BeamSearchScorer( batch_size=batch_size, - max_length=max_length, + max_length=stopping_criteria.max_length, num_beams=num_beams, device=self.device, length_penalty=length_penalty, @@ -1039,7 +1043,6 @@ def generate( beam_scorer, logits_processor=logits_processor, stopping_criteria=stopping_criteria, - max_length=max_length, pad_token_id=pad_token_id, eos_token_id=eos_token_id, output_scores=output_scores, @@ -1056,9 +1059,11 @@ def generate( batch_size = input_ids.shape[0] * num_return_sequences length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty + if stopping_criteria.max_length is None: + raise ValueError("`max_length` needs to be a stopping_criteria for now.") beam_scorer = BeamSearchScorer( batch_size=batch_size, - max_length=max_length, + max_length=stopping_criteria.max_length, num_beams=num_beams, device=self.device, length_penalty=length_penalty, @@ -1079,7 +1084,6 @@ def generate( logits_processor=logits_processor, logits_warper=logits_warper, stopping_criteria=stopping_criteria, - max_length=max_length, pad_token_id=pad_token_id, eos_token_id=eos_token_id, output_scores=output_scores, @@ -1100,10 +1104,13 @@ def generate( if num_beams % num_beam_groups != 0: raise ValueError("`num_beams` should be divisible by `num_beam_groups` for group beam search.") + if stopping_criteria.max_length is None: + raise ValueError("`max_length` needs to be a stopping_criteria for now.") + diverse_beam_scorer = BeamSearchScorer( batch_size=batch_size, - max_length=max_length, num_beams=num_beams, + max_length=stopping_criteria.max_length, device=self.device, length_penalty=length_penalty, do_early_stopping=early_stopping, @@ -1119,7 +1126,6 @@ def generate( diverse_beam_scorer, logits_processor=logits_processor, stopping_criteria=stopping_criteria, - max_length=max_length, pad_token_id=pad_token_id, eos_token_id=eos_token_id, output_scores=output_scores, @@ -1160,7 +1166,8 @@ def greedy_search( :class:`~transformers.StoppingCriteria` used to tell if the generation loop should stop. max_length (:obj:`int`, `optional`, defaults to 20): - The maximum length of the sequence to be generated. + **DEPRECATED**. Use :obj:`logits_processor` or :obj:`stopping_criteria` directly to cap the number of + generated tokens. The maximum length of the sequence to be generated. pad_token_id (:obj:`int`, `optional`): The id of the `padding` token. eos_token_id (:obj:`int`, `optional`): @@ -1220,8 +1227,12 @@ def greedy_search( # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - max_length = max_length if max_length is not None else self.config.max_length - validate_stopping_criteria(stopping_criteria, max_length) + if max_length is not None: + warnings.warn( + "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", + UserWarning, + ) + stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id output_scores = output_scores if output_scores is not None else self.config.output_scores @@ -1251,7 +1262,7 @@ def greedy_search( cur_len = input_ids.shape[-1] this_peer_finished = False # used by synced_gpus only - while cur_len < max_length: + while True: if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. @@ -1384,7 +1395,8 @@ def sample( :class:`~transformers.LogitsWarper` used to warp the prediction score distribution of the language modeling head applied before multinomial sampling at each generation step. max_length (:obj:`int`, `optional`, defaults to 20): - The maximum length of the sequence to be generated. + **DEPRECATED**. Use :obj:`logits_processor` or :obj:`stopping_criteria` directly to cap the number of + generated tokens. The maximum length of the sequence to be generated. pad_token_id (:obj:`int`, `optional`): The id of the `padding` token. eos_token_id (:obj:`int`, `optional`): @@ -1452,8 +1464,12 @@ def sample( # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - max_length = max_length if max_length is not None else self.config.max_length - validate_stopping_criteria(stopping_criteria, max_length) + if max_length is not None: + warnings.warn( + "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", + UserWarning, + ) + stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id @@ -1485,7 +1501,7 @@ def sample( this_peer_finished = False # used by synced_gpus only # auto-regressive generation - while cur_len < max_length: + while True: if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. @@ -1620,7 +1636,8 @@ def beam_search( An instance of :class:`~transformers.StoppingCriteriaList`. List of instances of class derived from :class:`~transformers.StoppingCriteria` used to tell if the generation loop should stop. max_length (:obj:`int`, `optional`, defaults to 20): - The maximum length of the sequence to be generated. + **DEPRECATED**. Use :obj:`logits_processor` or :obj:`stopping_criteria` directly to cap the number of + generated tokens. The maximum length of the sequence to be generated. pad_token_id (:obj:`int`, `optional`): The id of the `padding` token. eos_token_id (:obj:`int`, `optional`): @@ -1700,8 +1717,14 @@ def beam_search( # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - max_length = max_length if max_length is not None else self.config.max_length - validate_stopping_criteria(stopping_criteria, max_length) + if max_length is not None: + warnings.warn( + "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", + UserWarning, + ) + stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) + if len(stopping_criteria) == 0: + warnings.warn("You don't have defined any stopping_criteria, this will likely loop forever", UserWarning) pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id output_scores = output_scores if output_scores is not None else self.config.output_scores @@ -1740,7 +1763,7 @@ def beam_search( beam_scores = beam_scores.view((batch_size * num_beams,)) this_peer_finished = False # used by synced_gpus only - while cur_len < max_length: + while True: if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. @@ -1770,7 +1793,7 @@ def beam_search( # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` # cannot be generated both before and after the `F.log_softmax` operation. next_token_logits = self.adjust_logits_during_generation( - next_token_logits, cur_len=cur_len, max_length=max_length + next_token_logits, cur_len=cur_len, max_length=None ) next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) @@ -1907,7 +1930,8 @@ def beam_sample( :class:`~transformers.LogitsWarper` used to warp the prediction score distribution of the language modeling head applied before multinomial sampling at each generation step. max_length (:obj:`int`, `optional`, defaults to 20): - The maximum length of the sequence to be generated. + **DEPRECATED**. Use :obj:`logits_processor` or :obj:`stopping_criteria` directly to cap the number of + generated tokens. The maximum length of the sequence to be generated. pad_token_id (:obj:`int`, `optional`): The id of the `padding` token. eos_token_id (:obj:`int`, `optional`): @@ -1994,7 +2018,12 @@ def beam_sample( # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - max_length = max_length if max_length is not None else self.config.max_length + if max_length is not None: + warnings.warn( + "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", + UserWarning, + ) + stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id output_scores = output_scores if output_scores is not None else self.config.output_scores @@ -2028,7 +2057,7 @@ def beam_sample( beam_scores = beam_scores.view((batch_size * num_beams,)) this_peer_finished = False # used by synced_gpus only - while cur_len < max_length: + while True: if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. @@ -2058,7 +2087,7 @@ def beam_sample( # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` # cannot be generated both before and after the `F.log_softmax` operation. next_token_logits = self.adjust_logits_during_generation( - next_token_logits, cur_len=cur_len, max_length=max_length + next_token_logits, cur_len=cur_len, max_length=None ) next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) @@ -2195,7 +2224,8 @@ def group_beam_search( An instance of :class:`~transformers.StoppingCriteriaList`. List of instances of class derived from :class:`~transformers.StoppingCriteria` used to tell if the generation loop should stop. max_length (:obj:`int`, `optional`, defaults to 20): - The maximum length of the sequence to be generated. + **DEPRECATED**. Use :obj:`logits_processor` or :obj:`stopping_criteria` directly to cap the number of + generated tokens. The maximum length of the sequence to be generated. pad_token_id (:obj:`int`, `optional`): The id of the `padding` token. eos_token_id (:obj:`int`, `optional`): @@ -2279,8 +2309,12 @@ def group_beam_search( # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() - max_length = max_length if max_length is not None else self.config.max_length - validate_stopping_criteria(stopping_criteria, max_length) + if max_length is not None: + warnings.warn( + "`max_length` is deprecated in this function, use `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", + UserWarning, + ) + stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id output_scores = output_scores if output_scores is not None else self.config.output_scores @@ -2324,7 +2358,7 @@ def group_beam_search( beam_scores = beam_scores.view((batch_size * num_beams,)) this_peer_finished = False # used by synced_gpus only - while cur_len < max_length: + while True: if synced_gpus: # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. @@ -2378,7 +2412,7 @@ def group_beam_search( # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` # cannot be generated both before and after the `F.log_softmax` operation. next_token_logits = self.adjust_logits_during_generation( - next_token_logits, cur_len=cur_len, max_length=max_length + next_token_logits, cur_len=cur_len, max_length=None ) next_token_scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * group_size, vocab_size) diff --git a/tests/test_generation_stopping_criteria.py b/tests/test_generation_stopping_criteria.py index 7cbdbce1425a..995ea97736e0 100644 --- a/tests/test_generation_stopping_criteria.py +++ b/tests/test_generation_stopping_criteria.py @@ -40,10 +40,10 @@ def test_list_criteria(self): self.assertFalse(criteria(input_ids, scores)) - input_ids, scores = self._get_tensors(10) + input_ids, scores = self._get_tensors(9) self.assertFalse(criteria(input_ids, scores)) - input_ids, scores = self._get_tensors(11) + input_ids, scores = self._get_tensors(10) self.assertTrue(criteria(input_ids, scores)) def test_max_length_criteria(self): @@ -52,10 +52,10 @@ def test_max_length_criteria(self): input_ids, scores = self._get_tensors(5) self.assertFalse(criteria(input_ids, scores)) - input_ids, scores = self._get_tensors(10) + input_ids, scores = self._get_tensors(9) self.assertFalse(criteria(input_ids, scores)) - input_ids, scores = self._get_tensors(11) + input_ids, scores = self._get_tensors(10) self.assertTrue(criteria(input_ids, scores)) def test_max_time_criteria(self): @@ -73,7 +73,6 @@ def test_validate_stopping_criteria(self): with self.assertWarns(UserWarning): validate_stopping_criteria(StoppingCriteriaList([MaxLengthCriteria(10)]), 11) - stopping_criteria = StoppingCriteriaList() - validate_stopping_criteria(stopping_criteria, 11) + stopping_criteria = validate_stopping_criteria(StoppingCriteriaList(), 11) self.assertEqual(len(stopping_criteria), 1) diff --git a/tests/test_generation_utils.py b/tests/test_generation_utils.py index 6b84a42e07fb..42c44b8c54e8 100644 --- a/tests/test_generation_utils.py +++ b/tests/test_generation_utils.py @@ -1358,13 +1358,14 @@ def test_max_length_backward_compat_greedy(self): bos_token_id=bart_model.config.bos_token_id, ) - bart_model.greedy_search( - input_ids, - max_length=max_length, - pad_token_id=bart_model.config.pad_token_id, - eos_token_id=bart_model.config.eos_token_id, - **model_kwargs, - ) + with self.assertWarns(UserWarning): + bart_model.greedy_search( + input_ids, + max_length=max_length, + pad_token_id=bart_model.config.pad_token_id, + eos_token_id=bart_model.config.eos_token_id, + **model_kwargs, + ) def test_max_length_backward_compat_sample(self): article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" @@ -1381,13 +1382,14 @@ def test_max_length_backward_compat_sample(self): bos_token_id=bart_model.config.bos_token_id, ) with torch.no_grad(): - bart_model.sample( - input_ids, - max_length=max_length, - pad_token_id=bart_model.config.pad_token_id, - eos_token_id=bart_model.config.eos_token_id, - **model_kwargs, - ) + with self.assertWarns(UserWarning): + bart_model.sample( + input_ids, + max_length=max_length, + pad_token_id=bart_model.config.pad_token_id, + eos_token_id=bart_model.config.eos_token_id, + **model_kwargs, + ) def test_max_length_backward_compat_beam_search(self): article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" @@ -1413,9 +1415,10 @@ def test_max_length_backward_compat_beam_search(self): num_beams=num_beams, device=torch_device, ) - _ = bart_model.beam_search( - input_ids, num_beams=num_beams, max_length=max_length, beam_scorer=beam_scorer, **model_kwargs - ) + with self.assertWarns(UserWarning): + _ = bart_model.beam_search( + input_ids, num_beams=num_beams, max_length=max_length, beam_scorer=beam_scorer, **model_kwargs + ) def test_max_length_backward_compat_group_beam_search(self): article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" @@ -1445,9 +1448,10 @@ def test_max_length_backward_compat_group_beam_search(self): num_beam_hyps_to_keep=num_return_sequences, num_beam_groups=num_beam_groups, ) - bart_model.group_beam_search( - input_ids, diverse_beam_scorer, num_beams=num_beams, max_length=max_length, **model_kwargs - ) + with self.assertWarns(UserWarning): + bart_model.group_beam_search( + input_ids, diverse_beam_scorer, num_beams=num_beams, max_length=max_length, **model_kwargs + ) def test_max_length_warning_if_different(self): article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""