Skip to content

Commit

Permalink
Utilize to_numpy=True if it is available
Browse files Browse the repository at this point in the history
  • Loading branch information
james77777778 committed Jul 31, 2024
1 parent 86fd5a0 commit ac0b3c1
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions keras_nlp/src/layers/modeling/reversible_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def _int8_call(self, inputs, reverse=False):

def quantize(self, mode, type_check=True):
import gc
import inspect

assert_quantization_support()
if type_check and type(self) is not ReversibleEmbedding:
Expand All @@ -262,19 +263,27 @@ def quantize(self, mode, type_check=True):
)
self._check_quantize_args(mode, self.compute_dtype)

def abs_max_quantize(inputs, axis):
sig = inspect.signature(keras.quantizers.abs_max_quantize)
if "to_numpy" in sig.parameters:
return keras.quantizers.abs_max_quantize(
inputs, axis=axis, to_numpy=True
)
else:
# `keras<=3.4.1` doesn't support `to_numpy`
return keras.quantizers.abs_max_quantize(inputs, axis=axis)

self._tracker.unlock()
if mode == "int8":
embeddings, embeddings_scale = keras.quantizers.abs_max_quantize(
embeddings, embeddings_scale = abs_max_quantize(
self._embeddings, axis=-1
)
embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)
self._untrack_variable(self._embeddings)
del self._embeddings
if not self.tie_weights:
reverse_embeddings, reverse_embeddings_scale = (
keras.quantizers.abs_max_quantize(
self.reverse_embeddings, axis=0
)
reverse_embeddings, reverse_embeddings_scale = abs_max_quantize(
self.reverse_embeddings, axis=0
)
reverse_embeddings_scale = ops.squeeze(
reverse_embeddings_scale, axis=0
Expand Down

0 comments on commit ac0b3c1

Please sign in to comment.