Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement TopP, TopK and Beam samplers #652

Merged

Conversation

chenmoneygithub
Copy link
Contributor

A few things covered:

  • Implement TopP, TopK and Beam samplers
  • Add "Sampler" suffix to our sampler class names.
  • Add run_eagerly option to our sampler to align with model.compile().

One special thing to note:

  • In Beam sampler, there is a strange error on batch_size=1 case, which is after the first iteration, the shape of beam_probs changes from [1, None] to [None, None], so we add shape_invariants specifically for beam sampler.

Copy link
Contributor

@jbischof jbischof left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @chenmoneygithub looking great! The biggest thing is unifying our docstring style

keras_nlp/samplers/beam_sampler.py Show resolved Hide resolved

Examples:
```python
BATCH_SIZE = 8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chenmoneygithub I really don't think these arg blocks match the rest of the library's style. I don't think there's a perfect answer but I'd prefer it not to be obvious who wrote what 🥗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I checked our code base, it appears we have a few code having this style:

  • maskedLmhead: link
  • PositionEmbedding: link
  • SinePositionEmbedding: link
  • Generation utils: link

I feel like giving these numbers an explicit meaning makes the code longer but kinda easier to understand.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a totally valid opinion @chenmoneygithub but it's also important for the code to have a unified style rather than each contributor producing different looking code. If there's an example that's unclear without named params I'm open to trying something different but otherwise I'm hoping we can compromise!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Taking steps towards a more unified style (and then reflecting that in our style guide) sgtm. What are the main places this differs, besides the constants at the top?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically what I asked for in #658 was the current thought. A bit less script-like and a little more "drop this line in colab and see what we're talking about".

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could see us replacing the "model" with something like tf.random.uniform(shape, minval=-1, maxval=1). It is kind of weird to me that we show a whole model, that is trainable but randomly initialized (so results will be random anyway), and not even sequence aware so would never really perform even if your trained it. For a new user this seems a bit of a red herring.

Would be more concise to do something like:

    def token_probability_fn(inputs, mask):
        return tf.random.uniform(...) # Replace with a real model!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to keep the model part so that the example is closer to real use cases.

I will unify the docstring to move those hypers inline.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me the model falls into an "uncanny valley" of code examples. It's not something that will actually work, yet also not clearly a random dummy data. As a newbie I worry I would not understand, first, that results will be random, and second, that this is a "bad model" for the task.

Fine to merge as is, but I hope we can play around with some improvements down the road.

):
if run_eagerly and jit_compile:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need two flags or just one then? What happens if they are both False---what is a "non-XLA" graph? I'm confused!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is following the style of model.compile(): code link

non-XLA graph is quite common, anything annotated with tf.function without jit_compile=True is a non-xla graph.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, agree, was just trying to understand the difference

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Chen! Left some comments and a few questions.

keras_nlp/samplers/__init__.py Outdated Show resolved Hide resolved
beams, max_indexes[:, tf.newaxis], axis=1, batch_dims=1
)

prompt = tf.squeeze(max_beams, axis=1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return immediately

keras_nlp/samplers/top_p_sampler.py Show resolved Hide resolved
keras_nlp/samplers/top_p_sampler.py Outdated Show resolved Hide resolved
keras_nlp/samplers/top_p_sampler.py Outdated Show resolved Hide resolved
)
if from_logits:
pred = keras.activations.softmax(pred, axis=-1)
# Sort preds in descending order.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main change I could see us making if we want to, is splitting the overload points into sample and sample_step (I think some discussion of this on other PR?).

Overload sample if you want to control the whole process (e.g. beam search). Overload sample_step if you want to control simply going from one probability distribution to one sample. E.g. the body of this class could look like:

def sample_step(preds):
     sorted_preds, sorted_indices = tf.math.top_k(
          pred, k=tf.shape(preds)[1], sorted=True
      )
      # Calculate cumulative probability distribution.
      cumulative_probs = tf.math.cumsum(sorted_preds, axis=-1)
      # Create a mask for the tokens to keep.
      keep_mask = cumulative_probs <= self.p
      # Shift to include the last token that exceed p.
      shifted_keep_mask = tf.concat(
          [tf.ones_like(keep_mask[:, :1]), keep_mask[:, :-1]], axis=-1
      )
      # Filter out unmasked tokens and sample from filtered distribution.
      probs = tf.where(
          shifted_keep_mask,
          sorted_preds,
          tf.zeros(tf.shape(pred), dtype=sorted_preds.dtype),
      )
      sorted_next_token = tf.random.categorical(
          tf.math.log(probs), 1, seed=self.seed
      )
      return tf.gather_nd(
          sorted_indices, sorted_next_token, batch_dims=1
      )

This might improve the readability of our simple samples, while still keeping full extensibility.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As long as we can make BeamSampler an outlier, I am down to this refactoring!

Actually we can move one step further to only leave get_next_token open, which takes in a prob distribution and returns the next token, and the rest updating logic can be shared across those samplers.

Reflected this change in the PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is what I was hoping for! 🚀

keras_nlp/samplers/top_p_sampler.py Outdated Show resolved Hide resolved

Examples:
```python
BATCH_SIZE = 8
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could see us replacing the "model" with something like tf.random.uniform(shape, minval=-1, maxval=1). It is kind of weird to me that we show a whole model, that is trainable but randomly initialized (so results will be random anyway), and not even sequence aware so would never really perform even if your trained it. For a new user this seems a bit of a red herring.

