Skip to content

Commit 802ebc3

Browse files
added normalization and augmentations
1 parent 35ead0c commit 802ebc3

File tree

1 file changed

+37
-64
lines changed

1 file changed

+37
-64
lines changed

examples/infection_phenotyping/Infection_classification_model.py

+37-64
Original file line numberDiff line numberDiff line change
@@ -6,51 +6,56 @@
66
import torch.nn as nn
77
import lightning.pytorch as pl
88
import torch.nn.functional as F
9-
import torchview
9+
10+
# import torchview
1011
from typing import Literal, Sequence
1112
from skimage.exposure import rescale_intensity
1213
from matplotlib.cm import get_cmap
1314

1415
# import napari
1516
from pytorch_lightning.loggers import TensorBoardLogger
1617
from torch import Tensor
17-
from monai.transforms import (
18-
RandRotate,
19-
Resize,
20-
Zoom,
21-
Flip,
22-
RandFlip,
23-
RandZoom,
24-
RandRotate90,
25-
RandRotate,
26-
RandAffine,
27-
Rand2DElastic,
28-
Rand3DElastic,
29-
RandGaussianNoise,
30-
RandGaussianNoised,
31-
)
3218
from pytorch_lightning.callbacks import ModelCheckpoint
33-
from monai.losses import DiceLoss
34-
from viscy.light.engine import VSUNet
19+
20+
# from monai.losses import DiceLoss
21+
# from viscy.light.engine import VSUNet
3522
from viscy.unet.networks.Unet2D import Unet2d
3623
from viscy.data.hcs import Sample
24+
from viscy.transforms import RandWeightedCropd, RandGaussianNoised
25+
from viscy.transforms import NormalizeSampled
3726

3827
# %% Create a dataloader and visualize the batches.
3928
# Set the path to the dataset
40-
dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2023_09_28_DENV_A2_infMarked.zarr"
29+
dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2023_09_27_DENV_A2_infMarked_refined.zarr"
4130

4231
# Create an instance of HCSDataModule
4332
data_module = HCSDataModule(
4433
dataset_path,
45-
source_channel=["Sensor"],
34+
source_channel=["Sensor", "Phase"],
4635
target_channel=["Inf_mask"],
4736
yx_patch_size=[128, 128],
4837
split_ratio=0.8,
4938
z_window_size=1,
5039
architecture="2D",
5140
num_workers=1,
52-
batch_size=12,
53-
augmentations=[],
41+
batch_size=64,
42+
normalizations=[
43+
NormalizeSampled(
44+
keys=["Sensor", "Phase"],
45+
level="fov_statistics",
46+
subtrahend="median",
47+
divisor="iqr",
48+
)
49+
],
50+
augmentations=[
51+
RandWeightedCropd(
52+
num_samples=8,
53+
spatial_size=[-1, 128, 128],
54+
keys=["Sensor", "Phase", "Inf_mask"],
55+
w_key="Inf_mask",
56+
),
57+
RandGaussianNoised(keys=["Sensor", "Phase"], mean=0.0, std=1.0, prob=0.5),
58+
],
5459
)
5560

5661
# Prepare the data
@@ -159,13 +164,14 @@ def validation_step(self, batch: Sample, batch_idx: int):
159164
loss = self.loss_function(pred, target_one_hot)
160165
if batch_idx < self.log_batches_per_epoch:
161166
self.validation_step_outputs.extend(
162-
self._detach_sample((source, target, pred))
167+
self._detach_sample((source, target_one_hot, pred))
163168
)
164169
self.log(
165170
"loss/validate",
166171
loss,
167172
sync_dist=True,
168173
add_dataloader_idx=False,
174+
logger=True,
169175
)
170176
return loss
171177

@@ -209,21 +215,21 @@ def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]):
209215

210216
# %% Define the logger
211217
logger = TensorBoardLogger(
212-
"/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/Infection_phenotyping_data/logs",
213-
name="infection_classification_model",
218+
"/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/",
219+
name="logs_wPhase",
214220
)
215221

216222
# Pass the logger to the Trainer
217223
trainer = pl.Trainer(
218224
logger=logger,
219-
max_epochs=50,
220-
default_root_dir="/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/Infection_phenotyping_data/logs",
225+
max_epochs=100,
226+
default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase",
221227
log_every_n_steps=1,
222228
)
223229

224230
# Define the checkpoint callback
225231
checkpoint_callback = ModelCheckpoint(
226-
dirpath="/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/Infection_phenotyping_data/checkpoints",
232+
dirpath="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase/",
227233
filename="checkpoint_{epoch:02d}",
228234
save_top_k=-1,
229235
verbose=True,
@@ -236,43 +242,10 @@ def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]):
236242

237243
# Fit the model
238244
model = LightningUNet(
239-
in_channels=1,
245+
in_channels=2,
240246
out_channels=4,
241-
loss_function=nn.CrossEntropyLoss(weight=torch.tensor([0.1, 0.4, 0.4, 0.1])),
247+
loss_function=nn.CrossEntropyLoss(weight=torch.tensor([0.1, 0.3, 0.3, 0.3])),
242248
)
243249
trainer.fit(model, data_module)
244250

245-
246-
# %% test the model on the test set
247-
test_datapath = "/hpc/projects/intracellular_dashboard/viral-sensor/2023_12_08-BJ5a-calibration/5_classify/2023_12_08_BJ5a_pAL040_72HPI_Calibration_1.zarr"
248-
249-
test_dm = HCSDataModule(
250-
test_datapath,
251-
source_channel=["Sensor", "Nuclei_mask"],
252-
)
253-
# Load the predict dataset
254-
test_dataloader = test_dm.test_dataloader()
255-
256-
# Set the model to evaluation mode
257-
unet_model.eval()
258-
259-
# Create a list to store the predictions
260-
predictions = []
261-
262-
# Iterate over the test batches
263-
for batch in test_dataloader:
264-
# Extract the input from the batch
265-
input_data = batch["source"]
266-
267-
# Forward pass through the model
268-
output = unet_model(input_data)
269-
270-
# Append the predictions to the list
271-
predictions.append(output.detach().cpu().numpy())
272-
273-
# Convert the predictions to a numpy array
274-
predictions = np.stack(predictions)
275-
276-
# Save the predictions as added channel in zarr format
277-
# use iohub or viscy to save the predictions!!!
278-
zarr.save("predictions.zarr", predictions)
251+
# %%

0 commit comments

Comments
 (0)