Skip to content

Commit

Permalink
Merge pull request #17651 from Frightera:frightera_categorical_focal_…
Browse files Browse the repository at this point in the history
…loss_v2

PiperOrigin-RevId: 520482950
  • Loading branch information
tensorflower-gardener committed Mar 30, 2023
2 parents 42047cc + f560336 commit 0f8e81f
Show file tree
Hide file tree
Showing 12 changed files with 699 additions and 0 deletions.
4 changes: 4 additions & 0 deletions keras/api/golden/v1/tensorflow.keras.backend.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ tf_module {
name: "categorical_crossentropy"
argspec: "args=[\'target\', \'output\', \'from_logits\', \'axis\'], varargs=None, keywords=None, defaults=[\'False\', \'-1\'], "
}
member_method {
name: "categorical_focal_crossentropy"
argspec: "args=[\'target\', \'output\', \'alpha\', \'gamma\', \'from_logits\', \'axis\'], varargs=None, keywords=None, defaults=[\'0.25\', \'2.0\', \'False\', \'-1\'], "
}
member_method {
name: "clear_session"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
path: "tensorflow.keras.losses.CategoricalFocalCrossentropy"
tf_class {
is_instance: "<class \'keras.losses.CategoricalFocalCrossentropy\'>"
is_instance: "<class \'keras.losses.LossFunctionWrapper\'>"
is_instance: "<class \'keras.losses.Loss\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'alpha\', \'gamma\', \'from_logits\', \'label_smoothing\', \'axis\', \'reduction\', \'name\'], varargs=None, keywords=None, defaults=[\'0.25\', \'2.0\', \'False\', \'0.0\', \'-1\', \'auto\', \'categorical_focal_crossentropy\'], "
}
member_method {
name: "call"
argspec: "args=[\'self\', \'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "from_config"
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_config"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}
8 changes: 8 additions & 0 deletions keras/api/golden/v1/tensorflow.keras.losses.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ tf_module {
name: "CategoricalCrossentropy"
mtype: "<type \'type\'>"
}
member {
name: "CategoricalFocalCrossentropy"
mtype: "<type \'type\'>"
}
member {
name: "CategoricalHinge"
mtype: "<type \'type\'>"
Expand Down Expand Up @@ -100,6 +104,10 @@ tf_module {
name: "categorical_crossentropy"
argspec: "args=[\'y_true\', \'y_pred\', \'from_logits\', \'label_smoothing\', \'axis\'], varargs=None, keywords=None, defaults=[\'False\', \'0.0\', \'-1\'], "
}
member_method {
name: "categorical_focal_crossentropy"
argspec: "args=[\'y_true\', \'y_pred\', \'alpha\', \'gamma\', \'from_logits\', \'label_smoothing\', \'axis\'], varargs=None, keywords=None, defaults=[\'0.25\', \'2.0\', \'False\', \'0.0\', \'-1\'], "
}
member_method {
name: "categorical_hinge"
argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None"
Expand Down
4 changes: 4 additions & 0 deletions keras/api/golden/v1/tensorflow.keras.metrics.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,10 @@ tf_module {
name: "categorical_crossentropy"
argspec: "args=[\'y_true\', \'y_pred\', \'from_logits\', \'label_smoothing\', \'axis\'], varargs=None, keywords=None, defaults=[\'False\', \'0.0\', \'-1\'], "
}
member_method {
name: "categorical_focal_crossentropy"
argspec: "args=[\'y_true\', \'y_pred\', \'alpha\', \'gamma\', \'from_logits\', \'label_smoothing\', \'axis\'], varargs=None, keywords=None, defaults=[\'0.25\', \'2.0\', \'False\', \'0.0\', \'-1\'], "
}
member_method {
name: "cosine"
argspec: "args=[\'y_true\', \'y_pred\', \'axis\'], varargs=None, keywords=None, defaults=[\'-1\'], "
Expand Down
4 changes: 4 additions & 0 deletions keras/api/golden/v2/tensorflow.keras.backend.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ tf_module {
name: "categorical_crossentropy"
argspec: "args=[\'target\', \'output\', \'from_logits\', \'axis\'], varargs=None, keywords=None, defaults=[\'False\', \'-1\'], "
}
member_method {
name: "categorical_focal_crossentropy"
argspec: "args=[\'target\', \'output\', \'alpha\', \'gamma\', \'from_logits\', \'axis\'], varargs=None, keywords=None, defaults=[\'0.25\', \'2.0\', \'False\', \'-1\'], "
}
member_method {
name: "clear_session"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
path: "tensorflow.keras.losses.CategoricalFocalCrossentropy"
tf_class {
is_instance: "<class \'keras.losses.CategoricalFocalCrossentropy\'>"
is_instance: "<class \'keras.losses.LossFunctionWrapper\'>"
is_instance: "<class \'keras.losses.Loss\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'alpha\', \'gamma\', \'from_logits\', \'label_smoothing\', \'axis\', \'reduction\', \'name\'], varargs=None, keywords=None, defaults=[\'0.25\', \'2.0\', \'False\', \'0.0\', \'-1\', \'auto\', \'categorical_focal_crossentropy\'], "
}
member_method {
name: "call"
argspec: "args=[\'self\', \'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "from_config"
argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_config"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}
8 changes: 8 additions & 0 deletions keras/api/golden/v2/tensorflow.keras.losses.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ tf_module {
name: "CategoricalCrossentropy"
mtype: "<type \'type\'>"
}
member {
name: "CategoricalFocalCrossentropy"
mtype: "<type \'type\'>"
}
member {
name: "CategoricalHinge"
mtype: "<type \'type\'>"
Expand Down Expand Up @@ -104,6 +108,10 @@ tf_module {
name: "categorical_crossentropy"
argspec: "args=[\'y_true\', \'y_pred\', \'from_logits\', \'label_smoothing\', \'axis\'], varargs=None, keywords=None, defaults=[\'False\', \'0.0\', \'-1\'], "
}
member_method {
name: "categorical_focal_crossentropy"
argspec: "args=[\'y_true\', \'y_pred\', \'alpha\', \'gamma\', \'from_logits\', \'label_smoothing\', \'axis\'], varargs=None, keywords=None, defaults=[\'0.25\', \'2.0\', \'False\', \'0.0\', \'-1\'], "
}
member_method {
name: "categorical_hinge"
argspec: "args=[\'y_true\', \'y_pred\'], varargs=None, keywords=None, defaults=None"
Expand Down
4 changes: 4 additions & 0 deletions keras/api/golden/v2/tensorflow.keras.metrics.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,10 @@ tf_module {
name: "categorical_crossentropy"
argspec: "args=[\'y_true\', \'y_pred\', \'from_logits\', \'label_smoothing\', \'axis\'], varargs=None, keywords=None, defaults=[\'False\', \'0.0\', \'-1\'], "
}
member_method {
name: "categorical_focal_crossentropy"
argspec: "args=[\'y_true\', \'y_pred\', \'alpha\', \'gamma\', \'from_logits\', \'label_smoothing\', \'axis\'], varargs=None, keywords=None, defaults=[\'0.25\', \'2.0\', \'False\', \'0.0\', \'-1\'], "
}
member_method {
name: "deserialize"
argspec: "args=[\'config\', \'custom_objects\', \'use_legacy_format\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
Expand Down
97 changes: 97 additions & 0 deletions keras/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5574,6 +5574,103 @@ def categorical_crossentropy(target, output, from_logits=False, axis=-1):
return -tf.reduce_sum(target * tf.math.log(output), axis)


@keras_export("keras.backend.categorical_focal_crossentropy")
@tf.__internal__.dispatch.add_dispatch_support
@doc_controls.do_not_generate_docs
def categorical_focal_crossentropy(
target,
output,
alpha=0.25,
gamma=2.0,
from_logits=False,
axis=-1,
):
"""Computes the alpha balanced focal crossentropy loss.
According to [Lin et al., 2018](https://arxiv.org/pdf/1708.02002.pdf), it
helps to apply a focal factor to down-weight easy examples and focus more on
hard examples. The general formula for the focal loss (FL)
is as follows:
`FL(p_t) = (1 − p_t)^gamma * log(p_t)`
where `p_t` is defined as follows:
`p_t = output if y_true == 1, else 1 - output`
`(1 − p_t)^gamma` is the `modulating_factor`, where `gamma` is a focusing
parameter. When `gamma` = 0, there is no focal effect on the cross entropy.
`gamma` reduces the importance given to simple examples in a smooth manner.
The authors use alpha-balanced variant of focal loss (FL) in the paper:
`FL(p_t) = −alpha * (1 − p_t)^gamma * log(p_t)`
where `alpha` is the weight factor for the classes. If `alpha` = 1, the
loss won't be able to handle class imbalance properly as all
classes will have the same weight. This can be a constant or a list of
constants. If alpha is a list, it must have the same length as the number
of classes.
The formula above can be generalized to:
`FL(p_t) = alpha * (1 − p_t)^gamma * CrossEntropy(target, output)`
where minus comes from `CrossEntropy(target, output)` (CE).
Extending this to multi-class case is straightforward:
`FL(p_t) = alpha * (1 − p_t)^gamma * CategoricalCE(target, output)`
Args:
target: Ground truth values from the dataset.
output: Predictions of the model.
alpha: A weight balancing factor for all classes, default is `0.25` as
mentioned in the reference. It can be a list of floats or a scalar.
In the multi-class case, alpha may be set by inverse class
frequency by using `compute_class_weight` from `sklearn.utils`.
gamma: A focusing parameter, default is `2.0` as mentioned in the
reference. It helps to gradually reduce the importance given to
simple examples in a smooth manner.
from_logits: Whether `output` is expected to be a logits tensor. By
default, we consider that `output` encodes a probability
distribution.
axis: Int specifying the channels axis. `axis=-1` corresponds to data
format `channels_last`, and `axis=1` corresponds to data format
`channels_first`.
Returns:
A tensor.
"""
target = tf.convert_to_tensor(target)
output = tf.convert_to_tensor(output)
target.shape.assert_is_compatible_with(output.shape)

output, from_logits = _get_logits(
output, from_logits, "Softmax", "categorical_focal_crossentropy"
)

if from_logits:
output = tf.nn.softmax(output, axis=axis)

# Adjust the predictions so that the probability of
# each class for every sample adds up to 1
# This is needed to ensure that the cross entropy is
# computed correctly.
output = output / tf.reduce_sum(output, axis=axis, keepdims=True)

epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
output = tf.clip_by_value(output, epsilon_, 1.0 - epsilon_)

# Calculate cross entropy
cce = -target * tf.math.log(output)

# Calculate factors
modulating_factor = tf.pow(1.0 - output, gamma)
weighting_factor = tf.multiply(modulating_factor, alpha)

# Apply weighting factor
focal_cce = tf.multiply(weighting_factor, cce)
focal_cce = tf.reduce_sum(focal_cce, axis=axis)
return focal_cce


@keras_export("keras.backend.sparse_categorical_crossentropy")
@tf.__internal__.dispatch.add_dispatch_support
@doc_controls.do_not_generate_docs
Expand Down
47 changes: 47 additions & 0 deletions keras/backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2244,6 +2244,19 @@ def test_binary_focal_crossentropy_with_sigmoid(self):
)
self.assertArrayNear(result[0], [7.995, 0.022, 0.701], 1e-3)

@test_combinations.generate(
test_combinations.combine(mode=["graph", "eager"])
)
def test_categorical_focal_crossentropy_with_softmax(self):
t = backend.constant([[0, 1, 0]])
logits = backend.constant([[8.0, 1.0, 1.0]])
p = backend.softmax(logits)
p = tf.identity(tf.identity(p))
result = self.evaluate(
backend.categorical_focal_crossentropy(t, p, gamma=2.0)
)
self.assertArrayNear(result, [1.747], 1e-3)

@test_combinations.generate(
test_combinations.combine(mode=["graph", "eager"])
)
Expand All @@ -2260,6 +2273,21 @@ def test_binary_focal_crossentropy_from_logits(self):
)
self.assertArrayNear(result[0], [7.995, 0.022, 0.701], 1e-3)

@test_combinations.generate(
test_combinations.combine(mode=["graph", "eager"])
)
def test_categorical_focal_crossentropy_from_logits(self):
t = backend.constant([[0, 1, 0]])
logits = backend.constant([[8.0, 1.0, 1.0]])
result = self.evaluate(
backend.categorical_focal_crossentropy(
target=t,
output=logits,
from_logits=True,
)
)
self.assertArrayNear(result, [1.7472], 1e-3)

@test_combinations.generate(
test_combinations.combine(mode=["graph", "eager"])
)
Expand All @@ -2279,6 +2307,25 @@ def test_binary_focal_crossentropy_no_focal_effect_with_zero_gamma(self):
non_focal_result = self.evaluate(backend.binary_crossentropy(t, p))
self.assertArrayNear(focal_result[0], non_focal_result[0], 1e-3)

@test_combinations.generate(
test_combinations.combine(mode=["graph", "eager"])
)
def test_categorical_focal_crossentropy_no_focal_effect(self):
t = backend.constant([[0, 1, 0]])
logits = backend.constant([[8.0, 1.0, 1.0]])
p = backend.softmax(logits)
p = tf.identity(tf.identity(p))
focal_result = self.evaluate(
backend.categorical_focal_crossentropy(
target=t,
output=p,
gamma=0.0,
alpha=1.0,
)
)
non_focal_result = self.evaluate(backend.categorical_crossentropy(t, p))
self.assertArrayNear(focal_result, non_focal_result, 1e-3)

@test_combinations.generate(
test_combinations.combine(mode=["graph", "eager"])
)
Expand Down
Loading

0 comments on commit 0f8e81f

Please sign in to comment.