Skip to content

Commit b470ed1

Browse files
corrected prediction module
1 parent 82428ed commit b470ed1

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

examples/infection_phenotyping/test_infection_classifier.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from viscy.data.hcs import Sample
77
import lightning.pytorch as pl
88
import torch
9-
9+
import torchmetrics
1010
from viscy.light.predict_writer import HCSPredictionWriter
1111
from monai.transforms import DivisiblePad
1212

@@ -16,7 +16,7 @@
1616
data_module = HCSDataModule(
1717
test_datapath,
1818
source_channel=["Sensor", "Phase"],
19-
target_channel=[],
19+
target_channel=["inf_mask"],
2020
split_ratio=0.8,
2121
z_window_size=1,
2222
architecture="2D",
@@ -36,8 +36,6 @@
3636
data_module.prepare_data()
3737

3838
data_module.setup(stage="predict")
39-
test_dm = data_module.test_dataloader()
40-
sample = next(iter(test_dm))
4139

4240
# %%
4341
class LightningUNet(pl.LightningModule):
@@ -49,6 +47,9 @@ def __init__(
4947
):
5048
super(LightningUNet, self).__init__()
5149
self.unet_model = Unet2d(in_channels=in_channels, out_channels=out_channels)
50+
# self.pred_cm = torchmetrics.classification.ConfusionMatrix(
51+
# task="multiclass", num_classes=self.n_classes
52+
# )
5253
if ckpt_path is not None:
5354
state_dict = torch.load(ckpt_path, map_location=torch.device("cpu"))[
5455
"state_dict"
@@ -62,8 +63,8 @@ def forward(self, x):
6263
def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0):
6364
source = self._predict_pad(batch["source"])
6465
pred_class = self.forward(source)
65-
pred_int = torch.argmax(pred_class, dim=4, keepdim=True)
66-
return self._predict_pad.inverse(pred_int)
66+
pred_int = torch.argmax(pred_class, dim=1, keepdim=True)
67+
return pred_int
6768

6869
def on_predict_start(self):
6970
"""Pad the input shape to be divisible by the downsampling factor.
@@ -79,7 +80,7 @@ def on_predict_start(self):
7980

8081
trainer = pl.Trainer(
8182
default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase",
82-
callbacks=[HCSPredictionWriter(output_path, write_input=True)],
83+
callbacks=[HCSPredictionWriter(output_path, write_input=False)],
8384
)
8485
model = LightningUNet(
8586
in_channels=2,

0 commit comments

Comments
 (0)