diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7e648db..79de69f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -11,39 +11,39 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - operating-system: [ubuntu-latest, windows-latest, macos-latest] - python-version: [3.7, 3.8, 3.9] - torch-version: [1.10.2, 1.11.0, 1.12.0] + os: [ubuntu-latest, windows-latest, macos-latest] + python-version: [3.9, "3.10"] + torch-version: [1.13.1, 2.5.1] fail-fast: false steps: - name: Checkout - uses: actions/checkout@v2 + uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Restore Ubuntu cache - uses: actions/cache@v1 - if: matrix.operating-system == 'ubuntu-latest' + uses: actions/cache@v4 + if: matrix.os == 'ubuntu-latest' with: path: ~/.cache/pip key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py')}} restore-keys: ${{ matrix.os }}-${{ matrix.python-version }}- - name: Restore MacOS cache - uses: actions/cache@v1 - if: matrix.operating-system == 'macos-latest' + uses: actions/cache@v4 + if: matrix.os == 'macos-latest' with: path: ~/Library/Caches/pip key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py')}} restore-keys: ${{ matrix.os }}-${{ matrix.python-version }}- - name: Restore Windows cache - uses: actions/cache@v1 - if: matrix.operating-system == 'windows-latest' + uses: actions/cache@v4 + if: matrix.os == 'windows-latest' with: path: ~\AppData\Local\pip\Cache key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py')}} @@ -52,9 +52,14 @@ jobs: - name: Update pip run: python -m pip install --upgrade pip + - name: Install package in development mode + run: pip install -e .[dev] + + - name: Show installed packages + run: pip list + - name: Lint with flake8, black and isort run: | - pip install -e .[dev] # stop the build if there are Python syntax errors or undefined names flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics black . --check --config pyproject.toml @@ -66,17 +71,23 @@ jobs: run: > pip install numpy - - name: Install PyTorch on Linux and Windows + - name: Install PyTorch==1.13.1 on Linux and Windows if: > - matrix.operating-system == 'ubuntu-latest' || - matrix.operating-system == 'windows-latest' + (matrix.os == 'ubuntu-latest' || + matrix.os == 'windows-latest') && + matrix.torch-version == '1.13.1' run: > pip install torch==${{ matrix.torch-version }}+cpu -f https://download.pytorch.org/whl/torch_stable.html - - name: Install PyTorch on MacOS - if: matrix.operating-system == 'macos-latest' - run: pip install torch==${{ matrix.torch-version }} + - name: Install PyTorch==2.5.1 on Linux and Windows + if: > + (matrix.os == 'ubuntu-latest' || + matrix.os == 'windows-latest') && + matrix.torch-version == '2.5.1' + run: > + pip install torch==${{ matrix.torch-version }} + -f https://download.pytorch.org/whl/torch_stable.html - name: Install balanced-loss package from local setup.py run: > diff --git a/.github/workflows/package_testing.yml b/.github/workflows/package_testing.yml index 56d0a97..5cf6ecf 100644 --- a/.github/workflows/package_testing.yml +++ b/.github/workflows/package_testing.yml @@ -10,39 +10,39 @@ jobs: strategy: matrix: - operating-system: [ubuntu-latest, windows-latest, macos-latest] - python-version: [3.7, 3.8, 3.9] - torch-version: [1.10.2, 1.11.0, 1.12.0] + os: [ubuntu-latest, windows-latest, macos-latest] + python-version: [3.9, "3.10"] + torch-version: [1.13.1, 2.5.1] fail-fast: false steps: - name: Checkout - uses: actions/checkout@v2 + uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Restore Ubuntu cache - uses: actions/cache@v1 - if: matrix.operating-system == 'ubuntu-latest' + uses: actions/cache@v4 + if: matrix.os == 'ubuntu-latest' with: path: ~/.cache/pip key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py')}} restore-keys: ${{ matrix.os }}-${{ matrix.python-version }}- - name: Restore MacOS cache - uses: actions/cache@v1 - if: matrix.operating-system == 'macos-latest' + uses: actions/cache@v4 + if: matrix.os == 'macos-latest' with: path: ~/Library/Caches/pip key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py')}} restore-keys: ${{ matrix.os }}-${{ matrix.python-version }}- - name: Restore Windows cache - uses: actions/cache@v1 - if: matrix.operating-system == 'windows-latest' + uses: actions/cache@v4 + if: matrix.os == 'windows-latest' with: path: ~\AppData\Local\pip\Cache key: ${{ matrix.os }}-${{ matrix.python-version }}-${{ hashFiles('**/setup.py')}} @@ -55,16 +55,26 @@ jobs: run: > pip install numpy - - name: Install PyTorch on Linux and Windows + - name: Install PyTorch==1.13.1 on Linux and Windows if: > - matrix.operating-system == 'ubuntu-latest' || - matrix.operating-system == 'windows-latest' + (matrix.os == 'ubuntu-latest' || + matrix.os == 'windows-latest') && + matrix.torch-version == '1.13.1' run: > pip install torch==${{ matrix.torch-version }}+cpu -f https://download.pytorch.org/whl/torch_stable.html + - name: Install PyTorch==2.5.1 on Linux and Windows + if: > + (matrix.os == 'ubuntu-latest' || + matrix.os == 'windows-latest') && + matrix.torch-version == '2.5.1' + run: > + pip install torch==${{ matrix.torch-version }} + -f https://download.pytorch.org/whl/torch_stable.html + - name: Install PyTorch on MacOS - if: matrix.operating-system == 'macos-latest' + if: matrix.os == 'macos-latest' run: pip install torch==${{ matrix.torch-version }} - name: Install latest balanced-loss package @@ -74,4 +84,3 @@ jobs: - name: Unittest balanced-loss run: | python -m unittest - diff --git a/.github/workflows/publish_pypi.yml b/.github/workflows/publish_pypi.yml index f7e7ea2..7000cf3 100644 --- a/.github/workflows/publish_pypi.yml +++ b/.github/workflows/publish_pypi.yml @@ -9,9 +9,9 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: '3.x' - name: Install dependencies diff --git a/README.md b/README.md index c5d6d44..5988a05 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,6 @@ When training dataset labels are imbalanced, one thing to do is to balance the l ![alt-text](https://user-images.githubusercontent.com/34196005/180266198-e27d8cba-f5e1-49ca-9f82-d8656333e3c4.png) - ## Installation ```bash @@ -134,6 +133,7 @@ What is the difference between this repo and vandit15's? - This repo implements loss functions as `torch.nn.Module` - In addition to class balanced losses, this repo also supports the standard versions of the cross entropy/focal loss etc. over the same API - All typos and errors in vandit15's source are fixed +- Continiously tested on PyTorch 1.13.1 and 2.5.1 ## References diff --git a/balanced_loss/__init__.py b/balanced_loss/__init__.py index 5245528..1548bc8 100644 --- a/balanced_loss/__init__.py +++ b/balanced_loss/__init__.py @@ -1,3 +1,3 @@ from .losses import Loss -__version__ = "0.1.0" +__version__ = "0.1.1" diff --git a/balanced_loss/losses.py b/balanced_loss/losses.py index 9e1bf50..eceab97 100644 --- a/balanced_loss/losses.py +++ b/balanced_loss/losses.py @@ -44,6 +44,7 @@ def __init__( fl_gamma=2, samples_per_class=None, class_balanced=False, + safe: bool = False, ): """ Compute the Class Balanced Loss between `logits` and the ground truth `labels`. @@ -60,6 +61,7 @@ def __init__( samples_per_class: A python list of size [num_classes]. Required if class_balance is True. class_balanced: bool. Whether to use class balanced loss. + safe: bool. Whether to allow labels with no samples. Returns: Loss instance """ @@ -73,12 +75,9 @@ def __init__( self.fl_gamma = fl_gamma self.samples_per_class = samples_per_class self.class_balanced = class_balanced + self.safe = safe - def forward( - self, - logits: torch.tensor, - labels: torch.tensor, - ): + def forward(self, logits: torch.tensor, labels: torch.tensor): """ Compute the Class Balanced Loss between `logits` and the ground truth `labels`. Class Balanced Loss: ((1-beta)/(1-beta^n))*Loss(labels, logits) @@ -97,8 +96,16 @@ def forward( if self.class_balanced: effective_num = 1.0 - np.power(self.beta, self.samples_per_class) + # Avoid division by 0 error for test cases without all labels present. + if self.safe: + effective_num_classes = np.sum(effective_num != 0) + effective_num[effective_num == 0] = np.inf + + else: + effective_num_classes = num_classes + weights = (1.0 - self.beta) / np.array(effective_num) - weights = weights / np.sum(weights) * num_classes + weights = weights / np.sum(weights) * effective_num_classes weights = torch.tensor(weights, device=logits.device).float() if self.loss_type != "cross_entropy": diff --git a/setup.py b/setup.py index ff789ae..b22d7dc 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,7 @@ def get_version(): setuptools.setup( name="balanced-loss", version=get_version(), - author="", + author="fcakyon", license="MIT", description="Easy to use class-balanced cross-entropy and focal loss implementation for Pytorch.", long_description=get_long_description(), @@ -54,9 +54,9 @@ def get_version(): "Intended Audience :: Developers", "Intended Audience :: Science/Research", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", "Topic :: Software Development :: Libraries", "Topic :: Software Development :: Libraries :: Python Modules", "Topic :: Education",