Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix pearson aggregation #998

Merged
merged 12 commits into from
Apr 29, 2022

Conversation

ben-davidson-6
Copy link
Contributor

What does this PR do?

Fixes aggregation across devices for PearsonCorrCoeff

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

@ben-davidson-6
Copy link
Contributor Author

The problem previously was that in the aggregation we were assuming that the per device var_x, var_y and corxy were actually variances and correlations, but they were actually just running sums. The sums can be seen in the following wikepedia article https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance:

  • welfords online algorithm for variance (denoted M_{2,n} in wikepedia)
  • online calculation for covariance (Denoted C_n in wikepedia)

@codecov
Copy link

codecov bot commented Apr 28, 2022

Codecov Report

Merging #998 (ccf86e1) into master (9011ec9) will decrease coverage by 0%.
The diff coverage is 100%.

@@          Coverage Diff          @@
##           master   #998   +/-   ##
=====================================
- Coverage      95%    95%   -0%     
=====================================
  Files         179    179           
  Lines        7644   7652    +8     
=====================================
- Hits         7255   7250    -5     
- Misses        389    402   +13     

@SkafteNicki SkafteNicki added the bug / fix Something isn't working label Apr 29, 2022
@SkafteNicki SkafteNicki added this to the v0.9 milestone Apr 29, 2022
Copy link
Member

@SkafteNicki SkafteNicki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

tests/regression/test_pearson.py Outdated Show resolved Hide resolved
torchmetrics/regression/pearson.py Show resolved Hide resolved
@mergify mergify bot added the ready label Apr 29, 2022
Copy link
Member

@Borda Borda left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@Borda Borda enabled auto-merge (squash) April 29, 2022 12:57
@Borda Borda merged commit 32facee into Lightning-AI:master Apr 29, 2022
@Borda Borda modified the milestones: v0.9, v0.8 May 3, 2022
Borda pushed a commit that referenced this pull request May 5, 2022
* break tests
* fix test
* changelog
* Apply suggestions from code review

Co-authored-by: Nicki Skafte Detlefsen <[email protected]>

(cherry picked from commit 32facee)
@ben-davidson-6
Copy link
Contributor Author

Can I pretty please get a tshirt for my contribution :)

https://discord.com/channels/1077906959069626439/1077914894428549171/1210976775127957504

@Borda
Copy link
Member

Borda commented Feb 24, 2024

Can I pretty please get a tshirt for my contribution :)

sure, pls fill the form :)

@ben-davidson-6
Copy link
Contributor Author

where is the form sorry, I can't find it.

@Cstrausman89
Copy link

where is the form sorry, I can't find it.

Sorry about that! I just DM'd it to you on Discord. Thanks again!

@alexrgilbert
Copy link

I am a bit confused by the current implementation of the _final_aggregation function used by PearsonCorrCoef, as updated with this PR. As a side note, I believe the reference link is outdated (as it doesn't reflect the current implementation).

A quick description of how to implement a parallel algorithm for aggregating running statistics for calculating Pearson correlation is given on the Wikipedia page for variance calculation algorithms (below the sections noted in the earlier comment by @ben-davidson-6). More detailed derivations and analysis can be found in papers by Chan et al. and Schubert et al. (which are cited by the Wikipedia article).

While the current implementation can indeed be simplified into the equations provided by these references, it is significantly more complex (and difficult to understand). Is there any reason that I am overlooking (e.g., numerical precision, avoiding overflow) for this specific implementation?

Below is a simplified implementation which I believe matches the output of the current (I have tested on my own data and it passes the torchmetrics unit tests), but is more closely aligned with the source algorithms. If there is no reason for the current implementation, would it be worthwhile to replace with this simpler implementation?

An additional, IMPORTANT point: the current implementation modifies the states in place, such that if compute is called multiple times without cacheing, the value will drift. This would also be fixed in the below implementation.

def _final_aggregation(
    means_x: torch.Tensor,
    means_y: torch.Tensor,
    vars_x: torch.Tensor,
    vars_y: torch.Tensor,
    corrs_xy: torch.Tensor,
    nbs: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Aggregate the statistics from multiple devices.

    Formula taken from here: `Parallel algorithm for calculating variance <https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm>`_

    """
    if len(means_x) == 1:
        return means_x[0], means_y[0], vars_x[0], vars_y[0], corrs_xy[0], nbs[0]
    mx1, my1, vx1, vy1, cxy1, n1 = means_x[0], means_y[0], vars_x[0], vars_y[0], corrs_xy[0], nbs[0]
    for i in range(1, len(means_x)):
        mx2, my2, vx2, vy2, cxy2, n2 = means_x[i], means_y[i], vars_x[i], vars_y[i], corrs_xy[i], nbs[i]
        # count
        nb = n1 + n2
        # mean_x
        mean_x = (n1 * mx1 + n2 * mx2) / nb
        # mean_y
        mean_y = (n1 * my1 + n2 * my2) / nb
        # intermediates for running variances
        n12_b = n1 * n2 / nb
        delta_x = mx2 - mx1
        delta_y = my2 - my1
        # var_x
        var_x = vx1 + vx2 + n12_b * delta_x ** 2
        # var_y
        var_y = vy1 + vy2 + n12_b * delta_y ** 2
        # corr_xy
        corr_xy = cxy1 + cxy2 + n12_b * delta_x * delta_y

        mx1, my1, vx1, vy1, cxy1, n1 = mean_x, mean_y, var_x, var_y, corr_xy, nb
    return mean_x, mean_y, var_x, var_y, corr_xy, nb


Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working ready
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants