6
6
from viscy .data .hcs import Sample
7
7
import lightning .pytorch as pl
8
8
import torch
9
-
9
+ import torchmetrics
10
10
from viscy .light .predict_writer import HCSPredictionWriter
11
11
from monai .transforms import DivisiblePad
12
12
16
16
data_module = HCSDataModule (
17
17
test_datapath ,
18
18
source_channel = ["Sensor" , "Phase" ],
19
- target_channel = [],
19
+ target_channel = ["inf_mask" ],
20
20
split_ratio = 0.8 ,
21
21
z_window_size = 1 ,
22
22
architecture = "2D" ,
36
36
data_module .prepare_data ()
37
37
38
38
data_module .setup (stage = "predict" )
39
- test_dm = data_module .test_dataloader ()
40
- sample = next (iter (test_dm ))
41
39
42
40
# %%
43
41
class LightningUNet (pl .LightningModule ):
@@ -49,6 +47,9 @@ def __init__(
49
47
):
50
48
super (LightningUNet , self ).__init__ ()
51
49
self .unet_model = Unet2d (in_channels = in_channels , out_channels = out_channels )
50
+ # self.pred_cm = torchmetrics.classification.ConfusionMatrix(
51
+ # task="multiclass", num_classes=self.n_classes
52
+ # )
52
53
if ckpt_path is not None :
53
54
state_dict = torch .load (ckpt_path , map_location = torch .device ("cpu" ))[
54
55
"state_dict"
@@ -62,8 +63,8 @@ def forward(self, x):
62
63
def predict_step (self , batch : Sample , batch_idx : int , dataloader_idx : int = 0 ):
63
64
source = self ._predict_pad (batch ["source" ])
64
65
pred_class = self .forward (source )
65
- pred_int = torch .argmax (pred_class , dim = 4 , keepdim = True )
66
- return self . _predict_pad . inverse ( pred_int )
66
+ pred_int = torch .argmax (pred_class , dim = 1 , keepdim = True )
67
+ return pred_int
67
68
68
69
def on_predict_start (self ):
69
70
"""Pad the input shape to be divisible by the downsampling factor.
@@ -79,7 +80,7 @@ def on_predict_start(self):
79
80
80
81
trainer = pl .Trainer (
81
82
default_root_dir = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase" ,
82
- callbacks = [HCSPredictionWriter (output_path , write_input = True )],
83
+ callbacks = [HCSPredictionWriter (output_path , write_input = False )],
83
84
)
84
85
model = LightningUNet (
85
86
in_channels = 2 ,
0 commit comments