Skip to content

Commit

Permalink
fix serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
chenmoneygithub committed Jan 20, 2023
1 parent 9945c13 commit 4d9a9b7
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 20 deletions.
23 changes: 16 additions & 7 deletions keras_nlp/models/gpt2/gpt2_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,18 @@
from keras_nlp.utils.python_utils import classproperty


# @keras.utils.register_keras_serializable(package="keras_nlp")
@keras.utils.register_keras_serializable(package="keras_nlp")
class EmbeddingMapping(keras.layers.Layer):
"""A layer multiplying model outputs by the token embedding.
This layer is used to map model outputs to logits over all vocab tokens.
It's used in `GPT2CausalLM` to calculate next token's probability.
Args:
embedding_layer: a `tf.keras.layers.Embedding` instance, the token
embedding layer.
"""

def __init__(self, embedding_layer, name="embedding_mapping", **kwargs):
super().__init__(name=name, **kwargs)
self.embedding_layer = embedding_layer
Expand Down Expand Up @@ -198,10 +208,9 @@ class GPT2CausalLM(PipelineModel):
def __init__(self, backbone, preprocessor=None, **kwargs):
inputs = backbone.input
x = backbone(inputs)
# embedding_layer = backbone.get_layer("token_embedding")
# embedding_map_layer = EmbeddingMapping(embedding_layer)
# outputs = embedding_map_layer(x)
outputs = x
embedding_layer = backbone.get_layer("token_embedding")
embedding_map_layer = EmbeddingMapping(embedding_layer)
outputs = embedding_map_layer(x)

# Instantiate using Functional API Model constructor
super().__init__(
Expand All @@ -219,12 +228,12 @@ def preprocess_samples(self, x, y=None, sample_weight=None):

@property
def backbone(self):
"""The associated `keras_nlp.models.RobertaBackbone`."""
"""The associated `keras_nlp.models.GPT2Backbone`."""
return self._backbone

@property
def preprocessor(self):
"""A `keras_nlp.models.RobertaMaskedLMPreprocessor` for preprocessing inputs."""
"""A `keras_nlp.models.GPT2CausalLMPreprocessor` for preprocessing."""
return self._preprocessor

@classproperty
Expand Down
2 changes: 2 additions & 0 deletions keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@

"""GPT2 Causal LM preprocessor layer."""

from tensorflow import keras

from keras_nlp.models.gpt2.gpt2_preprocessor import GPT2Preprocessor
from keras_nlp.utils.keras_utils import pack_x_y_sample_weight


@keras.utils.register_keras_serializable(package="keras_nlp")
class GPT2CausalLMPreprocessor(GPT2Preprocessor):
"""GPT2 Causal LM preprocessor.
Expand Down
17 changes: 7 additions & 10 deletions keras_nlp/models/gpt2/gpt2_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import copy

import tensorflow as tf
from tensorflow import keras

from keras_nlp.models.gpt2.gpt2_presets import backbone_presets
from keras_nlp.models.gpt2.gpt2_tokenizer import GPT2Tokenizer
Expand All @@ -28,13 +29,13 @@
from keras_nlp.utils.python_utils import classproperty


@keras.utils.register_keras_serializable(package="keras_nlp")
class GPT2Preprocessor(Preprocessor):
"""GPT2 preprocessing layer which tokenizes and packs inputs.
This preprocessing layer will do three things:
This preprocessing layer will do 2 things:
- Tokenize the input using the `tokenizer`.
- Add the id of '<|endoftext|>' to the start and end of the tokenized input.
- Construct a dictionary with keys `"token_ids"`, `"padding_mask"`, that can
be passed directly to a `keras_nlp.models.GPT2Backbone`.
Expand Down Expand Up @@ -135,23 +136,19 @@ def call(self, x, y=None, sample_weight=None):
if len(x) > 1:
raise ValueError(
"GPT2 requires each input feature to contain only "
f"one segment, but received: {len(x)}. If you are using GPT2 "
f"one segment, but received {len(x)}. If you are using GPT2 "
"for a multi-segment classification task, please refer to "
"classification models like BERT or RoBERTa."
)
token_ids = self._tokenizer(x[0])
# batch_size = token_ids.nrows()
# start_column = tf.fill((batch_size, 1), self._tokenizer.end_token_id)
# end_column = tf.fill((batch_size, 1), self._tokenizer.end_token_id)
# token_ids = tf.concat([start_column, token_ids, end_column], axis=1)
input_is_1d = False
if len(token_ids.shape) == 1:
input_is_1d = True
input_is_1d = len(token_ids.shape) == 1
if input_is_1d:
token_ids = tf.RaggedTensor.from_tensor([token_ids])
mask = tf.ones_like(token_ids, dtype=tf.bool)
mask = mask.to_tensor(shape=(None, self.sequence_length))
token_ids = token_ids.to_tensor(shape=(None, self.sequence_length))
if input_is_1d:
# If the input is a single string, we let the output be a 1D tensor.
token_ids = tf.squeeze(token_ids, axis=0)
mask = tf.squeeze(mask, axis=0)
x = {
Expand Down
3 changes: 0 additions & 3 deletions keras_nlp/models/roberta/roberta_preprocessor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,6 @@ def test_tokenize_labeled_dataset(self):
sw = tf.constant([1.0] * 4)
ds = tf.data.Dataset.from_tensor_slices((x, y, sw))
ds = ds.map(self.preprocessor)
import pdb

pdb.set_trace()
x_out, y_out, sw_out = ds.batch(4).take(1).get_single_element()
self.assertAllEqual(
x_out["token_ids"], [[0, 3, 4, 5, 3, 6, 2, 1, 1, 1, 1, 1]] * 4
Expand Down

0 comments on commit 4d9a9b7

Please sign in to comment.