From 931ac288928df1ecb30305f7ffb53ad858502695 Mon Sep 17 00:00:00 2001 From: Matt Watson <1389937+mattdangerw@users.noreply.github.com> Date: Wed, 21 Aug 2024 16:29:28 -0700 Subject: [PATCH] Add an option to disable default compilation (#1787) --- keras_nlp/src/models/causal_lm.py | 2 -- keras_nlp/src/models/classifier.py | 2 -- keras_nlp/src/models/masked_lm.py | 2 -- keras_nlp/src/models/task.py | 10 +++++++++- keras_nlp/src/tests/test_case.py | 4 ++++ 5 files changed, 13 insertions(+), 7 deletions(-) diff --git a/keras_nlp/src/models/causal_lm.py b/keras_nlp/src/models/causal_lm.py index e72959aa67..7fa61d6ba7 100644 --- a/keras_nlp/src/models/causal_lm.py +++ b/keras_nlp/src/models/causal_lm.py @@ -72,8 +72,6 @@ class CausalLM(Task): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - # Default compilation. - self.compile() def compile( self, diff --git a/keras_nlp/src/models/classifier.py b/keras_nlp/src/models/classifier.py index ebfa811d1f..b56a156346 100644 --- a/keras_nlp/src/models/classifier.py +++ b/keras_nlp/src/models/classifier.py @@ -54,8 +54,6 @@ class Classifier(Task): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - # Default compilation. - self.compile() def compile( self, diff --git a/keras_nlp/src/models/masked_lm.py b/keras_nlp/src/models/masked_lm.py index 0969ccf00f..52703cdb7c 100644 --- a/keras_nlp/src/models/masked_lm.py +++ b/keras_nlp/src/models/masked_lm.py @@ -45,8 +45,6 @@ class MaskedLM(Task): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - # Default compilation. - self.compile() def compile( self, diff --git a/keras_nlp/src/models/task.py b/keras_nlp/src/models/task.py index 6a46f09ee0..abee4ecf29 100644 --- a/keras_nlp/src/models/task.py +++ b/keras_nlp/src/models/task.py @@ -56,12 +56,17 @@ class Task(PipelineModel): to load a pre-trained config and weights. Calling `from_preset()` on a task will automatically instantiate a `keras_nlp.models.Backbone` and `keras_nlp.models.Preprocessor`. + + Args: + compile: boolean, defaults to `True`. If `True` will compile the model + with default parameters on construction. Model can still be + recompiled with a new loss, optimizer and metrics before training. """ backbone_cls = None preprocessor_cls = None - def __init__(self, *args, **kwargs): + def __init__(self, *args, compile=True, **kwargs): super().__init__(*args, **kwargs) self._functional_layer_ids = set( id(layer) for layer in self._flatten_layers() @@ -69,6 +74,9 @@ def __init__(self, *args, **kwargs): self._initialized = True if self.backbone is not None: self.dtype_policy = self._backbone.dtype_policy + if compile: + # Default compilation. + self.compile() def preprocess_samples(self, x, y=None, sample_weight=None): if self.preprocessor is not None: diff --git a/keras_nlp/src/tests/test_case.py b/keras_nlp/src/tests/test_case.py index 634335eb78..8902d336b7 100644 --- a/keras_nlp/src/tests/test_case.py +++ b/keras_nlp/src/tests/test_case.py @@ -495,6 +495,10 @@ def run_task_test( task.preprocessor = None task.fit(ds.map(preprocessor)) task.preprocessor = preprocessor + # Turn off default compilation, should error during `fit()`. + task = cls(**init_kwargs, compile=False) + with self.assertRaisesRegex(ValueError, "You must call `compile"): + task.fit(ds) def run_preset_test( self,