-
Notifications
You must be signed in to change notification settings - Fork 248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement TopP, TopK and Beam samplers #652
Implement TopP, TopK and Beam samplers #652
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @chenmoneygithub looking great! The biggest thing is unifying our docstring style
keras_nlp/samplers/beam_sampler.py
Outdated
|
||
Examples: | ||
```python | ||
BATCH_SIZE = 8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@chenmoneygithub I really don't think these arg blocks match the rest of the library's style. I don't think there's a perfect answer but I'd prefer it not to be obvious who wrote what 🥗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a totally valid opinion @chenmoneygithub but it's also important for the code to have a unified style rather than each contributor producing different looking code. If there's an example that's unclear without named params I'm open to trying something different but otherwise I'm hoping we can compromise!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Taking steps towards a more unified style (and then reflecting that in our style guide) sgtm. What are the main places this differs, besides the constants at the top?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Basically what I asked for in #658 was the current thought. A bit less script-like and a little more "drop this line in colab and see what we're talking about".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could see us replacing the "model" with something like tf.random.uniform(shape, minval=-1, maxval=1)
. It is kind of weird to me that we show a whole model, that is trainable but randomly initialized (so results will be random anyway), and not even sequence aware so would never really perform even if your trained it. For a new user this seems a bit of a red herring.
Would be more concise to do something like:
def token_probability_fn(inputs, mask):
return tf.random.uniform(...) # Replace with a real model!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would like to keep the model part so that the example is closer to real use cases.
I will unify the docstring to move those hypers inline.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To me the model falls into an "uncanny valley" of code examples. It's not something that will actually work, yet also not clearly a random dummy data. As a newbie I worry I would not understand, first, that results will be random, and second, that this is a "bad model" for the task.
Fine to merge as is, but I hope we can play around with some improvements down the road.
): | ||
if run_eagerly and jit_compile: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need two flags or just one then? What happens if they are both False
---what is a "non-XLA" graph? I'm confused!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is following the style of model.compile()
: code link
non-XLA graph is quite common, anything annotated with tf.function
without jit_compile=True
is a non-xla graph.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, agree, was just trying to understand the difference
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Chen! Left some comments and a few questions.
keras_nlp/samplers/beam_sampler.py
Outdated
beams, max_indexes[:, tf.newaxis], axis=1, batch_dims=1 | ||
) | ||
|
||
prompt = tf.squeeze(max_beams, axis=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return immediately
keras_nlp/samplers/top_p_sampler.py
Outdated
) | ||
if from_logits: | ||
pred = keras.activations.softmax(pred, axis=-1) | ||
# Sort preds in descending order. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main change I could see us making if we want to, is splitting the overload points into sample
and sample_step
(I think some discussion of this on other PR?).
Overload sample
if you want to control the whole process (e.g. beam search). Overload sample_step
if you want to control simply going from one probability distribution to one sample. E.g. the body of this class could look like:
def sample_step(preds):
sorted_preds, sorted_indices = tf.math.top_k(
pred, k=tf.shape(preds)[1], sorted=True
)
# Calculate cumulative probability distribution.
cumulative_probs = tf.math.cumsum(sorted_preds, axis=-1)
# Create a mask for the tokens to keep.
keep_mask = cumulative_probs <= self.p
# Shift to include the last token that exceed p.
shifted_keep_mask = tf.concat(
[tf.ones_like(keep_mask[:, :1]), keep_mask[:, :-1]], axis=-1
)
# Filter out unmasked tokens and sample from filtered distribution.
probs = tf.where(
shifted_keep_mask,
sorted_preds,
tf.zeros(tf.shape(pred), dtype=sorted_preds.dtype),
)
sorted_next_token = tf.random.categorical(
tf.math.log(probs), 1, seed=self.seed
)
return tf.gather_nd(
sorted_indices, sorted_next_token, batch_dims=1
)
This might improve the readability of our simple samples, while still keeping full extensibility.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As long as we can make BeamSampler
an outlier, I am down to this refactoring!
Actually we can move one step further to only leave get_next_token
open, which takes in a prob distribution and returns the next token, and the rest updating logic can be shared across those samplers.
Reflected this change in the PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is what I was hoping for! 🚀
keras_nlp/samplers/beam_sampler.py
Outdated
|
||
Examples: | ||
```python | ||
BATCH_SIZE = 8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could see us replacing the "model" with something like tf.random.uniform(shape, minval=-1, maxval=1)
. It is kind of weird to me that we show a whole model, that is trainable but randomly initialized (so results will be random anyway), and not even sequence aware so would never really perform even if your trained it. For a new user this seems a bit of a red herring.
Would be more concise to do something like:
def token_probability_fn(inputs, mask):
return tf.random.uniform(...) # Replace with a real model!
ce07a4d
to
64ff158
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just some nits. Looking good!
|
||
Examples: | ||
```python | ||
VOCAB_SIZE = 10 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's inline this arg as well. Or am I missing something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This one is used twice - embedding and dense, and the arg name does not suggest "vocab_size", so I am keeping this one for clarify.
@@ -252,20 +247,103 @@ def __call__( | |||
|
|||
return tf.squeeze(prompt, axis=0) if input_is_1d else prompt | |||
|
|||
@format_docstring(sample_args=sample_args_docstring) | |||
def get_next_token(self, next_token_probs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think sliding window can fit into this paradigm?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea, it should work
keras_nlp/samplers/greedy_sampler.py
Outdated
|
||
sampler = keras_nlp.samplers.GreedySampler() | ||
# Print the generated sequence (token ids). | ||
print(sampler(prompt, token_probability_fn, 10)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For clarity you could use named args:
print(sampler(prompt, token_probability_fn, max_length=10))
This matches or perhaps enhances the readability of using globals.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good call, done
64ff158
to
950ad43
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a few more comments re get_config
. +approval, as I am fine to merge after changes to get something landed. But hope we will stay open to changes here as we dig more into the generative case!
Wait for @fchollet, he wants to take a look and can make the final call on class names. |
Short class names seem to fit in the Keras ecosystem: sampler=keras_nlp.samplers.Beam(num_beams=3),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=keras.metrics.SparseCategoricalAccuracy(),
kernel_initializer=initializers.RandomNormal(stddev=0.01),
activation=activations.relu, |
Let's merge this one to unblock the TFLite sprint, I will sync with Francois offline regarding the string identifier. |
|
||
def get_next_token(self, next_token_probs): | ||
# Beam search overrides the whole `sample` method. | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should raise an error I suppose?
A few things covered:
run_eagerly
option to our sampler to align withmodel.compile()
.One special thing to note:
batch_size=1
case, which is after the first iteration, the shape ofbeam_probs
changes from[1, None]
to[None, None]
, so we addshape_invariants
specifically for beam sampler.