Skip to content

Commit

Permalink
Avoid nan loss when there are labels with no samples in the training …
Browse files Browse the repository at this point in the history
…data. (#12)
  • Loading branch information
chbeltz authored Dec 16, 2024
1 parent 1c71f91 commit 20f3779
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 46 deletions.
47 changes: 29 additions & 18 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,39 +11,39 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
operating-system: [ubuntu-latest, windows-latest, macos-latest]
python-version: [3.7, 3.8, 3.9]
torch-version: [1.10.2, 1.11.0, 1.12.0]
os: [ubuntu-latest, windows-latest, macos-latest]
python-version: [3.9, "3.10"]
torch-version: [1.13.1, 2.5.1]
fail-fast: false

steps:
- name: Checkout
uses: actions/checkout@v2
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v2
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: Restore Ubuntu cache
uses: actions/cache@v1
if: matrix.operating-system == 'ubuntu-latest'
uses: actions/cache@v4
if: matrix.os == 'ubuntu-latest'
with:
path: ~/.cache/pip
key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py')}}
restore-keys: ${{ matrix.os }}-${{ matrix.python-version }}-

- name: Restore MacOS cache
uses: actions/cache@v1
if: matrix.operating-system == 'macos-latest'
uses: actions/cache@v4
if: matrix.os == 'macos-latest'
with:
path: ~/Library/Caches/pip
key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py')}}
restore-keys: ${{ matrix.os }}-${{ matrix.python-version }}-

- name: Restore Windows cache
uses: actions/cache@v1
if: matrix.operating-system == 'windows-latest'
uses: actions/cache@v4
if: matrix.os == 'windows-latest'
with:
path: ~\AppData\Local\pip\Cache
key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py')}}
Expand All @@ -52,9 +52,14 @@ jobs:
- name: Update pip
run: python -m pip install --upgrade pip

- name: Install package in development mode
run: pip install -e .[dev]

- name: Show installed packages
run: pip list

- name: Lint with flake8, black and isort
run: |
pip install -e .[dev]
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
black . --check --config pyproject.toml
Expand All @@ -66,17 +71,23 @@ jobs:
run: >
pip install numpy
- name: Install PyTorch on Linux and Windows
- name: Install PyTorch==1.13.1 on Linux and Windows
if: >
matrix.operating-system == 'ubuntu-latest' ||
matrix.operating-system == 'windows-latest'
(matrix.os == 'ubuntu-latest' ||
matrix.os == 'windows-latest') &&
matrix.torch-version == '1.13.1'
run: >
pip install torch==${{ matrix.torch-version }}+cpu
-f https://download.pytorch.org/whl/torch_stable.html
- name: Install PyTorch on MacOS
if: matrix.operating-system == 'macos-latest'
run: pip install torch==${{ matrix.torch-version }}
- name: Install PyTorch==2.5.1 on Linux and Windows
if: >
(matrix.os == 'ubuntu-latest' ||
matrix.os == 'windows-latest') &&
matrix.torch-version == '2.5.1'
run: >
pip install torch==${{ matrix.torch-version }}
-f https://download.pytorch.org/whl/torch_stable.html
- name: Install balanced-loss package from local setup.py
run: >
Expand Down
41 changes: 25 additions & 16 deletions .github/workflows/package_testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,39 +10,39 @@ jobs:

strategy:
matrix:
operating-system: [ubuntu-latest, windows-latest, macos-latest]
python-version: [3.7, 3.8, 3.9]
torch-version: [1.10.2, 1.11.0, 1.12.0]
os: [ubuntu-latest, windows-latest, macos-latest]
python-version: [3.9, "3.10"]
torch-version: [1.13.1, 2.5.1]
fail-fast: false

steps:
- name: Checkout
uses: actions/checkout@v2
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v2
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: Restore Ubuntu cache
uses: actions/cache@v1
if: matrix.operating-system == 'ubuntu-latest'
uses: actions/cache@v4
if: matrix.os == 'ubuntu-latest'
with:
path: ~/.cache/pip
key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py')}}
restore-keys: ${{ matrix.os }}-${{ matrix.python-version }}-

- name: Restore MacOS cache
uses: actions/cache@v1
if: matrix.operating-system == 'macos-latest'
uses: actions/cache@v4
if: matrix.os == 'macos-latest'
with:
path: ~/Library/Caches/pip
key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py')}}
restore-keys: ${{ matrix.os }}-${{ matrix.python-version }}-

- name: Restore Windows cache
uses: actions/cache@v1
if: matrix.operating-system == 'windows-latest'
uses: actions/cache@v4
if: matrix.os == 'windows-latest'
with:
path: ~\AppData\Local\pip\Cache
key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py')}}
Expand All @@ -55,16 +55,26 @@ jobs:
run: >
pip install numpy
- name: Install PyTorch on Linux and Windows
- name: Install PyTorch==1.13.1 on Linux and Windows
if: >
matrix.operating-system == 'ubuntu-latest' ||
matrix.operating-system == 'windows-latest'
(matrix.os == 'ubuntu-latest' ||
matrix.os == 'windows-latest') &&
matrix.torch-version == '1.13.1'
run: >
pip install torch==${{ matrix.torch-version }}+cpu
-f https://download.pytorch.org/whl/torch_stable.html
- name: Install PyTorch==2.5.1 on Linux and Windows
if: >
(matrix.os == 'ubuntu-latest' ||
matrix.os == 'windows-latest') &&
matrix.torch-version == '2.5.1'
run: >
pip install torch==${{ matrix.torch-version }}
-f https://download.pytorch.org/whl/torch_stable.html
- name: Install PyTorch on MacOS
if: matrix.operating-system == 'macos-latest'
if: matrix.os == 'macos-latest'
run: pip install torch==${{ matrix.torch-version }}

- name: Install latest balanced-loss package
Expand All @@ -74,4 +84,3 @@ jobs:
- name: Unittest balanced-loss
run: |
python -m unittest
4 changes: 2 additions & 2 deletions .github/workflows/publish_pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ jobs:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v2
uses: actions/setup-python@v5
with:
python-version: '3.x'
- name: Install dependencies
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ When training dataset labels are imbalanced, one thing to do is to balance the l

![alt-text](https://user-images.githubusercontent.com/34196005/180266198-e27d8cba-f5e1-49ca-9f82-d8656333e3c4.png)


## Installation

```bash
Expand Down Expand Up @@ -134,6 +133,7 @@ What is the difference between this repo and vandit15's?
- This repo implements loss functions as `torch.nn.Module`
- In addition to class balanced losses, this repo also supports the standard versions of the cross entropy/focal loss etc. over the same API
- All typos and errors in vandit15's source are fixed
- Continiously tested on PyTorch 1.13.1 and 2.5.1

## References

Expand Down
2 changes: 1 addition & 1 deletion balanced_loss/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .losses import Loss

__version__ = "0.1.0"
__version__ = "0.1.1"
19 changes: 13 additions & 6 deletions balanced_loss/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
fl_gamma=2,
samples_per_class=None,
class_balanced=False,
safe: bool = False,
):
"""
Compute the Class Balanced Loss between `logits` and the ground truth `labels`.
Expand All @@ -60,6 +61,7 @@ def __init__(
samples_per_class: A python list of size [num_classes].
Required if class_balance is True.
class_balanced: bool. Whether to use class balanced loss.
safe: bool. Whether to allow labels with no samples.
Returns:
Loss instance
"""
Expand All @@ -73,12 +75,9 @@ def __init__(
self.fl_gamma = fl_gamma
self.samples_per_class = samples_per_class
self.class_balanced = class_balanced
self.safe = safe

def forward(
self,
logits: torch.tensor,
labels: torch.tensor,
):
def forward(self, logits: torch.tensor, labels: torch.tensor):
"""
Compute the Class Balanced Loss between `logits` and the ground truth `labels`.
Class Balanced Loss: ((1-beta)/(1-beta^n))*Loss(labels, logits)
Expand All @@ -97,8 +96,16 @@ def forward(

if self.class_balanced:
effective_num = 1.0 - np.power(self.beta, self.samples_per_class)
# Avoid division by 0 error for test cases without all labels present.
if self.safe:
effective_num_classes = np.sum(effective_num != 0)
effective_num[effective_num == 0] = np.inf

else:
effective_num_classes = num_classes

weights = (1.0 - self.beta) / np.array(effective_num)
weights = weights / np.sum(weights) * num_classes
weights = weights / np.sum(weights) * effective_num_classes
weights = torch.tensor(weights, device=logits.device).float()

if self.loss_type != "cross_entropy":
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def get_version():
setuptools.setup(
name="balanced-loss",
version=get_version(),
author="",
author="fcakyon",
license="MIT",
description="Easy to use class-balanced cross-entropy and focal loss implementation for Pytorch.",
long_description=get_long_description(),
Expand All @@ -54,9 +54,9 @@ def get_version():
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
"Topic :: Education",
Expand Down

0 comments on commit 20f3779

Please sign in to comment.