Skip to content

Commit

Permalink
Address comments from code-review.
Browse files Browse the repository at this point in the history
  • Loading branch information
Frightera committed Mar 26, 2023
1 parent 49c03a2 commit f560336
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 37 deletions.
55 changes: 35 additions & 20 deletions keras/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5589,26 +5589,38 @@ def categorical_focal_crossentropy(
According to [Lin et al., 2018](https://arxiv.org/pdf/1708.02002.pdf), it
helps to apply a focal factor to down-weight easy examples and focus more on
hard examples. By default, the focal tensor is computed as follows:
hard examples. The general formula for the focal loss (FL)
is as follows:
It has pt defined as:
pt = p, if y = 1 else 1 - p
`FL(p_t) = (1 − p_t)^gamma * log(p_t)`
The authors use alpha-balanced variant of focal loss in the paper:
FL(pt) = −α_t * (1 − pt)^gamma * log(pt)
where `p_t` is defined as follows:
`p_t = output if y_true == 1, else 1 - output`
Extending this to multi-class case is straightforward:
FL(pt) = α_t * (1 − pt)^gamma * CE, where minus comes from
negative log-likelihood and included in CE.
`(1 − p_t)^gamma` is the `modulating_factor`, where `gamma` is a focusing
parameter. When `gamma` = 0, there is no focal effect on the cross entropy.
`gamma` reduces the importance given to simple examples in a smooth manner.
The authors use alpha-balanced variant of focal loss (FL) in the paper:
`FL(p_t) = −alpha * (1 − p_t)^gamma * log(p_t)`
where `alpha` is the weight factor for the classes. If `alpha` = 1, the
loss won't be able to handle class imbalance properly as all
classes will have the same weight. This can be a constant or a list of
constants. If alpha is a list, it must have the same length as the number
of classes.
`modulating_factor` is (1 − pt)^gamma, where `gamma` is a focusing
parameter. When `gamma` = 0, there is no focal effect on the categorical
crossentropy. And if alpha = 1, at the same time the loss is equivalent
to the categorical crossentropy.
The formula above can be generalized to:
`FL(p_t) = alpha * (1 − p_t)^gamma * CrossEntropy(target, output)`
where minus comes from `CrossEntropy(target, output)` (CE).
Extending this to multi-class case is straightforward:
`FL(p_t) = alpha * (1 − p_t)^gamma * CategoricalCE(target, output)`
Args:
target: A tensor with the same shape as `output`.
output: A tensor.
target: Ground truth values from the dataset.
output: Predictions of the model.
alpha: A weight balancing factor for all classes, default is `0.25` as
mentioned in the reference. It can be a list of floats or a scalar.
In the multi-class case, alpha may be set by inverse class
Expand All @@ -5619,6 +5631,9 @@ def categorical_focal_crossentropy(
from_logits: Whether `output` is expected to be a logits tensor. By
default, we consider that `output` encodes a probability
distribution.
axis: Int specifying the channels axis. `axis=-1` corresponds to data
format `channels_last`, and `axis=1` corresponds to data format
`channels_first`.
Returns:
A tensor.
Expand All @@ -5631,13 +5646,13 @@ def categorical_focal_crossentropy(
output, from_logits, "Softmax", "categorical_focal_crossentropy"
)

output = tf.__internal__.smart_cond.smart_cond(
from_logits,
lambda: softmax(output),
lambda: output,
)
if from_logits:
output = tf.nn.softmax(output, axis=axis)

# scale preds so that the class probas of each sample sum to 1
# Adjust the predictions so that the probability of
# each class for every sample adds up to 1
# This is needed to ensure that the cross entropy is
# computed correctly.
output = output / tf.reduce_sum(output, axis=axis, keepdims=True)

epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
Expand Down
46 changes: 29 additions & 17 deletions keras/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,29 +926,41 @@ def __init__(
class CategoricalFocalCrossentropy(LossFunctionWrapper):
"""Computes the alpha balanced focal crossentropy loss.
Use this crossentropy loss function when there are two or more label
classes and if you want to handle class imbalance without using
`class_weights`. We expect labels to be provided in a `one_hot`
representation.
According to [Lin et al., 2018](https://arxiv.org/pdf/1708.02002.pdf), it
helps to apply a focal factor to down-weight easy examples and focus more on
hard examples. By default, the focal tensor is computed as follows:
hard examples. The general formula for the focal loss (FL)
is as follows:
It has pt defined as:
pt = p, if y = 1 else 1 - p
`FL(p_t) = (1 − p_t)^gamma * log(p_t)`
The authors use alpha-balanced variant of focal loss in the paper:
FL(pt) = −α_t * (1 − pt)^gamma * log(pt)
where `p_t` is defined as follows:
`p_t = output if y_true == 1, else 1 - output`
Extending this to multi-class case is straightforward:
FL(pt) = α_t * (1 − pt)^gamma * CE, where minus comes from
negative log-likelihood and included in CE.
`(1 − p_t)^gamma` is the `modulating_factor`, where `gamma` is a focusing
parameter. When `gamma` = 0, there is no focal effect on the cross entropy.
`gamma` reduces the importance given to simple examples in a smooth manner.
`modulating_factor` is (1 − pt)^gamma, where `gamma` is a focusing
parameter. When `gamma` = 0, there is no focal effect on the categorical
crossentropy. And if alpha = 1, at the same time the loss is equivalent to
the categorical crossentropy.
The authors use alpha-balanced variant of focal loss (FL) in the paper:
`FL(p_t) = −alpha * (1 − p_t)^gamma * log(p_t)`
Use this crossentropy loss function when there are two or more label
classes and if you want to handle class imbalance without using
`class_weights`.
We expect labels to be provided in a `one_hot` representation.
where `alpha` is the weight factor for the classes. If `alpha` = 1, the
loss won't be able to handle class imbalance properly as all
classes will have the same weight. This can be a constant or a list of
constants. If alpha is a list, it must have the same length as the number
of classes.
The formula above can be generalized to:
`FL(p_t) = alpha * (1 − p_t)^gamma * CrossEntropy(y_true, y_pred)`
where minus comes from `CrossEntropy(y_true, y_pred)` (CE).
Extending this to multi-class case is straightforward:
`FL(p_t) = alpha * (1 − p_t)^gamma * CategoricalCE(y_true, y_pred)`
In the snippet below, there is `# classes` floating pointing values per
example. The shape of both `y_pred` and `y_true` are
Expand Down Expand Up @@ -981,7 +993,7 @@ class CategoricalFocalCrossentropy(LossFunctionWrapper):
Usage with the `compile()` API:
```python
model.compile(optimizer='sgd',
model.compile(optimizer='adam',
loss=tf.keras.losses.CategoricalFocalCrossentropy())
```
Args:
Expand Down

0 comments on commit f560336

Please sign in to comment.