-
Notifications
You must be signed in to change notification settings - Fork 251
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add f_net_classifier and f_net_classifier_test #670
Merged
Merged
Changes from 38 commits
Commits
Show all changes
39 commits
Select commit
Hold shift + click to select a range
ba4dd4e
add f_net_classifier and f_net_classifier_test
ADITYADAS1999 ba51c2a
Update f_net_classifier.py
ADITYADAS1999 084ce59
Update f_net_classifier.py
ADITYADAS1999 b7cc09d
remove spaces
ADITYADAS1999 fe0f6f4
Update f_net_classifier_test.py
ADITYADAS1999 992160d
Update f_net_classifier.py
ADITYADAS1999 7cd46df
Revert "Update f_net_classifier.py"
ADITYADAS1999 ff674c5
Revert "Update f_net_classifier_test.py"
ADITYADAS1999 0a746e8
remove unnecessary imports
ADITYADAS1999 7bcde56
add necessary imports
ADITYADAS1999 92699a0
Merge branch 'keras-team:master' into new_branch
ADITYADAS1999 cba52ff
Update f_net_classifier.py
ADITYADAS1999 c634ae6
Delete f_net_classifier.py
ADITYADAS1999 583fc00
Delete f_net_classifier_test.py
ADITYADAS1999 06cf94f
Add files via upload
ADITYADAS1999 39f2573
Update f_net_classifier_test.py
ADITYADAS1999 4dd5664
Update f_net_classifier_test.py
ADITYADAS1999 ee4c8e9
Update f_net_classifier.py
ADITYADAS1999 3a8fa0a
Update f_net_classifier.py
ADITYADAS1999 614f8df
Update f_net_classifier.py
ADITYADAS1999 8a39cd3
Delete f_net_classifier.py
ADITYADAS1999 a7c5659
Add files via upload
ADITYADAS1999 bf44198
Update f_net_classifier.py
ADITYADAS1999 5d7b45d
Update f_net_classifier.py
ADITYADAS1999 6e1e279
Update f_net_classifier.py
ADITYADAS1999 850192d
Update f_net_classifier.py
ADITYADAS1999 d33ca41
Update f_net_classifier.py
ADITYADAS1999 4844c62
Delete f_net_classifier_test.py
ADITYADAS1999 ec847f3
Delete f_net_classifier.py
ADITYADAS1999 7b6225b
Add files via upload
ADITYADAS1999 8d5165d
suggested changes
ADITYADAS1999 f159ad2
suggested changes
ADITYADAS1999 9cab7d2
Update f_net_presets_test.py
ADITYADAS1999 95db3eb
Update f_net_presets_test.py
ADITYADAS1999 4953fe7
Delete f_net_presets_test.py
ADITYADAS1999 8a6459d
Add files via upload
ADITYADAS1999 5ccca25
Update f_net_presets_test.py
ADITYADAS1999 1986080
Revert "Update f_net_presets_test.py"
ADITYADAS1999 9380fdd
Fix presets tests
mattdangerw File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
# Copyright 2022 The KerasNLP Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""FNet classification model.""" | ||
|
||
import copy | ||
|
||
from tensorflow import keras | ||
|
||
from keras_nlp.models.f_net.f_net_backbone import FNetBackbone | ||
from keras_nlp.models.f_net.f_net_backbone import f_net_kernel_initializer | ||
from keras_nlp.models.f_net.f_net_preprocessor import FNetPreprocessor | ||
from keras_nlp.models.f_net.f_net_presets import backbone_presets | ||
from keras_nlp.models.task import Task | ||
from keras_nlp.utils.python_utils import classproperty | ||
|
||
|
||
@keras.utils.register_keras_serializable(package="keras_nlp") | ||
class FNetClassifier(Task): | ||
"""An end-to-end f_net model for classification tasks. | ||
|
||
This model attaches a classification head to a | ||
`keras_nlp.model.FNetBackbone` model, mapping from the backbone | ||
outputs to logit output suitable for a classification task. For usage of | ||
this model with pre-trained weights, see the `from_preset()` method. | ||
|
||
This model can optionally be configured with a `preprocessor` layer, in | ||
which case it will automatically apply preprocessing to raw inputs during | ||
`fit()`, `predict()`, and `evaluate()`. This is done by default when | ||
creating the model with `from_preset()`. | ||
|
||
Disclaimer: Pre-trained models are provided on an "as is" basis, without | ||
warranties or conditions of any kind. | ||
|
||
Args: | ||
backbone: A `keras_nlp.models.FNetBackbone` instance. | ||
num_classes: int. Number of classes to predict. | ||
hidden_dim: int. The size of the pooler layer. | ||
dropout: float. The dropout probability value, applied after the dense | ||
layer. | ||
preprocessor: A `keras_nlp.models.FNetPreprocessor` or `None`. If | ||
`None`, this model will not apply preprocessing, and inputs should | ||
be preprocessed before calling the model. | ||
|
||
Example usage: | ||
```python | ||
preprocessed_features = { | ||
"token_ids": tf.ones(shape=(2, 12), dtype=tf.int64), | ||
"segment_ids": tf.constant( | ||
[[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0]] * 2, shape=(2, 12) | ||
), | ||
"padding_mask": tf.constant( | ||
[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]] * 2, shape=(2, 12) | ||
), | ||
} | ||
labels = [0, 3] | ||
|
||
# Randomly initialize a Fnet backbone | ||
backbone = keras_nlp.models.FNetBackbone( | ||
vocabulary_size=32000, | ||
num_layers=12, | ||
num_heads=12, | ||
hidden_dim=768, | ||
intermediate_dim=3072, | ||
max_sequence_length=12, | ||
) | ||
|
||
# Create a Fnet classifier and fit your data. | ||
classifier = keras_nlp.models.FnetClassifier( | ||
backbone, | ||
num_classes=4, | ||
preprocessor=None, | ||
) | ||
classifier.compile( | ||
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), | ||
) | ||
classifier.fit(x=preprocessed_features, y=labels, batch_size=2) | ||
|
||
# Access backbone programatically (e.g., to change `trainable`) | ||
classifier.backbone.trainable = False | ||
``` | ||
""" | ||
|
||
def __init__( | ||
self, | ||
backbone, | ||
num_classes=2, | ||
dropout=0.1, | ||
preprocessor=None, | ||
**kwargs, | ||
): | ||
inputs = backbone.input | ||
pooled = backbone(inputs)["pooled_output"] | ||
pooled = keras.layers.Dropout(dropout)(pooled) | ||
outputs = keras.layers.Dense( | ||
num_classes, | ||
kernel_initializer=f_net_kernel_initializer(), | ||
name="logits", | ||
)(pooled) | ||
# Instantiate using Functional API Model constructor | ||
super().__init__( | ||
inputs=inputs, | ||
outputs=outputs, | ||
include_preprocessing=preprocessor is not None, | ||
**kwargs, | ||
) | ||
# All references to `self` below this line | ||
self._backbone = backbone | ||
self._preprocessor = preprocessor | ||
self.num_classes = num_classes | ||
self.dropout = dropout | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"num_classes": self.num_classes, | ||
"dropout": self.dropout, | ||
} | ||
) | ||
return config | ||
|
||
@classproperty | ||
def backbone_cls(cls): | ||
return FNetBackbone | ||
|
||
@classproperty | ||
def preprocessor_cls(cls): | ||
return FNetPreprocessor | ||
|
||
@classproperty | ||
def presets(cls): | ||
return copy.deepcopy(backbone_presets) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
# Copyright 2023 The KerasNLP Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Tests for FNet classification model.""" | ||
|
||
import io | ||
import os | ||
|
||
import sentencepiece | ||
import tensorflow as tf | ||
from absl.testing import parameterized | ||
from tensorflow import keras | ||
|
||
from keras_nlp.models.f_net.f_net_backbone import FNetBackbone | ||
from keras_nlp.models.f_net.f_net_classifier import FNetClassifier | ||
from keras_nlp.models.f_net.f_net_preprocessor import FNetPreprocessor | ||
from keras_nlp.models.f_net.f_net_tokenizer import FNetTokenizer | ||
|
||
|
||
class FNetClassifierTest(tf.test.TestCase, parameterized.TestCase): | ||
def setUp(self): | ||
self.backbone = FNetBackbone( | ||
vocabulary_size=1000, | ||
num_layers=2, | ||
hidden_dim=64, | ||
intermediate_dim=128, | ||
max_sequence_length=128, | ||
name="encoder", | ||
) | ||
|
||
bytes_io = io.BytesIO() | ||
vocab_data = tf.data.Dataset.from_tensor_slices( | ||
["the quick brown fox", "the earth is round"] | ||
) | ||
|
||
sentencepiece.SentencePieceTrainer.train( | ||
sentence_iterator=vocab_data.as_numpy_iterator(), | ||
model_writer=bytes_io, | ||
vocab_size=10, | ||
model_type="WORD", | ||
pad_id=3, | ||
unk_id=0, | ||
bos_id=4, | ||
eos_id=5, | ||
pad_piece="<pad>", | ||
unk_piece="<unk>", | ||
bos_piece="[CLS]", | ||
eos_piece="[SEP]", | ||
) | ||
|
||
self.proto = bytes_io.getvalue() | ||
|
||
self.preprocessor = FNetPreprocessor( | ||
tokenizer=FNetTokenizer(proto=self.proto), | ||
sequence_length=12, | ||
) | ||
|
||
self.classifier = FNetClassifier( | ||
self.backbone, | ||
4, | ||
preprocessor=self.preprocessor, | ||
) | ||
self.classifier_no_preprocessing = FNetClassifier( | ||
self.backbone, | ||
4, | ||
preprocessor=None, | ||
) | ||
|
||
self.raw_batch = tf.constant( | ||
[ | ||
"the quick brown fox.", | ||
"the slow brown fox.", | ||
"the smelly brown fox.", | ||
"the old brown fox.", | ||
] | ||
) | ||
self.preprocessed_batch = self.preprocessor(self.raw_batch) | ||
self.raw_dataset = tf.data.Dataset.from_tensor_slices( | ||
(self.raw_batch, tf.ones((4,))) | ||
).batch(2) | ||
self.preprocessed_dataset = self.raw_dataset.map(self.preprocessor) | ||
|
||
def test_valid_call_classifier(self): | ||
self.classifier(self.preprocessed_batch) | ||
|
||
@parameterized.named_parameters( | ||
("jit_compile_false", False), ("jit_compile_true", True) | ||
) | ||
def test_fnet_classifier_predict(self, jit_compile): | ||
self.classifier.compile(jit_compile=jit_compile) | ||
self.classifier.predict(self.raw_batch) | ||
|
||
@parameterized.named_parameters( | ||
("jit_compile_false", False), ("jit_compile_true", True) | ||
) | ||
def test_fnet_classifier_predict_no_preprocessing(self, jit_compile): | ||
self.classifier_no_preprocessing.compile(jit_compile=jit_compile) | ||
self.classifier_no_preprocessing.predict(self.preprocessed_batch) | ||
|
||
@parameterized.named_parameters( | ||
("jit_compile_false", False), ("jit_compile_true", True) | ||
) | ||
def test_fnet_classifier_fit(self, jit_compile): | ||
self.classifier.compile( | ||
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), | ||
jit_compile=jit_compile, | ||
) | ||
self.classifier.fit(self.raw_dataset) | ||
|
||
@parameterized.named_parameters( | ||
("jit_compile_false", False), ("jit_compile_true", True) | ||
) | ||
def test_fnet_classifier_fit_no_preprocessing(self, jit_compile): | ||
self.classifier_no_preprocessing.compile( | ||
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), | ||
jit_compile=jit_compile, | ||
) | ||
self.classifier_no_preprocessing.fit(self.preprocessed_dataset) | ||
|
||
@parameterized.named_parameters( | ||
("tf_format", "tf", "model"), | ||
("keras_format", "keras_v3", "model.keras"), | ||
) | ||
def test_saved_model(self, save_format, filename): | ||
model_output = self.classifier.predict(self.raw_batch) | ||
save_path = os.path.join(self.get_temp_dir(), filename) | ||
self.classifier.save(save_path, save_format=save_format) | ||
restored_model = keras.models.load_model(save_path) | ||
|
||
# Check we got the real object back. | ||
self.assertIsInstance(restored_model, FNetClassifier) | ||
|
||
# Check that output matches. | ||
restored_output = restored_model.predict(self.raw_batch) | ||
self.assertAllClose(model_output, restored_output) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You will also need to add some preset tests in
f_net_presets_tests.py
, you can take a look at #668, which is doing a similar thing for AlBERT.