Skip to content

Commit

Permalink
Removed max_length from being mandatory within generate. (#11314)
Browse files Browse the repository at this point in the history
* Removed `max_length` from being mandatory within `generate`.

- Moving on to fully using `StoppingCriteria` for `greedy` and `sample`
modes.
- `max_length` still used for `beam_search` and `group_beam_search`
(Follow up PR)
- Fixes a bug with MaxLengthStoppingCriteria (we should stop as soon a
we hit the max_length, the comparison needs to be or equal, that affects
the tests).
- Added options to use `logits_processor` and `stopping_criteria`
directly within `generate` function (so some users can define their own
`logits_processor` and `stopping_criteria`).
- Modified the backward compat tests to make sure we issue a warning.

* Fix `max_length` argument in `generate`.

* Moving validate to being functional.

- Renamed `smax_length` to `stoppping_max_length`.

* Removing `logits_processor` and `stopping_criteria` from `generate`
arguments.

* Deepcopy.

* Fix global variable name.
  • Loading branch information
Narsil authored and Rocketknight1 committed Apr 21, 2021
1 parent f6bbd3a commit 5c96445
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 81 deletions.
41 changes: 23 additions & 18 deletions src/transformers/generation_stopping_criteria.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import time
import warnings
from abc import ABC
from copy import deepcopy
from typing import Optional

import torch

from .file_utils import add_start_docstrings


LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
STOPPING_CRITERIA_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
Expand All @@ -33,7 +34,7 @@
class StoppingCriteria(ABC):
"""Abstract base class for all stopping criteria that can be applied during generation."""

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, score: torch.FloatTensor, **kwargs) -> bool:
raise NotImplementedError("StoppingCriteria needs to be subclassed")

Expand All @@ -51,9 +52,9 @@ class MaxLengthCriteria(StoppingCriteria):
def __init__(self, max_length: int):
self.max_length = max_length

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return input_ids.shape[-1] > self.max_length
return input_ids.shape[-1] >= self.max_length


class MaxTimeCriteria(StoppingCriteria):
Expand All @@ -73,25 +74,29 @@ def __init__(self, max_time: float, initial_timestamp: Optional[float] = None):
self.max_time = max_time
self.initial_timestamp = time.time() if initial_timestamp is None else initial_timestamp

@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return time.time() - self.initial_timestamp > self.max_time


class StoppingCriteriaList(list):
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
return any(criteria(input_ids, scores) for criteria in self)


def validate_stopping_criteria(stopping_criteria: StoppingCriteriaList, max_length: int):
found = False
for stopping_criterium in stopping_criteria:
if isinstance(stopping_criterium, MaxLengthCriteria):
found = True
if stopping_criterium.max_length != max_length:
warnings.warn(
"You set different `max_length` for stopping criteria and `max_length` parameter", UserWarning
)
if not found:
stopping_criteria.append(MaxLengthCriteria(max_length=max_length))
@property
def max_length(self) -> Optional[int]:
for stopping_criterium in self:
if isinstance(stopping_criterium, MaxLengthCriteria):
return stopping_criterium.max_length
return None


def validate_stopping_criteria(stopping_criteria: StoppingCriteriaList, max_length: int) -> StoppingCriteriaList:
stopping_max_length = stopping_criteria.max_length
new_stopping_criteria = deepcopy(stopping_criteria)
if stopping_max_length is not None and stopping_max_length != max_length:
warnings.warn("You set different `max_length` for stopping criteria and `max_length` parameter", UserWarning)
elif stopping_max_length is None:
new_stopping_criteria.append(MaxLengthCriteria(max_length=max_length))
return new_stopping_criteria
Loading

0 comments on commit 5c96445

Please sign in to comment.