6
6
import torch .nn as nn
7
7
import lightning .pytorch as pl
8
8
import torch .nn .functional as F
9
- import torchview
9
+
10
+ # import torchview
10
11
from typing import Literal , Sequence
11
12
from skimage .exposure import rescale_intensity
12
13
from matplotlib .cm import get_cmap
13
14
14
15
# import napari
15
16
from pytorch_lightning .loggers import TensorBoardLogger
16
17
from torch import Tensor
17
- from monai .transforms import (
18
- RandRotate ,
19
- Resize ,
20
- Zoom ,
21
- Flip ,
22
- RandFlip ,
23
- RandZoom ,
24
- RandRotate90 ,
25
- RandRotate ,
26
- RandAffine ,
27
- Rand2DElastic ,
28
- Rand3DElastic ,
29
- RandGaussianNoise ,
30
- RandGaussianNoised ,
31
- )
32
18
from pytorch_lightning .callbacks import ModelCheckpoint
33
- from monai .losses import DiceLoss
34
- from viscy .light .engine import VSUNet
19
+
20
+ # from monai.losses import DiceLoss
21
+ # from viscy.light.engine import VSUNet
35
22
from viscy .unet .networks .Unet2D import Unet2d
36
23
from viscy .data .hcs import Sample
24
+ from viscy .transforms import RandWeightedCropd , RandGaussianNoised
25
+ from viscy .transforms import NormalizeSampled
37
26
38
27
# %% Create a dataloader and visualize the batches.
39
28
# Set the path to the dataset
40
- dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2023_09_28_DENV_A2_infMarked .zarr"
29
+ dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2023_09_27_DENV_A2_infMarked_refined .zarr"
41
30
42
31
# Create an instance of HCSDataModule
43
32
data_module = HCSDataModule (
44
33
dataset_path ,
45
- source_channel = ["Sensor" ],
34
+ source_channel = ["Sensor" , "Phase" ],
46
35
target_channel = ["Inf_mask" ],
47
36
yx_patch_size = [128 , 128 ],
48
37
split_ratio = 0.8 ,
49
38
z_window_size = 1 ,
50
39
architecture = "2D" ,
51
40
num_workers = 1 ,
52
- batch_size = 12 ,
53
- augmentations = [],
41
+ batch_size = 64 ,
42
+ normalizations = [
43
+ NormalizeSampled (
44
+ keys = ["Sensor" , "Phase" ],
45
+ level = "fov_statistics" ,
46
+ subtrahend = "median" ,
47
+ divisor = "iqr" ,
48
+ )
49
+ ],
50
+ augmentations = [
51
+ RandWeightedCropd (
52
+ num_samples = 8 ,
53
+ spatial_size = [- 1 , 128 , 128 ],
54
+ keys = ["Sensor" , "Phase" , "Inf_mask" ],
55
+ w_key = "Inf_mask" ,
56
+ ),
57
+ RandGaussianNoised (keys = ["Sensor" , "Phase" ], mean = 0.0 , std = 1.0 , prob = 0.5 ),
58
+ ],
54
59
)
55
60
56
61
# Prepare the data
@@ -159,13 +164,14 @@ def validation_step(self, batch: Sample, batch_idx: int):
159
164
loss = self .loss_function (pred , target_one_hot )
160
165
if batch_idx < self .log_batches_per_epoch :
161
166
self .validation_step_outputs .extend (
162
- self ._detach_sample ((source , target , pred ))
167
+ self ._detach_sample ((source , target_one_hot , pred ))
163
168
)
164
169
self .log (
165
170
"loss/validate" ,
166
171
loss ,
167
172
sync_dist = True ,
168
173
add_dataloader_idx = False ,
174
+ logger = True ,
169
175
)
170
176
return loss
171
177
@@ -209,21 +215,21 @@ def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]):
209
215
210
216
# %% Define the logger
211
217
logger = TensorBoardLogger (
212
- "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/Infection_phenotyping_data/logs " ,
213
- name = "infection_classification_model " ,
218
+ "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/ " ,
219
+ name = "logs_wPhase " ,
214
220
)
215
221
216
222
# Pass the logger to the Trainer
217
223
trainer = pl .Trainer (
218
224
logger = logger ,
219
- max_epochs = 50 ,
220
- default_root_dir = "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/Infection_phenotyping_data/logs " ,
225
+ max_epochs = 100 ,
226
+ default_root_dir = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase " ,
221
227
log_every_n_steps = 1 ,
222
228
)
223
229
224
230
# Define the checkpoint callback
225
231
checkpoint_callback = ModelCheckpoint (
226
- dirpath = "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/Infection_phenotyping_data/checkpoints " ,
232
+ dirpath = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase/ " ,
227
233
filename = "checkpoint_{epoch:02d}" ,
228
234
save_top_k = - 1 ,
229
235
verbose = True ,
@@ -236,43 +242,10 @@ def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]):
236
242
237
243
# Fit the model
238
244
model = LightningUNet (
239
- in_channels = 1 ,
245
+ in_channels = 2 ,
240
246
out_channels = 4 ,
241
- loss_function = nn .CrossEntropyLoss (weight = torch .tensor ([0.1 , 0.4 , 0.4 , 0.1 ])),
247
+ loss_function = nn .CrossEntropyLoss (weight = torch .tensor ([0.1 , 0.3 , 0.3 , 0.3 ])),
242
248
)
243
249
trainer .fit (model , data_module )
244
250
245
-
246
- # %% test the model on the test set
247
- test_datapath = "/hpc/projects/intracellular_dashboard/viral-sensor/2023_12_08-BJ5a-calibration/5_classify/2023_12_08_BJ5a_pAL040_72HPI_Calibration_1.zarr"
248
-
249
- test_dm = HCSDataModule (
250
- test_datapath ,
251
- source_channel = ["Sensor" , "Nuclei_mask" ],
252
- )
253
- # Load the predict dataset
254
- test_dataloader = test_dm .test_dataloader ()
255
-
256
- # Set the model to evaluation mode
257
- unet_model .eval ()
258
-
259
- # Create a list to store the predictions
260
- predictions = []
261
-
262
- # Iterate over the test batches
263
- for batch in test_dataloader :
264
- # Extract the input from the batch
265
- input_data = batch ["source" ]
266
-
267
- # Forward pass through the model
268
- output = unet_model (input_data )
269
-
270
- # Append the predictions to the list
271
- predictions .append (output .detach ().cpu ().numpy ())
272
-
273
- # Convert the predictions to a numpy array
274
- predictions = np .stack (predictions )
275
-
276
- # Save the predictions as added channel in zarr format
277
- # use iohub or viscy to save the predictions!!!
278
- zarr .save ("predictions.zarr" , predictions )
251
+ # %%
0 commit comments