Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Adding a min_steps parameter to BeamSearch #5207

Merged
merged 9 commits into from
May 17, 2021
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
You can do this by setting the parameter `load_weights` to `False`.
See [PR #5172](https://github.com/allenai/allennlp/pull/5172) for more details.
- Added `SpanExtractorWithSpanWidthEmbedding`, putting specific span embedding computations into the `_embed_spans` method and leaving the common code in `SpanExtractorWithSpanWidthEmbedding` to unify the arguments, and modified `BidirectionalEndpointSpanExtractor`, `EndpointSpanExtractor` and `SelfAttentiveSpanExtractor` accordingly. Now, `SelfAttentiveSpanExtractor` can also embed span widths.
- Added a `min_steps` parameter to `BeamSearch` to set a minimum length for the predicted sequences.

### Fixed

Expand Down
21 changes: 21 additions & 0 deletions allennlp/nn/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,10 @@ class BeamSearch(FromParams):

Using the [`GumbelSampler`](#gumbelsampler), on the other hand, will give you
[Stochastic Beam Search](https://api.semanticscholar.org/CorpusID:76662039).

min_steps : `int`, optional (default = `0`)
The minimum number of decoding steps to take, i.e. the minimum length of
the predicted sequences. This does not include the start or end tokens.
"""

def __init__(
Expand All @@ -471,19 +475,25 @@ def __init__(
beam_size: int = 10,
per_node_beam_size: int = None,
sampler: Sampler = None,
min_steps: int = 0,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think making this an Optional[int] is more clear.

Suggested change
min_steps: int = 0,
min_steps: Optional[int] = None,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I pushed this change. I had to change the code to make sure the value is valid, too.

) -> None:
if not max_steps > 0:
raise ValueError("max_steps must be positive")
if not beam_size > 0:
raise ValueError("beam_size must be positive")
if per_node_beam_size is not None and not per_node_beam_size > 0:
raise ValueError("per_node_beam_size must be positive")
if not min_steps >= 0:
raise ValueError("min_steps must be non-negative")
if not min_steps <= max_steps:
raise ValueError("min_steps must be less than or equal to max_steps")

self._end_index = end_index
self.max_steps = max_steps
self.beam_size = beam_size
self.per_node_beam_size = per_node_beam_size or beam_size
self.sampler = sampler or DeterministicSampler()
self.min_steps = min_steps

@staticmethod
def _reconstruct_sequences(predictions, backpointers):
Expand Down Expand Up @@ -629,6 +639,10 @@ def _search(
start_class_log_probabilities, batch_size, num_classes
)

# Prevent selecting the end symbol if there is any min_steps constraint
if self.min_steps >= 1:
start_class_log_probabilities[:, self._end_index] = float("-inf")

# Get the initial predicted classed and their log probabilities.
# shape: (batch_size, beam_size), (batch_size, beam_size)
(
Expand Down Expand Up @@ -675,6 +689,13 @@ def _search(
# shape: (batch_size * beam_size, num_classes)
class_log_probabilities, state = step(last_predictions, state, timestep + 1)

# The `timestep`-th iteration of the for loop is generating the `timestep + 2`-th token
# of the sequence (because `timestep` is 0-indexed and we generated the first token
# before the for loop). Here we block the end index if the search is not allowed to
# terminate on this iteration.
if timestep + 2 <= self.min_steps:
class_log_probabilities[:, self._end_index] = float("-inf")

# shape: (batch_size * beam_size, num_classes)
last_predictions_expanded = last_predictions.unsqueeze(-1).expand(
batch_size * self.beam_size, num_classes
Expand Down
95 changes: 94 additions & 1 deletion tests/nn/beam_search_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,18 @@
] # end token -> jth token
)

# A transition matrix that favors shorter sequences over longer ones
short_sequence_transition_probabilities = torch.tensor(
[
[0.0, 0.1, 0.0, 0.0, 0.0, 0.9], # start token -> jth token
[0.0, 0.0, 0.1, 0.0, 0.0, 0.9], # 1st token -> jth token
[0.0, 0.0, 0.0, 0.1, 0.0, 0.9], # 2nd token -> jth token
[0.0, 0.0, 0.0, 0.0, 0.1, 0.9], # ...
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0], # ...
[0.2, 0.1, 0.2, 0.2, 0.2, 0.3],
] # end token -> jth token
)

log_probabilities = torch.log(
torch.tensor([[0.1, 0.3, 0.3, 0.3, 0.0, 0.0], [0.0, 0.0, 0.4, 0.3, 0.2, 0.1]])
)
Expand Down Expand Up @@ -62,6 +74,25 @@ def take_step_with_timestep(
return take_step_no_timestep(last_predictions, state)


def take_short_sequence_step(
last_predictions: torch.Tensor,
state: Dict[str, torch.Tensor],
timestep: int,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Take decoding step.

This method is the same as `take_step_no_timestep` except it uses the
`short_sequence_transition_probabilities` transitions instead of `transition_probabilities`
"""
log_probs_list = []
for last_token in last_predictions:
log_probs = torch.log(short_sequence_transition_probabilities[last_token.item()])
log_probs_list.append(log_probs)

return torch.stack(log_probs_list), state


class BeamSearchTest(AllenNlpTestCase):
def setup_method(self):
super().setup_method()
Expand Down Expand Up @@ -101,7 +132,7 @@ def _check_results(

# log_probs should be shape `(batch_size, beam_size, max_predicted_length)`.
assert list(log_probs.size()) == [batch_size, beam_size]
np.testing.assert_allclose(log_probs[0].numpy(), expected_log_probs)
np.testing.assert_allclose(log_probs[0].numpy(), expected_log_probs, rtol=1e-6)

@pytest.mark.parametrize("step_function", [take_step_with_timestep, take_step_no_timestep])
def test_search(self, step_function):
Expand Down Expand Up @@ -211,6 +242,68 @@ def test_early_stopping(self):
beam_search=beam_search,
)

def test_take_short_sequence_step(self):
"""
Tests to ensure the top-k from the short_sequence_transition_probabilities
transition matrix is expected
"""
self.beam_search.beam_size = 5
expected_top_k = np.array(
[[5, 5, 5, 5, 5], [1, 5, 5, 5, 5], [1, 2, 5, 5, 5], [1, 2, 3, 5, 5], [1, 2, 3, 4, 5]]
)
expected_log_probs = np.log(np.array([0.9, 0.09, 0.009, 0.0009, 0.0001]))
self._check_results(
expected_top_k=expected_top_k,
expected_log_probs=expected_log_probs,
take_step=take_short_sequence_step,
)

def test_min_steps(self):
"""
Tests to ensure all output sequences are greater than a specified minimum length.
It uses the `take_short_sequence_step` step function, which favors shorter sequences.
See `test_take_short_sequence_step`.
"""
self.beam_search.beam_size = 1

# An empty sequence is allowed under this step function
self.beam_search.min_steps = 0
expected_top_k = np.array([[5]])
expected_log_probs = np.log(np.array([0.9]))
self._check_results(
expected_top_k=expected_top_k,
expected_log_probs=expected_log_probs,
take_step=take_short_sequence_step,
)

self.beam_search.min_steps = 1
expected_top_k = np.array([[1, 5]])
expected_log_probs = np.log(np.array([0.09]))
self._check_results(
expected_top_k=expected_top_k,
expected_log_probs=expected_log_probs,
take_step=take_short_sequence_step,
)

self.beam_search.min_steps = 2
expected_top_k = np.array([[1, 2, 5]])
expected_log_probs = np.log(np.array([0.009]))
self._check_results(
expected_top_k=expected_top_k,
expected_log_probs=expected_log_probs,
take_step=take_short_sequence_step,
)

self.beam_search.beam_size = 3
self.beam_search.min_steps = 2
expected_top_k = np.array([[1, 2, 5, 5, 5], [1, 2, 3, 5, 5], [1, 2, 3, 4, 5]])
expected_log_probs = np.log(np.array([0.009, 0.0009, 0.0001]))
self._check_results(
expected_top_k=expected_top_k,
expected_log_probs=expected_log_probs,
take_step=take_short_sequence_step,
)

def test_different_per_node_beam_size(self):
# per_node_beam_size = 1
beam_search = BeamSearch(self.end_index, beam_size=3, per_node_beam_size=1)
Expand Down