From 1f829a8b74c818e55f3e955a4d397cca0d77f1ea Mon Sep 17 00:00:00 2001 From: Matt Watson Date: Sat, 20 Jul 2024 00:11:46 +0000 Subject: [PATCH] Slightly more defensive handling of type for backbone We have some internal customers doing some weird legacy stuff with backbones. --- keras_nlp/src/models/backbone.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/keras_nlp/src/models/backbone.py b/keras_nlp/src/models/backbone.py index 4477df4a2c..bbcbdbb04c 100644 --- a/keras_nlp/src/models/backbone.py +++ b/keras_nlp/src/models/backbone.py @@ -76,14 +76,14 @@ def __init__(self, *args, dtype=None, **kwargs): id(layer) for layer in self._flatten_layers() ) self._initialized = True - # Before Keras 3.2, there is no `keras.dtype_policies.get`. - if hasattr(keras.dtype_policies, "get"): - self.dtype_policy = keras.dtype_policies.get(dtype) - else: - if isinstance(dtype, keras.dtype_policies.DTypePolicy): - dtype = dtype.name - dtype = dtype or keras.config.dtype_policy().name - self.dtype_policy = keras.dtype_policies.DTypePolicy(dtype) + if dtype is not None: + try: + self.dtype_policy = keras.dtype_policies.get(dtype) + # Before Keras 3.2, there is no `keras.dtype_policies.get`. + except AttributeError: + if isinstance(dtype, keras.DTypePolicy): + dtype = dtype.name + self.dtype_policy = keras.DTypePolicy(dtype) def __setattr__(self, name, value): # Work around setattr issues for Keras 2 and Keras 3 torch backend. @@ -121,7 +121,7 @@ def get_config(self): } # Add quantization support by utilizing `DTypePolicyMap` - if hasattr(keras.dtype_policies, "DTypePolicyMap"): + try: if isinstance( self.dtype_policy, keras.dtype_policies.DTypePolicyMap ): @@ -133,6 +133,9 @@ def get_config(self): policy_map[layer.path] = layer.dtype_policy if len(policy_map) > 0: config.update({"dtype": policy_map}) + # Before Keras 3.2, there is no `keras.dtype_policies.get`. + except AttributeError: + pass return config @classmethod