-
Notifications
You must be signed in to change notification settings - Fork 408
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add PixelwiseRegressionTask #1241
Changes from 1 commit
c21b847
6a2c519
20dd757
0942fd2
39700a8
d6fefb6
12180b9
0b81934
ba565e4
ffda89e
e05a3b5
eff0917
298d7bb
7cf0e26
381a9d9
a36d74d
2ae4d8b
e6fbbdb
8741d6b
74a14d6
e8b001f
d71a252
fb5c642
e9ca047
e57eb49
5bc1f5d
cdcd3c5
b135bdc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -131,16 +131,14 @@ def training_step(self, *args: Any, **kwargs: Any) -> Tensor: | |
batch = args[0] | ||
x = batch["image"] | ||
y = batch[self.target_key] | ||
|
||
if y.ndim == 1: | ||
y = y.unsqueeze(dim=1) | ||
|
||
y_hat = self(x) | ||
|
||
loss = self.loss(y_hat, y) | ||
if y_hat.ndim != y.ndim: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If
while the output of the models will be:
|
||
y = y.unsqueeze(dim=1) | ||
|
||
loss = self.loss(y_hat, y.to(torch.float)) | ||
self.log("train_loss", loss) # logging to TensorBoard | ||
self.train_metrics(y_hat, y) | ||
self.train_metrics(y_hat, y.to(torch.float)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Cast to float only for loss and metrics in case the plotting expects a different dtype |
||
|
||
return loss | ||
|
||
|
@@ -160,15 +158,14 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: | |
batch_idx = args[1] | ||
x = batch["image"] | ||
y = batch[self.target_key] | ||
y_hat = self(x) | ||
|
||
if y.ndim == 1: | ||
if y_hat.ndim != y.ndim: | ||
y = y.unsqueeze(dim=1) | ||
|
||
y_hat = self(x) | ||
|
||
loss = self.loss(y_hat, y) | ||
loss = self.loss(y_hat, y.to(torch.float)) | ||
self.log("val_loss", loss) | ||
self.val_metrics(y_hat, y) | ||
self.val_metrics(y_hat, y.to(torch.float)) | ||
|
||
if ( | ||
batch_idx < 10 | ||
|
@@ -179,6 +176,9 @@ def validation_step(self, *args: Any, **kwargs: Any) -> None: | |
): | ||
try: | ||
datamodule = self.trainer.datamodule | ||
if self.target_key == "mask": | ||
y = y.squeeze(dim=1) | ||
y_hat = y_hat.squeeze(dim=1) | ||
batch["prediction"] = y_hat | ||
for key in ["image", self.target_key, "prediction"]: | ||
batch[key] = batch[key].cpu() | ||
|
@@ -206,15 +206,14 @@ def test_step(self, *args: Any, **kwargs: Any) -> None: | |
batch = args[0] | ||
x = batch["image"] | ||
y = batch[self.target_key] | ||
y_hat = self(x) | ||
|
||
if y.ndim == 1: | ||
if y_hat.ndim != y.ndim: | ||
y = y.unsqueeze(dim=1) | ||
|
||
y_hat = self(x) | ||
|
||
loss = self.loss(y_hat, y) | ||
loss = self.loss(y_hat, y.to(torch.float)) | ||
self.log("test_loss", loss) | ||
self.test_metrics(y_hat, y) | ||
self.test_metrics(y_hat, y.to(torch.float)) | ||
|
||
def on_test_epoch_end(self) -> None: | ||
"""Logs epoch level test metrics.""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Testing regression on Inria binary [0, 1] masks for now since we don't have a readily available pixelwise regression datamodule.