Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Add sampling strategies to beam search #4768

Merged
merged 20 commits into from
Nov 11, 2020
Merged

Conversation

epwalsh
Copy link
Member

@epwalsh epwalsh commented Nov 3, 2020

TODO (@jvstokes)

  • implement multinomial sampler
  • implement top-k sampler
  • implement top-p sampler
  • test multinomial sampler
  • test top-k sampler
  • test top-p sampler
  • test gumbel sampler, especially with edge cases like when the end token is predicted
  • update CHANGELOG
  • finalize documentation

@epwalsh epwalsh changed the title Add sampler strategies to beam search Add sampling strategies to beam search Nov 3, 2020
Copy link
Member Author

@epwalsh epwalsh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Almost there 😬! My suggestions are mostly cosmetic. I also think I found a way to avoid the loop within top-p.

@epwalsh epwalsh marked this pull request as ready for review November 10, 2020 18:01
Comment on lines 248 to 254
# Create a mask for filtering out probabilities that don't make the top `p`.
# shape: (batch_size, num_classes)
exclusion_mask = probabilities_summed >= self.p

# We want to include the firt index where probabilities_summes >= p, so we shift over one.
exclusion_mask[..., 1:] = exclusion_mask[..., :-1].clone()
exclusion_mask[..., 0] = False
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about this:

Suggested change
# Create a mask for filtering out probabilities that don't make the top `p`.
# shape: (batch_size, num_classes)
exclusion_mask = probabilities_summed >= self.p
# We want to include the firt index where probabilities_summes >= p, so we shift over one.
exclusion_mask[..., 1:] = exclusion_mask[..., :-1].clone()
exclusion_mask[..., 0] = False
# Create a mask for filtering out probabilities that don't make the top `p`.
# shape: (batch_size, num_classes)
exclusion_mask = probabilities_summed > self.p
# Make sure there's at least `per_node_beam_size` options.
exclusion_mask[..., :per_node_beam_size] = False

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we actually need both. If we don't shift the mask then then the cumulative sum is less than p, i.e.
probs = [0.5, 0.31, 0.19] , p = 0.8
cumsum = [0.5, 0.81, 1.0]
mask: [False, True, True]
softmax: [1.0, 0.0, 0.0]

Copy link
Member Author

@epwalsh epwalsh Nov 10, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. How about this then:

Suggested change
# Create a mask for filtering out probabilities that don't make the top `p`.
# shape: (batch_size, num_classes)
exclusion_mask = probabilities_summed >= self.p
# We want to include the firt index where probabilities_summes >= p, so we shift over one.
exclusion_mask[..., 1:] = exclusion_mask[..., :-1].clone()
exclusion_mask[..., 0] = False
# Create a mask for filtering out probabilities that don't make the top `p`.
# shape: (batch_size, num_classes)
exclusion_mask = probabilities_summed >= self.p
# We want to include the first index where probabilities_summed >= p, so we shift over one.
exclusion_mask[..., 1:] = exclusion_mask[..., :-1].clone()
# We also want to make sure we have at least `per_node_beam_size` options.
exclusion_mask[..., :per_node_beam_size] = False

(only change is the last line)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to allow the client to sample with replacement from < per_node_beam_size options?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i.e. maybe they sample the 10 examples from the top 8 cumulative probabilities

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly. So maybe only fill up to per_node_beam_size as False if with_replacement is False, otherwise only fill the first as False.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you don't think we should build that functionality in though I can clean it up!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's okay to have that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops sorry I forgot to refresh and missed your comment! And yes that's the change I ended up making, so it should be all good

@epwalsh epwalsh merged commit 9f7cc24 into master Nov 11, 2020
@epwalsh epwalsh deleted the beam-search-sampler-try branch November 11, 2020 00:00
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants