Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
james77777778 committed Jul 3, 2024
1 parent 8f65809 commit 6f276c7
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 177 deletions.
30 changes: 11 additions & 19 deletions keras_nlp/src/layers/modeling/reversible_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,13 @@ def __init__(

def build(self, inputs_shape=None):
super().build(inputs_shape)

if not self.tie_weights:
if self.quantization_mode != "int8":
self.reverse_embeddings = self.add_weight(
name="reverse_embeddings",
shape=(self.output_dim, self.input_dim),
initializer=self.embeddings_initializer,
dtype=self.dtype,
)
if not self.tie_weights and self.quantization_mode != "int8":
self.reverse_embeddings = self.add_weight(
name="reverse_embeddings",
shape=(self.output_dim, self.input_dim),
initializer=self.embeddings_initializer,
dtype=self.dtype,
)

def call(self, inputs, reverse=False):
if reverse:
Expand Down Expand Up @@ -148,11 +146,8 @@ def save_own_variables(self, store):
if not self.tie_weights:
# Store the reverse embedding weights as the last weights.
target_variables.append(self.reverse_embeddings)
if self.quantization_mode is not None:
if self.quantization_mode == "int8":
target_variables.append(self.reverse_embeddings_scale)
else:
raise self._quantization_mode_error(self.quantization_mode)
if self.quantization_mode == "int8":
target_variables.append(self.reverse_embeddings_scale)
for i, variable in enumerate(target_variables, start=len(store)):
store[str(i)] = variable

Expand All @@ -163,11 +158,8 @@ def load_own_variables(self, store):
if not self.tie_weights:
# Last weights in the stores are the reverse embedding weights.
target_variables = [self.reverse_embeddings]
if self.quantization_mode is not None:
if self.quantization_mode == "int8":
target_variables.append(self.reverse_embeddings_scale)
else:
raise self._quantization_mode_error(self.quantization_mode)
if self.quantization_mode == "int8":
target_variables.append(self.reverse_embeddings_scale)
for i, variable in enumerate(
target_variables, start=len(store) - len(target_variables)
):
Expand Down
3 changes: 3 additions & 0 deletions keras_nlp/src/models/bloom/bloom_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ def test_backbone_basics(self):
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output_shape=(2, 5, 8),
# TODO: Set to `True`. Error msg: Layer LayerNormalization does not
# have a `quantized_call()` method implemented.
run_quantization_check=False,
)

@pytest.mark.large
Expand Down
75 changes: 0 additions & 75 deletions keras_nlp/src/models/gemma/gemma_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os

import keras
import pytest
from absl.testing import parameterized
from keras import ops

from keras_nlp.src.models.gemma.gemma_backbone import GemmaBackbone
Expand Down Expand Up @@ -173,42 +170,6 @@ def test_distribution_with_lora(self):
if "attention/value/lora_kernel_b" in w.path:
self.assertEqual(tuple(w.value.sharding.spec), (None, None))

@parameterized.named_parameters(("int8", "int8"), ("float8", "float8"))
def test_quantize(self, mode):
model = GemmaBackbone(**self.init_kwargs)
y_float = model(self.input_data)
model.quantize(mode)

# Verify weights dtype
selected_layer = model.transformer_layers[0].attention.query_dense
if mode == "int8":
self.assertDTypeEqual(selected_layer._kernel, "int8")
elif mode == "float8":
self.assertLen(selected_layer.trainable_weights, 7)
self.assertTrue(hasattr(selected_layer, "kernel_amax_history"))

# Try eager call and verify output correctness
y_quantized = model(self.input_data)
mse = ops.mean(ops.square(y_float - y_quantized))
if mode == "int8":
# A weak correctness test
self.assertLess(mse, 1e-2)
elif mode == "float8":
# float8 quantization requires extra calibration, so we skip the
# assertion
pass

# Try saving and reloading the model
temp_filepath = os.path.join(
self.get_temp_dir(), "quantized_model.keras"
)
model.save(temp_filepath)
reloaded_model = keras.models.load_model(temp_filepath)
self.assertAllClose(
model.predict(self.input_data),
reloaded_model.predict(self.input_data),
)


class Gemma2BackboneTest(TestCase):
def setUp(self):
Expand Down Expand Up @@ -249,39 +210,3 @@ def test_saved_model(self):
init_kwargs=self.init_kwargs,
input_data=self.input_data,
)

