From 562b9dd95a751005aee29fbab4a08045c95385d2 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Wed, 7 Aug 2024 09:52:24 +0800 Subject: [PATCH] Support kwargs to Backbone.from_preset and fix the dtype forwarding in Task.from_preset --- keras_nlp/src/models/backbone.py | 2 +- keras_nlp/src/models/backbone_test.py | 18 +++++++++++- keras_nlp/src/models/task.py | 8 ++--- keras_nlp/src/models/task_test.py | 10 +++++++ keras_nlp/src/utils/preset_utils.py | 37 ++++++++++++++++++++---- keras_nlp/src/utils/preset_utils_test.py | 17 +++++++++++ 6 files changed, 79 insertions(+), 13 deletions(-) diff --git a/keras_nlp/src/models/backbone.py b/keras_nlp/src/models/backbone.py index bbcbdbb04c..a58072dfce 100644 --- a/keras_nlp/src/models/backbone.py +++ b/keras_nlp/src/models/backbone.py @@ -213,7 +213,7 @@ class like `keras_nlp.models.Backbone.from_preset()`, or from f"`from_preset` directly on `{preset_cls.__name__}` instead." ) - backbone = load_serialized_object(preset, CONFIG_FILE) + backbone = load_serialized_object(preset, CONFIG_FILE, **kwargs) if load_weights: jax_memory_cleanup(backbone) backbone.load_weights(get_file(preset, MODEL_WEIGHTS_FILE)) diff --git a/keras_nlp/src/models/backbone_test.py b/keras_nlp/src/models/backbone_test.py index 639d730b19..a1ecebf408 100644 --- a/keras_nlp/src/models/backbone_test.py +++ b/keras_nlp/src/models/backbone_test.py @@ -27,7 +27,7 @@ from keras_nlp.src.utils.preset_utils import load_config -class TestTask(TestCase): +class TestBackbone(TestCase): def test_preset_accessors(self): bert_presets = set(BertBackbone.presets.keys()) gpt2_presets = set(GPT2Backbone.presets.keys()) @@ -46,6 +46,22 @@ def test_from_preset(self): GPT2Backbone, ) + @pytest.mark.large + def test_from_preset_with_kwargs(self): + # Test `dtype` + backbone = Backbone.from_preset( + "bert_tiny_en_uncased", load_weights=False, dtype="bfloat16" + ) + self.assertIsInstance(backbone, BertBackbone) + self.assertEqual(backbone.dtype_policy.name, "bfloat16") + + # Test kwargs forwarding + backbone = Backbone.from_preset( + "bert_tiny_en_uncased", load_weights=False, dropout=0.5 + ) + self.assertIsInstance(backbone, BertBackbone) + self.assertAllClose(backbone.dropout, 0.5) + @pytest.mark.large def test_from_preset_errors(self): with self.assertRaises(ValueError): diff --git a/keras_nlp/src/models/task.py b/keras_nlp/src/models/task.py index 8a59fe274f..a0123bfd66 100644 --- a/keras_nlp/src/models/task.py +++ b/keras_nlp/src/models/task.py @@ -258,13 +258,11 @@ def from_preset( ) cls = subclasses[0] # Forward dtype to the backbone. - config_overrides = {} + backbone_kwargs = {} if "dtype" in kwargs: - config_overrides["dtype"] = kwargs.pop("dtype") + backbone_kwargs = {"dtype": kwargs.pop("dtype")} backbone = backbone_preset_cls.from_preset( - preset, - load_weights=load_weights, - config_overrides=config_overrides, + preset, load_weights=load_weights, **backbone_kwargs ) if "preprocessor" in kwargs: preprocessor = kwargs.pop("preprocessor") diff --git a/keras_nlp/src/models/task_test.py b/keras_nlp/src/models/task_test.py index e1e235fa01..11b5475b71 100644 --- a/keras_nlp/src/models/task_test.py +++ b/keras_nlp/src/models/task_test.py @@ -71,6 +71,16 @@ def test_from_preset(self): # TODO: Add a classifier task loading test when there is a classifier # with new design available on Kaggle. + @pytest.mark.large + def test_from_preset_with_kwargs(self): + # Test `dtype` + model = CausalLM.from_preset( + "gpt2_base_en", load_weights=False, dtype="bfloat16" + ) + self.assertIsInstance(model, GPT2CausalLM) + self.assertEqual(model.dtype_policy.name, "bfloat16") + self.assertEqual(model.backbone.dtype_policy.name, "bfloat16") + @pytest.mark.large def test_from_preset_errors(self): with self.assertRaises(ValueError): diff --git a/keras_nlp/src/utils/preset_utils.py b/keras_nlp/src/utils/preset_utils.py index 522a4e0dc9..f797bf9f18 100644 --- a/keras_nlp/src/utils/preset_utils.py +++ b/keras_nlp/src/utils/preset_utils.py @@ -561,13 +561,16 @@ def check_format(preset): return "keras" -def load_serialized_object( - preset, - config_file=CONFIG_FILE, - config_overrides={}, -): +def load_serialized_object(preset, config_file=CONFIG_FILE, **kwargs): + kwargs = kwargs or {} config = load_config(preset, config_file) - config["config"] = {**config["config"], **config_overrides} + + # `dtype` in config might be a serialized `DTypePolicy` or `DTypePolicyMap`. + # Ensure that `dtype` is properly configured. + dtype = kwargs.pop("dtype", None) + config = set_dtype_in_config(config, dtype) + + config["config"] = {**config["config"], **kwargs} return keras.saving.deserialize_keras_object(config) @@ -590,3 +593,25 @@ def jax_memory_cleanup(layer): for weight in layer.weights: if getattr(weight, "_value", None) is not None: weight._value.delete() + + +def set_dtype_in_config(config, dtype=None): + if dtype is None: + return config + + config = config.copy() + if "dtype" not in config["config"]: + # Forward `dtype` to the config. + config["config"]["dtype"] = dtype + elif ( + "dtype" in config["config"] + and isinstance(config["config"]["dtype"], dict) + and "DTypePolicyMap" in config["config"]["dtype"]["class_name"] + ): + # If it is `DTypePolicyMap` in `config`, forward `dtype` as its default + # policy. + policy_map_config = config["config"]["dtype"]["config"] + policy_map_config["default_policy"] = dtype + for k in policy_map_config["policy_map"].keys(): + policy_map_config["policy_map"][k]["config"]["source_name"] = dtype + return config diff --git a/keras_nlp/src/utils/preset_utils_test.py b/keras_nlp/src/utils/preset_utils_test.py index 9185f4b6da..558ddcc62e 100644 --- a/keras_nlp/src/utils/preset_utils_test.py +++ b/keras_nlp/src/utils/preset_utils_test.py @@ -23,10 +23,12 @@ from keras_nlp.src.models import BertBackbone from keras_nlp.src.models import BertTokenizer from keras_nlp.src.tests.test_case import TestCase +from keras_nlp.src.utils.keras_utils import has_quantization_support from keras_nlp.src.utils.preset_utils import CONFIG_FILE from keras_nlp.src.utils.preset_utils import METADATA_FILE from keras_nlp.src.utils.preset_utils import TOKENIZER_CONFIG_FILE from keras_nlp.src.utils.preset_utils import check_format +from keras_nlp.src.utils.preset_utils import load_serialized_object class PresetUtilsTest(TestCase): @@ -113,3 +115,18 @@ def test_incorrect_metadata(self): with self.assertRaisesRegex(ValueError, "doesn't have `keras_version`"): check_format(preset_dir) + + @parameterized.named_parameters( + ("gemma2_2b_en", "gemma2_2b_en", "bfloat16", False), + ("llama2_7b_en_int8", "llama2_7b_en_int8", "bfloat16", True), + ) + @pytest.mark.extra_large + def test_load_serialized_object(self, preset, dtype, is_quantized): + if is_quantized and not has_quantization_support(): + self.skipTest("This version of Keras doesn't support quantization.") + + model = load_serialized_object(preset, dtype=dtype) + if is_quantized: + self.assertEqual(model.dtype_policy.name, "map_bfloat16") + else: + self.assertEqual(model.dtype_policy.name, "bfloat16")