diff --git a/tfx/dsl/component/experimental/decorators_test.py b/tfx/dsl/component/experimental/decorators_test.py index 21f3113a32..5757a7bb36 100644 --- a/tfx/dsl/component/experimental/decorators_test.py +++ b/tfx/dsl/component/experimental/decorators_test.py @@ -42,6 +42,7 @@ from tfx.types.system_executions import SystemExecution _TestBeamPipelineArgs = ['--my_testing_beam_pipeline_args=foo'] +_TestEmptyBeamPipeline = beam.Pipeline() class _InputArtifact(types.Artifact): @@ -140,7 +141,7 @@ def verify_beam_pipeline_arg(a: int) -> OutputDict(b=float): # pytype: disable= def verify_beam_pipeline_arg_non_none_default_value( a: int, - beam_pipeline: BeamComponentParameter[beam.Pipeline] = beam.Pipeline(), + beam_pipeline: BeamComponentParameter[beam.Pipeline] = _TestEmptyBeamPipeline, ) -> OutputDict(b=float): # pytype: disable=invalid-annotation,wrong-arg-types del beam_pipeline return {'b': float(a)} diff --git a/tfx/dsl/component/experimental/decorators_typeddict_test.py b/tfx/dsl/component/experimental/decorators_typeddict_test.py index 0e4ef8f41f..b631b812c5 100644 --- a/tfx/dsl/component/experimental/decorators_typeddict_test.py +++ b/tfx/dsl/component/experimental/decorators_typeddict_test.py @@ -40,6 +40,7 @@ from tfx.types.system_executions import SystemExecution _TestBeamPipelineArgs = ['--my_testing_beam_pipeline_args=foo'] +_TestEmptyBeamPipeline = beam.Pipeline() class _InputArtifact(types.Artifact): @@ -140,7 +141,7 @@ def verify_beam_pipeline_arg(a: int) -> TypedDict('Output6', dict(b=float)): # def verify_beam_pipeline_arg_non_none_default_value( a: int, - beam_pipeline: BeamComponentParameter[beam.Pipeline] = beam.Pipeline(), + beam_pipeline: BeamComponentParameter[beam.Pipeline] = _TestEmptyBeamPipeline, ) -> TypedDict('Output7', dict(b=float)): # pytype: disable=wrong-arg-types del beam_pipeline return {'b': float(a)} diff --git a/tfx/examples/bert/utils/bert_models.py b/tfx/examples/bert/utils/bert_models.py index d67fa1c6b0..a75f129f21 100644 --- a/tfx/examples/bert/utils/bert_models.py +++ b/tfx/examples/bert/utils/bert_models.py @@ -59,16 +59,15 @@ def build_bert_classifier(bert_layer: tf.keras.layers.Layer, def compile_bert_classifier( model: tf.keras.Model, - loss: tf.keras.losses.Loss = tf.keras.losses.SparseCategoricalCrossentropy( - from_logits=True), + loss: tf.keras.losses.Loss | None = None, learning_rate: float = 2e-5, metrics: Optional[List[Union[str, tf.keras.metrics.Metric]]] = None): """Compile the BERT classifier using suggested parameters. Args: model: A keras model. Most likely the output of build_bert_classifier. - loss: tf.keras.losses. The suggested loss function expects integer labels - (e.g. 0, 1, 2). If the labels are one-hot encoded, consider using + loss: Default None will use tf.keras.losses. The suggested loss function expects + integer labels (e.g. 0, 1, 2). If the labels are one-hot encoded, consider using tf.keras.lossesCategoricalCrossEntropy with from_logits set to true. learning_rate: Suggested learning rate to be used in tf.keras.optimizer.Adam. The three suggested learning_rates for @@ -79,6 +78,8 @@ def compile_bert_classifier( Returns: None. """ + if loss is None: + loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) if metrics is None: metrics = ["sparse_categorical_accuracy"]