Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IOU Class Metric Module #4563

Closed
zmurez opened this issue Nov 7, 2020 · 5 comments · Fixed by #4704
Closed

IOU Class Metric Module #4563

zmurez opened this issue Nov 7, 2020 · 5 comments · Fixed by #4704
Labels
feature Is an improvement or enhancement help wanted Open to be worked on
Milestone

Comments

@zmurez
Copy link

zmurez commented Nov 7, 2020

🚀 Feature

Is there a reason why IOU doesn't have a class metric module (currently the only classification class metrics implemented are: Accuracy, Precision, Recall, Fbeta)? Is this already on the roadmap?

I implemented a version below. Does this look good? If so, should I submit a PR?
I was initially worried that the reason it wasn't implemented yet was non-trivial issues with syncing across devices in ddp, but I don't see any issues with my implementation... Did I miss something?

import torch
from typing import Any, Callable, Optional

import pytorch_lightning as pl
from pytorch_lightning.metrics.metric import Metric
from pytorch_lightning.metrics.functional.classification import stat_scores_multiple_classes
from pytorch_lightning.metrics.functional.reduction import reduce


class IOU(Metric):
    """
    Computes IOU.
    Args:
        ...
        compute_on_step:
            Forward only calls ``update()`` and return None if this is set to False. default: True
        dist_sync_on_step:
            Synchronize metric state across processes at each ``forward()``
            before returning the value at the step. default: False
        process_group:
            Specify the process group on which synchronization is called. default: None (which selects the entire world)
        dist_sync_fn:
            Callback that performs the allgather operation on the metric state. When `None`, DDP
            will be used to perform the allgather. default: None
    """

    def __init__(
        self,
        num_classes: int,
        ignore_index: Optional[int] = None,
        absent_score: float = 0.0,
        reduction: str = 'elementwise_mean',
        compute_on_step: bool = True,
        dist_sync_on_step: bool = False,
        process_group: Optional[Any] = None,
        #dist_sync_fn: Callable = None,
    ):
        super().__init__(
            compute_on_step=compute_on_step,
            dist_sync_on_step=dist_sync_on_step,
            process_group=process_group,
            #dist_sync_fn=dist_sync_fn,
        )

        self.num_classes = num_classes
        self.ignore_index = ignore_index
        self.absent_score = absent_score
        self.reduction = reduction
        self.add_state("tps", default=torch.zeros(num_classes), dist_reduce_fx="sum")
        self.add_state("fps", default=torch.zeros(num_classes), dist_reduce_fx="sum")
        self.add_state("fns", default=torch.zeros(num_classes), dist_reduce_fx="sum")
        self.add_state("sups", default=torch.zeros(num_classes), dist_reduce_fx="sum")
        

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        """
        Update state with predictions and targets.
        Args:
            preds: Predictions from model
            target: Ground truth values
        """

        tps, fps, _, fns, sups = stat_scores_multiple_classes(preds, target, self.num_classes)

        self.tps += tps
        self.fps += fps
        self.fns += fns
        self.sups += sups

    def compute(self):
        """
        Computes mean squared error over state.
        """

        scores = torch.zeros(self.num_classes, device=self.tps.device, dtype=torch.float32)

        for class_idx in range(self.num_classes):
            if class_idx == self.ignore_index:
                continue

            tp = self.tps[class_idx]
            fp = self.fps[class_idx]
            fn = self.fns[class_idx]
            sup = self.sups[class_idx]

            # If this class is absent in the target (no support) AND absent in the pred (no true or false
            # positives), then use the absent_score for this class.
            if sup + tp + fp == 0:
                scores[class_idx] = self.absent_score
                continue

            denom = tp + fp + fn
            # Note that we do not need to worry about division-by-zero here since we know (sup + tp + fp != 0) from above,
            # which means ((tp+fn) + tp + fp != 0), which means (2tp + fp + fn != 0). Since all vars are non-negative, we
            # can conclude (tp + fp + fn > 0), meaning the denominator is non-zero for each class.
            score = tp.to(torch.float) / denom
            scores[class_idx] = score

        # Remove the ignored class index from the scores.
        if self.ignore_index is not None and self.ignore_index >= 0 and self.ignore_index < self.num_classes:
            scores = torch.cat([
                scores[:self.ignore_index],
                scores[self.ignore_index + 1:],
            ])

        print(scores)

        return reduce(scores, reduction=self.reduction)

Thanks

@zmurez zmurez added feature Is an improvement or enhancement help wanted Open to be worked on labels Nov 7, 2020
@github-actions
Copy link
Contributor

github-actions bot commented Nov 7, 2020

Hi! thanks for your contribution!, great first issue!

@SkafteNicki
Copy link
Member

Hi @zmurez
It is definitely on the roadmap that all metrics in the functional backend should have a counterpart in the class backend. After we revamped class metric in v1.0 we are slowly adding class interfaces. We have just not reached IOU jet.
That said, we would be more than happy to receive a PR :]

Do note the following:
We are also in the process of unifying the functional and class based backends. In practice this means that all computations should happen in the functional backend and the class backend should just call the functional backend. Please see how the code is organized for all regression metrics.

@SkafteNicki SkafteNicki added this to the 1.2 milestone Nov 7, 2020
@SkafteNicki SkafteNicki linked a pull request Nov 17, 2020 that will close this issue
11 tasks
@heng-yuwen
Copy link

Hi, I wonder what does dist_sync_on_step mean here? shouldn't we pass a True when use method like ddp?

@SkafteNicki
Copy link
Member

@123mutourener thats up for the user to decide. If I set it to False I will only get the metric from process 0. Assuming a uniform distribution of samples over all processes this may be a good enough proxie for what you want to log. Of cause setting it to True will give a more precise result but it will slow down training (due to added communication between processes)

@heng-yuwen
Copy link

@SkafteNicki Thanks! That is helpful.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement help wanted Open to be worked on
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants