-
Notifications
You must be signed in to change notification settings - Fork 248
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* initial commit * Add keras_nlp.samplers * Change padding to left to right * more samplers * Add GPT2 text generation stuff * correct top-p and beam sampler * initial commit * Add keras_nlp.samplers * Change padding to left to right * Add serialization support, and move some args from constructor to call * Add string example * small changes * Address comments: fix docstring, remove multicase support * Address comments: move token_probability_fn to the second place * some initials * add more sampler class, and a few changes on the base sampler class * dummy * add some arg defaults * small fix * fix docstring * some changes * add classes * fix serialization * fix docstring * address comments * one more * fix docstring * minor fix
- Loading branch information
1 parent
9c5f850
commit 5737da9
Showing
8 changed files
with
931 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,239 @@ | ||
# Copyright 2022 The KerasNLP Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""GPT2 Causal LM (Language Model).""" | ||
|
||
import copy | ||
|
||
import tensorflow as tf | ||
from tensorflow import keras | ||
|
||
import keras_nlp | ||
from keras_nlp.models.gpt2.gpt2_backbone import GPT2Backbone | ||
from keras_nlp.models.gpt2.gpt2_causal_lm_preprocessor import ( | ||
GPT2CausalLMPreprocessor, | ||
) | ||
from keras_nlp.models.gpt2.gpt2_presets import backbone_presets | ||
from keras_nlp.models.task import Task | ||
from keras_nlp.utils.python_utils import classproperty | ||
|
||
|
||
@keras.utils.register_keras_serializable(package="keras_nlp") | ||
class GPT2CausalLM(Task): | ||
"""An end-to-end GPT2 model for causal langauge modeling. | ||
A causal language model (LM) predicts the next token based on previous | ||
tokens the next token based on previous tokens, which is the way GPT2 gets | ||
pretrained. You can finetune `GPT2CausalLM` to generate text similar to | ||
the custom dataset. `GPT2CausalLM` also has a method `generate()`, which | ||
generates text based on given prompt. | ||
This model can optionally be configured with a `preprocessor` layer, in | ||
which case it will automatically apply preprocessing to raw inputs during | ||
`fit()`, `predict()`, and `evaluate()`. This is done by default when | ||
creating the model with `from_preset()`. | ||
Disclaimer: Pre-trained models are provided on an "as is" basis, without | ||
warranties or conditions of any kind. The underlying model is provided by a | ||
third party and subject to a separate license, available | ||
[here](https://github.com/openai/gpt-2). | ||
Args: | ||
backbone: A `keras_nlp.models.GPT2Backbone` instance. | ||
preprocessor: A `keras_nlp.models.GPT2CausalLMPreprocessor` or `None`. | ||
If `None`, this model will not apply preprocessing, and inputs | ||
should be preprocessed before calling the model. | ||
Examples: | ||
Use `generate()` method to do text generation. | ||
```python | ||
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") | ||
gpt2_lm.generate("I want to say", max_length=30) | ||
# Generate with batched prompts. | ||
gpt2_lm.generate(["This is a", "Where are you"], max_length=30) | ||
``` | ||
Use a custom sampler for text generation. | ||
```python | ||
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") | ||
# Use string identifier to set sampler. | ||
gpt2_lm.generate("I want to say", max_length=30, sampler="top_p") | ||
# Construct a sampler instance. | ||
sampler = keras_nlp.samplers.BeamSampler(num_beams=2) | ||
gpt2_lm.generate("I want to say", max_length=30, sampler=sampler) | ||
``` | ||
Map raw string to languages model logit predictions. | ||
```python | ||
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en") | ||
gpt2_lm.predict(["You know this is just a test string"]) | ||
``` | ||
Load a pretrained GPT2 and fit on a string dataset. | ||
```python | ||
features = [ | ||
"I don't listen to music while coding.", | ||
"But I watch youtube while coding!", | ||
] | ||
ds = tf.data.Dataset.from_tensor_slices(features) | ||
# Create a `GPT2CausalLM` and fit your data. | ||
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset( | ||
"gpt2_base_en", | ||
) | ||
gpt2_lm.compile( | ||
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), | ||
) | ||
gpt2_lm.fit(ds, batch_size=2) | ||
``` | ||
Load a pretrained `GPT2CausalLM` with custom preprocessor, and predict on | ||
string inputs. | ||
```python | ||
# Use a shorter sequence length. | ||
preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset( | ||
"gpt2_base_en", | ||
sequence_length=128, | ||
) | ||
# Create a `GPT2CausalLM`, using pretrained GPT2 and custom preprocessor. | ||
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset( | ||
"gpt2_base_en", | ||
preprocessor=preprocessor, | ||
) | ||
gpt2_lm.predict(["You know this is still a test string"]) | ||
``` | ||
Fit your preprocessed data with randomly initialized GPT2. This is useful | ||
when you want to do data preprocessing inside `tf.data` pipeline. | ||
```python | ||
# Define preprocessed input. | ||
features = { | ||
"token_ids": tf.constant( | ||
[[1, 2, 3, 4, 0, 0]] * 2, shape=(2, 6) | ||
), | ||
"padding_mask": tf.constant( | ||
[[1, 1, 1, 1, 0, 0]] * 2, shape=(2, 6) | ||
), | ||
} | ||
labels = tf.constant( | ||
[[2, 3, 4, 0, 0, 0]] * 2, shape=(2, 6) | ||
) | ||
sample_weight = tf.constant( | ||
[[1, 1, 1, 0, 0, 0]] * 2, shape=(2, 6) | ||
) | ||
# Randomly initialize a GPT2 backbone. | ||
backbone = keras_nlp.models.GPT2Backbone( | ||
vocabulary_size=50257, | ||
num_layers=2, | ||
num_heads=2, | ||
hidden_dim=128, | ||
intermediate_dim=256, | ||
max_sequence_length=128, | ||
) | ||
# Create a `GPT2CausalLM` without preprocessor and fit the data. | ||
gpt2_lm = keras_nlp.models.GPT2CausalLM(backbone, preprocessor=None) | ||
gpt2_lm.compile( | ||
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), | ||
) | ||
gpt2_lm.fit( | ||
x=features, | ||
y=labels, | ||
sample_weight=sample_weight, | ||
batch_size=2, | ||
) | ||
``` | ||
""" | ||
|
||
def __init__(self, backbone, preprocessor=None, **kwargs): | ||
inputs = backbone.input | ||
x = backbone(inputs) | ||
# Use token embedding weights to project from the token representation | ||
# to vocabulary logits. | ||
outputs = tf.matmul( | ||
x, | ||
backbone.token_embedding.embeddings, | ||
transpose_b=True, | ||
) | ||
|
||
# Instantiate using Functional API Model constructor. | ||
super().__init__( | ||
inputs=inputs, | ||
outputs=outputs, | ||
include_preprocessing=preprocessor is not None, | ||
**kwargs, | ||
) | ||
|
||
self._backbone = backbone | ||
self._preprocessor = preprocessor | ||
|
||
@classproperty | ||
def presets(cls): | ||
return copy.deepcopy(backbone_presets) | ||
|
||
@classproperty | ||
def backbone_cls(cls): | ||
return GPT2Backbone | ||
|
||
@classproperty | ||
def preprocessor_cls(cls): | ||
return GPT2CausalLMPreprocessor | ||
|
||
def _get_token_probability(self, prompt, mask): | ||
model_inputs = { | ||
"token_ids": prompt, | ||
"padding_mask": mask, | ||
} | ||
return self(model_inputs) | ||
|
||
def generate( | ||
self, | ||
prompt, | ||
max_length, | ||
sampler="top_k", | ||
): | ||
"""Generate text. | ||
This method generates text based on given `prompt`. Generation will | ||
continue until `max_length` is met, and all tokens generated after | ||
`end_token` will be truncated. The sampling approach used can be | ||
controlled via the sampler argument. | ||
Args: | ||
prompt: a string, string Tensor or string RaggedTensor. The prompt | ||
text for generation. | ||
max_length: int. The max length of generated sequence. | ||
sampler: a string or `keras_nlp.samplers.Sampler` instance. The | ||
sampler to be used for text generation. | ||
""" | ||
end_token_id = self.preprocessor.tokenizer.end_token_id | ||
|
||
sampler = keras_nlp.samplers.get(sampler) | ||
if hasattr(self, "jit_compile"): | ||
# `jit_compile` is a public property as of tf 2.12. hasattr is for | ||
# backward compat. | ||
sampler.jit_compile = self.jit_compile | ||
sampler.run_eagerly = self.run_eagerly | ||
generated = sampler( | ||
self.preprocessor.tokenizer(prompt), | ||
self._get_token_probability, | ||
max_length=max_length, | ||
end_token_id=end_token_id, | ||
) | ||
return self.preprocessor.tokenizer.detokenize(generated) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# Copyright 2023 The KerasNLP Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""GPT2 Causal LM preprocessor layer.""" | ||
|
||
from absl import logging | ||
from tensorflow import keras | ||
|
||
from keras_nlp.models.gpt2.gpt2_preprocessor import GPT2Preprocessor | ||
from keras_nlp.utils.keras_utils import pack_x_y_sample_weight | ||
|
||
|
||
@keras.utils.register_keras_serializable(package="keras_nlp") | ||
class GPT2CausalLMPreprocessor(GPT2Preprocessor): | ||
"""GPT2 Causal LM preprocessor. | ||
This preprocessor is majorly used as the preprocesor for `GPT2CausalLM`. | ||
This class subclasses `keras_nlp.models.GPT2Preprocessor` and keeps most of | ||
its functionality. The only change is `GPT2CausalLMPreprocessor` sets | ||
`y` (label) and `sample_weights` field by shifting the input sequence one | ||
step towards left, and drop the last token as it does not have a successor, | ||
e.g., if the tokenized input is `[1, 2, 3, 0, 0]` with | ||
`padding_mask = [1, 1, 1, 0, 0]`, then after preprocessing, we | ||
will have `x = [1, 2, 3, 0]` and `y = [2, 3, 0, 0]`, with | ||
`padding_mask = [1, 1, 1, 0]` and `sample_weights = [1, 1, 0, 0]`. | ||
Args: | ||
tokenizer: A `keras_nlp.models.GPT2Tokenizer` instance. | ||
sequence_length: The length of the packed inputs. | ||
Examples: | ||
```python | ||
# Load the preprocessor from a preset. | ||
preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset( | ||
"gpt2_base_en" | ||
) | ||
# Tokenize and pack a single sentence. | ||
sentence = tf.constant("League of legends") | ||
preprocessor(sentence) | ||
# Same output. | ||
preprocessor("League of legends") | ||
# Tokenize a batch of sentences. | ||
sentences = tf.constant(["Taco tuesday", "Fish taco please!"]) | ||
preprocessor(sentences) | ||
# Same output. | ||
preprocessor(["Taco tuesday", "Fish taco please!"]) | ||
# Map a dataset to preprocess a single sentence. | ||
features = tf.constant( | ||
[ | ||
"Avatar 2 is amazing!", | ||
"Well, I am not sure.", | ||
] | ||
) | ||
labels = tf.constant([1, 0]) | ||
ds = tf.data.Dataset.from_tensor_slices((features, labels)) | ||
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) | ||
# Map a dataset to preprocess unlabled sentences. | ||
ds = tf.data.Dataset.from_tensor_slices(features) | ||
ds = ds.map(preprocessor, num_parallel_calls=tf.data.AUTOTUNE) | ||
""" | ||
|
||
def call(self, x, y=None, sample_weight=None): | ||
if y is not None or sample_weight is not None: | ||
logging.warning( | ||
"`GPT2CausalLMPreprocessor` generates `y` and `sample_weight` " | ||
"based on your input data, but your data already contains `y` " | ||
"or `sample_weight`. Your `y` and `sample_weight` will be " | ||
"ignored." | ||
) | ||
|
||
x = super().call(x) | ||
token_ids, padding_mask = x["token_ids"], x["padding_mask"] | ||
# The last token does not have a next token, so we truncate it out. | ||
x = { | ||
"token_ids": token_ids[..., :-1], | ||
"padding_mask": padding_mask[..., :-1], | ||
} | ||
# Target `y` will be the next token. | ||
y = token_ids[..., 1:] | ||
sample_weight = padding_mask[..., 1:] | ||
return pack_x_y_sample_weight(x, y, sample_weight) |
Oops, something went wrong.