|
2 | 2 | """
|
3 | 3 | # 2D Virtual Staining of A549 Cells
|
4 | 4 | ---
|
| 5 | +## Prediction using the VSCyto2D to predict nuclei and membrane from phase. |
5 | 6 | This example shows how to virtually stain A549 cells using the _VSCyto2D_ model.
|
6 |
| -
|
7 |
| -First we import the necessary libraries and set the random seed for reproducibility. |
| 7 | +The model is trained to predict the membrane and nuclei channels from the phase channel. |
8 | 8 | """
|
9 | 9 | # %% Imports and paths
|
10 | 10 | from pathlib import Path
|
11 | 11 |
|
12 | 12 | import matplotlib.pyplot as plt
|
13 | 13 | import numpy as np
|
14 |
| -import pandas as pd |
15 |
| -import torch |
16 |
| -import torchview |
17 |
| -import torchvision |
18 | 14 | from iohub import open_ome_zarr
|
19 |
| -from lightning.pytorch import seed_everything |
20 |
| - |
21 |
| -# from rich.pretty import pprint #TODO: add pretty print(?) |
22 |
| - |
23 |
| -from napari.utils.notebook_display import nbscreenshot |
24 |
| -import napari |
25 | 15 |
|
26 |
| -# %% Imports and paths |
27 | 16 | from viscy.data.hcs import HCSDataModule
|
28 | 17 |
|
29 |
| -# Trainer class and UNet. |
| 18 | +# %% Imports and paths |
| 19 | +# Viscy classes for the trainer and model |
30 | 20 | from viscy.light.engine import FcmaeUNet
|
| 21 | +from viscy.light.predict_writer import HCSPredictionWriter |
31 | 22 | from viscy.light.trainer import VSTrainer
|
32 | 23 | from viscy.transforms import NormalizeSampled
|
33 |
| -from viscy.light.predict_writer import HCSPredictionWriter |
34 |
| -from viscy.data.hcs import HCSDataModule |
| 24 | +from skimage.exposure import rescale_intensity |
35 | 25 |
|
36 | 26 | # %% [markdown]
|
37 |
| -""" |
38 |
| -## Prediction using the 2D U-Net model to predict nuclei and membrane from phase. |
39 |
| -
|
40 |
| -### Construct a 2D U-Net |
41 |
| -See ``viscy.unet.networks.Unet2D.Unet2d`` ([source code](https://github.com/mehta-lab/VisCy/blob/7c5e4c1d68e70163cf514d22c475da8ea7dc3a88/viscy/unet/networks/Unet2D.py#L7)) for configuration details. |
42 |
| -""" |
43 | 27 |
|
44 | 28 | # %%
|
45 |
| -input_data_path = "/hpc/projects/comp.micro/virtual_staining/datasets/test/cell_types_20x/a549_sliced/a549_hoechst_cellmask_test.zarr/0/0/0" |
| 29 | +input_data_path = "/hpc/projects/comp.micro/virtual_staining/datasets/test/cell_types_20x/a549_sliced/a549_hoechst_cellmask_test.zarr" |
46 | 30 | model_ckpt_path = "/hpc/projects/comp.micro/virtual_staining/models/hek-a549-bj5a-20x/lightning_logs/tiny-2x2-finetune-e2e-amp-hek-a549-bj5a-nucleus-membrane-400ep/checkpoints/last.ckpt"
|
47 | 31 | output_path = "./test_a549_demo.zarr"
|
| 32 | +fov = "0/0/0" # NOTE: FOV of interest |
48 | 33 |
|
| 34 | +input_data_path = Path(input_data_path) / fov |
49 | 35 | # %%
|
50 | 36 | # Create a the VSCyto2D
|
51 | 37 |
|
| 38 | +# NOTE: Change the following parameters as needed. |
52 | 39 | GPU_ID = 0
|
53 | 40 | BATCH_SIZE = 10
|
54 | 41 | YX_PATCH_SIZE = (384, 384)
|
55 | 42 | phase_channel_name = "Phase3D"
|
56 | 43 |
|
57 |
| - |
| 44 | +# %%[markdown] |
| 45 | +""" |
| 46 | +For this example we will use the following parameters: |
| 47 | +### For more information on the VSCyto2D model: |
| 48 | +See ``viscy.unet.networks.fcmae`` ([source code](https://github.com/mehta-lab/VisCy/blob/6a3457ec8f43ecdc51b1760092f1a678ed73244d/viscy/unet/networks/fcmae.py#L398)) for configuration details. |
| 49 | +""" |
58 | 50 | # %%
|
59 | 51 | # Setup the data module.
|
60 | 52 | data_module = HCSDataModule(
|
|
98 | 90 | model_VSCyto2D.eval()
|
99 | 91 |
|
100 | 92 | # %%
|
| 93 | +# Setup the Trainer |
101 | 94 | trainer = VSTrainer(
|
102 | 95 | accelerator="gpu",
|
103 | 96 | callbacks=[HCSPredictionWriter(output_path)],
|
|
111 | 104 | )
|
112 | 105 |
|
113 | 106 | # %%
|
| 107 | +# Open the output_zarr store and inspect the output |
| 108 | +colormap_1 = [0.1254902, 0.6784314, 0.972549] # bop blue |
| 109 | +colormap_2 = [0.972549, 0.6784314, 0.1254902] # bop orange |
| 110 | + |
| 111 | +# Show the individual channels and the fused in a 1x3 plot |
| 112 | +output_path = Path(output_path) / fov |
| 113 | +# %% |
| 114 | + |
| 115 | +fig, ax = plt.subplots(1, 3, figsize=(15, 5)) |
| 116 | +with open_ome_zarr(output_path, mode="r") as store: |
| 117 | + |
| 118 | + # Get the 2D images |
| 119 | + vs_nucleus = store[0][0, 0, 0] # (t,c,z,y,x) |
| 120 | + vs_membrane = store[0][0, 1, 0] # (t,c,z,y,x) |
| 121 | + # Rescale the intensity |
| 122 | + vs_nucleus = rescale_intensity(vs_nucleus, out_range=(0, 1)) |
| 123 | + vs_membrane = rescale_intensity(vs_membrane, out_range=(0, 1)) |
| 124 | + # VS Nucleus RGB |
| 125 | + vs_nucleus_rgb = np.zeros((*store.data.shape[-2:], 3)) |
| 126 | + vs_nucleus_rgb[:, :, 0] = vs_nucleus * colormap_1[0] |
| 127 | + vs_nucleus_rgb[:, :, 1] = vs_nucleus * colormap_1[1] |
| 128 | + vs_nucleus_rgb[:, :, 2] = vs_nucleus * colormap_1[2] |
| 129 | + # VS Membrane RGB |
| 130 | + vs_membrane_rgb = np.zeros((*store.data.shape[-2:], 3)) |
| 131 | + vs_membrane_rgb[:, :, 0] = vs_membrane * colormap_2[0] |
| 132 | + vs_membrane_rgb[:, :, 1] = vs_membrane * colormap_2[1] |
| 133 | + vs_membrane_rgb[:, :, 2] = vs_membrane * colormap_2[2] |
| 134 | + # Merge the two channels |
| 135 | + merged_image = np.zeros((*store.data.shape[-2:], 3)) |
| 136 | + merged_image[:, :, 0] = vs_nucleus * colormap_1[0] + vs_membrane * colormap_2[0] |
| 137 | + merged_image[:, :, 1] = vs_nucleus * colormap_1[1] + vs_membrane * colormap_2[1] |
| 138 | + merged_image[:, :, 2] = vs_nucleus * colormap_1[2] + vs_membrane * colormap_2[2] |
| 139 | + |
| 140 | + # Plot |
| 141 | + ax[0].imshow(vs_nucleus_rgb) |
| 142 | + ax[0].set_title("VS Nucleus") |
| 143 | + ax[1].imshow(vs_membrane_rgb) |
| 144 | + ax[1].set_title("VS Membrane") |
| 145 | + ax[2].imshow(merged_image) |
| 146 | + ax[2].set_title("VS Nucleus+Membrane") |
| 147 | + for a in ax: |
| 148 | + a.axis("off") |
| 149 | + plt.margins(0, 0) |
| 150 | + plt.show() |
| 151 | +# %% |
0 commit comments