Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPT2 Text Generation APIs #592

Merged
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
7f7ae43
initial commit
chenmoneygithub Dec 9, 2022
c53b4a9
Add keras_nlp.samplers
chenmoneygithub Dec 10, 2022
e6483a4
Change padding to left to right
chenmoneygithub Dec 12, 2022
513121e
more samplers
chenmoneygithub Dec 20, 2022
0eb68f6
Add GPT2 text generation stuff
chenmoneygithub Dec 21, 2022
fa41d23
correct top-p and beam sampler
chenmoneygithub Jan 4, 2023
26fd509
initial commit
chenmoneygithub Dec 9, 2022
7e4c651
Add keras_nlp.samplers
chenmoneygithub Dec 10, 2022
28bcfe1
Change padding to left to right
chenmoneygithub Dec 12, 2022
9757f4d
Add serialization support, and move some args from constructor to call
chenmoneygithub Jan 5, 2023
f7508cb
Add string example
chenmoneygithub Jan 6, 2023
b658b61
small changes
chenmoneygithub Jan 6, 2023
76c430c
Address comments: fix docstring, remove multicase support
chenmoneygithub Jan 9, 2023
bb430dd
Address comments: move token_probability_fn to the second place
chenmoneygithub Jan 9, 2023
afd3082
some initials
chenmoneygithub Jan 10, 2023
273a6a5
Merge branch 'master' into text-generation-extend
chenmoneygithub Jan 10, 2023
31ad970
add more sampler class, and a few changes on the base sampler class
chenmoneygithub Jan 13, 2023
331f568
Merge branch 'text-generation-extend' into text-generation-playground
chenmoneygithub Jan 13, 2023
5300800
dummy
chenmoneygithub Jan 13, 2023
de2ac9c
add some arg defaults
chenmoneygithub Jan 13, 2023
42c164f
Merge branch 'text-generation-extend' into text-generation-playground
chenmoneygithub Jan 13, 2023
08f3c1e
small fix
chenmoneygithub Jan 13, 2023
2b93ad8
fix docstring
chenmoneygithub Jan 17, 2023
309d6d4
merge
chenmoneygithub Jan 18, 2023
8206103
some changes
chenmoneygithub Jan 19, 2023
9945c13
add classes
chenmoneygithub Jan 20, 2023
4fa8fc5
fix serialization
chenmoneygithub Jan 20, 2023
cb12604
fix docstring
chenmoneygithub Jan 23, 2023
f7685ca
address comments
chenmoneygithub Jan 24, 2023
3bac2ad
Merge branch 'master' into text-generation-playground
chenmoneygithub Jan 25, 2023
2ed9adb
one more
chenmoneygithub Jan 25, 2023
f2821b5
fix docstring
chenmoneygithub Jan 26, 2023
728a471
minor fix
chenmoneygithub Jan 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions keras_nlp/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@
from keras_nlp.models.f_net.f_net_backbone import FNetBackbone
from keras_nlp.models.f_net.f_net_preprocessor import FNetPreprocessor
from keras_nlp.models.f_net.f_net_tokenizer import FNetTokenizer
from keras_nlp.models.gpt2.gpt2_backbone import GPT2Backbone
from keras_nlp.models.gpt2.gpt2_causal_lm import GPT2CausalLM
from keras_nlp.models.gpt2.gpt2_causal_lm_preprocessor import (
GPT2CausalLMPreprocessor,
)
from keras_nlp.models.gpt2.gpt2_preprocessor import GPT2Preprocessor
from keras_nlp.models.gpt2.gpt2_tokenizer import GPT2Tokenizer
from keras_nlp.models.roberta.roberta_backbone import RobertaBackbone
from keras_nlp.models.roberta.roberta_classifier import RobertaClassifier
from keras_nlp.models.roberta.roberta_preprocessor import RobertaPreprocessor
Expand Down
280 changes: 280 additions & 0 deletions keras_nlp/models/gpt2/gpt2_causal_lm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
# Copyright 2022 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GPT2 Causal LM (Language Model)."""

import copy

import tensorflow as tf
from tensorflow import keras

import keras_nlp
from keras_nlp.models.gpt2.gpt2_backbone import GPT2Backbone
from keras_nlp.models.gpt2.gpt2_causal_lm_preprocessor import (
GPT2CausalLMPreprocessor,
)
from keras_nlp.models.gpt2.gpt2_presets import backbone_presets
from keras_nlp.models.task import Task
from keras_nlp.utils.python_utils import classproperty


@keras.utils.register_keras_serializable(package="keras_nlp")
class ReverseEmbedding(keras.layers.Layer):
chenmoneygithub marked this conversation as resolved.
Show resolved Hide resolved
"""A layer multiplying model outputs by the token embedding.
chenmoneygithub marked this conversation as resolved.
Show resolved Hide resolved

This layer is used to map model outputs to logits over all vocab tokens.
It's used in `GPT2CausalLM` to calculate next token's probability.

Args:
embedding_layer: a `tf.keras.layers.Embedding` instance, the token
chenmoneygithub marked this conversation as resolved.
Show resolved Hide resolved
embedding layer.
"""

def __init__(self, embedding_layer, name="embedding_mapping", **kwargs):
super().__init__(name=name, **kwargs)
self.embedding_layer = embedding_layer

def call(self, inputs):
return tf.matmul(
inputs,
self.embedding_layer.embeddings,
transpose_b=True,
)

def get_config(self):
config = super().get_config()
config.update(
{
"embedding_layer": keras.layers.serialize(self.embedding_layer),
}
)
return config

@classmethod
def from_config(cls, config):
config["embedding_layer"] = keras.layers.deserialize(
config["embedding_layer"],
)
return cls(**config)


@keras.utils.register_keras_serializable(package="keras_nlp")
class GPT2CausalLM(Task):
"""GPT2 Causal LM task model.
chenmoneygithub marked this conversation as resolved.
Show resolved Hide resolved

Causal LM is predicting the next token based on previous tokens, which is
chenmoneygithub marked this conversation as resolved.
Show resolved Hide resolved
the way GPT2 gets pretrained. Users can finetune `GPT2CausalLM` to generate
chenmoneygithub marked this conversation as resolved.
Show resolved Hide resolved
text similar to the custom dataset. `GPT2CausalLM` also has a public method
chenmoneygithub marked this conversation as resolved.
Show resolved Hide resolved
`generate()`, which generates text based on given prompt.

This model can optionally be configured with a `preprocessor` layer, in
which case it will automatically apply preprocessing to raw inputs during
`fit()`, `predict()`, and `evaluate()`. This is done by default when
creating the model with `from_preset()`.

Disclaimer: Pre-trained models are provided on an "as is" basis, without
warranties or conditions of any kind. The underlying model is provided by a
third party and subject to a separate license, available
[here](https://github.com/openai/gpt-2).

Args:
backbone: A `keras_nlp.models.GPT2Backbone` instance.
preprocessor: A `keras_nlp.models.GPT2CausalLMPreprocessor` or `None`.
If `None`, this model will not apply preprocessing, and inputs
should be preprocessed before calling the model.

Examples:
chenmoneygithub marked this conversation as resolved.
Show resolved Hide resolved

Use `generate()` method to do text generation.
```python
chenmoneygithub marked this conversation as resolved.
Show resolved Hide resolved
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")
gpt2_lm.generate("I want to say", max_length=30)

# Generate with batched prompts.
gpt2_lm.generate(["This is a", "Where are you"], max_length=30)
```

Use a custom sampler for text generation.
```python
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")

# Use string identifier to set sampler.
gpt2_lm.generate("I want to say", max_length=30, sampler="top_p")

# Construct a sampler instance.
sampler = keras_nlp.samplers.BeamSampler(num_beams=2)
gpt2_lm.generate("I want to say", max_length=30, sampler=sampler)
```

Load a pretrained `GPT2CausalLM` and get outputs on raw string inputs.
chenmoneygithub marked this conversation as resolved.
Show resolved Hide resolved
```python
str_inputs = "You know this is just a test string"
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")
gpt2_lm.predict([str_inputs])
chenmoneygithub marked this conversation as resolved.
Show resolved Hide resolved
```

Load a pretrained GPT2 and fit on a string dataset.
```python
features = [
"I don't listen to music while coding.",
"But I watch youtube while coding!",
]
ds = tf.data.Dataset.from_tensor_slices(features)
chenmoneygithub marked this conversation as resolved.
Show resolved Hide resolved

# Create a `GPT2CausalLM` and fit your data.
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset(
"gpt2_base_en",
)
gpt2_lm.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
)
gpt2_lm.fit(x=features, batch_size=2)
```

Load a pretrain `GPT2CausalLM` with custom preprocessor, and predict on
chenmoneygithub marked this conversation as resolved.
Show resolved Hide resolved
string inputs.
```python
str_inputs = "You know this is still a test string"
chenmoneygithub marked this conversation as resolved.
Show resolved Hide resolved

# Use a shorter sequence length.
preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset(
"gpt2_base_en",
sequence_length=128,
)

# Create a `GPT2CausalLM`, using pretrained GPT2 and custom preprocessor.
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset(
"gpt2_base_en",
preprocessor=preprocessor,
)
gpt2_lm.predict([str_inputs])
```

Fit your preprocessed data with randomly initialized GPT2. This is useful
jbischof marked this conversation as resolved.
Show resolved Hide resolved
when you want to do data preprocessing inside `tf.data` pipeline.
```python
# Define preprocessed input.
features = {
"token_ids": tf.constant(
[[1, 2, 3, 4, 0, 0]] * 2, shape=(2, 6)
),
"padding_mask": tf.constant(
[[1, 1, 1, 1, 0, 0]] * 2, shape=(2, 6)
),
}
labels = tf.constant(
[[2, 3, 4, 0, 0, 0]] * 2, shape=(2, 6)
)
sample_weight = tf.constant(
[[1, 1, 1, 0, 0, 0]] * 2, shape=(2, 6)
)

# Randomly initialize a GPT2 backbone.
backbone = keras_nlp.models.GPT2Backbone(
vocabulary_size=50257,
num_layers=2,
num_heads=2,
hidden_dim=128,
intermediate_dim=256,
max_sequence_length=128,
)
# Create a `GPT2CausalLM` without preprocessor and fit the data.
gpt2_lm = keras_nlp.models.GPT2CausalLM(backbone, preprocessor=None)
gpt2_lm.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
)
gpt2_lm.fit(
x=features,
y=labels,
sample_weight=sample_weight,
batch_size=2,
)
```

"""

def __init__(self, backbone, preprocessor=None, **kwargs):
inputs = backbone.input
x = backbone(inputs)
embedding_layer = backbone.get_layer("token_embedding")
chenmoneygithub marked this conversation as resolved.
Show resolved Hide resolved
embedding_map_layer = ReverseEmbedding(embedding_layer)
outputs = embedding_map_layer(x)

# Instantiate using Functional API Model constructor
super().__init__(
inputs=inputs,
outputs=outputs,
include_preprocessing=preprocessor is not None,
**kwargs,
)

self._backbone = backbone
self._preprocessor = preprocessor

@classproperty
def presets(cls):
return copy.deepcopy(backbone_presets)

@classproperty
def backbone_cls(cls):
return GPT2Backbone

@classproperty
def preprocessor_cls(cls):
return GPT2CausalLMPreprocessor

def _get_token_probability(self, prompt, mask):
model_inputs = {
"token_ids": prompt,
"padding_mask": mask,
}
return self(model_inputs)

def generate(
self,
prompt,
max_length,
end_token="<|endoftext|>",
chenmoneygithub marked this conversation as resolved.
Show resolved Hide resolved
sampler="top_k",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious, why was top-k chosen as the default?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It;s working well with my finetuning tasks. I feel we want to later change this default to contrastive search, which is not yet available

):
"""Generate text.

This method generates text based on given `prompt`. Generation will
continue until `max_length` is met, and all tokens generated after
`end_token` will be truncated.
chenmoneygithub marked this conversation as resolved.
Show resolved Hide resolved

Args:
prompt: a string, string Tensor or string RaggedTensor. The prompt
text for generation.
max_length: int. The max length of generated sequence.
end_token: string, defaults to "<|endoftext|>", which is the default
chenmoneygithub marked this conversation as resolved.
Show resolved Hide resolved
end token of GPT2. The token marking the end of the sequence,
tokens generated after the end token will be truncated.
sampler: a string or `keras_nlp.samplers.Sampler` instance. The
sampler to be used for text generation.
"""
end_token_id = self.preprocessor.tokenizer.token_to_id(end_token)

if isinstance(sampler, str):
sampler = keras_nlp.samplers.get(sampler)
chenmoneygithub marked this conversation as resolved.
Show resolved Hide resolved
if hasattr(self, "jit_compile"):
mattdangerw marked this conversation as resolved.
Show resolved Hide resolved
chenmoneygithub marked this conversation as resolved.
Show resolved Hide resolved
sampler.jit_compile = self.jit_compile
sampler.run_eagerly = self.run_eagerly
prompt = self.preprocessor.tokenizer(prompt)
chenmoneygithub marked this conversation as resolved.
Show resolved Hide resolved
generated = sampler(
prompt,
self._get_token_probability,
max_length=max_length,
end_token_id=end_token_id,
)
return self.preprocessor.tokenizer.detokenize(generated)
85 changes: 85 additions & 0 deletions keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""GPT2 Causal LM preprocessor layer."""

from tensorflow import keras

from keras_nlp.models.gpt2.gpt2_preprocessor import GPT2Preprocessor
from keras_nlp.utils.keras_utils import pack_x_y_sample_weight


@keras.utils.register_keras_serializable(package="keras_nlp")
class GPT2CausalLMPreprocessor(GPT2Preprocessor):
"""GPT2 Causal LM preprocessor.

This class subclasses `keras_nlp.models.GPT2Preprocessor` and keeps most of
chenmoneygithub marked this conversation as resolved.
Show resolved Hide resolved
its functionality. The only change is `GPT2CausalLMPreprocessor` sets
`y` (label) and `sample_weights` field by shifting the input sequence one
step towards left, and drop the last token as it does not have a successor,
e.g., if the tokenized input is `[1, 2, 3, 0, 0]` with
`padding_mask=[1, 1, 1, 0, 0]`, then after preprocessing, we
chenmoneygithub marked this conversation as resolved.
Show resolved Hide resolved
will have `x=[1, 2, 3, 0]` and `y=[2, 3, 0, 0]`, with
`padding_mask=[1, 1, 1, 0]` and `sample_weights=[1, 1, 0, 0]`.

Args:
tokenizer: A `keras_nlp.models.GPT2Tokenizer` instance.
sequence_length: The length of the packed inputs.

Examples:
```python
# Load the preprocessor from a preset.
preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset(
"gpt2_base_en"
)

# Tokenize and pack a single sentence.
sentence = tf.constant("league of legends")
chenmoneygithub marked this conversation as resolved.
Show resolved Hide resolved
preprocessor(sentence)
# Same output.
preprocessor("league of legends")

# Tokenize a batch of sentences.
sentences = tf.constant(["taco tuesday", "fish taco please!"])
preprocessor(sentences)
# Same output.
preprocessor(["taco tuesday", "fish taco please!"])

# Map a dataset to preprocess a single sentence.
features = tf.constant(
[
"Avatar 2 is amazing!",
"Well, I am not sure.",
]
)
labels = tf.constant([1, 0])
ds = tf.data.Dataset.from_tensor_slices((features, labels))
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)

# Map a dataset to preprocess unlabled sentences.
ds = tf.data.Dataset.from_tensor_slices(features)
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE)
"""

def call(self, x, y=None, sample_weight=None):

x = super().call(x)
chenmoneygithub marked this conversation as resolved.
Show resolved Hide resolved
token_ids, padding_mask = x["token_ids"], x["padding_mask"]
x = {
jbischof marked this conversation as resolved.
Show resolved Hide resolved
"token_ids": token_ids[..., :-1],
"padding_mask": padding_mask[..., :-1],
}
y = token_ids[..., 1:]
sample_weight = padding_mask[..., 1:]
jbischof marked this conversation as resolved.
Show resolved Hide resolved
return pack_x_y_sample_weight(x, y, sample_weight)
Loading