Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored May 31, 2024
1 parent cff6701 commit c13055b
Show file tree
Hide file tree
Showing 4 changed files with 0 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def _generalized_dice_update(

if input_format == "index":
preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1)
if input_format == "index":
target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1)

if not include_background:
Expand Down
1 change: 0 additions & 1 deletion src/torchmetrics/functional/segmentation/mean_iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def _mean_iou_update(

if input_format == "index":
preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1)
if input_format == "index":
target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1)

if not include_background:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def _reference_generalized_dice(
"""Calculate reference metric for `MeanIoU`."""
if input_format == "index":
preds = torch.nn.functional.one_hot(preds, num_classes=NUM_CLASSES).movedim(-1, 1)
if input_format == "index":
target = torch.nn.functional.one_hot(target, num_classes=NUM_CLASSES).movedim(-1, 1)
val = compute_generalized_dice(preds, target, include_background=include_background)
if reduce:
Expand Down
1 change: 0 additions & 1 deletion tests/unittests/segmentation/test_mean_iou.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def _reference_mean_iou(
"""Calculate reference metric for `MeanIoU`."""
if input_format == "index":
preds = torch.nn.functional.one_hot(preds, num_classes=NUM_CLASSES).movedim(-1, 1)
if input_format == "index":
target = torch.nn.functional.one_hot(target, num_classes=NUM_CLASSES).movedim(-1, 1)

val = compute_iou(preds, target, include_background=include_background)
Expand Down

0 comments on commit c13055b

Please sign in to comment.