Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cellular infection phenotyping using annotated viral sensor data & label-free images #70

Merged
merged 94 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from 82 commits
Commits
Show all changes
94 commits
Select commit Hold shift + click to select a range
c6692f1
refactor data loading into its own module
ziw-liu Jan 10, 2024
3d8e7e2
update type annotations
ziw-liu Jan 10, 2024
fdcbf55
move the logging module out
ziw-liu Jan 11, 2024
a291381
move old logging into utils
ziw-liu Jan 11, 2024
3cf8fa2
rename tests to match module name
ziw-liu Jan 11, 2024
d4cd41d
bump torch
ziw-liu Jan 11, 2024
e87d396
draft fcmae encoder
ziw-liu Jan 12, 2024
dccce5f
add stem to the encoder
ziw-liu Jan 12, 2024
5508731
wip: masked stem layernorm
ziw-liu Jan 12, 2024
3eec48e
wip: patchify masked features for linear
ziw-liu Jan 17, 2024
8c54feb
use mlp from timm
ziw-liu Jan 17, 2024
83ecf4a
hack: POC training script for FCMAE
ziw-liu Jan 17, 2024
2fffc99
fix mask for fitting
ziw-liu Jan 17, 2024
2a598b2
remove training script
ziw-liu Jan 17, 2024
b9b1880
default architecture
ziw-liu Jan 17, 2024
fd7700d
fine-tuning options
ziw-liu Jan 22, 2024
054249f
fix cli for finetuning
ziw-liu Jan 24, 2024
d867e10
draft combined data module
ziw-liu Jan 24, 2024
b06a300
fix import
ziw-liu Jan 25, 2024
39eafab
manual validation loss reduction
ziw-liu Jan 27, 2024
9fbf7a5
update linting
ziw-liu Feb 2, 2024
e00f5f3
update development guide
ziw-liu Feb 2, 2024
9e345b6
update type hints
ziw-liu Feb 13, 2024
96deca5
bump iohub
ziw-liu Feb 20, 2024
e06aa57
draft ctmc v1 dataset
ziw-liu Feb 24, 2024
ea8b300
Merge branch 'main' into fcmae
ziw-liu Feb 24, 2024
72de113
update tests
ziw-liu Feb 24, 2024
13d0aa0
move test_data
ziw-liu Feb 24, 2024
78aed97
remove path conversion
ziw-liu Feb 24, 2024
74e7db3
configurable normalizations (#68)
edyoshikun Feb 26, 2024
9b3b032
fix ctmc dataloading
ziw-liu Feb 28, 2024
a356936
add example ctmc v1 loading script
ziw-liu Feb 28, 2024
bac26be
changing the normalization and augmentations default from None to emp…
edyoshikun Feb 28, 2024
0b598c7
invert intensity transform
ziw-liu Feb 29, 2024
ddb30e9
concatenated data module
ziw-liu Feb 29, 2024
9504755
subsample videos
ziw-liu Feb 29, 2024
808e39c
livecell dataset
ziw-liu Feb 29, 2024
43d641d
all sample fields are optional
ziw-liu Feb 29, 2024
42f81cf
fix multi-dataloader validation
ziw-liu Feb 29, 2024
4546fc7
lint
ziw-liu Feb 29, 2024
306f3ef
fixing preprocessing for varying array shapes (i.e aics dataset)
edyoshikun Feb 29, 2024
1a0e3ce
update loading scripts
ziw-liu Mar 2, 2024
d3ec94d
fix CombineMode
ziw-liu Mar 2, 2024
dd34712
added model and annotation code draft
Soorya19Pradeep Mar 4, 2024
5fc9da2
chnaged to simple unet model
Soorya19Pradeep Mar 5, 2024
e627488
start with lesser augmentations
Soorya19Pradeep Mar 6, 2024
310ba70
added readme file
Soorya19Pradeep Mar 6, 2024
34b81b9
added tensorboard logging
Soorya19Pradeep Mar 6, 2024
a4e2f0d
added validation step
Soorya19Pradeep Mar 7, 2024
0ebb5df
chnaged to viscy 2d unet
Soorya19Pradeep Mar 11, 2024
a0e426a
used crossentropyloss with one-hot encoding
Soorya19Pradeep Mar 12, 2024
5ecbde0
added sample image logging
Soorya19Pradeep Mar 12, 2024
58b7fa5
attempt to build magicgui annotation
mattersoflight Mar 13, 2024
35ead0c
renamed infection annotation tool
Soorya19Pradeep Mar 13, 2024
802ebc3
added normalization and augmentations
Soorya19Pradeep Mar 23, 2024
908039a
added model testing code
Soorya19Pradeep Mar 25, 2024
88615d5
removed annotation refiner
Soorya19Pradeep Mar 25, 2024
82428ed
corrected conversion of class to int
Soorya19Pradeep Mar 26, 2024
b470ed1
corrected prediction module
Soorya19Pradeep Mar 26, 2024
f3746f8
cleaned up the code and comments for the LightningUNet
mattersoflight Mar 26, 2024
20655d6
removed confusion matrix code, finding runtime error with model
mattersoflight Mar 26, 2024
d022dae
moved scripts to viscy.scripts.infection_phenotyping module to enable…
mattersoflight Mar 26, 2024
901fd70
combine the lightning modules for training and prediction, fix the DD…
mattersoflight Mar 26, 2024
708a67a
all the stubs for computing and logging confusion matrix per cell
mattersoflight Mar 26, 2024
6bb9ca3
separated training and test scripts
Soorya19Pradeep Apr 1, 2024
99a3876
lightning module
Soorya19Pradeep Apr 1, 2024
000a966
corrected test cm compute
Soorya19Pradeep Apr 2, 2024
688336e
corrected test module
Soorya19Pradeep Apr 3, 2024
6b58f34
separated test and prediction scripts
Soorya19Pradeep Apr 3, 2024
b6ad254
changed confusion matrix compute
Soorya19Pradeep Apr 5, 2024
bd37f4b
Merge branch 'main' into infection_phenotyping
ziw-liu Apr 12, 2024
9c9ce41
fix merge error
ziw-liu Apr 12, 2024
6b0a42d
split 2D and 2.5D model scripts
Soorya19Pradeep May 23, 2024
2ea8892
added covnext script
Soorya19Pradeep May 27, 2024
220eba1
fix model input parameter
Soorya19Pradeep May 28, 2024
c4839da
update input file
Soorya19Pradeep May 29, 2024
c04b4ac
add augmentations
Soorya19Pradeep Jul 2, 2024
418d6d9
refactor infection_classification code to viscy/applications
mattersoflight Jul 3, 2024
67b330c
changes made for BJ5 classification
Soorya19Pradeep Jul 9, 2024
d420e80
format code
Soorya19Pradeep Jul 9, 2024
bd23f3b
add explicit packaging list
ziw-liu Jul 9, 2024
701ea77
rename testing script
Soorya19Pradeep Jul 9, 2024
8ddb58e
update readme
Soorya19Pradeep Jul 10, 2024
a49cfba
move function to preprocessing
Soorya19Pradeep Jul 10, 2024
00baf9d
format code
Soorya19Pradeep Jul 10, 2024
cd35d22
formatting
ziw-liu Jul 10, 2024
9d528ca
histogram with dask
ziw-liu Jul 10, 2024
7e477f4
fix index and test
ziw-liu Jul 10, 2024
7a007f2
fix import
ziw-liu Jul 10, 2024
9b46035
black
ziw-liu Jul 10, 2024
173a5db
fix float comp
ziw-liu Jul 10, 2024
19cf4e6
clean up headers
ziw-liu Jul 10, 2024
4b36875
clean up import
ziw-liu Jul 10, 2024
37ab0aa
add argument to change number of classes
ziw-liu Jul 10, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# %%
import torch
import lightning.pytorch as pl
import torch.nn as nn

from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint

from viscy.transforms import RandWeightedCropd
from viscy.transforms import NormalizeSampled
from viscy.data.hcs import HCSDataModule
from applications.infection_classification.classify_infection_25D import (
SemanticSegUNet25D,
)

from iohub.ngff import open_ome_zarr

# %% Create a dataloader and visualize the batches.

# Set the path to the dataset
dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2023_11_08_Opencell_infection/OC43_infection_timelapse_trainVal.zarr"

# find ratio of background, uninfected and infected pixels
zarr_input = open_ome_zarr(
dataset_path,
layout="hcs",
mode="r+",
)
in_chan_names = zarr_input.channel_names

num_pixels_bkg = 0
num_pixels_uninf = 0
num_pixels_inf = 0
num_pixels = 0
for well_id, well_data in zarr_input.wells():
well_name, well_no = well_id.split("/")

for pos_name, pos_data in well_data.positions():
data = pos_data.data
T, C, Z, Y, X = data.shape
out_data = data.numpy()
for time in range(T):
Inf_mask = out_data[time, in_chan_names.index("Inf_mask"), ...]
# Calculate the number of pixels valued 0, 1, and 2 in 'Inf_mask'
num_pixels_bkg = num_pixels_bkg + (Inf_mask == 0).sum()
num_pixels_uninf = num_pixels_uninf + (Inf_mask == 1).sum()
num_pixels_inf = num_pixels_inf + (Inf_mask == 2).sum()
num_pixels = num_pixels + Z * X * Y

pixel_ratio_1 = [
num_pixels / num_pixels_bkg,
num_pixels / num_pixels_uninf,
num_pixels / num_pixels_inf,
]
pixel_ratio_sum = sum(pixel_ratio_1)
pixel_ratio = [ratio / pixel_ratio_sum for ratio in pixel_ratio_1]

# %% craete data module

# Create an instance of HCSDataModule
data_module = HCSDataModule(
dataset_path,
source_channel=["Phase", "HSP90"],
target_channel=["Inf_mask"],
yx_patch_size=[512, 512],
split_ratio=0.8,
z_window_size=5,
architecture="2.5D",
num_workers=3,
batch_size=32,
normalizations=[
NormalizeSampled(
keys=["Phase", "HSP90"],
level="fov_statistics",
subtrahend="median",
divisor="iqr",
)
],
augmentations=[
RandWeightedCropd(
num_samples=4,
spatial_size=[-1, 512, 512],
keys=["Phase", "HSP90"],
w_key="Inf_mask",
)
],
)

# Prepare the data
data_module.prepare_data()

# Setup the data
data_module.setup(stage="fit")

# Create a dataloader
train_dm = data_module.train_dataloader()

val_dm = data_module.val_dataloader()


# %% Define the logger
logger = TensorBoardLogger(
"/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/",
name="logs",
)

# Pass the logger to the Trainer
trainer = pl.Trainer(
logger=logger,
max_epochs=200,
default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/logs/",
log_every_n_steps=1,
devices=1, # Set the number of GPUs to use. This avoids run-time exception from distributed training when the node has multiple GPUs
)

# Define the checkpoint callback
checkpoint_callback = ModelCheckpoint(
dirpath="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/logs/",
filename="checkpoint_{epoch:02d}",
save_top_k=-1,
verbose=True,
monitor="loss/validate",
mode="min",
)

# Add the checkpoint callback to the trainer
trainer.callbacks.append(checkpoint_callback)

# Fit the model
model = SemanticSegUNet25D(
in_channels=2,
out_channels=3,
loss_function=nn.CrossEntropyLoss(weight=torch.tensor(pixel_ratio)),
)

print(model)

# %% Run training.

trainer.fit(model, data_module)

# %%
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# %%
import torch
import lightning.pytorch as pl
import torch.nn as nn

from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint

from viscy.transforms import (
RandWeightedCropd,
NormalizeSampled,
RandScaleIntensityd,
RandGaussianSmoothd,
)
from viscy.data.hcs import HCSDataModule
from applications.infection_classification.classify_infection_2D import (
SemanticSegUNet2D,
)
from iohub.ngff import open_ome_zarr

# %% calculate the ratio of background, uninfected and infected pixels in the input dataset

# Set the path to the dataset
dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_04_25_BJ5a_DENV_TimeCourse/4-human_annotation/train_data.zarr"

# find ratio of background, uninfected and infected pixels
zarr_input = open_ome_zarr(
dataset_path,
layout="hcs",
mode="r+",
)
in_chan_names = zarr_input.channel_names

num_pixels_bkg = 0
num_pixels_uninf = 0
num_pixels_inf = 0
num_pixels = 0
for well_id, well_data in zarr_input.wells():
well_name, well_no = well_id.split("/")

for pos_name, pos_data in well_data.positions():
data = pos_data.data
T, C, Z, Y, X = data.shape
out_data = data.numpy()
for time in range(T):
Inf_mask = out_data[time, in_chan_names.index("Inf_mask"), ...]
# Calculate the number of pixels valued 0, 1, and 2 in 'Inf_mask'
num_pixels_bkg = num_pixels_bkg + (Inf_mask == 0).sum()
num_pixels_uninf = num_pixels_uninf + (Inf_mask == 1).sum()
num_pixels_inf = num_pixels_inf + (Inf_mask == 2).sum()
num_pixels = num_pixels + Z * X * Y

pixel_ratio_1 = [
num_pixels / num_pixels_bkg,
num_pixels / num_pixels_uninf,
num_pixels / num_pixels_inf,
]
pixel_ratio_sum = sum(pixel_ratio_1)
pixel_ratio = [ratio / pixel_ratio_sum for ratio in pixel_ratio_1]

# %% Create an instance of HCSDataModule

data_module = HCSDataModule(
dataset_path,
source_channel=["TXR_Density3D", "Phase3D"],
target_channel=["Inf_mask"],
yx_patch_size=[128, 128],
split_ratio=0.7,
z_window_size=1,
architecture="2D",
num_workers=1,
batch_size=256,
normalizations=[
NormalizeSampled(
keys=["Phase3D", "TXR_Density3D"],
level="fov_statistics",
subtrahend="median",
divisor="iqr",
)
],
augmentations=[
RandWeightedCropd(
num_samples=16,
spatial_size=[-1, 128, 128],
keys=["TXR_Density3D", "Phase3D", "Inf_mask"],
w_key="Inf_mask",
),
RandScaleIntensityd(
keys=["TXR_Density3D", "Phase3D"],
factors=[0.5, 0.5],
prob=0.5,
),
RandGaussianSmoothd(
keys=["TXR_Density3D", "Phase3D"],
prob=0.5,
sigma_x=[0.5, 1.0],
sigma_y=[0.5, 1.0],
sigma_z=[0.5, 1.0],
),
],
)

# Prepare the data
data_module.prepare_data()

# Setup the data
data_module.setup(stage="fit")

# Create a dataloader
train_dm = data_module.train_dataloader()

val_dm = data_module.val_dataloader()

# %% Set up for training

# define the logger
logger = TensorBoardLogger(
"/hpc/projects/intracellular_dashboard/viral-sensor/2024_04_25_BJ5a_DENV_TimeCourse/5-infection_classifier/0-model_training/",
name="logs",
)

# Pass the logger to the Trainer
trainer = pl.Trainer(
logger=logger,
max_epochs=500,
default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/2024_04_25_BJ5a_DENV_TimeCourse/5-infection_classifier/0-model_training/logs/",
log_every_n_steps=1,
devices=1, # Set the number of GPUs to use. This avoids run-time exception from distributed training when the node has multiple GPUs
)

# Define the checkpoint callback
checkpoint_callback = ModelCheckpoint(
dirpath="/hpc/projects/intracellular_dashboard/viral-sensor/2024_04_25_BJ5a_DENV_TimeCourse/5-infection_classifier/0-model_training/logs/",
filename="checkpoint_{epoch:02d}",
save_top_k=-1,
verbose=True,
monitor="loss/validate",
mode="min",
)

# Add the checkpoint callback to the trainer
trainer.callbacks.append(checkpoint_callback)

# Fit the model
model = SemanticSegUNet2D(
in_channels=2,
out_channels=3,
loss_function=nn.CrossEntropyLoss(weight=torch.tensor(pixel_ratio)),
)

# visualize the model
print(model)

# %% Run training.

trainer.fit(model, data_module)

# %%
Loading
Loading