Skip to content

Commit

Permalink
added class_weights for cross entropy loss to segmentation.py (#1221)
Browse files Browse the repository at this point in the history
* added class_weights for cross entropy loss to segmentation.py

* added class_weights for cross entropy loss to segmentation.py

* added class_weights for cross entropy loss to segmentation.py and fixed formatting

* added class_weights for cross entropy loss to segmentation.py, fixed formatting, added deleted loss

* Made class_weights argument optional

* fixed black formatting

* included versionadded: 0.5 and parameter to docstring

* added newline between sections, removed manual type checking, moved class_weights parameter after loss parameter

* Deleted duplicated line

---------

Co-authored-by: Caleb Robinson <[email protected]>
  • Loading branch information
nsutezo and calebrob6 authored Apr 25, 2023
1 parent 7678627 commit 8a2e9b4
Showing 1 changed file with 14 additions and 1 deletion.
15 changes: 14 additions & 1 deletion torchgeo/trainers/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,13 @@ def config_task(self) -> None:

if self.hyperparams["loss"] == "ce":
ignore_value = -1000 if self.ignore_index is None else self.ignore_index
self.loss = nn.CrossEntropyLoss(ignore_index=ignore_value)

class_weights = (
torch.FloatTensor(self.class_weights) if self.class_weights else None
)
self.loss = nn.CrossEntropyLoss(
ignore_index=ignore_value, weight=class_weights
)
elif self.hyperparams["loss"] == "jaccard":
self.loss = smp.losses.JaccardLoss(
mode="multiclass", classes=self.hyperparams["num_classes"]
Expand All @@ -86,6 +92,8 @@ def __init__(self, **kwargs: Any) -> None:
num_classes: Number of semantic classes to predict
loss: Name of the loss function, currently supports
'ce', 'jaccard' or 'focal' loss
class_weights: Optional rescaling weight given to each
class and used with 'ce' loss
ignore_index: Optional integer class index to ignore in the loss and metrics
learning_rate: Learning rate for optimizer
learning_rate_schedule_patience: Patience for learning rate scheduler
Expand All @@ -100,6 +108,9 @@ def __init__(self, **kwargs: Any) -> None:
The *segmentation_model* parameter was renamed to *model*,
*encoder_name* renamed to *backbone*, and
*encoder_weights* to *weights*.
.. versionadded: 0.5
The *class_weights* parameter.
"""
super().__init__()

Expand All @@ -115,6 +126,8 @@ def __init__(self, **kwargs: Any) -> None:
UserWarning,
)
self.ignore_index = kwargs["ignore_index"]
self.class_weights = kwargs.get("class_weights", None)

self.config_task()

self.train_metrics = MetricCollection(
Expand Down

0 comments on commit 8a2e9b4

Please sign in to comment.