|
18 | 18 | # %% Create a dataloader and visualize the batches.
|
19 | 19 |
|
20 | 20 | # Set the path to the dataset
|
21 |
| -dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2023_11_08_Opencell_infection/OC43_infection_timelapse_trainVal.zarr" |
| 21 | +dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2023_11_08_Opencell_infection/OC43_infection_timelapse_all_curated_train.zarr" |
22 | 22 |
|
23 | 23 | # find ratio of background, uninfected and infected pixels
|
24 | 24 | zarr_input = open_ome_zarr(
|
|
56 | 56 | # Create an instance of HCSDataModule
|
57 | 57 | data_module = HCSDataModule(
|
58 | 58 | dataset_path,
|
59 |
| - source_channel=["Phase", "HSP90"], |
| 59 | + source_channel=["Phase", "HSP90", "phase_nucl_iqr","hsp90_skew"], |
60 | 60 | target_channel=["Inf_mask"],
|
61 | 61 | yx_patch_size=[256, 256],
|
62 | 62 | split_ratio=0.8,
|
|
66 | 66 | batch_size=16,
|
67 | 67 | normalizations=[
|
68 | 68 | NormalizeSampled(
|
69 |
| - keys=["Phase","HSP90"], |
| 69 | + keys=["Phase","HSP90", "phase_nucl_iqr","hsp90_skew"], |
70 | 70 | level="fov_statistics",
|
71 | 71 | subtrahend="median",
|
72 | 72 | divisor="iqr",
|
|
76 | 76 | RandWeightedCropd(
|
77 | 77 | num_samples=4,
|
78 | 78 | spatial_size=[-1, 256, 256],
|
79 |
| - keys=["Phase","HSP90"], |
| 79 | + keys=["Phase","HSP90", "phase_nucl_iqr","hsp90_skew"], |
80 | 80 | w_key="Inf_mask",
|
81 | 81 | )
|
82 | 82 | ],
|
|
141 | 141 |
|
142 | 142 | # Fit the model
|
143 | 143 | model = SemanticSegUNet25D(
|
144 |
| - in_channels=2, |
| 144 | + in_channels=4, |
145 | 145 | out_channels=3,
|
146 | 146 | loss_function=nn.CrossEntropyLoss(weight=torch.tensor(pixel_ratio)),
|
147 | 147 | )
|
|
0 commit comments