Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Standalone functions for generate pre/post processing for GPT-2 #998

Merged

Conversation

mattdangerw
Copy link
Member

@mattdangerw mattdangerw commented Apr 19, 2023

This decomposes generate in the way we discussed last week, with the goal of leaving the top-level functionality untouched, but allowing a more a granular way to access the preprocessing, postprocessing, and compiled generation function. Colab
HERE

Other than moving things around in the refactor, there is one major change we need to do here, which is the inner, compiled generate function must also return a padding mask of token ids that were updated. Without this padding mask, the postprocessor would not know where to truncate output before detokenization.

To accommodate this I made generate_function inputs and outputs a dict with keys "token_ids" and "padding_mask". I actually find this fairly intuitive, with this change generate_function has the same inputs and outputs as directly calling the model!

generate_function = causal_lm.make_generate_function()
# With early stopping at token 6.
generate_function({
   "token_ids":    [[1, 2, 3, 4, 0, 0, 0, 0]],
   "padding_mask": [[1, 1, 1, 1, 0, 0, 0, 0]],
}, end_token_id=6)
>>> {
   "token_ids":    [[1, 2, 3, 4, 5, 6, 0, 0]],
   "padding_mask": [[1, 1, 1, 1, 1, 1, 0, 0]],
}

@mattdangerw
Copy link
Member Author

/gcbrun

@mattdangerw mattdangerw force-pushed the split-generate-preprocessing branch from 786ad50 to 86fc49d Compare April 21, 2023 20:24
@mattdangerw
Copy link
Member Author

/gcbrun

@mattdangerw mattdangerw force-pushed the split-generate-preprocessing branch from 86fc49d to ef2464a Compare April 21, 2023 20:31
@mattdangerw
Copy link
Member Author

/gcbrun

@chenmoneygithub
Copy link
Contributor

Thanks for the PR! Overall looks good!

Took a pass over the code, and I have a few high-level comments:

  • I feel users will be confused about the difference between GPT2CausalLMPreprocessor > __call__(), GPT2CausalLMPreprocessor > generate_preprocess() and GPT2CausalLMPreprocessor > generate_postprocess(). While __call__ suggests a generic usage, it is actually representing the use case for generating causal LM training data for GPT2CausalLM, which is arguably a more general usage compared to calling it for generation.
  • The name generate_preprocess confused me for a while. The reason is generate is usually a verb, but here we are referring to the generate() method in GPT2CausalLM. I don't know what could be a better name tho...
  • I think we can move the final padding mask computation (end_token truncation) into postprocess.

Generally I think we need to make it easy for developers/users understand our breakdown, and why we expose 3 public methods in GPT2CausalLMPreprocessor.

@mattdangerw
Copy link
Member Author

mattdangerw commented Apr 24, 2023

Generally I think we need to make it easy for developers/users understand our breakdown, and why we expose 3 public methods in GPT2CausalLMPreprocessor.

I can work on the documentation. In my mind this is an advanced workflow. We do not advertise this at all in the first go around. generate() is just a magic black box! Much like many people using Keras might not realize that both predict_step and make_predict_function are overridable in core Keras. But if you did want to preprocess a bunch of string for generate separately, or build an exported tf.function for serving, or override just preprocessing but not postprocessing for generation, suddenly you have this more modular view of things to pick and choose from.

I think we can move the final padding mask computation (end_token truncation) into postprocess.

You can only do this if you pass the input padding mask to the output postprocess function right? Otherwise you lose track of what end tokens came from the original sequence. Let's think it terms of the "export flow". We cannot export generate() as a tracable function as we have decided to return pythonic objects, similar to predict(). Here's the rough export I am proposing...

def export_fn(x):
    x = gpt2_preprocessor.generate_preprocess(x)
    x = gpt2_lm.make_generate_function()(x)
    return gpt2_preprocessor.generate_postprocess(x)

With the change you are proposing, it needs to look like this...

def export_fn(x):
    preprocessed = gpt2_preprocessor.generate_preprocess(x)
    input_padding_mask = preprocessed["padding_mask"]
    output_token_ids = gpt2_lm.make_generate_function()(preprocessed)
    return gpt2_preprocessor.generate_postprocess(
        output_token_ids,
        padding_mask=input_padding_mask,
    )

The state gets a little more complex right? Postprocessing depends on both the partial output of preprocessing and the output of the generation function.

I also think there is a UX argument to return a padding mask in the "no preprocess" workflow. When you do, you pass in two tensors in a dict with shape (batch_size, max_length), you get back two tensors in a dict with shape (batch_size, max_length). You can use the padding mask as a reference to see where you tokens ids was updated. I would find it easier to work with if I am building the preprocessing and postprocessing 100% myself.

Copy link
Contributor

@jbischof jbischof left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New abstractions look very clean, thank you!

Some of the submethods could benefit from more documentation or (ideally) be more self-documenting so that the users can understand what these abstractions do and when they would use them.

keras_nlp/models/gpt2/gpt2_causal_lm.py Show resolved Hide resolved
keras_nlp/models/gpt2/gpt2_causal_lm.py Outdated Show resolved Hide resolved
`keras_nlp.models.GPT2CausalLM`. By default, it will take in batches of
strings, and return outputs in a `(x, y, sample_weight)` format, where the
`y` label is the next token id in the `x` sequence. For use with generation,
pass `return_labels=False` in which case the output will simply be the
encoded string features.
the layer also exposes two methods `generate_preprocess()` and
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we tell users when/why to use them?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a bit more color here.

