Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BytePair Tokenizer Implementation #303

Closed

Conversation

jessechancy
Copy link
Contributor

An implementation of OpenAI's BytePair encoder in TF compatible graph mode, which would allow for the e2e development of certain pretrained models that use this tokenizer (RoBERTa, GPT etc.).

Currently a rough version, and suggestions are welcome!

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just some quick initial comments.

Have we validated that on a number of different languages (and weirder unicode like control characters, emojis, etc)? We should make sure this is not just equivalent in ascii land.


class BytePairTokenizerCache:
def __init__(self):
self.key2id = tf.lookup.experimental.DenseHashTable(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why DenseHashTable for one and MutableHashTable for the other? Does this make a performance difference?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is mainly due to a limitation with these experimental hashtables. DenseHashTable is more efficient but it can only map string to int and not vice versa. Similarly, a reason why we needed two hashtables was due to the limitation we cannot have a string to string mapping.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it significantly degrade performance to use MutableHashTable in all cases? That might reduce our coverage on experimental features, and IIRC we might have some bugs to remove DenseHashTable in the future of tf.

return bs, cs # int to string mapping


class BytePairTokenizerCache:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels weird to have a cache with no size limit. That's not actually a cahce.

Would this become a complete memory hog on a sufficiently large vocabulary? Does that come up in practice?

I think the lru_cache gpt2 uses does have a max size by default.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The lru_cache, at least in this openai implementation, is only for the byte2unicode mapping, which would be a fixed size. The python dictionary for the cache is unbounded. We could add a limit to the cache if that is better, but to have an lru_cache would require reimplementing something like a MutableHashTable or DenseHashTable.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohh got it, thanks! Yeah if the original implementation also has an unbounded cache, that makes me more comfortable.

return bs, cs # int to string mapping


class BytePairTokenizerCache:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohh got it, thanks! Yeah if the original implementation also has an unbounded cache, that makes me more comfortable.

from keras_nlp.tokenizers.byte_pair_tokenizer import BytePairTokenizer


class BytePairTokenizerTest(tf.test.TestCase):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a little more coverage, for some key use cases

  • A batched tf dataset
  • A unbatched tf dataset
  • A function annotated with @tf.function
  • Some more complex unicode character cases (maybe we can validate these first with the original tokenizer impl to make sure we have it right)


class BytePairTokenizerCache:
def __init__(self):
self.key2id = tf.lookup.experimental.DenseHashTable(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it significantly degrade performance to use MutableHashTable in all cases? That might reduce our coverage on experimental features, and IIRC we might have some bugs to remove DenseHashTable in the future of tf.

@mattdangerw
Copy link
Member

Forgot to leave a comment here, but from conversations with @jessechancy

For the cache, we want a way to go from a string word input, to a token list output. We are currently doing that with two different hashtables, to go from string -> int, int -> string, to workaround an issue where tf does not offer a string -> string lookup.

I think we could have a slightly simpler workaround, where we hash the input string, and then do a int -> string lookup to get the tokenized form. That should save us one of these hash tables, which should make things simpler, faster and lower memory usage.

@mattdangerw
Copy link
Member

Another issue we need to work though is that python regex and tf regex appear to handle certain whitespace characters--non breaking spaces. We need to fix this, probably with some regex hacking.

Copy link
Contributor

@jbischof jbischof left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very impressive! Some minor stylistic comments.

Right now this code is a bit dense and lightly documented. Since we're doing a lot of heavy lifting here it would be nice to lay out the steps or organize the code in a way that someone could understand the gist without reading every line.



def create_static_hashtable(keys, values, default):
hashtable = tf.lookup.StaticHashTable(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return on this line

merges,
sequence_length: int = None,
**kwargs,
) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typehints seem to be unpopular in Keras, so drop for consistency

return hashtable


class BytePairTokenizer(tokenizer.Tokenizer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need a docstring


@tf.function
def _byte_pair_merge_loop_body(self, words, mask):
"""Iterative merging process for byte pair encoding algorithm."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Empty line after docstring

@chenmoneygithub
Copy link
Contributor

@mattdangerw Do we still have unresolved issues on functionality of this implementation? I played around with it a bit more and the functionality looks correct to me (compared with RoBERTa's tokenizer). I remember you mentioned there are some special token handled differently?

One concern I have is the implementation is quite complex, but after going through the code I don't know where we can really simplify it. Compared to the fairseq implementation, the additional complexity mainly comes from tensor manipulation when we do merge/check/etc, and how hash and while-loop are handled. We have two approaches here:

  1. Just use the python implementation and disregard the graph support, which seems to going against our design pattern.
  2. Use Jesse's implementation, which is with our design pattern, but the code is quite heavy.

@mattdangerw
Copy link
Member

@chenmoneygithub left a few brief comments above in that regard.

The issue with different output is apparently with non-breaking space characters.

And there are some things I would like to try re simplifying, in particular removing one of the hash tables by using a hash function (#303 (comment)).

Overall I think we should probably go with this and not the python implementation. IIUC the py_function escape hatch to tf.data will be prohibitively slow. But we do need some digging and cleanup on this PR before we can land it.

@chenmoneygithub
Copy link
Contributor

@mattdangerw Simply wrapping by py_function has tons of runtime errors, the alternative to this PR is not supporting tf.data pipeline, e.g., HuggingFace Roberta TF model, which actually won't cause performance loss based on what Jesse told me earlier, but the downside is it breaks our contract of supporting tf.data. My thought is the same tho, we should check in this implementation, just to clarify here.

Do you remember what output diff you saw earlier? Jesse found one minor diff before his presentation, but he said he had fixed it.

@abheesht17
Copy link
Collaborator

@mattdangerw, @chenmoneygithub - a minor comment here. The merges.txt file present in the official repo (and the HF repo - the HF repo has the same file as the one present in the official repo) has #version: ... on the first line: https://huggingface.co/gpt2/blob/main/merges.txt. I have copied over the same merges.txt file to the GCP folder.

So, we should probably ignore the first line of the text file after https://github.com/keras-team/keras-nlp/pull/303/files#diff-9f7f9b8a01fa1e5d050c27b1dcfdb801ec24d46f724a2eef011c8f86c9bed53aR142, right?

@chenmoneygithub
Copy link
Contributor

@abheesht17 Thanks for raising it! It's actually fine, because #version: 0.2 won't be performed as a merge rule, because it requires to see a token #version:0.2, but it should be split in the pre-split phase.

@mattdangerw
Copy link
Member

Closing, this was landed with some edits on #389 by @chenmoneygithub. But huge props on writing this, this is a big deal for the library!!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants