From ad106168b59698cb41a40e720160093cb686b30f Mon Sep 17 00:00:00 2001 From: smokestacklightnin <125844868+smokestacklightnin@users.noreply.github.com> Date: Sat, 26 Oct 2024 17:59:03 -0700 Subject: [PATCH] Fix some Ruff rule B008 violations. For remaining B008 violations, see [Issue 6945](https://github.com/tensorflow/tfx/issues/6945) --- tfx/examples/bert/utils/bert_models.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tfx/examples/bert/utils/bert_models.py b/tfx/examples/bert/utils/bert_models.py index d67fa1c6b0..cf59e9ac41 100644 --- a/tfx/examples/bert/utils/bert_models.py +++ b/tfx/examples/bert/utils/bert_models.py @@ -13,6 +13,7 @@ # limitations under the License. """Configurable fine-tuning BERT models for various tasks.""" +from __future__ import annotations from typing import Optional, List, Union import tensorflow as tf @@ -59,8 +60,7 @@ def build_bert_classifier(bert_layer: tf.keras.layers.Layer, def compile_bert_classifier( model: tf.keras.Model, - loss: tf.keras.losses.Loss = tf.keras.losses.SparseCategoricalCrossentropy( - from_logits=True), + loss: tf.keras.losses.Loss | None = None, learning_rate: float = 2e-5, metrics: Optional[List[Union[str, tf.keras.metrics.Metric]]] = None): """Compile the BERT classifier using suggested parameters. @@ -79,6 +79,9 @@ def compile_bert_classifier( Returns: None. """ + if loss is None: + loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) + if metrics is None: metrics = ["sparse_categorical_accuracy"]