diff --git a/keras_nlp/src/models/backbone.py b/keras_nlp/src/models/backbone.py
index bbcbdbb04c..a58072dfce 100644
--- a/keras_nlp/src/models/backbone.py
+++ b/keras_nlp/src/models/backbone.py
@@ -213,7 +213,7 @@ class like `keras_nlp.models.Backbone.from_preset()`, or from
                 f"`from_preset` directly on `{preset_cls.__name__}` instead."
             )
 
-        backbone = load_serialized_object(preset, CONFIG_FILE)
+        backbone = load_serialized_object(preset, CONFIG_FILE, **kwargs)
         if load_weights:
             jax_memory_cleanup(backbone)
             backbone.load_weights(get_file(preset, MODEL_WEIGHTS_FILE))
diff --git a/keras_nlp/src/models/backbone_test.py b/keras_nlp/src/models/backbone_test.py
index 639d730b19..a1ecebf408 100644
--- a/keras_nlp/src/models/backbone_test.py
+++ b/keras_nlp/src/models/backbone_test.py
@@ -27,7 +27,7 @@
 from keras_nlp.src.utils.preset_utils import load_config
 
 
-class TestTask(TestCase):
+class TestBackbone(TestCase):
     def test_preset_accessors(self):
         bert_presets = set(BertBackbone.presets.keys())
         gpt2_presets = set(GPT2Backbone.presets.keys())
@@ -46,6 +46,22 @@ def test_from_preset(self):
             GPT2Backbone,
         )
 
+    @pytest.mark.large
+    def test_from_preset_with_kwargs(self):
+        # Test `dtype`
+        backbone = Backbone.from_preset(
+            "bert_tiny_en_uncased", load_weights=False, dtype="bfloat16"
+        )
+        self.assertIsInstance(backbone, BertBackbone)
+        self.assertEqual(backbone.dtype_policy.name, "bfloat16")
+
+        # Test kwargs forwarding
+        backbone = Backbone.from_preset(
+            "bert_tiny_en_uncased", load_weights=False, dropout=0.5
+        )
+        self.assertIsInstance(backbone, BertBackbone)
+        self.assertAllClose(backbone.dropout, 0.5)
+
     @pytest.mark.large
     def test_from_preset_errors(self):
         with self.assertRaises(ValueError):
diff --git a/keras_nlp/src/models/task.py b/keras_nlp/src/models/task.py
index 8a59fe274f..a0123bfd66 100644
--- a/keras_nlp/src/models/task.py
+++ b/keras_nlp/src/models/task.py
@@ -258,13 +258,11 @@ def from_preset(
                 )
             cls = subclasses[0]
         # Forward dtype to the backbone.
-        config_overrides = {}
+        backbone_kwargs = {}
         if "dtype" in kwargs:
-            config_overrides["dtype"] = kwargs.pop("dtype")
+            backbone_kwargs = {"dtype": kwargs.pop("dtype")}
         backbone = backbone_preset_cls.from_preset(
-            preset,
-            load_weights=load_weights,
-            config_overrides=config_overrides,
+            preset, load_weights=load_weights, **backbone_kwargs
         )
         if "preprocessor" in kwargs:
             preprocessor = kwargs.pop("preprocessor")
diff --git a/keras_nlp/src/models/task_test.py b/keras_nlp/src/models/task_test.py
index e1e235fa01..11b5475b71 100644
--- a/keras_nlp/src/models/task_test.py
+++ b/keras_nlp/src/models/task_test.py
@@ -71,6 +71,16 @@ def test_from_preset(self):
         # TODO: Add a classifier task loading test when there is a classifier
         # with new design available on Kaggle.
 
+    @pytest.mark.large
+    def test_from_preset_with_kwargs(self):
+        # Test `dtype`
+        model = CausalLM.from_preset(
+            "gpt2_base_en", load_weights=False, dtype="bfloat16"
+        )
+        self.assertIsInstance(model, GPT2CausalLM)
+        self.assertEqual(model.dtype_policy.name, "bfloat16")
+        self.assertEqual(model.backbone.dtype_policy.name, "bfloat16")
+
     @pytest.mark.large
     def test_from_preset_errors(self):
         with self.assertRaises(ValueError):
diff --git a/keras_nlp/src/utils/preset_utils.py b/keras_nlp/src/utils/preset_utils.py
index 522a4e0dc9..f797bf9f18 100644
--- a/keras_nlp/src/utils/preset_utils.py
+++ b/keras_nlp/src/utils/preset_utils.py
@@ -561,13 +561,16 @@ def check_format(preset):
     return "keras"
 
 
-def load_serialized_object(
-    preset,
-    config_file=CONFIG_FILE,
-    config_overrides={},
-):
+def load_serialized_object(preset, config_file=CONFIG_FILE, **kwargs):
+    kwargs = kwargs or {}
     config = load_config(preset, config_file)
-    config["config"] = {**config["config"], **config_overrides}
+
+    # `dtype` in config might be a serialized `DTypePolicy` or `DTypePolicyMap`.
+    # Ensure that `dtype` is properly configured.
+    dtype = kwargs.pop("dtype", None)
+    config = set_dtype_in_config(config, dtype)
+
+    config["config"] = {**config["config"], **kwargs}
     return keras.saving.deserialize_keras_object(config)
 
 
@@ -590,3 +593,25 @@ def jax_memory_cleanup(layer):
         for weight in layer.weights:
             if getattr(weight, "_value", None) is not None:
                 weight._value.delete()
+
+
+def set_dtype_in_config(config, dtype=None):
+    if dtype is None:
+        return config
+
+    config = config.copy()
+    if "dtype" not in config["config"]:
+        # Forward `dtype` to the config.
+        config["config"]["dtype"] = dtype
+    elif (
+        "dtype" in config["config"]
+        and isinstance(config["config"]["dtype"], dict)
+        and "DTypePolicyMap" in config["config"]["dtype"]["class_name"]
+    ):
+        # If it is `DTypePolicyMap` in `config`, forward `dtype` as its default
+        # policy.
+        policy_map_config = config["config"]["dtype"]["config"]
+        policy_map_config["default_policy"] = dtype
+        for k in policy_map_config["policy_map"].keys():
+            policy_map_config["policy_map"][k]["config"]["source_name"] = dtype
+    return config
diff --git a/keras_nlp/src/utils/preset_utils_test.py b/keras_nlp/src/utils/preset_utils_test.py
index 9185f4b6da..558ddcc62e 100644
--- a/keras_nlp/src/utils/preset_utils_test.py
+++ b/keras_nlp/src/utils/preset_utils_test.py
@@ -23,10 +23,12 @@
 from keras_nlp.src.models import BertBackbone
 from keras_nlp.src.models import BertTokenizer
 from keras_nlp.src.tests.test_case import TestCase
+from keras_nlp.src.utils.keras_utils import has_quantization_support
 from keras_nlp.src.utils.preset_utils import CONFIG_FILE
 from keras_nlp.src.utils.preset_utils import METADATA_FILE
 from keras_nlp.src.utils.preset_utils import TOKENIZER_CONFIG_FILE
 from keras_nlp.src.utils.preset_utils import check_format
+from keras_nlp.src.utils.preset_utils import load_serialized_object
 
 
 class PresetUtilsTest(TestCase):
@@ -113,3 +115,18 @@ def test_incorrect_metadata(self):
 
         with self.assertRaisesRegex(ValueError, "doesn't have `keras_version`"):
             check_format(preset_dir)
+
+    @parameterized.named_parameters(
+        ("gemma2_2b_en", "gemma2_2b_en", "bfloat16", False),
+        ("llama2_7b_en_int8", "llama2_7b_en_int8", "bfloat16", True),
+    )
+    @pytest.mark.extra_large
+    def test_load_serialized_object(self, preset, dtype, is_quantized):
+        if is_quantized and not has_quantization_support():
+            self.skipTest("This version of Keras doesn't support quantization.")
+
+        model = load_serialized_object(preset, dtype=dtype)
+        if is_quantized:
+            self.assertEqual(model.dtype_policy.name, "map_bfloat16")
+        else:
+            self.assertEqual(model.dtype_policy.name, "bfloat16")