Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
fix binary classification for tensorflow segformer

fix binary classification for tf segformer huggingface#2

fix huggingface#5

Revert "fix huggingface#5"

This reverts commit 15b516055c25faa3297196095de19b41ff0149fe.

Revert "fix huggingface#4"

This reverts commit 0b534e62d03db5ef74f77b61837e0561a1fc129a.

fix huggingface#5

fix

fix

fix
  • Loading branch information
nikolaJovisic committed Aug 23, 2023
1 parent 3629190 commit e6f02d4
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions src/transformers/models/segformer/modeling_tf_segformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,20 +757,23 @@ def hf_compute_loss(self, logits, labels):
# `labels` is of shape (batch_size, height, width)
label_interp_shape = shape_list(labels)[1:]

upsampled_logits = tf.image.resize(logits, size=label_interp_shape, method="bilinear")
logits = tf.image.resize(logits, size=label_interp_shape, method="bilinear")



# compute weighted loss
loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction="none")
loss_fct = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction="none")

def masked_loss(real, pred):
unmasked_loss = loss_fct(real, pred)
mask = tf.cast(real != self.config.semantic_loss_ignore_index, dtype=unmasked_loss.dtype)
masked_loss = unmasked_loss * mask
# Reduction strategy in the similar spirit with
# https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_tf_utils.py#L210
reduced_masked_loss = tf.reduce_sum(masked_loss) / tf.reduce_sum(mask)
return tf.reshape(reduced_masked_loss, (1,))
logits = tf.reshape(logits, [-1])
labels = tf.reshape(labels, [-1])

return masked_loss(labels, upsampled_logits)
logits = tf.sigmoid(logits)

loss = loss_fct(labels, logits)
loss = tf.reduce_sum(loss)
loss = tf.reshape(loss, (1,))

return loss

@unpack_inputs
@add_start_docstrings_to_model_forward(SEGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
Expand Down Expand Up @@ -829,10 +832,7 @@ def call(

loss = None
if labels is not None:
if not self.config.num_labels > 1:
raise ValueError("The number of labels should be greater than one")
else:
loss = self.hf_compute_loss(logits=logits, labels=labels)
loss = self.hf_compute_loss(logits=logits, labels=labels)

# make logits of shape (batch_size, num_labels, height, width) to
# keep them consistent across APIs
Expand Down

0 comments on commit e6f02d4

Please sign in to comment.