Would be more concise to do something like:

    def token_probability_fn(inputs, mask):
        return tf.random.uniform(...) # Replace with a real model!

Copy link
Contributor

@jbischof jbischof left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some nits. Looking good!


Examples:
```python
VOCAB_SIZE = 10
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's inline this arg as well. Or am I missing something?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one is used twice - embedding and dense, and the arg name does not suggest "vocab_size", so I am keeping this one for clarify.

@@ -252,20 +247,103 @@ def __call__(

return tf.squeeze(prompt, axis=0) if input_is_1d else prompt

@format_docstring(sample_args=sample_args_docstring)
def get_next_token(self, next_token_probs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think sliding window can fit into this paradigm?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, it should work


sampler = keras_nlp.samplers.GreedySampler()
# Print the generated sequence (token ids).
print(sampler(prompt, token_probability_fn, 10))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For clarity you could use named args:

print(sampler(prompt, token_probability_fn, max_length=10))

This matches or perhaps enhances the readability of using globals.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good call, done

@jbischof jbischof requested a review from fchollet January 18, 2023 19:54
Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a few more comments re get_config. +approval, as I am fine to merge after changes to get something landed. But hope we will stay open to changes here as we dig more into the generative case!

keras_nlp/samplers/top_k_sampler.py Show resolved Hide resolved
keras_nlp/samplers/top_p_sampler.py Show resolved Hide resolved
keras_nlp/samplers/top_k_sampler.py Show resolved Hide resolved
keras_nlp/samplers/beam_sampler.py Show resolved Hide resolved
@jbischof
Copy link
Contributor

Wait for @fchollet, he wants to take a look and can make the final call on class names.

@jbischof
Copy link
Contributor

jbischof commented Jan 18, 2023

Short class names seem to fit in the Keras ecosystem:

sampler=keras_nlp.samplers.Beam(num_beams=3),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=keras.metrics.SparseCategoricalAccuracy(),
kernel_initializer=initializers.RandomNormal(stddev=0.01),
activation=activations.relu,

@chenmoneygithub
Copy link
Contributor Author

Let's merge this one to unblock the TFLite sprint, I will sync with Francois offline regarding the string identifier.

@chenmoneygithub chenmoneygithub merged commit 612bbf5 into keras-team:master Jan 18, 2023
@chenmoneygithub chenmoneygithub removed the request for review from fchollet January 18, 2023 20:49

def get_next_token(self, next_token_probs):
# Beam search overrides the whole `sample` method.
pass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should raise an error I suppose?

@chenmoneygithub chenmoneygithub deleted the text-generation-extend branch February 2, 2023 00:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants