Skip to content

Commit

Permalink
Standalone functions for generate pre/post processing for GPT-2 (#998)
Browse files Browse the repository at this point in the history
* Standalone functions for generate pre/post processing

This decomposes generate in the way we discussed last week, with the
goal of leaving the top-level functionality untouched, but allowing
a more a granular way to access the preprocessing, postprocessing,
and inner dense generation function. Colab
[HERE](https://colab.research.google.com/gist/mattdangerw/bb1ef01c1b67255def4a6ad9429de2df/split-preprocessing-demo.ipynb)

Other than moving things around in the refactor, there is one major
change we need to do here, which is the inner, compiled generate
function must also return a padding mask of token ids that were updated.
Without this padding mask, the postprocessor would not know where to
truncate output before detokenization.

To accommodate this I made `generate_function` inputs and outputs a dict
with keys "token_ids" and "padding_mask". I actually find this fairly
intuitive, with this change `generate_function` has the same inputs and
outputs as directly calling the model!

```python
generate_function = causal_lm.make_generate_function()
generate_function({
   "token_ids":    [[1, 2, 3, 4, 0, 0, 0, 0]],
   "padding_mask": [[1, 1, 1, 1, 0, 0, 0, 0]],
})
>>> {
   "token_ids":    [[1, 2, 3, 4, 5, 6, 7, 8]],
   "padding_mask": [[1, 1, 1, 1, 1, 1, 1, 1]],
}
generate_function({
   "token_ids":    [[1, 2, 3, 4, 0, 0, 0, 0]],
   "padding_mask": [[1, 1, 1, 1, 0, 0, 0, 0]],
}, end_token_id=6)
>>> {
   "token_ids":    [[1, 2, 3, 4, 5, 6, 0, 0]],
   "padding_mask": [[1, 1, 1, 1, 1, 1, 0, 0]],
}
```

* More docstring updates

* Fix merge conflict
  • Loading branch information
mattdangerw authored May 3, 2023
1 parent 7cbcbed commit 4a9c758
Show file tree
Hide file tree
Showing 6 changed files with 211 additions and 132 deletions.
180 changes: 112 additions & 68 deletions keras_nlp/models/gpt2/gpt2_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from keras_nlp.utils.keras_utils import is_xla_compatible
from keras_nlp.utils.python_utils import classproperty
from keras_nlp.utils.tf_utils import tensor_to_string_list
from keras_nlp.utils.tf_utils import truncate_at_token


@keras_nlp_export("keras_nlp.models.GPT2CausalLM")
Expand All @@ -49,7 +48,7 @@ class GPT2CausalLM(Task):
default, `"top_k"` sampling will be used.
This model can optionally be configured with a `preprocessor` layer, in
which case it will automatically apply preprocessing to raw inputs during
which case it will automatically apply preprocessing to string inputs during
`fit()`, `predict()`, `evaluate()` and `generate()`. This is done by default
when creating the model with `from_preset()`.
Expand Down Expand Up @@ -306,28 +305,23 @@ def make_generate_function(self):

def generate_step(
self,
token_ids,
padding_mask,
inputs,
end_token_id=None,
):
"""A compilable generation function for a single batch of inputs.
This function represents the inner, XLA-compilable, generation function
for a single batch of inputs. It takes in a dense `tf.Tensor` of token
ids, and return a dense `tf.Tensor` of token ids, and includes no
preprocessing. This function is wrapped by the `generate()` method.
for a single batch of inputs. Inputs should have the same structure as
model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`.
Args:
token_ids: A dense int Tensor, with shape
`(batch_size, max_length)`. The user provided token ids
padded to `max_length`.
padding_mask: A dense boolean Tensor, with the same shape as
`token_ids`. Positions that are True in the `padding_mask`
are assumed to be user input and never updated.
inputs: A dictionary with two keys `"token_ids"` and
`"padding_mask"` and batched tensor values.
end_token_id: The id of the end token to stop on. If all
sequences have produced a new `end_token_id`, generation
will stop.
"""
token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"]
# Create and seed cache with a single forward pass.
hidden_states, cache = self._build_cache(token_ids)
# Compute the lengths of all user inputted tokens ids.
Expand All @@ -352,7 +346,7 @@ def next(prompt, cache, index):
cache,
)

return self._sampler(
token_ids = self._sampler(
next=next,
prompt=token_ids,
cache=cache,
Expand All @@ -362,6 +356,78 @@ def next(prompt, cache, index):
hidden_states=hidden_states,
)

# Compute an output padding mask with the token ids we updated.
if end_token_id is not None:
# Build a mask of `end_token_id` locations not in the original
# prompt (not in locations where `padding_mask` is True).
end_locations = (token_ids == end_token_id) & (~padding_mask)
end_locations = tf.cast(end_locations, tf.int32)
# Use cumsum to get ones in all locations after end_locations.
overflow = tf.math.cumsum(end_locations, exclusive=True)
# Our padding mask is the inverse of these overflow locations.
padding_mask = ~tf.cast(overflow, tf.bool)
else:
# Without early stopping, all locations will have been updated.
padding_mask = tf.ones_like(token_ids, dtype=tf.bool)
return {
"token_ids": token_ids,
"padding_mask": padding_mask,
}

def _normalize_generate_inputs(
self,
inputs,
):
"""Normalize user input to the generate function.
This function coverts all inputs to tensors, adds a batch dimension if
necessary, and returns a iterable "dataset like" object (either an
actual `tf.data.Dataset` or a list with a single batch element).
"""
input_is_scalar = False

if isinstance(inputs, tf.data.Dataset):
return inputs, input_is_scalar

if isinstance(inputs, str) or isinstance(inputs, list):
inputs = tf.convert_to_tensor(inputs)

if isinstance(inputs, tf.Tensor) and inputs.shape.rank == 0:
input_is_scalar = True
inputs = inputs[tf.newaxis]

# We avoid coverting to a dataset purely for speed, for a single batch
# of input, creating a dataset would add significant overhead.
return [inputs], input_is_scalar

def _normalize_generate_outputs(
self,
outputs,
input_is_scalar,
):
"""Normalize user output from the generate function.
This function converts all output to numpy (for integer output), or
python strings (for string output). If a batch dimension was added to
the input, it is removed from the output (so generate can be string in,
string out).
"""

def normalize(x):
x = tf.concat(x, axis=0)
x = tf.squeeze(x, 0) if input_is_scalar else x
is_string = x.dtype == tf.string
# Convert outputs to a friendly pythonic type. For numerical outputs
# that is numpy, for string outputs that is `list` and `str`.
return tensor_to_string_list(x) if is_string else x.numpy()

if isinstance(outputs[0], dict):
return {
"token_ids": normalize([x["token_ids"] for x in outputs]),
"padding_mask": normalize([x["padding_mask"] for x in outputs]),
}
return normalize([x for x in outputs])

def generate(
self,
inputs,
Expand Down Expand Up @@ -397,65 +463,43 @@ def generate(
A string or string list if `preprocessor` is set, and a integer
tensor of token IDs if `preprocessor is None`.
"""
input_is_scalar = False

# Setup our three main passes.
# 1. Optionally preprocessing strings to dense integer tensors.
# 2. Generate new tokens via a compiled function on dense tensors.
# 3. Optionally postprocess dense integer tensors back to string.
generate_function = self.make_generate_function()
end_token_id = None
if self.preprocessor is not None:
end_token_id = self.preprocessor.tokenizer.end_token_id

def preprocess(x):
return self.preprocessor(
x,
sequence_length=max_length,
return_labels=False,
# We do not append an end token by default during
# generation, as generating directly in the same sequence is
# the most common workflow. If an end token directly after
# a prompt is desired, it can be added to the prompt string.
add_end_token=False,
)

if not isinstance(inputs, tf.data.Dataset):
inputs = tf.convert_to_tensor(inputs)
input_is_scalar = inputs.shape.rank == 0
inputs = inputs[tf.newaxis] if input_is_scalar else inputs
# Wrap a list to avoid the overhead of converting to dataset.
inputs = [preprocess(inputs)]
else:
def preprocess(x):
return self.preprocessor.generate_preprocess(
x, sequence_length=max_length
)

def generate(x):
return generate_function(x, end_token_id=end_token_id)

def postprocess(x):
return self.preprocessor.generate_postprocess(x)

# Normalize inputs, apply our three passes, and normalize outputs.
inputs, input_is_scalar = self._normalize_generate_inputs(inputs)

if self.preprocessor is not None:
if isinstance(inputs, tf.data.Dataset):
inputs = inputs.map(preprocess, tf.data.AUTOTUNE)
inputs = inputs.prefetch(tf.data.AUTOTUNE)
else:
if not isinstance(inputs, tf.data.Dataset):
# Wrap a list to avoid the overhead of converting to dataset.
inputs = [inputs]
else:
# Fast path for non-dataset, single-batch input.
inputs = [preprocess(x) for x in inputs]

generate_function = self.make_generate_function()
outputs = []
for batch in inputs:
token_ids, padding_mask = batch["token_ids"], batch["padding_mask"]
# If `preprocessor` is attached, we can stop after end_token_id.
end_token_id = None
if self.preprocessor is not None:
end_token_id = self.preprocessor.tokenizer.end_token_id
# Run the compiled generate function.
output = generate_function(token_ids, padding_mask, end_token_id)

if self.preprocessor is not None:
# Truncate to ragged by removing tokens after the first
# generated `end_token_id`.
output = truncate_at_token(output, end_token_id, padding_mask)
# Strip start token if added.
if self.preprocessor.add_start_token:
output = output[:, 1:]
# Detokenize.
output = self.preprocessor.tokenizer.detokenize(output)
outputs.append(output)

outputs = tf.concat(outputs, axis=0)
outputs = tf.squeeze(outputs, 0) if input_is_scalar else outputs
# Convert outputs to a friendly pythonic type. For numerical outputs
# that is numpy, for string outputs that is `list` and `str`.
if outputs.dtype == tf.string:
return tensor_to_string_list(outputs)
return outputs.numpy()
outputs = [generate(x) for x in inputs]

if self.preprocessor is not None:
outputs = [postprocess(x) for x in outputs]

return self._normalize_generate_outputs(outputs, input_is_scalar)

@classmethod
def create_layout_map(cls, mesh):
Expand Down
106 changes: 73 additions & 33 deletions keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,32 @@

"""GPT2 Causal LM preprocessor layer."""

import tensorflow as tf
from absl import logging

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.models.gpt2.gpt2_preprocessor import GPT2Preprocessor
from keras_nlp.utils.keras_utils import (
convert_inputs_to_list_of_tensor_segments,
)
from keras_nlp.utils.keras_utils import pack_x_y_sample_weight


@keras_nlp_export("keras_nlp.models.GPT2CausalLMPreprocessor")
class GPT2CausalLMPreprocessor(GPT2Preprocessor):
"""GPT2 Causal LM preprocessor.
This preprocessing layer is primarily meant to be used with
This preprocessing layer is meant for use with
`keras_nlp.models.GPT2CausalLM`. By default, it will take in batches of
strings, and return outputs in a `(x, y, sample_weight)` format, where the
`y` label is the next token id in the `x` sequence. For use with generation,
pass `return_labels=False`, in which case the output will simply be the
encoded string features.
`y` label is the next token id in the `x` sequence.
For use with generation, the layer also exposes two methods
`generate_preprocess()` and `generate_postprocess()`. When this preprocessor
is attached to a `keras_nlp.models.GPT2CausalLM` instance, these methods
will be called implicitly in `generate()`. They can also be called
standalone (e.g. to precompute preprocessing inputs for generation in a
separate process).
Args:
tokenizer: A `keras_nlp.models.GPT2Tokenizer` instance.
Expand All @@ -47,12 +56,6 @@ class GPT2CausalLMPreprocessor(GPT2Preprocessor):
generates label weights.
sequence_length: Pass to override the configured `sequence_length` of
the layer.
add_start_token: Pass to override the configured value of
`add_start_token` on the layer.
add_end_token: Pass to override the configured value of
`add_end_token` on the layer.
return_labels: If `True`, the output `"token_ids"` will be offset by one
and returned as labels. If `False` only features will be returned.
Examples:
```python
Expand Down Expand Up @@ -95,9 +98,6 @@ def call(
y=None,
sample_weight=None,
sequence_length=None,
add_start_token=None,
add_end_token=None,
return_labels=True,
):
if y is not None or sample_weight is not None:
logging.warning(
Expand All @@ -106,25 +106,65 @@ def call(
"or `sample_weight`. Your `y` and `sample_weight` will be "
"ignored."
)
if return_labels:
# Tokenize with one extra token to account for the truncation below.
sequence_length = (sequence_length or self.sequence_length) + 1
x = super().call(
sequence_length = sequence_length or self.sequence_length

x = convert_inputs_to_list_of_tensor_segments(x)[0]
x = self.tokenizer(x)
# Pad with one extra token to account for the truncation below.
token_ids, padding_mask = self.packer(
x,
sequence_length=sequence_length,
add_start_token=add_start_token,
add_end_token=add_end_token,
sequence_length=sequence_length + 1,
add_start_value=self.add_start_token,
add_end_value=self.add_end_token,
)
if return_labels:
token_ids, padding_mask = x["token_ids"], x["padding_mask"]
# The last token does not have a next token, so we truncate it out.
x = {
"token_ids": token_ids[..., :-1],
"padding_mask": padding_mask[..., :-1],
}
# Target `y` will be the next token.
y = token_ids[..., 1:]
sample_weight = padding_mask[..., 1:]
return pack_x_y_sample_weight(x, y, sample_weight)
else:
return x
# The last token does not have a next token, so we truncate it out.
x = {
"token_ids": token_ids[..., :-1],
"padding_mask": padding_mask[..., :-1],
}
# Target `y` will be the next token.
y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:]
return pack_x_y_sample_weight(x, y, sample_weight)

def generate_preprocess(
self,
x,
sequence_length=None,
):
"""Covert strings to integer token input for generation.
Similar to calling the layer for training, this method takes in strings
or tensor strings, tokenizes and packs the input, and computes a padding
mask masking all inputs not filled in with a padded value.
Unlike calling the the layer for training, this method does not compute
labels and will never append a `tokenizer.end_token_id` to the end of
the sequence (as generation is expected to continue at the end of the
inputted prompt).
"""
x = convert_inputs_to_list_of_tensor_segments(x)[0]
x = self.tokenizer(x)
token_ids, padding_mask = self.packer(
x, sequence_length=sequence_length, add_end_value=False
)
return {
"token_ids": token_ids,
"padding_mask": padding_mask,
}

def generate_postprocess(
self,
x,
):
"""Covert integer token output to strings for generation.
This method reverses `generate_preprocess()`, by first removing all
padding and start/end tokens, and then converting the interger sequence
back to a string.
"""
token_ids, padding_mask = x["token_ids"], x["padding_mask"]
# Strip any special tokens during detokenization (e.g. the start and
# end markers). In the future we could make this configurable.
padding_mask = padding_mask & (token_ids != self.tokenizer.end_token_id)
token_ids = tf.ragged.boolean_mask(token_ids, padding_mask)
return self.tokenizer.detokenize(token_ids)
Loading

0 comments on commit 4a9c758

Please sign in to comment.