Skip to content

Commit

Permalink
fix docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
chenmoneygithub committed Jan 23, 2023
1 parent 4fa8fc5 commit cb12604
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 153 deletions.
211 changes: 79 additions & 132 deletions keras_nlp/models/gpt2/gpt2_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@
GPT2CausalLMPreprocessor,
)
from keras_nlp.models.gpt2.gpt2_presets import backbone_presets
from keras_nlp.utils.pipeline_model import PipelineModel
from keras_nlp.models.task import Task
from keras_nlp.utils.python_utils import classproperty


@keras.utils.register_keras_serializable(package="keras_nlp")
class EmbeddingMapping(keras.layers.Layer):
class ReverseEmbedding(keras.layers.Layer):
"""A layer multiplying model outputs by the token embedding.
This layer is used to map model outputs to logits over all vocab tokens.
Expand Down Expand Up @@ -69,7 +69,7 @@ def from_config(cls, config):


@keras.utils.register_keras_serializable(package="keras_nlp")
class GPT2CausalLM(PipelineModel):
class GPT2CausalLM(Task):
"""GPT2 Causal LM task model.
Causal LM is predicting the next token based on previous tokens, which is
Expand All @@ -95,112 +95,110 @@ class GPT2CausalLM(PipelineModel):
Examples:
Example usage.
Use `generate()` method to do text generation.
```python
features = {
"token_ids": tf.constant(
[[1, 2, 3, 4, 0, 0]] * 2, shape=(2, 6)
),
"padding_mask": tf.constant(
[[1, 1, 1, 1, 0, 0]] * 2, shape=(2, 6)
),
}
labels = tf.constant(
[[2, 3, 4, 0, 0, 0]] * 2, shape=(2, 6)
)
sample_weights = tf.constant(
[[1, 1, 1, 0, 0, 0]] * 2, shape=(2, 6)
)
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")
gpt2_lm.generate("I want to say", max_length=30)
# Randomly initialize a GPT2 backbone.
backbone = keras_nlp.models.GPT2Backbone(
vocabulary_size=50257,
num_layers=2,
num_heads=2,
hidden_dim=128,
intermediate_dim=256,
max_sequence_length=128,
)
# Create a `GPT2CausalLM` and fit the data.
gpt2_lm = keras_nlp.models.GPT2CausalLM(backbone, preprocessor=None)
gpt2_lm.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
)
gpt2_lm.fit(
x=features,
y=labels,
sample_weights=sample_weights,
batch_size=2,
)
# Generate with batched prompts.
gpt2_lm.generate(["This is a", "Where are you"], max_length=30)
```
Use a custom sampler for text generation.
```python
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")
# Use string identifier to set sampler.
gpt2_lm.generate("I want to say", max_length=30, sampler="top_p")
# Construct a sampler instance.
sampler = keras_nlp.samplers.BeamSampler(num_beams=2)
gpt2_lm.generate("I want to say", max_length=30, sampler=sampler)
```
Load a pretrained `GPT2CausalLM` and get outputs on raw string inputs.
```python
str_inputs = "You know this is just a test string"
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")
gpt2_lm.predict([str_inputs])
```
Raw string inputs.
Load a pretrained GPT2 and fit on a string dataset.
```python
# Create a dataset with raw string features in an `(x, y)` format.
features = [
"I don't listen to music while coding.",
"But I watch youtube while coding!",
]
ds = tf.data.Dataset.from_tensor_slices(features)
# Create a `GPT2CausalLM` and fit your data.
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset(
"gpt2_base_en",
)
gpt2_lm.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
)
gpt2_lm.fit(x=features, batch_size=2)
```
Raw string inputs with customized preprocessing.
Load a pretrain `GPT2CausalLM` with custom preprocessor, and predict on
string inputs.
```python
# Create a dataset with raw string features in an `(x, y)` format.
features = [
"I don't listen to music while coding.",
"But I watch youtube while coding!",
]
str_inputs = "You know this is still a test string"
# Use a shorter sequence length.
preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset(
"gpt2_base_en",
sequence_length=128,
)
# Create a `GPT2CausalLM` and fit your data.
# Create a `GPT2CausalLM`, using pretrained GPT2 and custom preprocessor.
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset(
"gpt2_base_en",
preprocessor=preprocessor,
)
gpt2_lm.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
)
gpt2_lm.fit(x=features, batch_size=2)
gpt2_lm.predict([str_inputs])
```
# Use tf dataset.
Fit your preprocessed data with randomly initialized GPT2. This is useful
when you want to do data preprocessing inside `tf.data` pipeline.
```python
features = [
"I don't listen to music while coding.",
"But I watch youtube while coding!",
]
ds = tf.data.Dataset.from_tensor_slices(features)
# Define preprocessed input.
features = {
"token_ids": tf.constant(
[[1, 2, 3, 4, 0, 0]] * 2, shape=(2, 6)
),
"padding_mask": tf.constant(
[[1, 1, 1, 1, 0, 0]] * 2, shape=(2, 6)
),
}
labels = tf.constant(
[[2, 3, 4, 0, 0, 0]] * 2, shape=(2, 6)
)
sample_weight = tf.constant(
[[1, 1, 1, 0, 0, 0]] * 2, shape=(2, 6)
)
# Create a `GPT2CausalLM` and fit your data.
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset(
"gpt2_base_en",
preprocessor=preprocessor,
# Randomly initialize a GPT2 backbone.
backbone = keras_nlp.models.GPT2Backbone(
vocabulary_size=50257,
num_layers=2,
num_heads=2,
hidden_dim=128,
intermediate_dim=256,
max_sequence_length=128,
)
# Create a `GPT2CausalLM` without preprocessor and fit the data.
gpt2_lm = keras_nlp.models.GPT2CausalLM(backbone, preprocessor=None)
gpt2_lm.compile(
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
)
gpt2_lm.fit(x=features, batch_size=2)
```
# Use `generate()` method to generate text.
```python
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset("gpt2_base_en")
gpt2_lm.generate("I want to say", max_length=30)
# Generate with batched prompts.
gpt2_lm.generate(["This is a", "Where are you"], max_length=30)
gpt2_lm.fit(
x=features,
y=labels,
sample_weight=sample_weight,
batch_size=2,
)
```
"""
Expand All @@ -209,7 +207,7 @@ def __init__(self, backbone, preprocessor=None, **kwargs):
inputs = backbone.input
x = backbone(inputs)
embedding_layer = backbone.get_layer("token_embedding")
embedding_map_layer = EmbeddingMapping(embedding_layer)
embedding_map_layer = ReverseEmbedding(embedding_layer)
outputs = embedding_map_layer(x)

# Instantiate using Functional API Model constructor
Expand All @@ -223,47 +221,17 @@ def __init__(self, backbone, preprocessor=None, **kwargs):
self._backbone = backbone
self._preprocessor = preprocessor

def preprocess_samples(self, x, y=None, sample_weight=None):
return self.preprocessor(x, y=y, sample_weight=sample_weight)

@property
def backbone(self):
"""The associated `keras_nlp.models.GPT2Backbone`."""
return self._backbone

@property
def preprocessor(self):
"""A `keras_nlp.models.GPT2CausalLMPreprocessor` for preprocessing."""
return self._preprocessor

@classproperty
def presets(cls):
return copy.deepcopy(backbone_presets)

@classmethod
def from_preset(
cls,
preset,
load_weights=True,
**kwargs,
):
if "preprocessor" not in kwargs:
kwargs["preprocessor"] = GPT2CausalLMPreprocessor.from_preset(
preset
)

# Check if preset is backbone-only model.
if preset in GPT2Backbone.presets:
backbone = GPT2Backbone.from_preset(preset, load_weights)
return cls(backbone, **kwargs)

# Otherwise must be one of class presets.
# Currently no classifier-level presets, so we raise ValueError.
if preset not in cls.presets:
raise ValueError(
"`preset` must be one of "
f"""{", ".join(cls.presets)}. Received: {preset}."""
)
@classproperty
def backbone_cls(cls):
return GPT2Backbone

@classproperty
def preprocessor_cls(cls):
return GPT2CausalLMPreprocessor

def _get_token_probability(self, prompt, mask):
model_inputs = {
Expand Down Expand Up @@ -301,8 +269,7 @@ def generate(
sampler = keras_nlp.samplers.get(sampler)
if hasattr(self, "jit_compile"):
sampler.jit_compile = self.jit_compile
if hasattr(self, "run_eagerly"):
sampler.run_eagerly = self.run_eagerly
sampler.run_eagerly = self.run_eagerly
prompt = self.preprocessor.tokenizer(prompt)
generated = sampler(
prompt,
Expand All @@ -311,23 +278,3 @@ def generate(
end_token_id=end_token_id,
)
return self.preprocessor.tokenizer.detokenize(generated)

def get_config(self):
return {
"backbone": keras.layers.serialize(self.backbone),
"preprocessor": keras.layers.serialize(self.preprocessor),
"name": self.name,
"trainable": self.trainable,
}

@classmethod
def from_config(cls, config):
if "backbone" in config and isinstance(config["backbone"], dict):
config["backbone"] = keras.layers.deserialize(config["backbone"])
if "preprocessor" in config and isinstance(
config["preprocessor"], dict
):
config["preprocessor"] = keras.layers.deserialize(
config["preprocessor"]
)
return cls(**config)
26 changes: 8 additions & 18 deletions keras_nlp/models/gpt2/gpt2_causal_lm_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ class GPT2CausalLMPreprocessor(GPT2Preprocessor):
preprocessor("league of legends")
# Tokenize a batch of sentences.
sentences = tf.constant(["taco tuesday", "gi gi gi gi"])
sentences = tf.constant(["taco tuesday", "fish taco please!"])
preprocessor(sentences)
# Same output.
preprocessor(["taco tuesday", "gi gi gi gi"])
preprocessor(["taco tuesday", "fish taco please!"])
# Map a dataset to preprocess a single sentence.
features = tf.constant(
Expand All @@ -76,20 +76,10 @@ def call(self, x, y=None, sample_weight=None):

x = super().call(x)
token_ids, padding_mask = x["token_ids"], x["padding_mask"]
if len(token_ids.shape) == 1:
x = {
"token_ids": token_ids[:-1],
"padding_mask": padding_mask[:-1],
}
y = token_ids[1:]
sample_weight = padding_mask[1:]
else:
x = {
"token_ids": token_ids[:, :-1],
"padding_mask": padding_mask[:, :-1],
}

y = token_ids[:, 1:]
sample_weight = padding_mask[:, 1:]

x = {
"token_ids": token_ids[..., :-1],
"padding_mask": padding_mask[..., :-1],
}
y = token_ids[..., 1:]
sample_weight = padding_mask[..., 1:]
return pack_x_y_sample_weight(x, y, sample_weight)
7 changes: 6 additions & 1 deletion keras_nlp/models/gpt2/gpt2_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,12 @@ class GPT2Preprocessor(Preprocessor):
```
"""

def __init__(self, tokenizer, sequence_length, **kwargs):
def __init__(
self,
tokenizer,
sequence_length,
**kwargs,
):

super().__init__(**kwargs)

Expand Down
4 changes: 2 additions & 2 deletions keras_nlp/samplers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def token_probability_fn(inputs, mask):
prompt = tf.fill((8, 1), 1)
sampler = keras_nlp.samplers.Greedy()
sampler = keras_nlp.samplers.GreedySampler()
# Print the generated sequence (token ids).
print(sampler(prompt, token_probability_fn, max_length=10, end_token_id=2))
```
Expand Down Expand Up @@ -118,7 +118,7 @@ def token_probability_fn(inputs, mask):
return model(inputs)
prompt = tokenizer("the quick brown fox")
sampler = keras_nlp.samplers.Greedy()
sampler = keras_nlp.samplers.GreedySampler()
generated = sampler(
prompt,
token_probability_fn,
Expand Down

0 comments on commit cb12604

Please sign in to comment.