Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Segmentation Pretrained Weights #1046

Prev Previous commit
Next Next commit
Merge branch 'main' into trainers/segmentation-pretrained-weights
  • Loading branch information
isaaccorley authored May 3, 2023
commit ac69a7c492ce113d0e7c4338c3ff39d0878148b4
14 changes: 14 additions & 0 deletions torchgeo/trainers/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,20 @@ def config_task(self) -> None:
state_dict = get_weight(weights).get_state_dict(progress=True)
self.model.encoder = utils.load_state_dict(self.model, state_dict)

# Freeze backbone
if self.hyperparams.get("freeze_backbone", False) and self.hyperparams[
"model"
] in ["unet", "deeplabv3+"]:
for param in self.model.encoder.parameters():
param.requires_grad = False

# Freeze decoder
if self.hyperparams.get("freeze_decoder", False) and self.hyperparams[
"model"
] in ["unet", "deeplabv3+"]:
for param in self.model.decoder.parameters():
param.requires_grad = False

def __init__(self, **kwargs: Any) -> None:
"""Initialize the LightningModule with a model and loss function.

Expand Down
You are viewing a condensed version of this merge commit. You can view the full changes here.