Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid nan loss when there are labels with no samples in the training data. #12

Merged
merged 7 commits into from
Dec 16, 2024

Conversation

chbeltz
Copy link
Contributor

@chbeltz chbeltz commented Nov 22, 2024

Hello there.

I ran into problems today when trying to do a test run with training data that lacked samples for one of the labels. This causes the class-balanced focal loss to come out as nan.

import torch
from balanced_loss import Loss

samples_per_class = list(torch.tensor([ 310., 2489.,  114.,   17.,    0.,  725.]))
pred = torch.tensor([[4.1951e-04, 1.6066e-02, 3.2661e-03, 5.0763e-01, 1.0739e-03, 4.7154e-01],
        [7.6719e-03, 1.1280e-01, 5.8755e-02, 5.5621e-02, 6.6679e-01, 9.8361e-02],
        [3.0145e-03, 9.3653e-01, 1.7860e-02, 2.4776e-03, 3.6712e-03, 3.6448e-02],
        [1.0764e-03, 3.8136e-03, 4.5988e-03, 8.3224e-04, 9.8502e-01, 4.6638e-03],
        [9.5827e-03, 2.3838e-02, 5.1518e-02, 1.0943e-02, 2.9569e-02, 8.7455e-01]])
yb = torch.tensor([[0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0.],
        [1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0.]])
focal_loss = Loss(
    loss_type="focal_loss",
    samples_per_class=samples_per_class,
    beta=0.999, # class-balanced loss beta
    fl_gamma=2, # focal loss gamma
    class_balanced=True
)
print(focal_loss(pred, torch.argmax(yb, dim=-1).to(torch.int64)))

currently yields

tensor(nan)

/home/user/florist-environment/lib/python3.10/site-packages/balanced_loss/losses.py:111: RuntimeWarning: divide by zero encountered in divide
  weights = (1.0 - self.beta) / np.array(effective_num)
/home/user/florist-environment/lib/python3.10/site-packages/balanced_loss/losses.py:112: RuntimeWarning: invalid value encountered in divide
  weights = weights / np.sum(weights) * effective_num_classes

Adding a safe switch to the Loss class fixes this issue without any changes in weight for the non-zero-sample labels relative to leaving out the zero-sample labels. The loss, however, will come out larger than it would with alternative solution of removing the offending label.
grafik

I can see that this is an edge case. But, it will be helpful for me and I imagine it might also be for others. One could also consider raising a ValueError when no-sample labels are supplied hinting at making use of the safe switch.

@fcakyon
Copy link
Owner

fcakyon commented Dec 16, 2024

Hey @chbeltz thanks for your contribution!

Please reformat your code and we are good to merge 💯

@fcakyon fcakyon self-requested a review December 16, 2024 18:21
@fcakyon fcakyon added the enhancement New feature or request label Dec 16, 2024
…ature in Loss class for improved readability
…upgrade GitHub Actions to latest versions for improved performance and compatibility
…ce caching logic. Added installation steps for PyTorch versions 1.13.1 and 2.5.1, and included a step to display installed packages. This improves clarity and consistency across CI configurations.
…installation details for PyTorch versions 1.13.1 and 2.5.1. This enhances documentation accuracy and provides users with essential version information.
@fcakyon fcakyon merged commit 20f3779 into fcakyon:main Dec 16, 2024
@fcakyon fcakyon removed their request for review December 16, 2024 18:47
@fcakyon
Copy link
Owner

fcakyon commented Dec 16, 2024

@chbeltz its live on balanced-loss==0.1.1 🚀

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants