Skip to content

Commit

Permalink
GPT2 Text Generation APIs (#592)
Browse files Browse the repository at this point in the history
* initial commit

* Add keras_nlp.samplers

* Change padding to left to right

* more samplers

* Add GPT2 text generation stuff

* correct top-p and beam sampler

* initial commit

* Add keras_nlp.samplers

* Change padding to left to right

* Add serialization support, and move some args from constructor to call

* Add string example

* small changes

* Address comments: fix docstring, remove multicase support

* Address comments: move token_probability_fn to the second place

* some initials

* add more sampler class, and a few changes on the base sampler class

* dummy

* add some arg defaults

* small fix

* fix docstring

* some changes

* add classes

* fix serialization

* fix docstring

* address comments

* one more

* fix docstring

* minor fix
  • Loading branch information
chenmoneygithub authored Jan 27, 2023
1 parent 9c5f850 commit 5737da9
Show file tree
Hide file tree
Showing 8 changed files with 931 additions and 2 deletions.
7 changes: 7 additions & 0 deletions keras_nlp/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@
from keras_nlp.models.f_net.f_net_backbone import FNetBackbone
from keras_nlp.models.f_net.f_net_preprocessor import FNetPreprocessor
from keras_nlp.models.f_net.f_net_tokenizer import FNetTokenizer
from keras_nlp.models.gpt2.gpt2_backbone import GPT2Backbone
from keras_nlp.models.gpt2.gpt2_causal_lm import GPT2CausalLM
from keras_nlp.models.gpt2.gpt2_causal_lm_preprocessor import (
GPT2CausalLMPreprocessor,
)
from keras_nlp.models.gpt2.gpt2_preprocessor import GPT2Preprocessor
from keras_nlp.models.gpt2.gpt2_tokenizer import GPT2Tokenizer
from keras_nlp.models.roberta.roberta_backbone import RobertaBackbone
from keras_nlp.models.roberta.roberta_classifier import RobertaClassifier
from keras_nlp.models.roberta.roberta_preprocessor import RobertaPreprocessor
Expand Down
239 changes: 239 additions & 0 deletions keras_nlp/models/gpt2/gpt2_causal_lm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
# Copyright 2022 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GPT2 Causal LM (Language Model)."""

import copy

import tensorflow as tf
from tensorflow import keras

import keras_nlp
from keras_nlp.models.gpt2.gpt2_backbone import GPT2Backbone
from keras_nlp.models.gpt2.gpt2_causal_lm_preprocessor import (
GPT2CausalLMPreprocessor,
)
from keras_nlp.models.gpt2.gpt2_presets import backbone_presets
from keras_nlp.models.task import Task
from keras_nlp.utils.python_utils import classproperty


@keras.utils.register_keras_serializable(package="keras_nlp")
class GPT2CausalLM(Task):
"""An end-to-end GPT2 model for causal langauge modeling.
A causal language model (LM) predicts the next token based on previous
tokens the next token based on previous tokens, which is the way GPT2 gets
pretrained. You can finetune `GPT2CausalLM` to generate text similar to
the custom dataset. `GPT2CausalLM` also has a method `generate()`, which
generates text based on given prompt.
This model can optionally be configured with a `preprocessor` layer, in
which case it will automatically apply preprocessing to raw inputs during
`fit()`, `predict()`, and `evaluate()`. This is done by default when
creating the model with `from_preset()`.
Disclaimer: Pre-trained models are provided on an "as is" basis, without
warranties or conditions of any kind. The underlying model is provided by a
third party and subject to a separate license, available
[here](https://github.com/openai/gpt-2).
Args:
backbone: A `keras_nlp.models.GPT2Backbone` instance.
preprocessor: A `keras_nlp.models.GPT2CausalLMPreprocessor` or `None`.
If `None`, this model will not apply preprocessing, and inputs
should be preprocessed before calling the model.
Examples:
Use `generate()` method to do text generation.
```python
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")
gpt2_lm.generate("I want to say", max_length=30)
# Generate with batched prompts.
gpt2_lm.generate(["This is a", "Where are you"], max_length=30)
```
Use a custom sampler for text generation.
```python
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")
# Use string identifier to set sampler.
gpt2_lm.generate("I want to say", max_length=30, sampler="top_p")
# Construct a sampler instance.
sampler = keras_nlp.samplers.BeamSampler(num_beams=2)
gpt2_lm.generate("I want to say", max_length=30, sampler=sampler)
```
Map raw string to languages model logit predictions.
```python
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")
gpt2_lm.predict(["You know this is just a test string"])
```
Load a pretrained GPT2 and fit on a string dataset.
```python
features = [
"I don't listen to music while coding.",
"But I watch youtube while coding!",
]
ds = tf.data.Dataset.from_tensor_slices(features)
# Create a `GPT2CausalLM` and fit your data.
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset(
"gpt2_base_en",
)
gpt2_lm.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
)
gpt2_lm.fit(ds, batch_size=2)
```
Load a pretrained `GPT2CausalLM` with custom preprocessor, and predict on
string inputs.
```python
# Use a shorter sequence length.
preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset(
"gpt2_base_en",
sequence_length=128,
)
# Create a `GPT2CausalLM`, using pretrained GPT2 and custom preprocessor.
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset(
"gpt2_base_en",
preprocessor=preprocessor,
)
gpt2_lm.predict(["You know this is still a test string"])
```
Fit your preprocessed data with randomly initialized GPT2. This is useful
when you want to do data preprocessing inside `tf.data` pipeline.
```python
# Define preprocessed input.
features = {
"token_ids": tf.constant(
[[1, 2, 3, 4, 0, 0]] * 2, shape=(2, 6)
),
"padding_mask": tf.constant(
[[1, 1, 1, 1, 0, 0]] * 2, shape=(2, 6)
),
}
labels = tf.constant(
[[2, 3, 4, 0, 0, 0]] * 2, shape=(2, 6)
)
sample_weight = tf.constant(
[[1, 1, 1, 0, 0, 0]] * 2, shape=(2, 6)
)
# Randomly initialize a GPT2 backbone.
backbone = keras_nlp.models.GPT2Backbone(
vocabulary_size=50257,
num_layers=2,
num_heads=2,
hidden_dim=128,
intermediate_dim=256,
max_sequence_length=128,
)
# Create a `GPT2CausalLM` without preprocessor and fit the data.
gpt2_lm = keras_nlp.models.GPT2CausalLM(backbone, preprocessor=None)
gpt2_lm.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
)
gpt2_lm.fit(
x=features,
y=labels,
sample_weight=sample_weight,
batch_size=2,
)
```
"""

def __init__(self, backbone, preprocessor=None, **kwargs):
inputs = backbone.input
x = backbone(inputs)
# Use token embedding weights to project from the token representation
# to vocabulary logits.
outputs = tf.matmul(
x,
backbone.token_embedding.embeddings,
transpose_b=True,
)

# Instantiate using Functional API Model constructor.
super().__init__(
inputs=inputs,
outputs=outputs,
include_preprocessing=preprocessor is not None,
**kwargs,
)

self._backbone = backbone
self._preprocessor = preprocessor

@classproperty
def presets(cls):
return copy.deepcopy(backbone_presets)

@classproperty
def backbone_cls(cls):
return GPT2Backbone

@classproperty
def preprocessor_cls(cls):
return GPT2CausalLMPreprocessor

def _get_token_probability(self, prompt, mask):
model_inputs = {
"token_ids": prompt,
"padding_mask": mask,
}
return self(model_inputs)

def generate(
self,
prompt,
max_length,
sampler="top_k",
):
"""Generate text.
This method generates text based on given `prompt`. Generation will
continue until `max_length` is met, and all tokens generated after
`end_token` will be truncated. The sampling approach used can be
controlled via the sampler argument.
Args:
prompt: a string, string Tensor or string RaggedTensor. The prompt
text for generation.
max_length: int. The max length of generated sequence.
sampler: a string or `keras_nlp.samplers.Sampler` instance. The
sampler to be used for text generation.
"""
end_token_id = self.preprocessor.tokenizer.end_token_id

sampler = keras_nlp.samplers.get(sampler)
if hasattr(self, "jit_compile"):
# `jit_compile` is a public property as of tf 2.12. hasattr is for
# backward compat.
sampler.jit_compile = self.jit_compile
sampler.run_eagerly = self.run_eagerly
generated = sampler(
self.preprocessor.tokenizer(prompt),
self._get_token_probability,
max_length=max_length,
end_token_id=end_token_id,
)
return self.preprocessor.tokenizer.detokenize(generated)
96 changes: 96 additions & 0 deletions keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""GPT2 Causal LM preprocessor layer."""

from absl import logging
from tensorflow import keras

from keras_nlp.models.gpt2.gpt2_preprocessor import GPT2Preprocessor
from keras_nlp.utils.keras_utils import pack_x_y_sample_weight


@keras.utils.register_keras_serializable(package="keras_nlp")
class GPT2CausalLMPreprocessor(GPT2Preprocessor):
"""GPT2 Causal LM preprocessor.
This preprocessor is majorly used as the preprocesor for `GPT2CausalLM`.
This class subclasses `keras_nlp.models.GPT2Preprocessor` and keeps most of
its functionality. The only change is `GPT2CausalLMPreprocessor` sets
`y` (label) and `sample_weights` field by shifting the input sequence one
step towards left, and drop the last token as it does not have a successor,
e.g., if the tokenized input is `[1, 2, 3, 0, 0]` with
`padding_mask = [1, 1, 1, 0, 0]`, then after preprocessing, we
will have `x = [1, 2, 3, 0]` and `y = [2, 3, 0, 0]`, with
`padding_mask = [1, 1, 1, 0]` and `sample_weights = [1, 1, 0, 0]`.
Args:
tokenizer: A `keras_nlp.models.GPT2Tokenizer` instance.
sequence_length: The length of the packed inputs.
Examples:
```python
# Load the preprocessor from a preset.
preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset(
"gpt2_base_en"
)
# Tokenize and pack a single sentence.
sentence = tf.constant("League of legends")
preprocessor(sentence)
# Same output.
preprocessor("League of legends")
# Tokenize a batch of sentences.
sentences = tf.constant(["Taco tuesday", "Fish taco please!"])
preprocessor(sentences)
# Same output.
preprocessor(["Taco tuesday", "Fish taco please!"])
# Map a dataset to preprocess a single sentence.
features = tf.constant(
[
"Avatar 2 is amazing!",
"Well, I am not sure.",
]
)
labels = tf.constant([1, 0])
ds = tf.data.Dataset.from_tensor_slices((features, labels))
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
# Map a dataset to preprocess unlabled sentences.
ds = tf.data.Dataset.from_tensor_slices(features)
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
"""

def call(self, x, y=None, sample_weight=None):
if y is not None or sample_weight is not None:
logging.warning(
"`GPT2CausalLMPreprocessor` generates `y` and `sample_weight` "
"based on your input data, but your data already contains `y` "
"or `sample_weight`. Your `y` and `sample_weight` will be "
"ignored."
)

x = super().call(x)
token_ids, padding_mask = x["token_ids"], x["padding_mask"]
# The last token does not have a next token, so we truncate it out.
x = {
"token_ids": token_ids[..., :-1],
"padding_mask": padding_mask[..., :-1],
}
# Target `y` will be the next token.
y = token_ids[..., 1:]
sample_weight = padding_mask[..., 1:]
return pack_x_y_sample_weight(x, y, sample_weight)
Loading

0 comments on commit 5737da9

Please sign in to comment.