@parameterized.named_parameters(("int8", "int8"), ("float8", "float8"))
def test_quantize(self, mode):
model = GemmaBackbone(**self.init_kwargs)
y_float = model(self.input_data)
model.quantize(mode)

# Verify weights dtype
selected_layer = model.transformer_layers[0].attention.query_dense
if mode == "int8":
self.assertDTypeEqual(selected_layer._kernel, "int8")
elif mode == "float8":
self.assertLen(selected_layer.trainable_weights, 7)
self.assertTrue(hasattr(selected_layer, "kernel_amax_history"))

# Try eager call and verify output correctness
y_quantized = model(self.input_data)
mse = ops.mean(ops.square(y_float - y_quantized))
if mode == "int8":
# A weak correctness test
self.assertLess(mse, 1e-2)
elif mode == "float8":
# float8 quantization requires extra calibration, so we skip the
# assertion
pass

# Try saving and reloading the model
temp_filepath = os.path.join(
self.get_temp_dir(), "quantized_model.keras"
)
model.save(temp_filepath)
reloaded_model = keras.models.load_model(temp_filepath)
self.assertAllClose(
model.predict(self.input_data),
reloaded_model.predict(self.input_data),
)
4 changes: 4 additions & 0 deletions keras_nlp/src/models/opt/opt_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ def test_backbone_basics(self):
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output_shape=(2, 5, 2),
# TODO: Set to `True`. Error msg: Layer 'token_embedding' expected 1
# variables, but received 0 variables during loading. Expected:
# ['embeddings']
run_quantization_check=False,
)

@pytest.mark.large
Expand Down
117 changes: 35 additions & 82 deletions keras_nlp/src/models/pali_gemma/pali_gemma_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
# limitations under the License.
import os

import keras
import numpy as np
from absl.testing import parameterized

from keras_nlp.src.models.pali_gemma.pali_gemma_backbone import (
PaliGemmaBackbone,
Expand All @@ -35,12 +33,7 @@ def setUp(self):
self.vocabulary_size = 256
self.text_sequence_length = 64
self.image_size = 16
self.dummy_text = [
"the quick brown fox" for _ in range(self.batch_size)
]
self.dummy_images = np.random.uniform(
size=(self.batch_size, self.image_size, self.image_size, 3)
)
self.image_sequence_length = int((self.image_size / 4) ** 2)

proto = "gemma_test_vocab.spm"
tokenizer = PaliGemmaTokenizer(
Expand All @@ -65,95 +58,55 @@ def setUp(self):
"vit_hidden_dim": 8,
"vit_intermediate_dim": 16,
}
self.backbone = PaliGemmaBackbone(**self.init_kwargs)
self.dummy_imgs = np.random.rand(

dummy_images = np.random.rand(
self.batch_size, self.image_size, self.image_size, 3
)
self.dummy_text_token_ids = np.random.rand(
dummy_text_token_ids = np.random.rand(
self.batch_size, self.text_sequence_length
)
self.dummy_text = [
"answer en the quick brown fox" for i in range(self.batch_size)
]
dummy_text = ["answer en the quick brown fox"] * self.batch_size
self.input_data = {
"token_ids": dummy_text_token_ids,
"images": dummy_images,
"padding_mask": np.ones(
(self.batch_size, self.text_sequence_length),
dtype="int32",
),
"response_mask": np.zeros(
(self.batch_size, self.text_sequence_length),
dtype="int32",
),
}
self.raw_input_data = {
"images": dummy_images,
"prompts": dummy_text,
"responses": dummy_text,
}

def test_pali_gemma_backbone(self):
output = self.backbone(
{
"token_ids": self.dummy_text_token_ids,
"images": self.dummy_imgs,
"padding_mask": np.ones(
(self.batch_size, self.text_sequence_length),
dtype="int32",
),
"response_mask": np.zeros(
(self.batch_size, self.text_sequence_length),
dtype="int32",
),
}
)
self.assertEqual(
(
def test_backbone_basics(self):
self.run_backbone_test(
cls=PaliGemmaBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output_shape=(
self.batch_size,
self.text_sequence_length + self.backbone.image_sequence_length,
self.text_sequence_length + self.image_sequence_length,
8,
),
output.shape,
variable_length_data=[self.input_data],
run_mixed_precision_check=False, # TODO: Set to `True`
)

def test_pali_gemma_backbone_with_preprocessing(self):
x, _, _ = self.preprocessor(
{
"images": self.dummy_images,
"prompts": self.dummy_text,
"responses": self.dummy_text,
}
)
output = self.backbone(x)
model = PaliGemmaBackbone(**self.init_kwargs)
x, _, _ = self.preprocessor(self.raw_input_data)
output = model(x)
self.assertEqual(
(
self.batch_size,
self.text_sequence_length + self.backbone.image_sequence_length,
self.text_sequence_length + self.image_sequence_length,
8,
),
output.shape,
)

@parameterized.named_parameters(("int8", "int8"), ("float8", "float8"))
def test_quantize(self, mode):
input_data = {
"token_ids": self.dummy_text_token_ids,
"images": self.dummy_imgs,
"padding_mask": np.ones(
(self.batch_size, self.text_sequence_length),
dtype="int32",
),
"response_mask": np.zeros(
(self.batch_size, self.text_sequence_length),
dtype="int32",
),
}
model = PaliGemmaBackbone(**self.init_kwargs)
model(input_data)
model.quantize(mode)

# Verify weights dtype
selected_layer = model.transformer_layers[0].attention.query_dense
if mode == "int8":
self.assertDTypeEqual(selected_layer._kernel, "int8")
elif mode == "float8":
self.assertLen(selected_layer.trainable_weights, 7)
self.assertTrue(hasattr(selected_layer, "kernel_amax_history"))

# Try eager call
model(input_data)

# Try saving and reloading the model
temp_filepath = os.path.join(
self.get_temp_dir(), "quantized_model.keras"
)
model.save(temp_filepath)
reloaded_model = keras.models.load_model(temp_filepath)
self.assertAllClose(
model.predict(input_data),
reloaded_model.predict(input_data),
)
1 change: 1 addition & 0 deletions keras_nlp/src/models/pali_gemma/pali_gemma_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,7 @@ def __init__(
classifier_activation
)
self.image_sequence_length = int((image_size / patch_size) ** 2)
self.dtype_policy = keras.dtype_policies.get(dtype)

def get_config(self):
config = super().get_config()
Expand Down
35 changes: 34 additions & 1 deletion keras_nlp/src/tests/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,32 @@ def run_precision_test(self, cls, init_kwargs, input_data):
self.assertEqual(policy.compute_dtype, sublayer.compute_dtype)
self.assertEqual(policy.variable_dtype, sublayer.variable_dtype)

def run_quantization_test(self, cls, init_kwargs, input_data):
policy = keras.DTypePolicy("float32")
for mode in ["int8", "float8"]:
layer = cls(**{**init_kwargs, "dtype": policy})
layer.quantize(mode)
# Try eager call
if isinstance(layer, keras.Model):
_ = layer(input_data)
elif isinstance(input_data, dict):
_ = layer(**input_data)
else:
_ = layer(input_data)
# Verify sublayer's dtype policy
for sublayer in layer._flatten_layers():
if type(sublayer) is keras.layers.Dense:
self.assertEqual(
f"{mode}_from_float32", sublayer.dtype_policy.name
)
# Try saving and reloading the model
temp_filepath = os.path.join(self.get_temp_dir(), "layer.keras")
layer.save(temp_filepath)
reloaded_layer = keras.models.load_model(temp_filepath)
self.assertAllClose(
layer.predict(input_data), reloaded_layer.predict(input_data)
)

def run_model_saving_test(
self,
cls,
Expand Down Expand Up @@ -364,6 +390,7 @@ def run_backbone_test(
expected_output_shape,
variable_length_data=None,
run_mixed_precision_check=True,
run_quantization_check=True,
):
"""Run basic tests for a backbone, including compilation."""
backbone = cls(**init_kwargs)
Expand Down Expand Up @@ -405,7 +432,13 @@ def run_backbone_test(
name = re.sub("([a-z])([A-Z])", r"\1_\2", name).lower()
self.assertRegexpMatches(backbone.name, name)

self.run_precision_test(cls, init_kwargs, input_data)
# Check mixed precision.
if run_mixed_precision_check:
self.run_precision_test(cls, init_kwargs, input_data)

# Check quantization.
if run_quantization_check:
self.run_quantization_test(cls, init_kwargs, input_data)

def run_task_test(
self,
Expand Down

0 comments on commit 6f276c7

Please sign in to comment.