Skip to content

Commit cfedaaf

Browse files
committed
adding the working example. pending missing comments.
1 parent 114b4bb commit cfedaaf

File tree

1 file changed

+62
-24
lines changed

1 file changed

+62
-24
lines changed

examples/demos/VSCyto2d_a549cells/demo_vscyto2d.py

+62-24
Original file line numberDiff line numberDiff line change
@@ -2,59 +2,51 @@
22
"""
33
# 2D Virtual Staining of A549 Cells
44
---
5+
## Prediction using the VSCyto2D to predict nuclei and membrane from phase.
56
This example shows how to virtually stain A549 cells using the _VSCyto2D_ model.
6-
7-
First we import the necessary libraries and set the random seed for reproducibility.
7+
The model is trained to predict the membrane and nuclei channels from the phase channel.
88
"""
99
# %% Imports and paths
1010
from pathlib import Path
1111

1212
import matplotlib.pyplot as plt
1313
import numpy as np
14-
import pandas as pd
15-
import torch
16-
import torchview
17-
import torchvision
1814
from iohub import open_ome_zarr
19-
from lightning.pytorch import seed_everything
20-
21-
# from rich.pretty import pprint #TODO: add pretty print(?)
22-
23-
from napari.utils.notebook_display import nbscreenshot
24-
import napari
2515

26-
# %% Imports and paths
2716
from viscy.data.hcs import HCSDataModule
2817

29-
# Trainer class and UNet.
18+
# %% Imports and paths
19+
# Viscy classes for the trainer and model
3020
from viscy.light.engine import FcmaeUNet
21+
from viscy.light.predict_writer import HCSPredictionWriter
3122
from viscy.light.trainer import VSTrainer
3223
from viscy.transforms import NormalizeSampled
33-
from viscy.light.predict_writer import HCSPredictionWriter
34-
from viscy.data.hcs import HCSDataModule
24+
from skimage.exposure import rescale_intensity
3525

3626
# %% [markdown]
37-
"""
38-
## Prediction using the 2D U-Net model to predict nuclei and membrane from phase.
39-
40-
### Construct a 2D U-Net
41-
See ``viscy.unet.networks.Unet2D.Unet2d`` ([source code](https://github.com/mehta-lab/VisCy/blob/7c5e4c1d68e70163cf514d22c475da8ea7dc3a88/viscy/unet/networks/Unet2D.py#L7)) for configuration details.
42-
"""
4327

4428
# %%
45-
input_data_path = "/hpc/projects/comp.micro/virtual_staining/datasets/test/cell_types_20x/a549_sliced/a549_hoechst_cellmask_test.zarr/0/0/0"
29+
input_data_path = "/hpc/projects/comp.micro/virtual_staining/datasets/test/cell_types_20x/a549_sliced/a549_hoechst_cellmask_test.zarr"
4630
model_ckpt_path = "/hpc/projects/comp.micro/virtual_staining/models/hek-a549-bj5a-20x/lightning_logs/tiny-2x2-finetune-e2e-amp-hek-a549-bj5a-nucleus-membrane-400ep/checkpoints/last.ckpt"
4731
output_path = "./test_a549_demo.zarr"
32+
fov = "0/0/0" # NOTE: FOV of interest
4833

34+
input_data_path = Path(input_data_path) / fov
4935
# %%
5036
# Create a the VSCyto2D
5137

38+
# NOTE: Change the following parameters as needed.
5239
GPU_ID = 0
5340
BATCH_SIZE = 10
5441
YX_PATCH_SIZE = (384, 384)
5542
phase_channel_name = "Phase3D"
5643

57-
44+
# %%[markdown]
45+
"""
46+
For this example we will use the following parameters:
47+
### For more information on the VSCyto2D model:
48+
See ``viscy.unet.networks.fcmae`` ([source code](https://github.com/mehta-lab/VisCy/blob/6a3457ec8f43ecdc51b1760092f1a678ed73244d/viscy/unet/networks/fcmae.py#L398)) for configuration details.
49+
"""
5850
# %%
5951
# Setup the data module.
6052
data_module = HCSDataModule(
@@ -98,6 +90,7 @@
9890
model_VSCyto2D.eval()
9991

10092
# %%
93+
# Setup the Trainer
10194
trainer = VSTrainer(
10295
accelerator="gpu",
10396
callbacks=[HCSPredictionWriter(output_path)],
@@ -111,3 +104,48 @@
111104
)
112105

113106
# %%
107+
# Open the output_zarr store and inspect the output
108+
colormap_1 = [0.1254902, 0.6784314, 0.972549] # bop blue
109+
colormap_2 = [0.972549, 0.6784314, 0.1254902] # bop orange
110+
111+
# Show the individual channels and the fused in a 1x3 plot
112+
output_path = Path(output_path) / fov
113+
# %%
114+
115+
fig, ax = plt.subplots(1, 3, figsize=(15, 5))
116+
with open_ome_zarr(output_path, mode="r") as store:
117+
118+
# Get the 2D images
119+
vs_nucleus = store[0][0, 0, 0] # (t,c,z,y,x)
120+
vs_membrane = store[0][0, 1, 0] # (t,c,z,y,x)
121+
# Rescale the intensity
122+
vs_nucleus = rescale_intensity(vs_nucleus, out_range=(0, 1))
123+
vs_membrane = rescale_intensity(vs_membrane, out_range=(0, 1))
124+
# VS Nucleus RGB
125+
vs_nucleus_rgb = np.zeros((*store.data.shape[-2:], 3))
126+
vs_nucleus_rgb[:, :, 0] = vs_nucleus * colormap_1[0]
127+
vs_nucleus_rgb[:, :, 1] = vs_nucleus * colormap_1[1]
128+
vs_nucleus_rgb[:, :, 2] = vs_nucleus * colormap_1[2]
129+
# VS Membrane RGB
130+
vs_membrane_rgb = np.zeros((*store.data.shape[-2:], 3))
131+
vs_membrane_rgb[:, :, 0] = vs_membrane * colormap_2[0]
132+
vs_membrane_rgb[:, :, 1] = vs_membrane * colormap_2[1]
133+
vs_membrane_rgb[:, :, 2] = vs_membrane * colormap_2[2]
134+
# Merge the two channels
135+
merged_image = np.zeros((*store.data.shape[-2:], 3))
136+
merged_image[:, :, 0] = vs_nucleus * colormap_1[0] + vs_membrane * colormap_2[0]
137+
merged_image[:, :, 1] = vs_nucleus * colormap_1[1] + vs_membrane * colormap_2[1]
138+
merged_image[:, :, 2] = vs_nucleus * colormap_1[2] + vs_membrane * colormap_2[2]
139+
140+
# Plot
141+
ax[0].imshow(vs_nucleus_rgb)
142+
ax[0].set_title("VS Nucleus")
143+
ax[1].imshow(vs_membrane_rgb)
144+
ax[1].set_title("VS Membrane")
145+
ax[2].imshow(merged_image)
146+
ax[2].set_title("VS Nucleus+Membrane")
147+
for a in ax:
148+
a.axis("off")
149+
plt.margins(0, 0)
150+
plt.show()
151+
# %%

0 commit comments

Comments
 (0)