-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Add sampling strategies to beam search #4768
Conversation
…d tests for those samplers and stochastic_beam_search
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Almost there 😬! My suggestions are mostly cosmetic. I also think I found a way to avoid the loop within top-p.
…up documentation and sampeler code.
…ochastic-search-sampler
allennlp/nn/beam_search.py
Outdated
# Create a mask for filtering out probabilities that don't make the top `p`. | ||
# shape: (batch_size, num_classes) | ||
exclusion_mask = probabilities_summed >= self.p | ||
|
||
# We want to include the firt index where probabilities_summes >= p, so we shift over one. | ||
exclusion_mask[..., 1:] = exclusion_mask[..., :-1].clone() | ||
exclusion_mask[..., 0] = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about this:
# Create a mask for filtering out probabilities that don't make the top `p`. | |
# shape: (batch_size, num_classes) | |
exclusion_mask = probabilities_summed >= self.p | |
# We want to include the firt index where probabilities_summes >= p, so we shift over one. | |
exclusion_mask[..., 1:] = exclusion_mask[..., :-1].clone() | |
exclusion_mask[..., 0] = False | |
# Create a mask for filtering out probabilities that don't make the top `p`. | |
# shape: (batch_size, num_classes) | |
exclusion_mask = probabilities_summed > self.p | |
# Make sure there's at least `per_node_beam_size` options. | |
exclusion_mask[..., :per_node_beam_size] = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we actually need both. If we don't shift the mask then then the cumulative sum is less than p, i.e.
probs = [0.5, 0.31, 0.19] , p = 0.8
cumsum = [0.5, 0.81, 1.0]
mask: [False, True, True]
softmax: [1.0, 0.0, 0.0]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. How about this then:
# Create a mask for filtering out probabilities that don't make the top `p`. | |
# shape: (batch_size, num_classes) | |
exclusion_mask = probabilities_summed >= self.p | |
# We want to include the firt index where probabilities_summes >= p, so we shift over one. | |
exclusion_mask[..., 1:] = exclusion_mask[..., :-1].clone() | |
exclusion_mask[..., 0] = False | |
# Create a mask for filtering out probabilities that don't make the top `p`. | |
# shape: (batch_size, num_classes) | |
exclusion_mask = probabilities_summed >= self.p | |
# We want to include the first index where probabilities_summed >= p, so we shift over one. | |
exclusion_mask[..., 1:] = exclusion_mask[..., :-1].clone() | |
# We also want to make sure we have at least `per_node_beam_size` options. | |
exclusion_mask[..., :per_node_beam_size] = False |
(only change is the last line)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to allow the client to sample with replacement from < per_node_beam_size options?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i.e. maybe they sample the 10 examples from the top 8 cumulative probabilities
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Possibly. So maybe only fill up to per_node_beam_size
as False
if with_replacement
is False
, otherwise only fill the first as False
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you don't think we should build that functionality in though I can clean it up!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's okay to have that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oops sorry I forgot to refresh and missed your comment! And yes that's the change I ended up making, so it should be all good
…chastic-search-sampler
…allennlp into stochastic-search-sampler
TODO (@jvstokes)