-
Notifications
You must be signed in to change notification settings - Fork 422
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow passing labels in (n_samples, n_classes) to AveragePrecision #359
Comments
Hi! thanks for your contribution!, great first issue! |
Hi @discort thanks for your issue. |
thanks for getting back to me @SkafteNicki I found that def main():
import torch
from pytorch_lightning.metrics import AveragePrecision
ap = AveragePrecision(
num_classes=7,
pos_label=1,
compute_on_step=False,
dist_sync_on_step=False)
preds = torch.randn(8, 7)
target = torch.randint(0, 2, (8,7))
ap(preds, target)
res = ap.compute()
print('res=', res)
if __name__ == "__main__":
main() output:
but in def main():
import torch
from torchmetrics import AveragePrecision
ap = AveragePrecision(
num_classes=7,
pos_label=1,
compute_on_step=False,
dist_sync_on_step=False)
preds = torch.randn(8, 7)
target = torch.randint(0, 2, (8,7))
ap(preds, target)
res = ap.compute()
print('res=', res)
if __name__ == "__main__":
main() output:
I can check what changed if you are unable how to solve it quickly. |
🚀 Feature
Allow passing labels in
(n_samples, n_classes)
to AveragePrecision like in sklearn.metrics average_precision_scoreMotivation
Pitch
In sklearn there is ability to calculate average precision for multilabel target:
Alternatives
Additional context
Version:
The text was updated successfully, but these errors were encountered: