From 675ca2de21b68dc62e4909c80a2cd57d8ee8b601 Mon Sep 17 00:00:00 2001 From: Bin Du Date: Fri, 10 May 2024 14:56:23 -0700 Subject: [PATCH] Add the helper function in Keras setup to clean up special tokens in GTP2 and SentencePiece tokenizer. PiperOrigin-RevId: 632611251 --- .../examples/models/instrumented_keras_lms.py | 33 +++++++++++++++---- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/lit_nlp/examples/models/instrumented_keras_lms.py b/lit_nlp/examples/models/instrumented_keras_lms.py index f9b4cbcd..b3bdb41c 100644 --- a/lit_nlp/examples/models/instrumented_keras_lms.py +++ b/lit_nlp/examples/models/instrumented_keras_lms.py @@ -10,7 +10,6 @@ from lit_nlp.api import model as lit_model from lit_nlp.api import types as lit_types from lit_nlp.lib import utils as lit_utils -import numpy as np # pylint: disable=g-import-not-at-top @@ -87,10 +86,6 @@ def __init__( self.max_length = max_length self.dynamic_sequence_length = dynamic_sequence_length - self.ids_to_tokens = np.vectorize( - self.model.preprocessor.tokenizer.id_to_token - ) - # map ids: [batch_size, num_tokens] # to embs: [batch_size, num_tokens, emb_dim] self.embedder = self.model.backbone.token_embedding @@ -141,6 +136,30 @@ def encode_inputs(self, texts: Sequence[str]): # Actually trim the input tensors. return {k: v[:, :longest_sequence] for k, v in encoded_inputs.items()} + def clean_subword_token(self, tok: str) -> str: + """Clean up special subword token from the tokenizers if necessary. + + Args: + tok: the token to clean up. + Returns: + The replaced token if the provided token matches the special subword token + below; otherwise, the original token is returned. + """ + # For GPT2 tokenizer. + tok = tok.replace("Ċ", "\n") # newlines + tok = tok.replace("Ġ", "▁") # start of word -> magic underscore + # For SentencePiece Tokenizer. + tok = tok.replace("<0x0A>", "\n") # newlines + return tok + + def ids_to_clean_tokens(self, ids: Sequence[int]) -> Sequence[str]: + return [ + self.clean_subword_token( + self.model.preprocessor.tokenizer.id_to_token(id) + ) + for id in ids + ] + @classmethod def from_loaded(cls, existing: "_KerasBaseModel", *args, **kw): """Share weights and underlying Keras model with another instance.""" @@ -419,7 +438,7 @@ def _postprocess(self, preds): """Post-process single-example preds. Operates on numpy arrays.""" mask = preds.pop("padding_mask").astype(bool) ids = preds.pop("input_ids")[mask] - preds[FieldNames.TOKENS] = self.ids_to_tokens(ids) + preds[FieldNames.TOKENS] = self.ids_to_clean_tokens(ids) for key in lit_utils.find_spec_keys( self.output_spec(), lit_types.TokenScores ): @@ -479,7 +498,7 @@ def _postprocess(self, preds): # rather than acting as a boolean mask. mask = preds.pop("padding_mask").astype(bool) ids = preds.pop("token_ids")[mask] - preds[FieldNames.TOKENS] = self.ids_to_tokens(ids) + preds[FieldNames.TOKENS] = self.ids_to_clean_tokens(ids) return preds def predict_minibatch(self, inputs):