Skip to content

Commit

Permalink
Merge pull request #6952 from janasangeetha/sangeethajana-patch2
Browse files Browse the repository at this point in the history
Fix Ruff B008 errors
  • Loading branch information
lego0901 authored Nov 15, 2024
2 parents 8393d14 + 2f6a67d commit 7615e5a
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 6 deletions.
3 changes: 2 additions & 1 deletion tfx/dsl/component/experimental/decorators_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from tfx.types.system_executions import SystemExecution

_TestBeamPipelineArgs = ['--my_testing_beam_pipeline_args=foo']
_TestEmptyBeamPipeline = beam.Pipeline()


class _InputArtifact(types.Artifact):
Expand Down Expand Up @@ -140,7 +141,7 @@ def verify_beam_pipeline_arg(a: int) -> OutputDict(b=float): # pytype: disable=

def verify_beam_pipeline_arg_non_none_default_value(
a: int,
beam_pipeline: BeamComponentParameter[beam.Pipeline] = beam.Pipeline(),
beam_pipeline: BeamComponentParameter[beam.Pipeline] = _TestEmptyBeamPipeline,
) -> OutputDict(b=float): # pytype: disable=invalid-annotation,wrong-arg-types
del beam_pipeline
return {'b': float(a)}
Expand Down
3 changes: 2 additions & 1 deletion tfx/dsl/component/experimental/decorators_typeddict_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from tfx.types.system_executions import SystemExecution

_TestBeamPipelineArgs = ['--my_testing_beam_pipeline_args=foo']
_TestEmptyBeamPipeline = beam.Pipeline()


class _InputArtifact(types.Artifact):
Expand Down Expand Up @@ -140,7 +141,7 @@ def verify_beam_pipeline_arg(a: int) -> TypedDict('Output6', dict(b=float)): #

def verify_beam_pipeline_arg_non_none_default_value(
a: int,
beam_pipeline: BeamComponentParameter[beam.Pipeline] = beam.Pipeline(),
beam_pipeline: BeamComponentParameter[beam.Pipeline] = _TestEmptyBeamPipeline,
) -> TypedDict('Output7', dict(b=float)): # pytype: disable=wrong-arg-types
del beam_pipeline
return {'b': float(a)}
Expand Down
9 changes: 5 additions & 4 deletions tfx/examples/bert/utils/bert_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,15 @@ def build_bert_classifier(bert_layer: tf.keras.layers.Layer,

def compile_bert_classifier(
model: tf.keras.Model,
loss: tf.keras.losses.Loss = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True),
loss: tf.keras.losses.Loss | None = None,
learning_rate: float = 2e-5,
metrics: Optional[List[Union[str, tf.keras.metrics.Metric]]] = None):
"""Compile the BERT classifier using suggested parameters.
Args:
model: A keras model. Most likely the output of build_bert_classifier.
loss: tf.keras.losses. The suggested loss function expects integer labels
(e.g. 0, 1, 2). If the labels are one-hot encoded, consider using
loss: Default None will use tf.keras.losses. The suggested loss function expects
integer labels (e.g. 0, 1, 2). If the labels are one-hot encoded, consider using
tf.keras.lossesCategoricalCrossEntropy with from_logits set to true.
learning_rate: Suggested learning rate to be used in
tf.keras.optimizer.Adam. The three suggested learning_rates for
Expand All @@ -79,6 +78,8 @@ def compile_bert_classifier(
Returns:
None.
"""
if loss is None:
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
if metrics is None:
metrics = ["sparse_categorical_accuracy"]

Expand Down

0 comments on commit 7615e5a

Please sign in to comment.