Copy link
Contributor

@chenmoneygithub chenmoneygithub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally LGTM! Per our offline discussion, some more comments on normalize_input and so could help users understand.

keras_nlp/models/gpt2/gpt2_causal_lm.py Show resolved Hide resolved
keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py Outdated Show resolved Hide resolved
@mattdangerw
Copy link
Member Author

Thanks for review! Added some more docs.

@mattdangerw mattdangerw force-pushed the split-generate-preprocessing branch from ef2464a to 4c0690c Compare May 2, 2023 15:23
@mattdangerw
Copy link
Member Author

/gcbrun

@mattdangerw mattdangerw force-pushed the split-generate-preprocessing branch from 4c0690c to 418d1db Compare May 2, 2023 15:38
@mattdangerw
Copy link
Member Author

/gcbrun

Copy link
Contributor

@jbischof jbischof left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very impressive work @mattdangerw!

keras_nlp/models/gpt2/gpt2_causal_lm.py Show resolved Hide resolved
keras_nlp/models/gpt2/gpt2_causal_lm.py Show resolved Hide resolved
keras_nlp/models/gpt2/gpt2_causal_lm.py Outdated Show resolved Hide resolved
@jbischof jbischof changed the title Standalone functions for generate pre/post processing Standalone functions for generate pre/post processing for GPT-2 May 3, 2023
@mattdangerw
Copy link
Member Author

/gcbrun

This decomposes generate in the way we discussed last week, with the
goal of leaving the top-level functionality untouched, but allowing
a more a granular way to access the preprocessing, postprocessing,
and inner dense generation function. Colab
[HERE](https://colab.research.google.com/gist/mattdangerw/bb1ef01c1b67255def4a6ad9429de2df/split-preprocessing-demo.ipynb)

Other than moving things around in the refactor, there is one major
change we need to do here, which is the inner, compiled generate
function must also return a padding mask of token ids that were updated.
Without this padding mask, the postprocessor would not know where to
truncate output before detokenization.

To accommodate this I made `generate_function` inputs and outputs a dict
with keys "token_ids" and "padding_mask". I actually find this fairly
intuitive, with this change `generate_function` has the same inputs and
outputs as directly calling the model!

```python
generate_function = causal_lm.make_generate_function()
generate_function({
   "token_ids":    [[1, 2, 3, 4, 0, 0, 0, 0]],
   "padding_mask": [[1, 1, 1, 1, 0, 0, 0, 0]],
})
>>> {
   "token_ids":    [[1, 2, 3, 4, 5, 6, 7, 8]],
   "padding_mask": [[1, 1, 1, 1, 1, 1, 1, 1]],
}
generate_function({
   "token_ids":    [[1, 2, 3, 4, 0, 0, 0, 0]],
   "padding_mask": [[1, 1, 1, 1, 0, 0, 0, 0]],
}, end_token_id=6)
>>> {
   "token_ids":    [[1, 2, 3, 4, 5, 6, 0, 0]],
   "padding_mask": [[1, 1, 1, 1, 1, 1, 0, 0]],
}
```
@mattdangerw mattdangerw force-pushed the split-generate-preprocessing branch from d6e5d97 to 79d5bab Compare May 3, 2023 15:17
@mattdangerw
Copy link
Member Author

/gcbrun

@mattdangerw
Copy link
Member Author

/gcbrun

@mattdangerw mattdangerw merged commit 4a9c758 into keras-team:master May 3, 2023
chenmoneygithub pushed a commit that referenced this pull request May 4, 2023
* Standalone functions for generate pre/post processing

This decomposes generate in the way we discussed last week, with the
goal of leaving the top-level functionality untouched, but allowing
a more a granular way to access the preprocessing, postprocessing,
and inner dense generation function. Colab
[HERE](https://colab.research.google.com/gist/mattdangerw/bb1ef01c1b67255def4a6ad9429de2df/split-preprocessing-demo.ipynb)

Other than moving things around in the refactor, there is one major
change we need to do here, which is the inner, compiled generate
function must also return a padding mask of token ids that were updated.
Without this padding mask, the postprocessor would not know where to
truncate output before detokenization.

To accommodate this I made `generate_function` inputs and outputs a dict
with keys "token_ids" and "padding_mask". I actually find this fairly
intuitive, with this change `generate_function` has the same inputs and
outputs as directly calling the model!

```python
generate_function = causal_lm.make_generate_function()
generate_function({
   "token_ids":    [[1, 2, 3, 4, 0, 0, 0, 0]],
   "padding_mask": [[1, 1, 1, 1, 0, 0, 0, 0]],
})
>>> {
   "token_ids":    [[1, 2, 3, 4, 5, 6, 7, 8]],
   "padding_mask": [[1, 1, 1, 1, 1, 1, 1, 1]],
}
generate_function({
   "token_ids":    [[1, 2, 3, 4, 0, 0, 0, 0]],
   "padding_mask": [[1, 1, 1, 1, 0, 0, 0, 0]],
}, end_token_id=6)
>>> {
   "token_ids":    [[1, 2, 3, 4, 5, 6, 0, 0]],
   "padding_mask": [[1, 1, 1, 1, 1, 1, 0, 0]],
}
```

* More docstring updates

* Fix merge conflict
@tianhaoz95
Copy link

Hi @mattdangerw, the tf lite conversation issue I ran into in #1090 seems related to the x.numpy() call in normalize() when is_string is false.

Do you think they are related?

If so, do you have any suggestions on what tensor spec to use when converting it to tf.function to avoid this error?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants