Skip to content

Commit

Permalink
Add GPT2 text generation stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
chenmoneygithub committed Jan 3, 2023
1 parent 513121e commit 8cd580c
Show file tree
Hide file tree
Showing 5 changed files with 221 additions and 0 deletions.
5 changes: 5 additions & 0 deletions keras_nlp/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@
from keras_nlp.models.distil_bert.distil_bert_tokenizer import (
DistilBertTokenizer,
)
from keras_nlp.models.gpt2.gpt2_backbone import GPT2Backbone
from keras_nlp.models.gpt2.gpt2_causal_lm import GPT2CausalLM
from keras_nlp.models.gpt2.gpt2_preprocessor import GPT2CausalLMPreprocessor
from keras_nlp.models.gpt2.gpt2_preprocessor import GPT2Preprocessor
from keras_nlp.models.gpt2.gpt2_tokenizer import GPT2Tokenizer
from keras_nlp.models.roberta.roberta_backbone import RobertaBackbone
from keras_nlp.models.roberta.roberta_classifier import RobertaClassifier
from keras_nlp.models.roberta.roberta_preprocessor import RobertaPreprocessor
Expand Down
110 changes: 110 additions & 0 deletions keras_nlp/models/gpt2/gpt2_causal_lm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright 2022 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BERT task specific models and heads."""

import copy

import tensorflow as tf
from tensorflow import keras

from keras_nlp.models.gpt2.gpt2_backbone import GPT2Backbone
from keras_nlp.models.gpt2.gpt2_preprocessor import GPT2CausalLMPreprocessor
from keras_nlp.models.gpt2.gpt2_presets import backbone_presets
from keras_nlp.samplers.beam_sampler import BeamSampler
from keras_nlp.samplers.greedy_sampler import GreedySampler
from keras_nlp.samplers.top_k_sampler import TopKSampler
from keras_nlp.samplers.top_p_sampler import TopPSampler
from keras_nlp.utils.pipeline_model import PipelineModel
from keras_nlp.utils.python_utils import classproperty


@keras.utils.register_keras_serializable(package="keras_nlp")
class GPT2CausalLM(PipelineModel):
def __init__(self, backbone, preprocessor=None, **kwargs):

inputs = backbone.input
x = backbone(inputs)
x = tf.matmul(
x,
backbone.get_layer("token_embedding").embeddings,
transpose_b=True,
)
outputs = tf.keras.layers.Softmax()(x)
# Instantiate using Functional API Model constructor
super().__init__(
inputs=inputs,
outputs=outputs,
include_preprocessing=preprocessor is not None,
**kwargs,
)

self.preprocessor = preprocessor
self.backbone = backbone

def preprocess_samples(self, x, y=None, sample_weight=None):
return self.preprocessor(x, y=y, sample_weight=sample_weight)

@classproperty
def presets(cls):
return copy.deepcopy(backbone_presets)

@classmethod
def from_preset(
cls,
preset,
load_weights=True,
**kwargs,
):
if "preprocessor" not in kwargs:
kwargs["preprocessor"] = GPT2CausalLMPreprocessor.from_preset(
preset
)

# Check if preset is backbone-only model.
if preset in GPT2Backbone.presets:
backbone = GPT2Backbone.from_preset(preset, load_weights)
return cls(backbone, **kwargs)

# Otherwise must be one of class presets.
# Currently no classifier-level presets, so we raise ValueError.
if preset not in cls.presets:
raise ValueError(
"`preset` must be one of "
f"""{", ".join(cls.presets)}. Received: {preset}."""
)

def _get_generator(self, identifier):
maps = {
"greedy": GreedySampler(),
"top_k": TopKSampler(k=5, from_logits=False),
"top_p": TopPSampler(p=0.1, from_logits=False),
"beam": BeamSampler(num_beams=5),
}
return maps[identifier]

def _get_token_probability(self, prompt, mask):
model_inputs = {
"token_ids": prompt,
"padding_mask": mask,
}
probs = self(model_inputs)
return probs

def generate(self, prompt, max_length, generator="top_k"):
"""Pick one method as the default generation algo."""
if isinstance(generator, str):
generator = self._get_generator(generator)
prompt = self.preprocessor.tokenizer(prompt)
generated = generator(self._get_token_probability, prompt, max_length)
return self.preprocessor.tokenizer.detokenize(generated)
102 changes: 102 additions & 0 deletions keras_nlp/models/gpt2/gpt2_preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright 2022 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""GPT2 preprocessor layer."""

import copy

import tensorflow as tf
from tensorflow import keras

from keras_nlp.models.gpt2.gpt2_presets import backbone_presets
from keras_nlp.models.gpt2.gpt2_tokenizer import GPT2Tokenizer
from keras_nlp.utils.keras_utils import pack_x_y_sample_weight
from keras_nlp.utils.python_utils import classproperty


class GPT2Preprocessor(keras.layers.Layer):
def __init__(self, tokenizer, sequence_length, **kwargs):

super().__init__(**kwargs)

self.tokenizer = tokenizer
self.sequence_length = sequence_length

def call(self, x, y=None, sample_weight=None):
token_ids = self.tokenizer(x)
mask = tf.ones_like(token_ids, dtype=tf.bool)
mask = mask.to_tensor(shape=(None, self.sequence_length))
token_ids = token_ids.to_tensor(shape=(None, self.sequence_length))
x = {
"token_ids": token_ids,
"padding_mask": mask,
}

return pack_x_y_sample_weight(x, y, sample_weight)

@classproperty
def presets(cls):
return copy.deepcopy(backbone_presets)

@classmethod
def from_preset(
cls,
preset,
sequence_length=None,
**kwargs,
):
if preset not in cls.presets:
raise ValueError(
"`preset` must be one of "
f"""{", ".join(cls.presets)}. Received: {preset}."""
)

tokenizer = GPT2Tokenizer.from_preset(preset)

# Use model's `max_sequence_length` if `sequence_length` unspecified;
# otherwise check that `sequence_length` not too long.
metadata = cls.presets[preset]
max_sequence_length = metadata["config"]["max_sequence_length"]
if sequence_length is not None:
if sequence_length > max_sequence_length:
raise ValueError(
f"`sequence_length` cannot be longer than `{preset}` "
f"preset's `max_sequence_length` of {max_sequence_length}. "
f"Received: {sequence_length}."
)
else:
sequence_length = max_sequence_length

return cls(
tokenizer=tokenizer,
sequence_length=sequence_length,
**kwargs,
)


class GPT2CausalLMPreprocessor(GPT2Preprocessor):
def call(self, x, y=None, sample_weight=None):
token_ids = self.tokenizer(x)
mask = tf.ones_like(token_ids, dtype=tf.bool)
mask = mask.to_tensor(shape=(None, self.sequence_length))
token_ids = token_ids.to_tensor(shape=(None, self.sequence_length))
x = {
"token_ids": token_ids[:, :-1],
"padding_mask": mask[:, 1:],
}

y = token_ids[:, 1:]
sample_weight = mask[:, 1:]

return pack_x_y_sample_weight(x, y, sample_weight)
3 changes: 3 additions & 0 deletions keras_nlp/samplers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_nlp.samplers.beam_sampler import BeamSampler
from keras_nlp.samplers.greedy_sampler import GreedySampler
from keras_nlp.samplers.top_k_sampler import TopKSampler
from keras_nlp.samplers.top_p_sampler import TopPSampler
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Greedy Sampler."""

import tensorflow as tf
from tensorflow import keras

from keras_nlp.samplers.sampler import Sampler
from keras_nlp.samplers.sampler import base_sampler_keyword_args
Expand Down

0 comments on commit 8cd580c

Please sign in to comment.