Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added top k search util #232

Merged
merged 6 commits into from
Jun 24, 2022

Conversation

jessechancy
Copy link
Contributor

Follow up from random search

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Just a few comments

token_probability_fn,
prompt,
max_length,
k=10,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

10 is fine, but how did we come up with this number?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this number should be related to the vocab size, so maybe let's make it a required arg?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean vocab_size should be another argument? Or k should be a required arg

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Talked on chat. Let's make k required (no default). No need to take in vocab_size, that can continue to be inferred.

pad_token_id=0,
):
"""
Text generation utility based on top k sampling.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

top-k

here and elsewhere in docstring

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

edited

input_is_1d = prompt.shape.rank == 1
if input_is_1d:
prompt = prompt[tf.newaxis, :]
i = prompt.shape[1]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we are doing both shape[1] and shape[-1] to read the last axis? This reads confusing, let's choose one and be consistent. Unless there is a case where these are different?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

edited to both use shape[1]. The output should strictly be [batch_size, vocab_size] in the pred and [batch_size, length] in prompt


Args:
token_probability_fn: a callable, which takes in input_sequence
and output the probability distribution of the next token.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here and elsewhere in this file, we probably should mention that this function should return the unnormalized logits and not softmax probabilities right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually one confusing part where we should make a decision - currently there is not such enforcement on the return type, shall we add it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return type should be probabilities here for this to work, if not I would need to add a softmax over it

# If k is greater than the vocabulary size, use the entire vocabulary.
k = min(k, pred.shape[-1])
# Filter out top k tokens.
sorted_pred, sorted_indices = tf.math.top_k(pred, k=k, sorted=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need sorted here? tf.random.categorical doesn't need a sort order. We just need to make sure we gather than correct indices from the top_k call which you are already doing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep doesn't need sorted, edited

Copy link
Contributor

@chenmoneygithub chenmoneygithub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mainly looks good!

token_probability_fn,
prompt,
max_length,
k=10,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this number should be related to the vocab size, so maybe let's make it a required arg?


Args:
token_probability_fn: a callable, which takes in input_sequence
and output the probability distribution of the next token.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually one confusing part where we should make a decision - currently there is not such enforcement on the return type, shall we add it?

outputs = top_k_search(
token_probability_fn, inputs, k=2, max_length=max_length, seed=42
)
# Random sampling result with seed 42
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

top-k search result

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Edited

rtol=0.2,
)

def test_assert_top_k_generation_is_correct(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test is to assert only top-k tokens can appear, but the name does not suggest so. Let's rename to something like test_only_choose_from_top_k_tokens

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

edited

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couple comments!


i = prompt.shape[1]
while i < max_length:
# If the prompt has reached our desired length, exit while loop.
pred = token_probability_fn(prompt)
pred = token_fn(prompt)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to wrap in another function? This whole change is a little hard to read. I would find it simpler to just add a block here.

pred = token_probability_fn(prompt)
if from_logits:
    pred = tf.keras.activations.softmax(pred)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense, edited

def token_probability_fn(inputs):
return model(inputs)[:, -1, :]

prompt = tf.random.uniform(shape=[5, 5], maxval=VOCAB_SIZE, dtype=tf.int64)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is kind of a weird prompt to show. Who would want to generate sequences after 5 totally random tokens.

Maybe we should do something like

BATCH_SIZE = 8
VOCAB_SIZE = 10
FEATURE_SIZE = 16
START_ID=1
END_ID=2

...

prompt = tf.fill((BATCH_SIZE, 1), START_ID)

keras_nlp.utils.top_k_search(
        token_probability_fn,
        prompt,
        k=10,
        max_length=10,
        end_token_id=END_ID)

We may want to update other examples if they have the same problem.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also to be clear the ellipsis is me being lazy, not suggesting we put that in the docstring :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

edited docstring

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Last few nits!


# Print the generated sequence (token ids).
keras_nlp.utils.greedy_search(
token_probability_fn,
prompt,
max_length=10,
end_token_id=0,)
end_token_id=END_ID
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: trailing comma

token_probability_fn,
prompt,
max_length=10,
end_token_id=0,)
end_token_id=END_ID
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trailing comma

prompt,
max_length=10,
k=4,
end_token_id=END_ID
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trailing comma

prompt: a list or a Tensor, can be 1D or 2D, the initial tokens to
append generated tokens.
max_length: int. The max length of generated text.
from_logits: bool. Indicates whether `token_probability_fn` outputs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You document this but didn't actually add it? Should we just leave it off as for greedy search it is a no-op?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I added it at first and realised greedy search isn't affected, will remove it from docstring

@chenmoneygithub chenmoneygithub merged commit 31674a1 into keras-team:master Jun 24, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants