Skip to content

Commit

Permalink
Standalone functions for generate pre/post processing
Browse files Browse the repository at this point in the history
This decomposes generate in the way we discussed last week, with the
goal of leaving the top-level functionality untouched, but allowing
a more a granular way to access the preprocessing, postprocessing,
and inner dense generation function. Colab
[HERE](https://colab.research.google.com/gist/mattdangerw/bb1ef01c1b67255def4a6ad9429de2df/split-preprocessing-demo.ipynb)

Other than moving things around in the refactor, there is one major
change we need to do here, which is the inner, compiled generate
function must also return a padding mask of token ids that were updated.
Without this padding mask, the postprocessor would not know where to
truncate output before detokenization.

To accommodate this I made `generate_function` inputs and outputs a dict
with keys "token_ids" and "padding_mask". I actually find this fairly
intuitive, with this change `generate_function` has the same inputs and
outputs as directly calling the model!

```python
generate_function = causal_lm.make_generate_function()
generate_function({
   "token_ids":    [[1, 2, 3, 4, 0, 0, 0, 0]],
   "padding_mask": [[1, 1, 1, 1, 0, 0, 0, 0]],
})
>>> {
   "token_ids":    [[1, 2, 3, 4, 5, 6, 7, 8]],
   "padding_mask": [[1, 1, 1, 1, 1, 1, 1, 1]],
}
generate_function({
   "token_ids":    [[1, 2, 3, 4, 0, 0, 0, 0]],
   "padding_mask": [[1, 1, 1, 1, 0, 0, 0, 0]],
}, end_token_id=6)
>>> {
   "token_ids":    [[1, 2, 3, 4, 5, 6, 0, 0]],
   "padding_mask": [[1, 1, 1, 1, 1, 1, 0, 0]],
}
```
  • Loading branch information
mattdangerw committed Apr 21, 2023
1 parent edd6b6b commit 86fc49d
Show file tree
Hide file tree
Showing 6 changed files with 177 additions and 119 deletions.
157 changes: 101 additions & 56 deletions keras_nlp/models/gpt2/gpt2_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from keras_nlp.utils.keras_utils import is_xla_compatible
from keras_nlp.utils.python_utils import classproperty
from keras_nlp.utils.tf_utils import tensor_to_string_list
from keras_nlp.utils.tf_utils import truncate_at_token


@keras_nlp_export("keras_nlp.models.GPT2CausalLM")
Expand All @@ -48,7 +47,7 @@ class GPT2CausalLM(Task):
default, `"top_k"` sampling will be used.
This model can optionally be configured with a `preprocessor` layer, in
which case it will automatically apply preprocessing to raw inputs during
which case it will automatically apply preprocessing to string inputs during
`fit()`, `predict()`, `evaluate()` and `generate()`. This is done by default
when creating the model with `from_preset()`.
Expand Down Expand Up @@ -305,28 +304,23 @@ def make_generate_function(self):

def generate_step(
self,
token_ids,
padding_mask,
inputs,
end_token_id=None,
):
"""A compilable generation function for a single batch of inputs.
This function represents the inner, XLA-compilable, generation function
for a single batch of inputs. It takes in a dense `tf.Tensor` of token
ids, and return a dense `tf.Tensor` of token ids, and includes no
preprocessing. This function is wrapped by the `generate()` method.
for a single batch of inputs. Inputs should have the same structure as
model inputs, a dictionary with keys `"token_ids"` and `"padding_mask"`.
Args:
token_ids: A dense int Tensor, with shape
`(batch_size, max_length)`. The user provided token ids
padded to `max_length`.
padding_mask: A dense boolean Tensor, with the same shape as
`token_ids`. Positions that are True in the `padding_mask`
are assumed to be user input and never updated.
inputs: A dictionary with two batched tensor keys `"token_ids"`
and `"padding_mask"`.
end_token_id: The id of the end token to stop on. If all
sequences have produced a new `end_token_id`, generation
will stop.
"""
token_ids, padding_mask = inputs["token_ids"], inputs["padding_mask"]
# Create and seed cache with a single forward pass.
hidden_states, cache = self._build_cache(token_ids)
# Compute the lengths of all user inputted tokens ids.
Expand All @@ -351,7 +345,7 @@ def next(prompt, cache, index):
cache,
)

return self._sampler(
token_ids = self._sampler(
next=next,
prompt=token_ids,
cache=cache,
Expand All @@ -361,6 +355,57 @@ def next(prompt, cache, index):
hidden_states=hidden_states,
)

# Compute an output padding mask with the token ids we updated.
if end_token_id is not None:
end_locations = (token_ids == end_token_id) & (~padding_mask)
end_locations = tf.cast(end_locations, tf.int32)
overflow = tf.math.cumsum(end_locations, exclusive=True)
padding_mask = ~tf.cast(overflow, tf.bool)
else:
padding_mask = tf.ones_like(token_ids, dtype=tf.bool)
return {
"token_ids": token_ids,
"padding_mask": padding_mask,
}

def _normalize_inputs(
self,
inputs,
):
input_is_scalar = False

if isinstance(inputs, tf.data.Dataset):
return inputs, input_is_scalar

if isinstance(inputs, str) or isinstance(inputs, list):
inputs = tf.convert_to_tensor(inputs)

if isinstance(inputs, tf.Tensor) and inputs.shape.rank == 0:
input_is_scalar = True
inputs = inputs[tf.newaxis]

return [inputs], input_is_scalar

def _normalize_outputs(
self,
outputs,
input_is_scalar,
):
def normalize(x):
x = tf.concat(x, axis=0)
x = tf.squeeze(x, 0) if input_is_scalar else x
is_string = x.dtype == tf.string
# Convert outputs to a friendly pythonic type. For numerical outputs
# that is numpy, for string outputs that is `list` and `str`.
return tensor_to_string_list(x) if is_string else x.numpy()

if isinstance(outputs[0], dict):
return {
"token_ids": normalize([x["token_ids"] for x in outputs]),
"padding_mask": normalize([x["padding_mask"] for x in outputs]),
}
return normalize([x for x in outputs])

def generate(
self,
inputs,
Expand Down Expand Up @@ -396,47 +441,39 @@ def generate(
A string or string list if `preprocessor` is set, and a integer
tensor of token IDs if `preprocessor is None`.
"""
input_is_scalar = False

# Setup our three main passes.
# 1. Optionally preprocessing strings to dense integer tensors.
# 2. Generate new tokens via a compiled function on dense tensors.
# 3. Optionally postprocess dense integer tensors back to string.
generate_function = self.make_generate_function()
end_token_id = None
if self.preprocessor is not None:
end_token_id = self.preprocessor.tokenizer.end_token_id

def preprocess(x):
return self.preprocessor(
x,
sequence_length=max_length,
return_labels=False,
# We do not append an end token by default during
# generation, as generating directly in the same sequence is
# the most common workflow. If an end token directly after
# a prompt is desired, it can be added to the prompt string.
add_end_token=False,
)

if not isinstance(inputs, tf.data.Dataset):
inputs = tf.convert_to_tensor(inputs)
input_is_scalar = inputs.shape.rank == 0
inputs = inputs[tf.newaxis] if input_is_scalar else inputs
# Wrap a list to avoid the overhead of converting to dataset.
inputs = [preprocess(inputs)]
else:
def preprocess(x):
return self.preprocessor.generate_preprocess(
x, sequence_length=max_length
)

def generate(x):
return generate_function(x, end_token_id=end_token_id)

def postprocess(x):
return self.preprocessor.generate_postprocess(x)

# Normalize inputs, apply our three passes, and normalize outputs.
inputs, input_is_scalar = self._normalize_inputs(inputs)

if self.preprocessor is not None:
if isinstance(inputs, tf.data.Dataset):
inputs = inputs.map(preprocess, tf.data.AUTOTUNE)
inputs = inputs.prefetch(tf.data.AUTOTUNE)
else:
if not isinstance(inputs, tf.data.Dataset):
# Wrap a list to avoid the overhead of converting to dataset.
inputs = [inputs]
else:
inputs = [preprocess(x) for x in inputs]

generate_function = self.make_generate_function()
outputs = []
for batch in inputs:
token_ids, padding_mask = batch["token_ids"], batch["padding_mask"]
# If `preprocessor` is attached, we can stop after end_token_id.
end_token_id = None
if self.preprocessor is not None:
end_token_id = self.preprocessor.tokenizer.end_token_id
# Run the compiled generate function.
output = generate_function(token_ids, padding_mask, end_token_id)
outputs = [generate(x) for x in inputs]

<<<<<<< HEAD
if self.preprocessor is not None:
# Truncate to ragged by removing tokens after the first
# generated `end_token_id`.
Expand All @@ -447,11 +484,19 @@ def preprocess(x):
# Detokenize.
output = self.preprocessor.tokenizer.detokenize(output)
outputs.append(output)
||||||| parent of 786ad50 (Standalone functions for generate pre/post processing)
if self.preprocessor is not None:
# Truncate to ragged by removing tokens after a new end token.
output = truncate_at_token(output, end_token_id, padding_mask)
# Strip start token if added.
if self.preprocessor.add_start_token:
output = output[:, 1:]
# Detokenize.
output = self.preprocessor.tokenizer.detokenize(output)
outputs.append(output)
=======
if self.preprocessor is not None:
outputs = [postprocess(x) for x in outputs]
>>>>>>> 786ad50 (Standalone functions for generate pre/post processing)

outputs = tf.concat(outputs, axis=0)
outputs = tf.squeeze(outputs, 0) if input_is_scalar else outputs
# Convert outputs to a friendly pythonic type. For numerical outputs
# that is numpy, for string outputs that is `list` and `str`.
if outputs.dtype == tf.string:
return tensor_to_string_list(outputs)
return outputs.numpy()
return self._normalize_outputs(outputs, input_is_scalar)
82 changes: 50 additions & 32 deletions keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,27 @@

"""GPT2 Causal LM preprocessor layer."""

import tensorflow as tf
from absl import logging

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.models.gpt2.gpt2_preprocessor import GPT2Preprocessor
from keras_nlp.utils.keras_utils import (
convert_inputs_to_list_of_tensor_segments,
)
from keras_nlp.utils.keras_utils import pack_x_y_sample_weight


@keras_nlp_export("keras_nlp.models.GPT2CausalLMPreprocessor")
class GPT2CausalLMPreprocessor(GPT2Preprocessor):
"""GPT2 Causal LM preprocessor.
This preprocessing layer is primarily meant to be used with
This preprocessing layer is meant for use with
`keras_nlp.models.GPT2CausalLM`. By default, it will take in batches of
strings, and return outputs in a `(x, y, sample_weight)` format, where the
`y` label is the next token id in the `x` sequence. For use with generation,
pass `return_labels=False` in which case the output will simply be the
encoded string features.
the layer also exposes two methods `generate_preprocess()` and
`generate_postprocess()`.
Args:
tokenizer: A `keras_nlp.models.GPT2Tokenizer` instance.
Expand All @@ -47,12 +51,6 @@ class GPT2CausalLMPreprocessor(GPT2Preprocessor):
generates label weights.
sequence_length: Pass to override the configured `sequence_length` of
the layer.
add_start_token: Pass to override the configured value of
`add_start_token` on the layer.
add_end_token: Pass to override the configured value of
`add_end_token` on the layer.
return_labels: If `True`, the output `"token_ids"` will be offset by one
and returned as labels. If `False` only features will be returned.
Examples:
```python
Expand Down Expand Up @@ -95,9 +93,6 @@ def call(
y=None,
sample_weight=None,
sequence_length=None,
add_start_token=None,
add_end_token=None,
return_labels=True,
):
if y is not None or sample_weight is not None:
logging.warning(
Expand All @@ -106,25 +101,48 @@ def call(
"or `sample_weight`. Your `y` and `sample_weight` will be "
"ignored."
)
if return_labels:
# Tokenize with one extra token to account for the truncation below.
sequence_length = (sequence_length or self.sequence_length) + 1
x = super().call(
sequence_length = sequence_length or self.sequence_length

x = convert_inputs_to_list_of_tensor_segments(x)[0]
x = self.tokenizer(x)
# Truncate with one extra token to account for the truncation below.
token_ids, padding_mask = self.packer(
x,
sequence_length=sequence_length,
add_start_token=add_start_token,
add_end_token=add_end_token,
sequence_length=sequence_length + 1,
add_start_value=self.add_start_token,
add_end_value=self.add_end_token,
)
if return_labels:
token_ids, padding_mask = x["token_ids"], x["padding_mask"]
# The last token does not have a next token, so we truncate it out.
x = {
"token_ids": token_ids[..., :-1],
"padding_mask": padding_mask[..., :-1],
}
# Target `y` will be the next token.
y = token_ids[..., 1:]
sample_weight = padding_mask[..., 1:]
return pack_x_y_sample_weight(x, y, sample_weight)
else:
return x
# The last token does not have a next token, so we truncate it out.
x = {
"token_ids": token_ids[..., :-1],
"padding_mask": padding_mask[..., :-1],
}
# Target `y` will be the next token.
y, sample_weight = token_ids[..., 1:], padding_mask[..., 1:]
return pack_x_y_sample_weight(x, y, sample_weight)

def generate_preprocess(
self,
x,
sequence_length=None,
):
x = convert_inputs_to_list_of_tensor_segments(x)[0]
x = self.tokenizer(x)
token_ids, padding_mask = self.packer(
x, sequence_length=sequence_length, add_end_value=False
)
return {
"token_ids": token_ids,
"padding_mask": padding_mask,
}

def generate_postprocess(
self,
x,
):
token_ids, padding_mask = x["token_ids"], x["padding_mask"]
# Strip any special tokens during detokenization (e.g. the start and
# end markers). In the future we could make this configurable.
padding_mask = padding_mask & (token_ids != self.tokenizer.end_token_id)
token_ids = tf.ragged.boolean_mask(token_ids, padding_mask)
return self.tokenizer.detokenize(token_ids)
19 changes: 11 additions & 8 deletions keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,16 +107,19 @@ def test_dataset(self):
self.assertAllEqual(y, [[1, 3, 4, 2, 5, 6, 0, 0]] * 4)
self.assertAllEqual(sw, [[1, 1, 1, 1, 1, 1, 0, 0]] * 4)

def test_call_overrides(self):
def test_generate_preprocess(self):
input_data = "airplane at airport"
x, _, _ = self.preprocessor(input_data, add_start_token=False)
self.assertAllEqual(x["token_ids"], [1, 3, 4, 2, 5, 6, 0, 0])
x, _, _ = self.preprocessor(input_data, add_end_token=False)
x = self.preprocessor.generate_preprocess(input_data)
self.assertAllEqual(x["token_ids"], [6, 1, 3, 4, 2, 5, 0, 0])
x, _, _ = self.preprocessor(input_data, sequence_length=4)
self.assertAllEqual(x["token_ids"], [6, 1, 3, 4])
x = self.preprocessor(input_data, return_labels=False)
self.assertAllEqual(x["token_ids"], [6, 1, 3, 4, 2, 5, 6, 0])
self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 1, 0, 0])

def test_generate_postprocess(self):
input_data = {
"token_ids": tf.constant([6, 1, 3, 4, 2, 5, 0, 0]),
"padding_mask": tf.cast([1, 1, 1, 1, 1, 1, 0, 0], dtype="bool"),
}
x = self.preprocessor.generate_postprocess(input_data)
self.assertAllEqual(x, "airplane at airport")

def test_serialization(self):
config = keras.utils.serialize_keras_object(self.preprocessor)
Expand Down
Loading

0 comments on commit 86fc49d

Please sign in to comment.