Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Greedy text generation util #154

Merged
merged 9 commits into from
May 19, 2022

Conversation

chenmoneygithub
Copy link
Contributor

@chenmoneygithub chenmoneygithub changed the title initial commit for greedy text generation util Greedy text generation util Apr 29, 2022
Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Left some initial comments mainly on design front.

keras_nlp/utils/text_generation.py Outdated Show resolved Hide resolved
keras_nlp/utils/text_generation.py Outdated Show resolved Hide resolved
keras_nlp/utils/text_generation.py Outdated Show resolved Hide resolved
keras_nlp/utils/text_generation.py Outdated Show resolved Hide resolved
keras_nlp/utils/text_generation.py Show resolved Hide resolved
keras_nlp/utils/text_generation.py Outdated Show resolved Hide resolved
keras_nlp/utils/text_generation.py Outdated Show resolved Hide resolved
keras_nlp/utils/text_generation.py Outdated Show resolved Hide resolved
keras_nlp/utils/text_generation.py Outdated Show resolved Hide resolved
keras_nlp/utils/text_generation.py Show resolved Hide resolved
@chenmoneygithub chenmoneygithub force-pushed the text-generation branch 2 times, most recently from e0b878a to f2b503a Compare May 4, 2022 21:15
keras_nlp/utils/text_generation.py Outdated Show resolved Hide resolved
keras_nlp/utils/text_generation.py Outdated Show resolved Hide resolved
keras_nlp/utils/text_generation.py Show resolved Hide resolved
next_token, end_token_received, end_token_id
)
# Append the next token to current sequence.
input_ids = tf.concat([input_ids, next_token[:, tf.newaxis]], axis=-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about testing with XLA GPU at least if you cannot test on TPU :)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point.

  1. I am not sure how to add github GPU/TPU test. It also reminds that we might want to test distributed training (not for this utility but other modules) in the future. We will check more on it.
  2. IIUC XLA on GPU should be turned on manually? It requires @tf.function(jit_compile=True). For this specific utility, we probably don't need XLA testing?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if you realize that the tf.concat will yield dynamic shape and if you know xla requirements.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting - so tf.concat is something that cannot work on TPU and GPU/CPU (XLA)?

Our current plan is not to wrap this utility with tf.function(), but users can choose to wrap next_token_fn by tf.function(), as next_token_fn takes most of the computation time.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the performance will be bad or you only target this as a demo util function? Is it the reason you use python while loop but not tf.while_loop for the decoding loop?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, yes, this util is mainly useful for a demo like colab guide, so performance is not the focus now.

I am actually curious, how does model garden handle the token concatenation? Are you using a fixed size tensor, and change the value at each iteration? I am not sure how much performance diff it would have, since the bottleneck is mostly at model calling when I did benchmark on colab.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to allocate buffer to the max sequence length to decode and use in-place update.
The padded_decode path is dedicated to XLA: https://github.com/tensorflow/models/blob/master/official/nlp/modeling/ops/beam_search.py#L109
BTW, I read fairseq code before and I guess they use something similar for GPU performance optimization as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Hongkun! I will open an issue to track the refactoring,

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this isn't just not XLA compilable, it is not at all tf.function compilable right now.

We definitely should look at making it so, both for usability (use this inside a keras model) and performance for any sort of bulk inference like job.

For now, this has been more to get the API signature how we want it. I wouldn't say this should always be only a demo util function, that's just where we are at today.

from keras_nlp.utils.text_generation import generate_text_greedy


class TextGenerationTest(tf.test.TestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting, do you need tf.test.main()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems no - our test is based on pytest, which automatically captures the test cases.

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some more comments. But the big one, I couldn't figure out a way to actually do text generation with a plain text seed text and our tokenizers. Take a look:

https://colab.sandbox.google.com/gist/mattdangerw/34ec3d54511d74f6a9b2ca4bdb9e22b8/text-generation-scratch.ipynb

I think if we want to support batches of input, we need to support batches of raggeds, otherwise I'm at a loss as to how this could be used for batched text generation where the seed input is plain text.

keras_nlp/utils/text_generation.py Outdated Show resolved Hide resolved
keras_nlp/utils/text_generation.py Outdated Show resolved Hide resolved
keras_nlp/utils/text_generation.py Outdated Show resolved Hide resolved
```

"""
if 0 in input_ids.shape:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks at this more, and I really think we should not keep this check. Seems fully possible to support a empty input tensor input_ids=[] or a batch shape with input_ids=tf.zeros([bs, 0]).

The former would be really useful in guides when doing something really simple. There will not always be start tokens as a convention, see the main tf text generation guide as an example https://www.tensorflow.org/text/tutorials/text_generation.

We should also call tf.convert_to_tensor on non tensor input, so we can do things like input_ids=[] or input_ids=[start_id].

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing I feel confused is if there is no prompt at all, how shall we generate the next token? Maybe add an extra argument start_token?

Args:
token_probability_fn: a callable, which takes in input_sequence
and output the probability distribution of the next token.
input_ids: a list, the initial tokens to append generated tokens.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prompt?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sg!

sequence. If None, every sequence is generated up to `max_length`.

Returns:
A 1D int Tensor, or 2D int RaggedTensor representing the generated
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most likely it should operate on a single sequence, never a batch. The user could map it to a batch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Processing at single sequence level would make the code simpler, but the execution would be slowed down: there are lots of model calling inside this utility, so without parallelism it could take much longer.

return filtered_next_token, end_token_received


def generate_text_greedy(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This operates in integer token space, not text space. It should not be named "generate text" (also "generate text greedy" is not a sentence, unlike "greedy text generation" or "generate text greedily").

Maybe this should just be called greedy_search (vs beam_search)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sg!

and output the probability distribution of the next token.
input_ids: a list, the initial tokens to append generated tokens.
max_length: int. The max length of generated text.
end_token_id: int, defaults to None. The token marking the end of the
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

end_token?

Copy link
Member

@mattdangerw mattdangerw May 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We haven't really formalized this (I should add it to our API design guide). But I think if this operates strictly on ints end_token_id is the correct name. So people don't pass "<eos>" by mistake.

If we switched this to strings end_token would be the correct name. But it sounds like we will keep this in int space for now, and just rename to beam_search and greedy_search, which I think is a good call.

@chenmoneygithub chenmoneygithub force-pushed the text-generation branch 2 times, most recently from 29dc810 to 9cc5a60 Compare May 11, 2022 19:54
@chenmoneygithub
Copy link
Contributor Author

@chenmoneygithub
Copy link
Contributor Author

Colab example for machine translation:
https://colab.research.google.com/gist/chenmoneygithub/b1301804f8cfcb5347ade8c7b6cb124f/kerasnlp-machine-translation-model-util.ipynb

dataset: http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip

The accuracy is bad, but the main goal of this colab is to show how to use the text generation util.

@mattdangerw
Copy link
Member

Thanks @chenmoneygithub!

  • In the first example, is the manual import compute_causal_mask + passing to encoder the best way to do a decoder only architecture? If so we should fix, that does not feel like a clear and clean example. Not for this PR, but let's discuss this.
  • In the second example, is there a reason you aren't using end token id? When it is supplied what will the output look like? End token id followed by pad token ids? We need to make this clear in the documentation. Also, should we take in pad_token_id as an arg? It looks like you are doing some implicit assumptions about the zero index.
  • The second example gives and example of why I don't think the prompt should be required to have have non-zero length on the last dimension. You could easily just not do the convention of starting with "[START]", and it would be totally valid to allow sampling starting from the first token. This would also be true of a model that pads with a start token within the call graph. Here's modifications on your colab. Allowing zero shape on the last dim would be more flexible without breaking any of the cases today.

I'll play around with compiling this into functions and models. Curious what our current standing is there.

@chenmoneygithub
Copy link
Contributor Author

  1. Yes, created TransformerDecoder should support single inputs #182 for fix.
  2. Updated the PR to include a pad_token_id.
  3. Removed the requirement that prompt must be given.

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a few more comments. This is looking closer. We can simplify the end sequence token logic, and we need to do something user friendly in the compiled function case.

keras_nlp/utils/text_generation.py Outdated Show resolved Hide resolved
keras_nlp/utils/text_generation.py Outdated Show resolved Hide resolved
keras_nlp/utils/text_generation.py Show resolved Hide resolved
prompt = tf.concat([prompt, next_token[:, tf.newaxis]], axis=-1)
return get_subsequent_tokens(prompt, end_token_id_received)

generated_sequence = get_subsequent_tokens(prompt, end_token_id_received)
Copy link
Member

@mattdangerw mattdangerw May 17, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the end token logic could be much more readable (and probably more efficient) as a post process. tf.sequence_mask could help.

Something like

if end_token_id is not None:
    # Find index of first end_token_id.
    end_indices = tf.math.argmax(outputs == end_token_id, -1)
    # Use max_length if none found.
    end_indices = tf.where(end_indices == 0, max_length, end_indices)
    # Build a mask including end_token and replace overflow with pad_token_id.
    valid_indices = tf.sequence_mask(end_indices + 1, maxlen=max_length)
    outputs = tf.where(valid_indices, outputs, pad_token_id)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point! done

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Just minor edits now.

prompt = tf.random.uniform(shape=[5, 5], maxval=VOCAB_SIZE, dtype=tf.int64)

# Print the generated sequence (token ids).
keras_nlp.greedy_search(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

keras_nlp.utils.greedy_search right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch!

keras_nlp/utils/text_generation.py Outdated Show resolved Hide resolved
keras_nlp/utils/text_generation.py Outdated Show resolved Hide resolved
keras_nlp/utils/text_generation_test.py Outdated Show resolved Hide resolved
keras_nlp/utils/text_generation_test.py Show resolved Hide resolved
@saberkun
Copy link
Contributor

Can prompt be a prefix list of tokens? context: prefix-LM

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving last few comments.

end_token_id=2,
pad_token_id=0,
)
self.assertAllEqual(outputs[0, 2:], tf.repeat(3, max_length - 2))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just test the whole output here, that will be much more readable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sg!

keras_nlp/utils/text_generation_test.py Outdated Show resolved Hide resolved
keras_nlp/utils/text_generation.py Outdated Show resolved Hide resolved
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants