From f56033645287df77f6fd0c0741bb16b74090f0eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kaan=20B=C4=B1=C3=A7akc=C4=B1?= Date: Sun, 26 Mar 2023 21:25:03 +0100 Subject: [PATCH] Address comments from code-review. --- keras/backend.py | 55 ++++++++++++++++++++++++++++++------------------ keras/losses.py | 46 +++++++++++++++++++++++++--------------- 2 files changed, 64 insertions(+), 37 deletions(-) diff --git a/keras/backend.py b/keras/backend.py index c3fcdc8ece3..63e7bcd20bf 100644 --- a/keras/backend.py +++ b/keras/backend.py @@ -5589,26 +5589,38 @@ def categorical_focal_crossentropy( According to [Lin et al., 2018](https://arxiv.org/pdf/1708.02002.pdf), it helps to apply a focal factor to down-weight easy examples and focus more on - hard examples. By default, the focal tensor is computed as follows: + hard examples. The general formula for the focal loss (FL) + is as follows: - It has pt defined as: - pt = p, if y = 1 else 1 - p + `FL(p_t) = (1 − p_t)^gamma * log(p_t)` - The authors use alpha-balanced variant of focal loss in the paper: - FL(pt) = −α_t * (1 − pt)^gamma * log(pt) + where `p_t` is defined as follows: + `p_t = output if y_true == 1, else 1 - output` - Extending this to multi-class case is straightforward: - FL(pt) = α_t * (1 − pt)^gamma * CE, where minus comes from - negative log-likelihood and included in CE. + `(1 − p_t)^gamma` is the `modulating_factor`, where `gamma` is a focusing + parameter. When `gamma` = 0, there is no focal effect on the cross entropy. + `gamma` reduces the importance given to simple examples in a smooth manner. + + The authors use alpha-balanced variant of focal loss (FL) in the paper: + `FL(p_t) = −alpha * (1 − p_t)^gamma * log(p_t)` + + where `alpha` is the weight factor for the classes. If `alpha` = 1, the + loss won't be able to handle class imbalance properly as all + classes will have the same weight. This can be a constant or a list of + constants. If alpha is a list, it must have the same length as the number + of classes. - `modulating_factor` is (1 − pt)^gamma, where `gamma` is a focusing - parameter. When `gamma` = 0, there is no focal effect on the categorical - crossentropy. And if alpha = 1, at the same time the loss is equivalent - to the categorical crossentropy. + The formula above can be generalized to: + `FL(p_t) = alpha * (1 − p_t)^gamma * CrossEntropy(target, output)` + + where minus comes from `CrossEntropy(target, output)` (CE). + + Extending this to multi-class case is straightforward: + `FL(p_t) = alpha * (1 − p_t)^gamma * CategoricalCE(target, output)` Args: - target: A tensor with the same shape as `output`. - output: A tensor. + target: Ground truth values from the dataset. + output: Predictions of the model. alpha: A weight balancing factor for all classes, default is `0.25` as mentioned in the reference. It can be a list of floats or a scalar. In the multi-class case, alpha may be set by inverse class @@ -5619,6 +5631,9 @@ def categorical_focal_crossentropy( from_logits: Whether `output` is expected to be a logits tensor. By default, we consider that `output` encodes a probability distribution. + axis: Int specifying the channels axis. `axis=-1` corresponds to data + format `channels_last`, and `axis=1` corresponds to data format + `channels_first`. Returns: A tensor. @@ -5631,13 +5646,13 @@ def categorical_focal_crossentropy( output, from_logits, "Softmax", "categorical_focal_crossentropy" ) - output = tf.__internal__.smart_cond.smart_cond( - from_logits, - lambda: softmax(output), - lambda: output, - ) + if from_logits: + output = tf.nn.softmax(output, axis=axis) - # scale preds so that the class probas of each sample sum to 1 + # Adjust the predictions so that the probability of + # each class for every sample adds up to 1 + # This is needed to ensure that the cross entropy is + # computed correctly. output = output / tf.reduce_sum(output, axis=axis, keepdims=True) epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype) diff --git a/keras/losses.py b/keras/losses.py index a8c32d460b4..adf918a5102 100644 --- a/keras/losses.py +++ b/keras/losses.py @@ -926,29 +926,41 @@ def __init__( class CategoricalFocalCrossentropy(LossFunctionWrapper): """Computes the alpha balanced focal crossentropy loss. + Use this crossentropy loss function when there are two or more label + classes and if you want to handle class imbalance without using + `class_weights`. We expect labels to be provided in a `one_hot` + representation. + According to [Lin et al., 2018](https://arxiv.org/pdf/1708.02002.pdf), it helps to apply a focal factor to down-weight easy examples and focus more on - hard examples. By default, the focal tensor is computed as follows: + hard examples. The general formula for the focal loss (FL) + is as follows: - It has pt defined as: - pt = p, if y = 1 else 1 - p + `FL(p_t) = (1 − p_t)^gamma * log(p_t)` - The authors use alpha-balanced variant of focal loss in the paper: - FL(pt) = −α_t * (1 − pt)^gamma * log(pt) + where `p_t` is defined as follows: + `p_t = output if y_true == 1, else 1 - output` - Extending this to multi-class case is straightforward: - FL(pt) = α_t * (1 − pt)^gamma * CE, where minus comes from - negative log-likelihood and included in CE. + `(1 − p_t)^gamma` is the `modulating_factor`, where `gamma` is a focusing + parameter. When `gamma` = 0, there is no focal effect on the cross entropy. + `gamma` reduces the importance given to simple examples in a smooth manner. - `modulating_factor` is (1 − pt)^gamma, where `gamma` is a focusing - parameter. When `gamma` = 0, there is no focal effect on the categorical - crossentropy. And if alpha = 1, at the same time the loss is equivalent to - the categorical crossentropy. + The authors use alpha-balanced variant of focal loss (FL) in the paper: + `FL(p_t) = −alpha * (1 − p_t)^gamma * log(p_t)` - Use this crossentropy loss function when there are two or more label - classes and if you want to handle class imbalance without using - `class_weights`. - We expect labels to be provided in a `one_hot` representation. + where `alpha` is the weight factor for the classes. If `alpha` = 1, the + loss won't be able to handle class imbalance properly as all + classes will have the same weight. This can be a constant or a list of + constants. If alpha is a list, it must have the same length as the number + of classes. + + The formula above can be generalized to: + `FL(p_t) = alpha * (1 − p_t)^gamma * CrossEntropy(y_true, y_pred)` + + where minus comes from `CrossEntropy(y_true, y_pred)` (CE). + + Extending this to multi-class case is straightforward: + `FL(p_t) = alpha * (1 − p_t)^gamma * CategoricalCE(y_true, y_pred)` In the snippet below, there is `# classes` floating pointing values per example. The shape of both `y_pred` and `y_true` are @@ -981,7 +993,7 @@ class CategoricalFocalCrossentropy(LossFunctionWrapper): Usage with the `compile()` API: ```python - model.compile(optimizer='sgd', + model.compile(optimizer='adam', loss=tf.keras.losses.CategoricalFocalCrossentropy()) ``` Args: