Skip to content

Commit

Permalink
Fix compatibility for Keras3.1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
james77777778 committed Jul 10, 2024
1 parent be524fc commit 9a45a68
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 17 deletions.
45 changes: 45 additions & 0 deletions .github/workflows/actions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,51 @@ jobs:
run: |
pip uninstall -y tensorflow-text tensorflow
cd integration_tests && pytest . -k "NoTensorflow"
run_tests_with_keras_3_1_0:
name: Test the code with Keras 3.1.0
strategy:
fail-fast: false
matrix:
backend: [tensorflow, jax, torch]
runs-on: ubuntu-latest
env:
KERAS_BACKEND: ${{ matrix.backend }}
steps:
- uses: actions/checkout@v4
- name: Set up Python 3.9
uses: actions/setup-python@v5
with:
python-version: 3.9
- name: Get pip cache dir
id: pip-cache
run: |
python -m pip install --upgrade pip setuptools
echo "::set-output name=dir::$(pip cache dir)"
- name: pip cache
uses: actions/cache@v4
with:
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }}
restore-keys: |
${{ runner.os }}-pip-
- name: Install dependencies with Keras 3.1.0
run: |
pip install -r requirements.txt --progress-bar off
pip install --no-deps -e "." --progress-bar off
pip uninstall -y keras
pip install keras==3.1.0 --progress-bar off
- name: Test with pytest
run: |
pytest keras_nlp/
- name: Run integration tests
run: |
python pip_build.py --install
cd integration_tests && pytest . -k "not NoTensorflow"
- name: Run no tensorflow integration test
if: ${{ matrix.backend != 'tensorflow'}}
run: |
pip uninstall -y tensorflow-text tensorflow
cd integration_tests && pytest . -k "NoTensorflow"
check_format:
name: Check the code format
runs-on: ubuntu-latest
Expand Down
23 changes: 18 additions & 5 deletions keras_nlp/src/layers/modeling/reversible_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import keras
from keras import ops
from packaging.version import parse

from keras_nlp.src.api_export import keras_nlp_export

Expand Down Expand Up @@ -107,7 +108,10 @@ def __init__(

def build(self, inputs_shape=None):
super().build(inputs_shape)
if not self.tie_weights and self.quantization_mode != "int8":
if (
not self.tie_weights
and getattr(self, "quantization_mode", None) != "int8"
):
self.reverse_embeddings = self.add_weight(
name="reverse_embeddings",
shape=(self.output_dim, self.input_dim),
Expand Down Expand Up @@ -142,11 +146,15 @@ def save_own_variables(self, store):
if not self.built:
return
super().save_own_variables(store)
# Before Keras 3.2, the reverse weight is saved in the super() call.
# After Keras 3.2, the reverse weight must be saved manually.
if parse(keras.version()) < parse("3.2.0"):
return
target_variables = []
if not self.tie_weights:
# Store the reverse embedding weights as the last weights.
target_variables.append(self.reverse_embeddings)
if self.quantization_mode == "int8":
if getattr(self, "quantization_mode", None) == "int8":
target_variables.append(self.reverse_embeddings_scale)
for i, variable in enumerate(target_variables, start=len(store)):
store[str(i)] = variable
Expand All @@ -158,7 +166,7 @@ def load_own_variables(self, store):
if not self.tie_weights:
# Last weights in the stores are the reverse embedding weights.
target_variables = [self.reverse_embeddings]
if self.quantization_mode == "int8":
if getattr(self, "quantization_mode", None) == "int8":
target_variables.append(self.reverse_embeddings_scale)
for i, variable in enumerate(
target_variables, start=len(store) - len(target_variables)
Expand Down Expand Up @@ -226,10 +234,15 @@ def _int8_call(self, inputs, reverse=False):

return super()._int8_call(inputs)

def quantize(self, mode):
def quantize(self, mode, type_check=True):
import gc

if type(self) is not ReversibleEmbedding:
if parse(keras.version()) < parse("3.4.0"):
raise ValueError(
"`quantize` in KerasNLP requires Keras >= 3.4.0 to function "
f"correctly. Received: '{keras.version()}'"
)
if type_check and type(self) is not ReversibleEmbedding:
raise NotImplementedError(
f"Layer {self.__class__.__name__} does not have a `quantize()` "
"method implemented."
Expand Down
7 changes: 7 additions & 0 deletions keras_nlp/src/layers/modeling/reversible_embedding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from absl.testing import parameterized
from keras import ops
from keras import random
from packaging.version import parse

from keras_nlp.src.layers.modeling.reversible_embedding import (
ReversibleEmbedding,
Expand Down Expand Up @@ -103,6 +104,9 @@ def test_reverse_dtype(self):
("tie_weights", True), ("untie_weights", False)
)
def test_quantize_int8(self, tie_weights):
if parse(keras.version()) < parse("3.4.0"):
self.skipTest("This test needs keras>=3.4.0.")

layer_config = dict(
input_dim=100, output_dim=32, tie_weights=tie_weights
)
Expand Down Expand Up @@ -151,6 +155,9 @@ def test_quantize_int8(self, tie_weights):
("untie_weights", False),
)
def test_quantize_dtype_argument(self, tie_weights):
if parse(keras.version()) < parse("3.4.0"):
self.skipTest("This test needs keras>=3.4.0.")

self.run_layer_test(
cls=ReversibleEmbedding,
init_kwargs={
Expand Down
39 changes: 29 additions & 10 deletions keras_nlp/src/models/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import os

import keras
from packaging.version import parse

from keras_nlp.src.api_export import keras_nlp_export
from keras_nlp.src.utils.preset_utils import CONFIG_FILE
Expand Down Expand Up @@ -75,7 +76,14 @@ def __init__(self, *args, dtype=None, **kwargs):
id(layer) for layer in self._flatten_layers()
)
self._initialized = True
self.dtype_policy = keras.dtype_policies.get(dtype)
# Before Keras 3.2, there is no `keras.dtype_policies.get`.
if hasattr(keras.dtype_policies, "get"):
self.dtype_policy = keras.dtype_policies.get(dtype)
else:
if isinstance(dtype, keras.dtype_policies.DTypePolicy):
dtype = dtype.name
dtype = dtype or keras.config.dtype_policy().name
self.dtype_policy = keras.dtype_policies.DTypePolicy(dtype)

def __setattr__(self, name, value):
# Work around setattr issues for Keras 2 and Keras 3 torch backend.
Expand All @@ -100,6 +108,14 @@ def token_embedding(self):
def token_embedding(self, value):
self._token_embedding = value

def quantize(self, mode, **kwargs):
if parse(keras.version()) < parse("3.4.0"):
raise ValueError(
"`quantize` in KerasNLP requires Keras >= 3.4.0 to function "
f"correctly. Received: keras.version()={keras.version()}"
)
return super().quantize(mode, **kwargs)

def get_config(self):
# Don't chain to super here. `get_config()` for functional models is
# a nested layer config and cannot be passed to Backbone constructors.
Expand All @@ -109,15 +125,18 @@ def get_config(self):
}

# Add quantization support by utilizing `DTypePolicyMap`
if isinstance(self.dtype_policy, keras.dtype_policies.DTypePolicyMap):
config.update({"dtype": self.dtype_policy})
else:
policy_map = keras.dtype_policies.DTypePolicyMap()
for layer in self._flatten_layers():
if layer.quantization_mode is not None:
policy_map[layer.path] = layer.dtype_policy
if len(policy_map) > 0:
config.update({"dtype": policy_map})
if hasattr(keras.dtype_policies, "DTypePolicyMap"):
if isinstance(
self.dtype_policy, keras.dtype_policies.DTypePolicyMap
):
config.update({"dtype": self.dtype_policy})
else:
policy_map = keras.dtype_policies.DTypePolicyMap()
for layer in self._flatten_layers():
if layer.quantization_mode is not None:
policy_map[layer.path] = layer.dtype_policy
if len(policy_map) > 0:
config.update({"dtype": policy_map})
return config

@classmethod
Expand Down
9 changes: 8 additions & 1 deletion keras_nlp/src/models/pali_gemma/pali_gemma_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,14 @@ def __init__(
classifier_activation
)
self.image_sequence_length = int((image_size / patch_size) ** 2)
self.dtype_policy = keras.dtype_policies.get(dtype)
# Before Keras 3.2, there is no `keras.dtype_policies.get`.
if hasattr(keras.dtype_policies, "get"):
self.dtype_policy = keras.dtype_policies.get(dtype)
else:
if isinstance(dtype, keras.dtype_policies.DTypePolicy):
dtype = dtype.name
dtype = dtype or keras.config.dtype_policy().name
self.dtype_policy = keras.dtype_policies.DTypePolicy(dtype)

def get_config(self):
config = super().get_config()
Expand Down
3 changes: 2 additions & 1 deletion keras_nlp/src/tests/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from absl.testing import parameterized
from keras import ops
from keras import tree
from packaging.version import parse

from keras_nlp.src import layers as keras_nlp_layers
from keras_nlp.src.tokenizers.tokenizer import Tokenizer
Expand Down Expand Up @@ -445,7 +446,7 @@ def run_backbone_test(
self.run_precision_test(cls, init_kwargs, input_data)

# Check quantization.
if run_quantization_check:
if run_quantization_check and parse(keras.version()) >= parse("3.4.0"):
self.run_quantization_test(backbone, cls, init_kwargs, input_data)

def run_task_test(
Expand Down

0 comments on commit 9a45a68

Please sign in to comment.