Skip to content

Commit

Permalink
fix mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
isaaccorley committed Apr 13, 2023
1 parent 43c5db8 commit dd2f9a0
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion torchgeo/trainers/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def config_task(self) -> None:
f"Currently, only supports 'unet', 'deeplabv3+' and 'fcn'."
)

self.loss: nn.Module
if self.hyperparams["loss"] == "mse":
self.loss = nn.MSELoss()
elif self.hyperparams["loss"] == "mae":
Expand Down Expand Up @@ -295,7 +296,7 @@ def training_step(self, *args: Any, **kwargs: Any) -> Tensor:
self.log("train_loss", loss) # logging to TensorBoard
self.train_metrics(y_hat, y)

return loss
return cast(Tensor, loss)

def validation_step(self, *args: Any, **kwargs: Any) -> None:
"""Compute validation loss and log example predictions.
Expand Down

0 comments on commit dd2f9a0

Please sign in to comment.