-
Notifications
You must be signed in to change notification settings - Fork 251
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Standalone functions for generate pre/post processing
This decomposes generate in the way we discussed last week, with the goal of leaving the top-level functionality untouched, but allowing a more a granular way to access the preprocessing, postprocessing, and inner dense generation function. Colab [HERE](https://colab.research.google.com/gist/mattdangerw/bb1ef01c1b67255def4a6ad9429de2df/split-preprocessing-demo.ipynb) Other than moving things around in the refactor, there is one major change we need to do here, which is the inner, compiled generate function must also return a padding mask of token ids that were updated. Without this padding mask, the postprocessor would not know where to truncate output before detokenization. To accommodate this I made `generate_function` inputs and outputs a dict with keys "token_ids" and "padding_mask". I actually find this fairly intuitive, with this change `generate_function` has the same inputs and outputs as directly calling the model! ```python generate_function = causal_lm.make_generate_function() generate_function({ "token_ids": [[1, 2, 3, 4, 0, 0, 0, 0]], "padding_mask": [[1, 1, 1, 1, 0, 0, 0, 0]], }) >>> { "token_ids": [[1, 2, 3, 4, 5, 6, 7, 8]], "padding_mask": [[1, 1, 1, 1, 1, 1, 1, 1]], } generate_function({ "token_ids": [[1, 2, 3, 4, 0, 0, 0, 0]], "padding_mask": [[1, 1, 1, 1, 0, 0, 0, 0]], }, end_token_id=6) >>> { "token_ids": [[1, 2, 3, 4, 5, 6, 0, 0]], "padding_mask": [[1, 1, 1, 1, 1, 1, 0, 0]], } ```
- Loading branch information
1 parent
c9b8934
commit 4c0690c
Showing
6 changed files
with
193 additions
and
132 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.