Skip to content

Commit

Permalink
Add the helper function in Keras setup to clean up special tokens in …
Browse files Browse the repository at this point in the history
…GTP2 and SentencePiece tokenizer.

PiperOrigin-RevId: 632611251
  • Loading branch information
bdu91 authored and LIT team committed May 10, 2024
1 parent 4fb3bde commit 675ca2d
Showing 1 changed file with 26 additions and 7 deletions.
33 changes: 26 additions & 7 deletions lit_nlp/examples/models/instrumented_keras_lms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from lit_nlp.api import model as lit_model
from lit_nlp.api import types as lit_types
from lit_nlp.lib import utils as lit_utils
import numpy as np


# pylint: disable=g-import-not-at-top
Expand Down Expand Up @@ -87,10 +86,6 @@ def __init__(
self.max_length = max_length
self.dynamic_sequence_length = dynamic_sequence_length

self.ids_to_tokens = np.vectorize(
self.model.preprocessor.tokenizer.id_to_token
)

# map ids: <tf.int>[batch_size, num_tokens]
# to embs: <tf.float>[batch_size, num_tokens, emb_dim]
self.embedder = self.model.backbone.token_embedding
Expand Down Expand Up @@ -141,6 +136,30 @@ def encode_inputs(self, texts: Sequence[str]):
# Actually trim the input tensors.
return {k: v[:, :longest_sequence] for k, v in encoded_inputs.items()}

def clean_subword_token(self, tok: str) -> str:
"""Clean up special subword token from the tokenizers if necessary.
Args:
tok: the token to clean up.
Returns:
The replaced token if the provided token matches the special subword token
below; otherwise, the original token is returned.
"""
# For GPT2 tokenizer.
tok = tok.replace("Ċ", "\n") # newlines
tok = tok.replace("Ġ", "▁") # start of word -> magic underscore
# For SentencePiece Tokenizer.
tok = tok.replace("<0x0A>", "\n") # newlines
return tok

def ids_to_clean_tokens(self, ids: Sequence[int]) -> Sequence[str]:
return [
self.clean_subword_token(
self.model.preprocessor.tokenizer.id_to_token(id)
)
for id in ids
]

@classmethod
def from_loaded(cls, existing: "_KerasBaseModel", *args, **kw):
"""Share weights and underlying Keras model with another instance."""
Expand Down Expand Up @@ -419,7 +438,7 @@ def _postprocess(self, preds):
"""Post-process single-example preds. Operates on numpy arrays."""
mask = preds.pop("padding_mask").astype(bool)
ids = preds.pop("input_ids")[mask]
preds[FieldNames.TOKENS] = self.ids_to_tokens(ids)
preds[FieldNames.TOKENS] = self.ids_to_clean_tokens(ids)
for key in lit_utils.find_spec_keys(
self.output_spec(), lit_types.TokenScores
):
Expand Down Expand Up @@ -479,7 +498,7 @@ def _postprocess(self, preds):
# rather than acting as a boolean mask.
mask = preds.pop("padding_mask").astype(bool)
ids = preds.pop("token_ids")[mask]
preds[FieldNames.TOKENS] = self.ids_to_tokens(ids)
preds[FieldNames.TOKENS] = self.ids_to_clean_tokens(ids)
return preds

def predict_minibatch(self, inputs):
Expand Down

0 comments on commit 675ca2d

Please sign in to comment.