From c6692f1419808b26ca205aded069e9f62436e908 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 10 Jan 2024 15:19:18 -0800 Subject: [PATCH 01/92] refactor data loading into its own module --- examples/demo_dlmbl/debug_log_graph.py | 2 +- examples/demo_dlmbl/solution.py | 2 +- tests/light/test_data.py | 2 +- viscy/cli/cli.py | 2 +- viscy/data/__init__.py | 0 viscy/{light/data.py => data/hcs.py} | 0 viscy/light/engine.py | 2 +- viscy/light/predict_writer.py | 2 +- viscy/scripts/profiling.py | 2 +- 9 files changed, 7 insertions(+), 7 deletions(-) create mode 100644 viscy/data/__init__.py rename viscy/{light/data.py => data/hcs.py} (100%) diff --git a/examples/demo_dlmbl/debug_log_graph.py b/examples/demo_dlmbl/debug_log_graph.py index 1819b02f..ec987118 100644 --- a/examples/demo_dlmbl/debug_log_graph.py +++ b/examples/demo_dlmbl/debug_log_graph.py @@ -19,7 +19,7 @@ from torch.utils.tensorboard import SummaryWriter # for logging to tensorboard # HCSDataModule makes it easy to load data during training. -from viscy.light.data import HCSDataModule +from viscy.data.hcs import HCSDataModule # Trainer class and UNet. from viscy.light.engine import VSUNet diff --git a/examples/demo_dlmbl/solution.py b/examples/demo_dlmbl/solution.py index 933f939d..2c81aa6f 100644 --- a/examples/demo_dlmbl/solution.py +++ b/examples/demo_dlmbl/solution.py @@ -83,7 +83,7 @@ from torch.utils.tensorboard import SummaryWriter # for logging to tensorboard # HCSDataModule makes it easy to load data during training. -from viscy.light.data import HCSDataModule +from viscy.data.hcs import HCSDataModule # training augmentations from viscy.transforms import ( diff --git a/tests/light/test_data.py b/tests/light/test_data.py index 263f8f90..153f175f 100644 --- a/tests/light/test_data.py +++ b/tests/light/test_data.py @@ -4,7 +4,7 @@ from iohub import open_ome_zarr from pytest import mark -from viscy.light.data import HCSDataModule +from viscy.data.hcs import HCSDataModule from viscy.light.trainer import VSTrainer diff --git a/viscy/cli/cli.py b/viscy/cli/cli.py index 0946bb0f..f9a55f12 100644 --- a/viscy/cli/cli.py +++ b/viscy/cli/cli.py @@ -9,7 +9,7 @@ from lightning.pytorch.cli import LightningCLI from lightning.pytorch.loggers import TensorBoardLogger -from viscy.light.data import HCSDataModule +from viscy.data.hcs import HCSDataModule from viscy.light.engine import VSUNet from viscy.light.trainer import VSTrainer diff --git a/viscy/data/__init__.py b/viscy/data/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/viscy/light/data.py b/viscy/data/hcs.py similarity index 100% rename from viscy/light/data.py rename to viscy/data/hcs.py diff --git a/viscy/light/engine.py b/viscy/light/engine.py index f165a056..74f14aaa 100644 --- a/viscy/light/engine.py +++ b/viscy/light/engine.py @@ -25,8 +25,8 @@ structural_similarity_index_measure, ) +from viscy.data.hcs import Sample from viscy.evaluation.evaluation_metrics import mean_average_precision, ms_ssim_25d -from viscy.light.data import Sample from viscy.unet.networks.Unet2D import Unet2d from viscy.unet.networks.Unet21D import Unet21d from viscy.unet.networks.Unet25D import Unet25d diff --git a/viscy/light/predict_writer.py b/viscy/light/predict_writer.py index a6ae88cb..7a58009c 100644 --- a/viscy/light/predict_writer.py +++ b/viscy/light/predict_writer.py @@ -9,7 +9,7 @@ from lightning.pytorch.callbacks import BasePredictionWriter from numpy.typing import DTypeLike, NDArray -from viscy.light.data import HCSDataModule, Sample +from viscy.data.hcs import HCSDataModule, Sample __all__ = ["HCSPredictionWriter"] _logger = logging.getLogger("lightning.pytorch") diff --git a/viscy/scripts/profiling.py b/viscy/scripts/profiling.py index 0c947f45..a0c3ca6d 100644 --- a/viscy/scripts/profiling.py +++ b/viscy/scripts/profiling.py @@ -2,7 +2,7 @@ from profilehooks import profile -from viscy.light.data import HCSDataModule +from viscy.data.hcs import HCSDataModule dataset = "/path/to/dataset.zarr" From 3d8e7e2646a10e9483120ad4e12be736342cf621 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 10 Jan 2024 15:26:59 -0800 Subject: [PATCH 02/92] update type annotations --- viscy/unet/networks/Unet21D.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/viscy/unet/networks/Unet21D.py b/viscy/unet/networks/Unet21D.py index 7c32e34b..51ed9839 100644 --- a/viscy/unet/networks/Unet21D.py +++ b/viscy/unet/networks/Unet21D.py @@ -1,11 +1,11 @@ -from typing import Callable, Literal, Optional, Sequence, Union +from typing import Callable, Literal, Sequence import timm import torch from monai.networks.blocks import Convolution, ResidualUnit, UpSample from monai.networks.blocks.dynunet_block import get_conv_layer from monai.networks.utils import normal_init -from torch import nn +from torch import Tensor, nn def icnr_init( @@ -45,7 +45,7 @@ def _get_convnext_stage( in_channels: int, out_channels: int, depth: int, - upsample_factor: Optional[int] = None, + upsample_factor: int | None = None, ) -> nn.Module: stage = timm.models.convnext.ConvNeXtStage( in_chs=in_channels, @@ -83,7 +83,7 @@ def __init__( stride=kernel_size, ) - def forward(self, x: torch.Tensor): + def forward(self, x: Tensor): x = self.conv(x) b, c, d, h, w = x.shape # project Z/depth into channels @@ -101,7 +101,7 @@ def __init__( mode: Literal["deconv", "pixelshuffle"], conv_blocks: int, norm_name: str, - upsample_pre_conv: Optional[Union[Literal["default"], Callable]], + upsample_pre_conv: Literal["default"] | Callable | None, ) -> None: super().__init__() spatial_dims = 2 @@ -145,11 +145,11 @@ def __init__( upsample_factor=conv_weight_init_factor, ) - def forward(self, inp: torch.Tensor, skip: torch.Tensor) -> torch.Tensor: + def forward(self, inp: Tensor, skip: Tensor) -> Tensor: """ - :param torch.Tensor inp: Low resolution features - :param torch.Tensor skip: High resolution skip connection features - :return torch.Tensor: High resolution features + :param Tensor inp: Low resolution features + :param Tensor skip: High resolution skip connection features + :return Tensor: High resolution features """ inp = self.upsample(inp) inp = torch.cat([inp, skip], dim=1) @@ -192,7 +192,7 @@ def __init__( self.out = nn.PixelShuffle(2) self.out_stack_depth = out_stack_depth - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: Tensor) -> Tensor: x = self.upsample(x) d = self.out_stack_depth + 2 b, c, h, w = x.shape @@ -209,7 +209,7 @@ class UnsqueezeHead(nn.Module): def __init__(self) -> None: super().__init__() - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: Tensor) -> Tensor: x = x.unsqueeze(2) return x @@ -222,7 +222,7 @@ def __init__( mode: Literal["deconv", "pixelshuffle"], conv_blocks: int, strides: list[int], - upsample_pre_conv: Optional[Union[Literal["default"], Callable]], + upsample_pre_conv: Literal["default"] | Callable | None, ) -> None: super().__init__() self.decoder_stages = nn.ModuleList([]) @@ -240,7 +240,7 @@ def __init__( ) self.decoder_stages.append(stage) - def forward(self, features: Sequence[torch.Tensor]) -> torch.Tensor: + def forward(self, features: Sequence[Tensor]) -> Tensor: feat = features[0] # padding features.append(None) @@ -328,7 +328,7 @@ def num_blocks(self) -> int: """2-times downscaling factor of the smallest feature map""" return 6 - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: Tensor) -> Tensor: x = self.stem(x) x: list = self.encoder_stages(x) x.reverse() From fdcbf5536133291cee298c654fac5645ca4acfab Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 10 Jan 2024 16:01:28 -0800 Subject: [PATCH 03/92] move the logging module out --- viscy/unet/{utils => }/logging.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename viscy/unet/{utils => }/logging.py (100%) diff --git a/viscy/unet/utils/logging.py b/viscy/unet/logging.py similarity index 100% rename from viscy/unet/utils/logging.py rename to viscy/unet/logging.py From a2913817e0c432933ba5c81c3662993c579d7e66 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 10 Jan 2024 16:03:10 -0800 Subject: [PATCH 04/92] move old logging into utils --- viscy/{unet => utils}/logging.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename viscy/{unet => utils}/logging.py (100%) diff --git a/viscy/unet/logging.py b/viscy/utils/logging.py similarity index 100% rename from viscy/unet/logging.py rename to viscy/utils/logging.py From 3cf8fa23c73ce754e27498a07879ccb23db7d170 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 11 Jan 2024 09:31:21 -0800 Subject: [PATCH 05/92] rename tests to match module name --- tests/{torch_unet => unet}/networks/Unet25D_tests.py | 0 tests/{torch_unet => unet}/networks/Unet2D_tests.py | 0 tests/{torch_unet => unet}/networks/layers/ConvBlock2D_tests.py | 0 tests/{torch_unet => unet}/networks/layers/ConvBlock3D_tests.py | 0 4 files changed, 0 insertions(+), 0 deletions(-) rename tests/{torch_unet => unet}/networks/Unet25D_tests.py (100%) rename tests/{torch_unet => unet}/networks/Unet2D_tests.py (100%) rename tests/{torch_unet => unet}/networks/layers/ConvBlock2D_tests.py (100%) rename tests/{torch_unet => unet}/networks/layers/ConvBlock3D_tests.py (100%) diff --git a/tests/torch_unet/networks/Unet25D_tests.py b/tests/unet/networks/Unet25D_tests.py similarity index 100% rename from tests/torch_unet/networks/Unet25D_tests.py rename to tests/unet/networks/Unet25D_tests.py diff --git a/tests/torch_unet/networks/Unet2D_tests.py b/tests/unet/networks/Unet2D_tests.py similarity index 100% rename from tests/torch_unet/networks/Unet2D_tests.py rename to tests/unet/networks/Unet2D_tests.py diff --git a/tests/torch_unet/networks/layers/ConvBlock2D_tests.py b/tests/unet/networks/layers/ConvBlock2D_tests.py similarity index 100% rename from tests/torch_unet/networks/layers/ConvBlock2D_tests.py rename to tests/unet/networks/layers/ConvBlock2D_tests.py diff --git a/tests/torch_unet/networks/layers/ConvBlock3D_tests.py b/tests/unet/networks/layers/ConvBlock3D_tests.py similarity index 100% rename from tests/torch_unet/networks/layers/ConvBlock3D_tests.py rename to tests/unet/networks/layers/ConvBlock3D_tests.py From d4cd41db42ecf62b94ab26e5bbc9a4d7feecfcac Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 11 Jan 2024 09:31:30 -0800 Subject: [PATCH 06/92] bump torch --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 8d60ee1d..b60cd534 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ license = { file = "LICENSE" } authors = [{ name = "CZ Biohub SF", email = "compmicro@czbiohub.org" }] dependencies = [ "iohub==0.1.0rc0", - "torch>=2.0.0", + "torch>=2.1.2", "timm>=0.9.5", "tensorboard>=2.13.0", "lightning>=2.0.1", From e87d3969617de3bc7a0b47e136b8e1270dad1ea6 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 11 Jan 2024 16:35:30 -0800 Subject: [PATCH 07/92] draft fcmae encoder --- tests/unet/__init__.py | 0 tests/unet/test_fcmae.py | 43 ++++++ viscy/unet/networks/Unet21D.py | 2 +- viscy/unet/networks/fcmae.py | 235 +++++++++++++++++++++++++++++++++ 4 files changed, 279 insertions(+), 1 deletion(-) create mode 100644 tests/unet/__init__.py create mode 100644 tests/unet/test_fcmae.py create mode 100644 viscy/unet/networks/fcmae.py diff --git a/tests/unet/__init__.py b/tests/unet/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unet/test_fcmae.py b/tests/unet/test_fcmae.py new file mode 100644 index 00000000..ae8e0ec6 --- /dev/null +++ b/tests/unet/test_fcmae.py @@ -0,0 +1,43 @@ +import torch + +from viscy.unet.networks.fcmae import ( + MaskedConvNeXtV2Block, + MaskedConvNeXtV2Stage, + MaskedGlobalResponseNorm, +) + + +def test_masked_grn() -> None: + x = torch.rand(2, 3, 4, 5) + grn = MaskedGlobalResponseNorm(3, channels_last=False) + grn.gamma.data = torch.ones_like(grn.gamma.data) + mask = torch.ones((1, 1, 4, 5), dtype=torch.bool) + mask[:, :, 2:, 2:] = False + normalized = grn(x) + assert not torch.allclose(normalized, x) + assert torch.allclose(grn(x, mask)[:, :, 2:, 2:], grn(x[:, :, 2:, 2:])) + grn = MaskedGlobalResponseNorm(5, channels_last=True) + grn.gamma.data = torch.ones_like(grn.gamma.data) + mask = torch.ones((1, 3, 4, 1), dtype=torch.bool) + mask[:, 1:, 2:, :] = False + assert torch.allclose(grn(x, mask)[:, 1:, 2:, :], grn(x[:, 1:, 2:, :])) + + +def test_masked_convnextv2_block() -> None: + x = torch.rand(2, 3, 4, 5) + mask = x[0, 0] > 0.5 + block = MaskedConvNeXtV2Block(3, 3 * 2) + assert len(block(x).unique()) == x.numel() * 2 + block = MaskedConvNeXtV2Block(3, 3) + masked_out = block(x, mask) + assert len(masked_out[:, :, mask].unique()) == x.shape[1] + + +def test_masked_convnextv2_stage() -> None: + x = torch.rand(2, 3, 16, 16) + mask = torch.rand(4, 4) > 0.5 + stage = MaskedConvNeXtV2Stage(3, 3, kernel_size=7, stride=2, num_blocks=2) + out = stage(x) + assert out.shape == (2, 3, 8, 8) + masked_out = stage(x, mask) + assert not torch.allclose(masked_out, out) diff --git a/viscy/unet/networks/Unet21D.py b/viscy/unet/networks/Unet21D.py index 51ed9839..c4320240 100644 --- a/viscy/unet/networks/Unet21D.py +++ b/viscy/unet/networks/Unet21D.py @@ -12,7 +12,7 @@ def icnr_init( conv: nn.Module, upsample_factor: int, upsample_dims: int, - init=nn.init.kaiming_normal_, + init: Callable = nn.init.kaiming_normal_, ): """ ICNR initialization for 2D/3D kernels adapted from Aitken et al.,2017 , diff --git a/viscy/unet/networks/fcmae.py b/viscy/unet/networks/fcmae.py new file mode 100644 index 00000000..818e8f88 --- /dev/null +++ b/viscy/unet/networks/fcmae.py @@ -0,0 +1,235 @@ +""" +Fully Convolutional Masked Autoencoder as described in ConvNeXt V2 +based on the official JAX example in +https://github.com/facebookresearch/ConvNeXt-V2/blob/main/TRAINING.md#implementing-fcmae-with-masked-convolution-in-jax +also referring to timm's dense implementation of the encoder in ``timm.models.convnext`` +""" + + +from typing import Callable, Literal, Sequence + +import torch +from timm.layers import DropPath, LayerNorm2d, create_conv2d, trunc_normal_ +from timm.models.convnext import Downsample +from torch import BoolTensor, Tensor, nn + + +def _upsample_mask(mask: BoolTensor, features: Tensor) -> BoolTensor: + mask = mask[..., :, :][None, None] + if features.shape[-2:] != mask.shape[-2:]: + if not all(i % j == 0 for i, j in zip(features.shape[-2:], mask.shape[-2:])): + raise ValueError( + f"feature map shape {features.shape} must be divisible by " + f"mask shape {mask.shape}." + ) + mask = mask.repeat_interleave( + features.shape[-2] // mask.shape[-2], dim=-2 + ).repeat_interleave(features.shape[-1] // mask.shape[-1], dim=-1) + return mask + + +class MaskedGlobalResponseNorm(nn.Module): + """ + Masked Global Response Normalization. + + :param int dim: number of input channels + :param float eps: small value added for numerical stability, + defaults to 1e-6 + :param bool channels_last: BHWC (True) or BCHW (False) dimension ordering, + defaults to False + """ + + def __init__( + self, dim: int, eps: float = 1e-6, channels_last: bool = False + ) -> None: + super().__init__() + if channels_last: + self.spatial_dim = (1, 2) + self.channel_dim = -1 + weights_shape = (1, 1, 1, dim) + else: + self.spatial_dim = (2, 3) + self.channel_dim = 1 + weights_shape = (1, dim, 1, 1) + self.gamma = nn.Parameter(torch.zeros(weights_shape)) + self.beta = nn.Parameter(torch.zeros(weights_shape)) + self.eps = eps + + def forward(self, x: Tensor, mask: BoolTensor | None = None) -> Tensor: + """ + :param Tensor x: input tensor, BHWC or BCHW + :param BoolTensor | None mask: boolean mask, defaults to None + :return Tensor: normalized tensor + """ + samples = x if mask is None else x * ~mask + g_x = samples.norm(p=2, dim=self.spatial_dim, keepdim=True) + n_x = g_x / (g_x.mean(dim=self.channel_dim, keepdim=True) + self.eps) + return x + torch.addcmul(self.beta, self.gamma, x * n_x) + + +class MaskedConvNeXtV2Block(nn.Module): + """Masked ConvNeXt V2 Block. + + :param int in_channels: input channels + :param int | None out_channels: output channels, defaults to None + :param int kernel_size: depth-wise convolution kernel size, defaults to 7 + :param int stride: downsample stride, defaults to 1 + :param int mlp_ratio: MLP expansion ratio, defaults to 4 + :param float drop_path: drop path rate, defaults to 0.0 + """ + + def __init__( + self, + in_channels: int, + out_channels: int | None = None, + kernel_size: int = 7, + stride: int = 1, + mlp_ratio: int = 4, + drop_path: float = 0.0, + ) -> None: + super().__init__() + out_channels = out_channels or in_channels + self.dwconv = create_conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + depthwise=True, + ) + self.layernorm = LayerNorm2d(out_channels) + self.pwconv1 = nn.Conv2d(out_channels, mlp_ratio * out_channels, kernel_size=1) + self.act = nn.GELU() + self.grn = MaskedGlobalResponseNorm(mlp_ratio * out_channels) + self.pwconv2 = nn.Conv2d(mlp_ratio * out_channels, out_channels, kernel_size=1) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + if in_channels != out_channels or stride > 1: + self.shortcut = Downsample(in_channels, out_channels, stride=stride) + else: + self.shortcut = nn.Identity() + + def forward(self, x: Tensor, mask: BoolTensor | None = None) -> Tensor: + """ + :param Tensor x: input tensor (BCHW) + :param BoolTensor | None mask: boolean mask, defaults to None + :return Tensor: output tensor (BCHW) + """ + shortcut = self.shortcut(x) + if mask is not None: + x *= ~mask + x = self.dwconv(x) + if mask is not None: + x *= ~mask + x = self.layernorm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.grn(x, mask) + x = self.pwconv2(x) + x = self.drop_path(x) + shortcut + return x + + +class MaskedConvNeXtV2Stage(nn.Module): + """Masked ConvNeXt V2 Stage. + + :param int in_channels: input channels + :param int out_channels: output channels + :param int kernel_size: depth-wise convolution kernel size, defaults to 7 + :param int stride: downsampling factor of this stage, defaults to 2 + :param int num_blocks: number of residual blocks, defaults to 2 + :param Sequence[float] | None drop_path_rates: drop path rates of each block, + defaults to None + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 7, + stride: int = 2, + num_blocks: int = 2, + drop_path_rates: Sequence[float] | None = None, + ) -> None: + super().__init__() + if drop_path_rates is None: + drop_path_rates = [0.0] * num_blocks + elif len(drop_path_rates) != num_blocks: + raise ValueError( + "length of drop_path_rates must be equal to " + f"the number of blocks {num_blocks}, got {len(drop_path_rates)}." + ) + if in_channels != out_channels or stride > 1: + downsample_kernel_size = stride if stride > 1 else 1 + self.downsample = nn.Sequential( + LayerNorm2d(in_channels), + nn.Conv2d( + in_channels, + out_channels, + kernel_size=downsample_kernel_size, + stride=stride, + padding=0, + ), + ) + in_channels = out_channels + else: + self.downsample = nn.Identity() + self.blocks = nn.ModuleList() + for i in range(num_blocks): + self.blocks.append( + MaskedConvNeXtV2Block( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=1, + drop_path=drop_path_rates[i], + ) + ) + in_channels = out_channels + + def forward(self, x: Tensor, mask: BoolTensor | None = None) -> Tensor: + """ + :param Tensor x: input tensor (BCHW) + :param BoolTensor | None mask: boolean mask, defaults to None + :return Tensor: output tensor (BCHW) + """ + x = self.downsample(x) + if mask is not None: + mask = _upsample_mask(mask, x) + for block in self.blocks: + x = block(x, mask) + return x + + +class MaskedMultiscaleEncoder(nn.Module): + def __init__( + self, + in_channels: int, + stage_blocks: Sequence[int] = (3, 3, 9, 3), + dims: Sequence[int] = (96, 192, 384, 768), + drop_path_rate: float = 0.0, + ) -> None: + super().__init__() + self.stages = nn.ModuleList() + chs = [in_channels, *dims] + for i, num_blocks in enumerate(stage_blocks): + self.stages.append( + MaskedConvNeXtV2Stage( + chs[i], + chs[i + 1], + kernel_size=7, + stride=2, + num_blocks=num_blocks, + drop_path_rates=[drop_path_rate] * num_blocks, + ) + ) + + def forward(self, x: Tensor, mask: BoolTensor | None = None) -> Tensor: + """ + :param Tensor x: input tensor (BCHW) + :param BoolTensor | None mask: boolean mask, defaults to None + :return Tensor: output tensor (BCHW) + """ + features = [] + for stage in self.stages: + x = stage(x, mask) + features.append(x) + return features From dccce5f785581300dd4387f2d5f0548be50af5bf Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Fri, 12 Jan 2024 15:14:36 -0800 Subject: [PATCH 08/92] add stem to the encoder --- tests/unet/test_fcmae.py | 37 +++++++++++++ viscy/unet/networks/fcmae.py | 101 ++++++++++++++++++++++++++++++----- 2 files changed, 125 insertions(+), 13 deletions(-) diff --git a/tests/unet/test_fcmae.py b/tests/unet/test_fcmae.py index ae8e0ec6..73dc5920 100644 --- a/tests/unet/test_fcmae.py +++ b/tests/unet/test_fcmae.py @@ -1,9 +1,12 @@ import torch from viscy.unet.networks.fcmae import ( + AdaptiveProjection, MaskedConvNeXtV2Block, MaskedConvNeXtV2Stage, MaskedGlobalResponseNorm, + MaskedMultiscaleEncoder, + upsample_mask, ) @@ -41,3 +44,37 @@ def test_masked_convnextv2_stage() -> None: assert out.shape == (2, 3, 8, 8) masked_out = stage(x, mask) assert not torch.allclose(masked_out, out) + + +def test_adaptive_projection() -> None: + proj = AdaptiveProjection(3, 12, kernel_size_2d=4, kernel_depth=5, in_stack_depth=5) + assert proj(torch.rand(2, 3, 5, 8, 8)).shape == (2, 12, 2, 2) + assert proj(torch.rand(2, 3, 1, 12, 16)).shape == (2, 12, 3, 4) + proj = AdaptiveProjection( + 3, 12, kernel_size_2d=(2, 4), kernel_depth=5, in_stack_depth=15 + ) + assert proj(torch.rand(2, 3, 15, 6, 8)).shape == (2, 12, 3, 2) + + +def test_masked_multiscale_encoder() -> None: + xy_size = 64 + dims = [12, 24, 48, 96] + x = torch.rand(2, 3, 5, xy_size, xy_size) + encoder = MaskedMultiscaleEncoder(3, dims=dims) + # auto_masked_features, mask = encoder(x, mask_ratio=0.5) + auto_masked_features = encoder(x) + target_shape = list(x.shape) + target_shape.pop(1) + pre_masked_features = encoder(x) #encoder(x * ~upsample_mask(mask, target_shape).unsqueeze(1)) + assert len(auto_masked_features) == len(pre_masked_features) == 4 + for i, (dim, afeat, pfeat) in enumerate( + zip(dims, auto_masked_features, pre_masked_features) + ): + assert afeat.shape[0] == x.shape[0] + assert afeat.shape[1] == dim + stride = 2 * 2 ** (i + 1) + assert afeat.shape[2] == afeat.shape[3] == xy_size // stride + assert torch.allclose(afeat, pfeat, rtol=1e-1, atol=5e-2), ( + i, + (afeat - pfeat).abs().max(), + ) diff --git a/viscy/unet/networks/fcmae.py b/viscy/unet/networks/fcmae.py index 818e8f88..71644955 100644 --- a/viscy/unet/networks/fcmae.py +++ b/viscy/unet/networks/fcmae.py @@ -11,20 +11,19 @@ import torch from timm.layers import DropPath, LayerNorm2d, create_conv2d, trunc_normal_ from timm.models.convnext import Downsample -from torch import BoolTensor, Tensor, nn +from torch import BoolTensor, Size, Tensor, nn -def _upsample_mask(mask: BoolTensor, features: Tensor) -> BoolTensor: - mask = mask[..., :, :][None, None] - if features.shape[-2:] != mask.shape[-2:]: - if not all(i % j == 0 for i, j in zip(features.shape[-2:], mask.shape[-2:])): +def upsample_mask(mask: BoolTensor, target: Size) -> BoolTensor: + if target[-2:] != mask.shape[-2:]: + if not all(i % j == 0 for i, j in zip(target, mask.shape)): raise ValueError( - f"feature map shape {features.shape} must be divisible by " + f"feature map shape {target} must be divisible by " f"mask shape {mask.shape}." ) mask = mask.repeat_interleave( - features.shape[-2] // mask.shape[-2], dim=-2 - ).repeat_interleave(features.shape[-1] // mask.shape[-1], dim=-1) + target[-2] // mask.shape[-2], dim=-2 + ).repeat_interleave(target[-1] // mask.shape[-1], dim=-1) return mask @@ -193,12 +192,64 @@ def forward(self, x: Tensor, mask: BoolTensor | None = None) -> Tensor: """ x = self.downsample(x) if mask is not None: - mask = _upsample_mask(mask, x) + mask = upsample_mask(mask, x.shape) for block in self.blocks: x = block(x, mask) return x +class AdaptiveProjection(nn.Module): + """ + Patchifying layer for projecting 2D or 3D input into 2D feature maps. + Masking is not needed because the mask will cover entire patches. + + :param int in_channels: input channels + :param int out_channels: output channels + :param Sequence[int, int] | int kernel_size_2d: kernel width and height + :param int kernel_depth: kernel depth for 3D input + :param int in_stack_depth: input stack depth for 3D input + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size_2d: tuple[int, int] | int = 4, + kernel_depth: int = 5, + in_stack_depth: int = 5, + ) -> None: + super().__init__() + ratio = in_stack_depth // kernel_depth + if isinstance(kernel_size_2d, int): + kernel_size_2d = [kernel_size_2d] * 2 + kernel_size_3d = [kernel_depth, *kernel_size_2d] + self.conv3d = nn.Conv3d( + in_channels=in_channels, + out_channels=out_channels // ratio, + kernel_size=kernel_size_3d, + stride=kernel_size_3d, + ) + self.conv2d = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size_2d, + stride=kernel_size_2d, + ) + + def forward(self, x: Tensor) -> Tensor: + """ + :param Tensor x: input tensor (BCDHW) + :return Tensor: output tensor (BCHW) + """ + if x.shape[2] > 1: + x = self.conv3d(x) + b, c, d, h, w = x.shape + # project Z/depth into channels + # return a view when possible (contiguous) + return x.reshape(b, c * d, h, w) + return self.conv2d(x.squeeze(2)) + + class MaskedMultiscaleEncoder(nn.Module): def __init__( self, @@ -208,28 +259,52 @@ def __init__( drop_path_rate: float = 0.0, ) -> None: super().__init__() + stem_kernel_size_2d = 4 + self.stem = nn.Sequential( + AdaptiveProjection( + in_channels, dims[0], kernel_size_2d=stem_kernel_size_2d, kernel_depth=5 + ), + LayerNorm2d(dims[0]), + ) self.stages = nn.ModuleList() - chs = [in_channels, *dims] + chs = [dims[0], *dims] for i, num_blocks in enumerate(stage_blocks): + stride = 1 if i == 0 else 2 self.stages.append( MaskedConvNeXtV2Stage( chs[i], chs[i + 1], kernel_size=7, - stride=2, + stride=stride, num_blocks=num_blocks, drop_path_rates=[drop_path_rate] * num_blocks, ) ) + self.total_stride = stem_kernel_size_2d * 2 ** (len(self.stages) - 1) - def forward(self, x: Tensor, mask: BoolTensor | None = None) -> Tensor: + def forward(self, x: Tensor, mask_ratio: float = 0.0) -> Tensor: """ :param Tensor x: input tensor (BCHW) - :param BoolTensor | None mask: boolean mask, defaults to None + :param float mask_ratio: ratio of the feature maps to mask, + defaults to 0.0 (no masking) :return Tensor: output tensor (BCHW) """ + if mask_ratio > 0.0: + noise = torch.rand( + x.shape[0], + 1, + x.shape[-2] // self.total_stride, + x.shape[-1] // self.total_stride, + device=x.device, + ) + mask = noise > mask_ratio + else: + mask = None + x = self.stem(x) features = [] for stage in self.stages: x = stage(x, mask) features.append(x) + if mask is not None: + return features, mask return features From 55087315f6783417acad17500e0fe3b47899b125 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Fri, 12 Jan 2024 15:56:23 -0800 Subject: [PATCH 09/92] wip: masked stem layernorm --- tests/unet/test_fcmae.py | 15 +++++++------- viscy/unet/networks/fcmae.py | 38 ++++++++++++++++++++++++++++-------- 2 files changed, 38 insertions(+), 15 deletions(-) diff --git a/tests/unet/test_fcmae.py b/tests/unet/test_fcmae.py index 73dc5920..b9a3d389 100644 --- a/tests/unet/test_fcmae.py +++ b/tests/unet/test_fcmae.py @@ -1,7 +1,7 @@ import torch from viscy.unet.networks.fcmae import ( - AdaptiveProjection, + MaskedAdaptiveProjection, MaskedConvNeXtV2Block, MaskedConvNeXtV2Stage, MaskedGlobalResponseNorm, @@ -47,10 +47,12 @@ def test_masked_convnextv2_stage() -> None: def test_adaptive_projection() -> None: - proj = AdaptiveProjection(3, 12, kernel_size_2d=4, kernel_depth=5, in_stack_depth=5) + proj = MaskedAdaptiveProjection( + 3, 12, kernel_size_2d=4, kernel_depth=5, in_stack_depth=5 + ) assert proj(torch.rand(2, 3, 5, 8, 8)).shape == (2, 12, 2, 2) assert proj(torch.rand(2, 3, 1, 12, 16)).shape == (2, 12, 3, 4) - proj = AdaptiveProjection( + proj = MaskedAdaptiveProjection( 3, 12, kernel_size_2d=(2, 4), kernel_depth=5, in_stack_depth=15 ) assert proj(torch.rand(2, 3, 15, 6, 8)).shape == (2, 12, 3, 2) @@ -61,11 +63,10 @@ def test_masked_multiscale_encoder() -> None: dims = [12, 24, 48, 96] x = torch.rand(2, 3, 5, xy_size, xy_size) encoder = MaskedMultiscaleEncoder(3, dims=dims) - # auto_masked_features, mask = encoder(x, mask_ratio=0.5) - auto_masked_features = encoder(x) + auto_masked_features, mask = encoder(x, mask_ratio=0.5) target_shape = list(x.shape) target_shape.pop(1) - pre_masked_features = encoder(x) #encoder(x * ~upsample_mask(mask, target_shape).unsqueeze(1)) + pre_masked_features = encoder(x * ~upsample_mask(mask, target_shape).unsqueeze(1)) assert len(auto_masked_features) == len(pre_masked_features) == 4 for i, (dim, afeat, pfeat) in enumerate( zip(dims, auto_masked_features, pre_masked_features) @@ -74,7 +75,7 @@ def test_masked_multiscale_encoder() -> None: assert afeat.shape[1] == dim stride = 2 * 2 ** (i + 1) assert afeat.shape[2] == afeat.shape[3] == xy_size // stride - assert torch.allclose(afeat, pfeat, rtol=1e-1, atol=5e-2), ( + assert torch.allclose(afeat, pfeat, rtol=5e-2, atol=5e-2), ( i, (afeat - pfeat).abs().max(), ) diff --git a/viscy/unet/networks/fcmae.py b/viscy/unet/networks/fcmae.py index 71644955..416c50ad 100644 --- a/viscy/unet/networks/fcmae.py +++ b/viscy/unet/networks/fcmae.py @@ -9,7 +9,14 @@ from typing import Callable, Literal, Sequence import torch -from timm.layers import DropPath, LayerNorm2d, create_conv2d, trunc_normal_ +from timm.layers import ( + DropPath, + GlobalResponseNormMlp, + LayerNorm2d, + LayerNorm, + create_conv2d, + trunc_normal_, +) from timm.models.convnext import Downsample from torch import BoolTensor, Size, Tensor, nn @@ -198,10 +205,9 @@ def forward(self, x: Tensor, mask: BoolTensor | None = None) -> Tensor: return x -class AdaptiveProjection(nn.Module): +class MaskedAdaptiveProjection(nn.Module): """ - Patchifying layer for projecting 2D or 3D input into 2D feature maps. - Masking is not needed because the mask will cover entire patches. + Masked patchifying layer for projecting 2D or 3D input into 2D feature maps. :param int in_channels: input channels :param int out_channels: output channels @@ -235,19 +241,35 @@ def __init__( kernel_size=kernel_size_2d, stride=kernel_size_2d, ) + self.norm = nn.LayerNorm(out_channels) - def forward(self, x: Tensor) -> Tensor: + def forward(self, x: Tensor, mask: BoolTensor = None) -> Tensor: """ :param Tensor x: input tensor (BCDHW) + :param BoolTensor mask: boolean mask (B1HW), defaults to None :return Tensor: output tensor (BCHW) """ + # no need to mask before convolutions since patches do not spill over if x.shape[2] > 1: x = self.conv3d(x) b, c, d, h, w = x.shape # project Z/depth into channels # return a view when possible (contiguous) - return x.reshape(b, c * d, h, w) - return self.conv2d(x.squeeze(2)) + x = x.reshape(b, c * d, h, w) + else: + x = self.conv2d(x.squeeze(2)) + out_shape = x.shape + if mask is not None: + mask = upsample_mask(mask, x.shape) + x = x[mask] + else: + x = x.flatten(2) + x = x.permute(0, 2, 1) + x = self.norm(x) + x = x.permute(0, 2, 1) + if mask is not None: + out = torch.zeros(out_shape, device=x.device) + out[mask] = x class MaskedMultiscaleEncoder(nn.Module): @@ -261,7 +283,7 @@ def __init__( super().__init__() stem_kernel_size_2d = 4 self.stem = nn.Sequential( - AdaptiveProjection( + MaskedAdaptiveProjection( in_channels, dims[0], kernel_size_2d=stem_kernel_size_2d, kernel_depth=5 ), LayerNorm2d(dims[0]), From 3eec48ed78908eb44edf8cd96991da2b79c8cece Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Tue, 16 Jan 2024 20:23:32 -0800 Subject: [PATCH 10/92] wip: patchify masked features for linear --- tests/unet/test_fcmae.py | 51 +++++++++++++-- viscy/unet/networks/fcmae.py | 122 +++++++++++++++++++++-------------- 2 files changed, 119 insertions(+), 54 deletions(-) diff --git a/tests/unet/test_fcmae.py b/tests/unet/test_fcmae.py index b9a3d389..fc534981 100644 --- a/tests/unet/test_fcmae.py +++ b/tests/unet/test_fcmae.py @@ -4,13 +4,52 @@ MaskedAdaptiveProjection, MaskedConvNeXtV2Block, MaskedConvNeXtV2Stage, - MaskedGlobalResponseNorm, + # MaskedGlobalResponseNorm, MaskedMultiscaleEncoder, + generate_mask, + masked_patchify, + masked_unpatchify, upsample_mask, ) -def test_masked_grn() -> None: +def test_generate_mask(): + w = 64 + s = 16 + m = 0.75 + mask = generate_mask((2, 3, w, w), stride=s, mask_ratio=m) + assert mask.shape == (2, 1, w // s, w // s) + assert mask.dtype == torch.bool + ratio = mask.sum((2, 3)) / mask.numel() * mask.shape[0] + assert torch.allclose(ratio, torch.ones_like(ratio) * m) + + +def test_masked_patchify(): + b, c, h, w = 2, 3, 4, 8 + x = torch.rand(b, c, h, w) + mask_ratio = 0.75 + mask = generate_mask(x.shape, stride=2, mask_ratio=mask_ratio) + mask = upsample_mask(mask, x.shape) + feat = masked_patchify(x, mask) + assert feat.shape == (b, int(h * w * (1 - mask_ratio)), c) + + +def test_unmasked_patchify_roundtrip(): + x = torch.rand(2, 3, 4, 8) + y = masked_unpatchify(masked_patchify(x), out_shape=x.shape) + assert torch.allclose(x, y) + + +def test_masked_patchify_roundtrip(): + x = torch.rand(2, 3, 4, 8) + mask = generate_mask(x.shape, stride=2, mask_ratio=0.5) + mask = upsample_mask(mask, x.shape) + y = masked_unpatchify(masked_patchify(x, mask), out_shape=x.shape, mask=mask) + assert torch.all((y == 0) ^ (x == y)) + assert torch.all((y == 0)[:, 0:1] == mask) + + +def test_masked_grn(): x = torch.rand(2, 3, 4, 5) grn = MaskedGlobalResponseNorm(3, channels_last=False) grn.gamma.data = torch.ones_like(grn.gamma.data) @@ -36,7 +75,7 @@ def test_masked_convnextv2_block() -> None: assert len(masked_out[:, :, mask].unique()) == x.shape[1] -def test_masked_convnextv2_stage() -> None: +def test_masked_convnextv2_stage(): x = torch.rand(2, 3, 16, 16) mask = torch.rand(4, 4) > 0.5 stage = MaskedConvNeXtV2Stage(3, 3, kernel_size=7, stride=2, num_blocks=2) @@ -46,19 +85,21 @@ def test_masked_convnextv2_stage() -> None: assert not torch.allclose(masked_out, out) -def test_adaptive_projection() -> None: +def test_adaptive_projection(): proj = MaskedAdaptiveProjection( 3, 12, kernel_size_2d=4, kernel_depth=5, in_stack_depth=5 ) assert proj(torch.rand(2, 3, 5, 8, 8)).shape == (2, 12, 2, 2) assert proj(torch.rand(2, 3, 1, 12, 16)).shape == (2, 12, 3, 4) + mask = torch.rand(2, 1, 2, 2) > 0.5 + masked_out = proj(torch.rand(2, 3, 5, 16, 16), mask) proj = MaskedAdaptiveProjection( 3, 12, kernel_size_2d=(2, 4), kernel_depth=5, in_stack_depth=15 ) assert proj(torch.rand(2, 3, 15, 6, 8)).shape == (2, 12, 3, 2) -def test_masked_multiscale_encoder() -> None: +def test_masked_multiscale_encoder(): xy_size = 64 dims = [12, 24, 48, 96] x = torch.rand(2, 3, 5, xy_size, xy_size) diff --git a/viscy/unet/networks/fcmae.py b/viscy/unet/networks/fcmae.py index 416c50ad..d852f780 100644 --- a/viscy/unet/networks/fcmae.py +++ b/viscy/unet/networks/fcmae.py @@ -2,18 +2,18 @@ Fully Convolutional Masked Autoencoder as described in ConvNeXt V2 based on the official JAX example in https://github.com/facebookresearch/ConvNeXt-V2/blob/main/TRAINING.md#implementing-fcmae-with-masked-convolution-in-jax -also referring to timm's dense implementation of the encoder in ``timm.models.convnext`` +and timm's dense implementation of the encoder in ``timm.models.convnext`` """ from typing import Callable, Literal, Sequence import torch +import torch.nn.functional as F from timm.layers import ( DropPath, GlobalResponseNormMlp, LayerNorm2d, - LayerNorm, create_conv2d, trunc_normal_, ) @@ -21,7 +21,27 @@ from torch import BoolTensor, Size, Tensor, nn +def generate_mask(target: Size, stride: int, mask_ratio: float) -> BoolTensor: + """ + :param Size target: target shape + :param int stride: total stride + :param float mask_ratio: ratio of the pixels to mask + :return BoolTensor: boolean mask (N, H*W) + """ + m_height = target[-2] // stride + m_width = target[-1] // stride + mask_numel = m_height * m_width + masked_elements = int(mask_numel * mask_ratio) + mask = torch.rand(target[0], mask_numel).argsort(1) < masked_elements + return mask.reshape(target[0], 1, m_height, m_width) + + def upsample_mask(mask: BoolTensor, target: Size) -> BoolTensor: + """ + :param BoolTensor mask: low-resolution boolean mask (B1HW) + :param Size target: target size (BCHW) + :return BoolTensor: upsampled boolean mask (B1HW) + """ if target[-2:] != mask.shape[-2:]: if not all(i % j == 0 for i, j in zip(target, mask.shape)): raise ValueError( @@ -34,43 +54,48 @@ def upsample_mask(mask: BoolTensor, target: Size) -> BoolTensor: return mask -class MaskedGlobalResponseNorm(nn.Module): +def masked_patchify(features: Tensor, mask: BoolTensor | None = None) -> Tensor: """ - Masked Global Response Normalization. - - :param int dim: number of input channels - :param float eps: small value added for numerical stability, - defaults to 1e-6 - :param bool channels_last: BHWC (True) or BCHW (False) dimension ordering, - defaults to False + :param Tensor features: input image features (BCHW) + :param BoolTensor mask: boolean mask (B1HW) + :return Tensor: masked channel-last features (BLC, L = H * W * mask_ratio) """ + if mask is None: + return features.flatten(2).permute(0, 2, 1) + b, c = features.shape[:2] + # (B, C, H, W) -> (B, H, W, C) + features = features.permute(0, 2, 3, 1) + # (B, H, W, C) -> (B * L, C) -> (B, L, C) + features = features[~mask[:, 0]].reshape(b, -1, c) - def __init__( - self, dim: int, eps: float = 1e-6, channels_last: bool = False - ) -> None: - super().__init__() - if channels_last: - self.spatial_dim = (1, 2) - self.channel_dim = -1 - weights_shape = (1, 1, 1, dim) - else: - self.spatial_dim = (2, 3) - self.channel_dim = 1 - weights_shape = (1, dim, 1, 1) - self.gamma = nn.Parameter(torch.zeros(weights_shape)) - self.beta = nn.Parameter(torch.zeros(weights_shape)) - self.eps = eps + # kernel_size = tuple(features.shape[-i] // mask.shape[-i] for i in (2, 1)) + # # (B, C, H, W) -> (B, C * H_patch * Wp, H_grid * Wg) + # features = F.unfold(features, kernel_size=kernel_size, stride=kernel_size) + # patch_size = kernel_size[0] * kernel_size[1] + # # (B, C * Hp * Wp, Hg * Wg) -> (B, C, Hp * Wp, Hg * Wg) -> (B, Hg * Wg, C, Hp * Wp) + # features = features.view(b, c, patch_size, -1).permute(0, 3, 1, 2) + # # (B, 1, Hg, Wg) -> (B, Hg*Wg) + # idx = ~mask.flatten(1) + # # (B, Hg * Wg, C, Hp * Wp) -> (B * L, C, Hp * Wp) -> (B, L, C, Hp * Wp) + # features = features[idx].view(b, -1, c, patch_size) + # # (B, L, C, Hp * Wp) -> (B, L, Hp * Wp, C) -> (B, L * Hp * Wp, C) + # features = features.permute(0, 1, 3, 2).reshape(b, -1, c) + return features - def forward(self, x: Tensor, mask: BoolTensor | None = None) -> Tensor: - """ - :param Tensor x: input tensor, BHWC or BCHW - :param BoolTensor | None mask: boolean mask, defaults to None - :return Tensor: normalized tensor - """ - samples = x if mask is None else x * ~mask - g_x = samples.norm(p=2, dim=self.spatial_dim, keepdim=True) - n_x = g_x / (g_x.mean(dim=self.channel_dim, keepdim=True) + self.eps) - return x + torch.addcmul(self.beta, self.gamma, x * n_x) + +def masked_unpatchify( + features: Tensor, out_shape: Size, mask: BoolTensor | None = None +) -> Tensor: + if mask is None: + # (B, L, C) -> (B, C, L) -> (B, C, H, W) + return features.permute(0, 2, 1).reshape(out_shape) + b, c, w, h = out_shape + out = torch.zeros((b, w, h, c), device=features.device, dtype=features.dtype) + # (B, L, C) -> (B * L, C) + features = features.reshape(-1, c) + out[~mask[:, 0]] = features + # (B, H, W, C) -> (B, C, H, W) + return out.permute(0, 3, 1, 2) class MaskedConvNeXtV2Block(nn.Module): @@ -102,11 +127,13 @@ def __init__( stride=stride, depthwise=True, ) - self.layernorm = LayerNorm2d(out_channels) - self.pwconv1 = nn.Conv2d(out_channels, mlp_ratio * out_channels, kernel_size=1) - self.act = nn.GELU() - self.grn = MaskedGlobalResponseNorm(mlp_ratio * out_channels) - self.pwconv2 = nn.Conv2d(mlp_ratio * out_channels, out_channels, kernel_size=1) + self.layernorm = nn.LayerNorm(out_channels) + mid_channels = mlp_ratio * out_channels + self.mlp = GlobalResponseNormMlp( + in_features=out_channels, + hidden_features=mid_channels, + out_features=out_channels, + ) self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() if in_channels != out_channels or stride > 1: self.shortcut = Downsample(in_channels, out_channels, stride=stride) @@ -125,6 +152,8 @@ def forward(self, x: Tensor, mask: BoolTensor | None = None) -> Tensor: x = self.dwconv(x) if mask is not None: x *= ~mask + out_shape = x.shape + x = masked_project(x, mask) x = self.layernorm(x) x = self.pwconv1(x) x = self.act(x) @@ -268,8 +297,10 @@ def forward(self, x: Tensor, mask: BoolTensor = None) -> Tensor: x = self.norm(x) x = x.permute(0, 2, 1) if mask is not None: - out = torch.zeros(out_shape, device=x.device) + out = torch.zeros(out_shape, device=x.device, dtype=x.dtype) out[mask] = x + return out + return x.reshape(out_shape) class MaskedMultiscaleEncoder(nn.Module): @@ -312,14 +343,7 @@ def forward(self, x: Tensor, mask_ratio: float = 0.0) -> Tensor: :return Tensor: output tensor (BCHW) """ if mask_ratio > 0.0: - noise = torch.rand( - x.shape[0], - 1, - x.shape[-2] // self.total_stride, - x.shape[-1] // self.total_stride, - device=x.device, - ) - mask = noise > mask_ratio + mask = generate_mask(x.shape, self.total_stride, mask_ratio) else: mask = None x = self.stem(x) From 8c54febcf71f8074fe6e7c40198ac1a673cf7678 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Tue, 16 Jan 2024 21:37:39 -0800 Subject: [PATCH 11/92] use mlp from timm --- tests/unet/test_fcmae.py | 51 ++++++------------- viscy/unet/networks/fcmae.py | 95 +++++++++++++++--------------------- 2 files changed, 55 insertions(+), 91 deletions(-) diff --git a/tests/unet/test_fcmae.py b/tests/unet/test_fcmae.py index fc534981..ba0d7a24 100644 --- a/tests/unet/test_fcmae.py +++ b/tests/unet/test_fcmae.py @@ -4,7 +4,6 @@ MaskedAdaptiveProjection, MaskedConvNeXtV2Block, MaskedConvNeXtV2Stage, - # MaskedGlobalResponseNorm, MaskedMultiscaleEncoder, generate_mask, masked_patchify, @@ -30,7 +29,7 @@ def test_masked_patchify(): mask_ratio = 0.75 mask = generate_mask(x.shape, stride=2, mask_ratio=mask_ratio) mask = upsample_mask(mask, x.shape) - feat = masked_patchify(x, mask) + feat = masked_patchify(x, ~mask) assert feat.shape == (b, int(h * w * (1 - mask_ratio)), c) @@ -44,40 +43,28 @@ def test_masked_patchify_roundtrip(): x = torch.rand(2, 3, 4, 8) mask = generate_mask(x.shape, stride=2, mask_ratio=0.5) mask = upsample_mask(mask, x.shape) - y = masked_unpatchify(masked_patchify(x, mask), out_shape=x.shape, mask=mask) + y = masked_unpatchify(masked_patchify(x, ~mask), out_shape=x.shape, unmasked=~mask) assert torch.all((y == 0) ^ (x == y)) assert torch.all((y == 0)[:, 0:1] == mask) -def test_masked_grn(): - x = torch.rand(2, 3, 4, 5) - grn = MaskedGlobalResponseNorm(3, channels_last=False) - grn.gamma.data = torch.ones_like(grn.gamma.data) - mask = torch.ones((1, 1, 4, 5), dtype=torch.bool) - mask[:, :, 2:, 2:] = False - normalized = grn(x) - assert not torch.allclose(normalized, x) - assert torch.allclose(grn(x, mask)[:, :, 2:, 2:], grn(x[:, :, 2:, 2:])) - grn = MaskedGlobalResponseNorm(5, channels_last=True) - grn.gamma.data = torch.ones_like(grn.gamma.data) - mask = torch.ones((1, 3, 4, 1), dtype=torch.bool) - mask[:, 1:, 2:, :] = False - assert torch.allclose(grn(x, mask)[:, 1:, 2:, :], grn(x[:, 1:, 2:, :])) - - def test_masked_convnextv2_block() -> None: x = torch.rand(2, 3, 4, 5) - mask = x[0, 0] > 0.5 + mask = generate_mask(x.shape, stride=1, mask_ratio=0.5) block = MaskedConvNeXtV2Block(3, 3 * 2) - assert len(block(x).unique()) == x.numel() * 2 + unmasked_out = block(x) + assert len(unmasked_out.unique()) == x.numel() * 2 + all_unmasked = torch.ones_like(mask) + empty_masked_out = block(x, all_unmasked) + assert torch.allclose(unmasked_out, empty_masked_out) block = MaskedConvNeXtV2Block(3, 3) masked_out = block(x, mask) - assert len(masked_out[:, :, mask].unique()) == x.shape[1] + assert len(masked_out.unique()) == mask.sum() * x.shape[1] + 1 def test_masked_convnextv2_stage(): x = torch.rand(2, 3, 16, 16) - mask = torch.rand(4, 4) > 0.5 + mask = generate_mask(x.shape, stride=4, mask_ratio=0.5) stage = MaskedConvNeXtV2Stage(3, 3, kernel_size=7, stride=2, num_blocks=2) out = stage(x) assert out.shape == (2, 3, 8, 8) @@ -91,8 +78,9 @@ def test_adaptive_projection(): ) assert proj(torch.rand(2, 3, 5, 8, 8)).shape == (2, 12, 2, 2) assert proj(torch.rand(2, 3, 1, 12, 16)).shape == (2, 12, 3, 4) - mask = torch.rand(2, 1, 2, 2) > 0.5 - masked_out = proj(torch.rand(2, 3, 5, 16, 16), mask) + mask = generate_mask((1, 3, 5, 8, 8), stride=4, mask_ratio=0.6) + masked_out = proj(torch.rand(1, 3, 5, 16, 16), mask) + assert masked_out.shape == (1, 12, 4, 4) proj = MaskedAdaptiveProjection( 3, 12, kernel_size_2d=(2, 4), kernel_depth=5, in_stack_depth=15 ) @@ -104,19 +92,12 @@ def test_masked_multiscale_encoder(): dims = [12, 24, 48, 96] x = torch.rand(2, 3, 5, xy_size, xy_size) encoder = MaskedMultiscaleEncoder(3, dims=dims) - auto_masked_features, mask = encoder(x, mask_ratio=0.5) + auto_masked_features, _ = encoder(x, mask_ratio=0.5) target_shape = list(x.shape) target_shape.pop(1) - pre_masked_features = encoder(x * ~upsample_mask(mask, target_shape).unsqueeze(1)) - assert len(auto_masked_features) == len(pre_masked_features) == 4 - for i, (dim, afeat, pfeat) in enumerate( - zip(dims, auto_masked_features, pre_masked_features) - ): + assert len(auto_masked_features) == 4 + for i, (dim, afeat) in enumerate(zip(dims, auto_masked_features)): assert afeat.shape[0] == x.shape[0] assert afeat.shape[1] == dim stride = 2 * 2 ** (i + 1) assert afeat.shape[2] == afeat.shape[3] == xy_size // stride - assert torch.allclose(afeat, pfeat, rtol=5e-2, atol=5e-2), ( - i, - (afeat - pfeat).abs().max(), - ) diff --git a/viscy/unet/networks/fcmae.py b/viscy/unet/networks/fcmae.py index d852f780..a2e6849e 100644 --- a/viscy/unet/networks/fcmae.py +++ b/viscy/unet/networks/fcmae.py @@ -54,46 +54,38 @@ def upsample_mask(mask: BoolTensor, target: Size) -> BoolTensor: return mask -def masked_patchify(features: Tensor, mask: BoolTensor | None = None) -> Tensor: +def masked_patchify(features: Tensor, unmasked: BoolTensor | None = None) -> Tensor: """ :param Tensor features: input image features (BCHW) - :param BoolTensor mask: boolean mask (B1HW) + :param BoolTensor unmasked: boolean foreground mask (B1HW) :return Tensor: masked channel-last features (BLC, L = H * W * mask_ratio) """ - if mask is None: + if unmasked is None: return features.flatten(2).permute(0, 2, 1) b, c = features.shape[:2] # (B, C, H, W) -> (B, H, W, C) features = features.permute(0, 2, 3, 1) # (B, H, W, C) -> (B * L, C) -> (B, L, C) - features = features[~mask[:, 0]].reshape(b, -1, c) - - # kernel_size = tuple(features.shape[-i] // mask.shape[-i] for i in (2, 1)) - # # (B, C, H, W) -> (B, C * H_patch * Wp, H_grid * Wg) - # features = F.unfold(features, kernel_size=kernel_size, stride=kernel_size) - # patch_size = kernel_size[0] * kernel_size[1] - # # (B, C * Hp * Wp, Hg * Wg) -> (B, C, Hp * Wp, Hg * Wg) -> (B, Hg * Wg, C, Hp * Wp) - # features = features.view(b, c, patch_size, -1).permute(0, 3, 1, 2) - # # (B, 1, Hg, Wg) -> (B, Hg*Wg) - # idx = ~mask.flatten(1) - # # (B, Hg * Wg, C, Hp * Wp) -> (B * L, C, Hp * Wp) -> (B, L, C, Hp * Wp) - # features = features[idx].view(b, -1, c, patch_size) - # # (B, L, C, Hp * Wp) -> (B, L, Hp * Wp, C) -> (B, L * Hp * Wp, C) - # features = features.permute(0, 1, 3, 2).reshape(b, -1, c) + features = features[unmasked[:, 0]].reshape(b, -1, c) return features def masked_unpatchify( - features: Tensor, out_shape: Size, mask: BoolTensor | None = None + features: Tensor, out_shape: Size, unmasked: BoolTensor | None = None ) -> Tensor: - if mask is None: - # (B, L, C) -> (B, C, L) -> (B, C, H, W) + """ + :param Tensor features: dense channel-last features (BLC) + :param Size out_shape: output shape (BCHW) + :param BoolTensor | None unmasked: boolean foreground mask, defaults to None + :return Tensor: masked features (BCHW) + """ + if unmasked is None: return features.permute(0, 2, 1).reshape(out_shape) b, c, w, h = out_shape out = torch.zeros((b, w, h, c), device=features.device, dtype=features.dtype) # (B, L, C) -> (B * L, C) features = features.reshape(-1, c) - out[~mask[:, 0]] = features + out[unmasked[:, 0]] = features # (B, H, W, C) -> (B, C, H, W) return out.permute(0, 3, 1, 2) @@ -140,25 +132,23 @@ def __init__( else: self.shortcut = nn.Identity() - def forward(self, x: Tensor, mask: BoolTensor | None = None) -> Tensor: + def forward(self, x: Tensor, unmasked: BoolTensor | None = None) -> Tensor: """ :param Tensor x: input tensor (BCHW) - :param BoolTensor | None mask: boolean mask, defaults to None + :param BoolTensor | None unmasked: boolean foreground mask, defaults to None :return Tensor: output tensor (BCHW) """ shortcut = self.shortcut(x) - if mask is not None: - x *= ~mask + if unmasked is not None: + x *= unmasked x = self.dwconv(x) - if mask is not None: - x *= ~mask + if unmasked is not None: + x *= unmasked out_shape = x.shape - x = masked_project(x, mask) + x = masked_patchify(x, unmasked=unmasked) x = self.layernorm(x) - x = self.pwconv1(x) - x = self.act(x) - x = self.grn(x, mask) - x = self.pwconv2(x) + x = self.mlp(x.unsqueeze(1)).squeeze(1) + x = masked_unpatchify(x, out_shape=out_shape, unmasked=unmasked) x = self.drop_path(x) + shortcut return x @@ -220,17 +210,17 @@ def __init__( ) in_channels = out_channels - def forward(self, x: Tensor, mask: BoolTensor | None = None) -> Tensor: + def forward(self, x: Tensor, unmasked: BoolTensor | None = None) -> Tensor: """ :param Tensor x: input tensor (BCHW) - :param BoolTensor | None mask: boolean mask, defaults to None + :param BoolTensor | None unmasked: boolean foreground mask, defaults to None :return Tensor: output tensor (BCHW) """ x = self.downsample(x) - if mask is not None: - mask = upsample_mask(mask, x.shape) + if unmasked is not None: + unmasked = upsample_mask(unmasked, x.shape) for block in self.blocks: - x = block(x, mask) + x = block(x, unmasked) return x @@ -272,10 +262,10 @@ def __init__( ) self.norm = nn.LayerNorm(out_channels) - def forward(self, x: Tensor, mask: BoolTensor = None) -> Tensor: + def forward(self, x: Tensor, unmasked: BoolTensor = None) -> Tensor: """ :param Tensor x: input tensor (BCDHW) - :param BoolTensor mask: boolean mask (B1HW), defaults to None + :param BoolTensor unmasked: boolean foreground mask (B1HW), defaults to None :return Tensor: output tensor (BCHW) """ # no need to mask before convolutions since patches do not spill over @@ -288,19 +278,12 @@ def forward(self, x: Tensor, mask: BoolTensor = None) -> Tensor: else: x = self.conv2d(x.squeeze(2)) out_shape = x.shape - if mask is not None: - mask = upsample_mask(mask, x.shape) - x = x[mask] - else: - x = x.flatten(2) - x = x.permute(0, 2, 1) + if unmasked is not None: + unmasked = upsample_mask(unmasked, x.shape) + x = masked_patchify(x, unmasked=unmasked) x = self.norm(x) - x = x.permute(0, 2, 1) - if mask is not None: - out = torch.zeros(out_shape, device=x.device, dtype=x.dtype) - out[mask] = x - return out - return x.reshape(out_shape) + x = masked_unpatchify(x, out_shape=out_shape, unmasked=unmasked) + return x class MaskedMultiscaleEncoder(nn.Module): @@ -343,14 +326,14 @@ def forward(self, x: Tensor, mask_ratio: float = 0.0) -> Tensor: :return Tensor: output tensor (BCHW) """ if mask_ratio > 0.0: - mask = generate_mask(x.shape, self.total_stride, mask_ratio) + unmasked = ~generate_mask(x.shape, self.total_stride, mask_ratio) else: - mask = None + unmasked = None x = self.stem(x) features = [] for stage in self.stages: - x = stage(x, mask) + x = stage(x, unmasked=unmasked) features.append(x) - if mask is not None: - return features, mask + if unmasked is not None: + return features, unmasked return features From 83ecf4a7fcc138fcc9cab7f6b4c1ab6c5ce149a0 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 17 Jan 2024 00:14:58 -0800 Subject: [PATCH 12/92] hack: POC training script for FCMAE --- tests/light/test_engine.py | 10 +++ tests/unet/test_fcmae.py | 8 +++ viscy/light/engine.py | 44 +++++++++++++ viscy/scripts/train_fcmae.py | 66 ++++++++++++++++++++ viscy/unet/networks/fcmae.py | 117 +++++++++++++++++++++++++++++------ 5 files changed, 225 insertions(+), 20 deletions(-) create mode 100644 tests/light/test_engine.py create mode 100644 viscy/scripts/train_fcmae.py diff --git a/tests/light/test_engine.py b/tests/light/test_engine.py new file mode 100644 index 00000000..c6013365 --- /dev/null +++ b/tests/light/test_engine.py @@ -0,0 +1,10 @@ +from viscy.light.engine import FcmaeUNet + + +def test_fcmae_vsunet() -> None: + model = FcmaeUNet( + architecture="fcmae", + model_config=dict(in_channels=3), + train_mask_ratio=0.6, + ) + diff --git a/tests/unet/test_fcmae.py b/tests/unet/test_fcmae.py index ba0d7a24..870f1138 100644 --- a/tests/unet/test_fcmae.py +++ b/tests/unet/test_fcmae.py @@ -1,6 +1,7 @@ import torch from viscy.unet.networks.fcmae import ( + FullyConvolutionalMAE, MaskedAdaptiveProjection, MaskedConvNeXtV2Block, MaskedConvNeXtV2Stage, @@ -101,3 +102,10 @@ def test_masked_multiscale_encoder(): assert afeat.shape[1] == dim stride = 2 * 2 ** (i + 1) assert afeat.shape[2] == afeat.shape[3] == xy_size // stride + + +def test_fcmae(): + x = torch.rand(2, 3, 5, 128, 128) + model = FullyConvolutionalMAE(3) + assert model(x).shape == x.shape + assert model(x, mask_ratio=0.6).shape == x.shape diff --git a/viscy/light/engine.py b/viscy/light/engine.py index 74f14aaa..0262cc7d 100644 --- a/viscy/light/engine.py +++ b/viscy/light/engine.py @@ -27,6 +27,7 @@ from viscy.data.hcs import Sample from viscy.evaluation.evaluation_metrics import mean_average_precision, ms_ssim_25d +from viscy.unet.networks.fcmae import FullyConvolutionalMAE from viscy.unet.networks.Unet2D import Unet2d from viscy.unet.networks.Unet21D import Unet21d from viscy.unet.networks.Unet25D import Unet25d @@ -43,6 +44,7 @@ # same class with out_stack_depth > 1 "2.2D": Unet21d, "2.5D": Unet25d, + "fcmae": FullyConvolutionalMAE, } @@ -367,3 +369,45 @@ def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): self.logger.experiment.add_image( key, grid, self.current_epoch, dataformats="HWC" ) + + +class FcmaeUNet(VSUNet): + def __init__(self, train_mask_ratio: float = 0.0, **kwargs): + super().__init__(**kwargs) + self.train_mask_ratio = train_mask_ratio + + def forward(self, x, mask_ratio: float = 0.0): + return self.model(x, mask_ratio) + + def training_step(self, batch: Sample, batch_idx: int): + source = batch["source"] + target = batch["target"] + pred, mask = self.forward(source, mask_ratio=self.train_mask_ratio) + loss = F.mse_loss(pred, target, reduction="none") + loss = (loss * mask).sum() / mask.sum() + self.log( + "loss/train", + loss, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, + ) + if batch_idx < self.log_batches_per_epoch: + self.training_step_outputs.extend( + self._detach_sample((source, target, pred)) + ) + return loss + + def validation_step(self, batch: Sample, batch_idx: int): + source = batch["source"] + target = batch["target"] + pred, mask = self.forward(source, mask_ratio=self.train_mask_ratio) + loss = F.mse_loss(pred, target, reduction="none") + loss = (loss.mean(2) * mask).sum() / mask.sum() + self.log("loss/validate", loss, sync_dist=True) + if batch_idx < self.log_batches_per_epoch: + self.validation_step_outputs.extend( + self._detach_sample((source, target, pred)) + ) diff --git a/viscy/scripts/train_fcmae.py b/viscy/scripts/train_fcmae.py new file mode 100644 index 00000000..692bef6d --- /dev/null +++ b/viscy/scripts/train_fcmae.py @@ -0,0 +1,66 @@ +# %% +from lightning.pytorch.loggers import TensorBoardLogger +from torch import set_float32_matmul_precision + +from viscy.data.hcs import HCSDataModule +from viscy.light.engine import FcmaeUNet +from viscy.light.trainer import VSTrainer +from viscy.transforms import ( + RandAdjustContrastd, + RandAffined, + RandGaussianNoised, + RandGaussianSmoothd, + RandScaleIntensityd, + RandWeightedCropd, +) + +# %% +model = FcmaeUNet( + architecture="fcmae", + model_config=dict(in_channels=1), + train_mask_ratio=0.6, +) + +# %% +ch = "reconstructed-labelfree" + +data = HCSDataModule( + data_path="/hpc/projects/comp.micro/virtual_staining/datasets/training/raw-and-reconstructed.zarr", + source_channel=ch, + target_channel=ch, + z_window_size=5, + batch_size=64, + num_workers=12, + architecture="3D", + augmentations=[ + RandWeightedCropd(ch, ch, spatial_size=[-1, 512, 512], num_samples=2), + RandAffined( + ch, + prob=0.5, + rotate_range=[3.14, 0.0, 0.0], + shear_range=[0.0, 0.05, 0.05], + scale_range=[0.2, 0.3, 0.3], + ), + RandAdjustContrastd(ch, prob=0.3, gamma=[0.75, 1.5]), + RandScaleIntensityd(ch, prob=0.3, factors=0.5), + RandGaussianNoised(ch, prob=0.5, mean=0.0, std=5.0), + RandGaussianSmoothd( + ch, prob=0.5, sigma_z=[0.25, 1.5], sigma_y=[0.25, 1.5], sigma_x=[0.25, 1.5] + ), + ], +) + + +# %% +set_float32_matmul_precision("high") + +trainer = VSTrainer( + fast_dev_run=False, + max_epochs=50, + logger=TensorBoardLogger( + save_dir="/hpc/mydata/ziwen.liu/fcmae", version="test_0", log_graph=False + ), +) +trainer.fit(model, data) + +# %% diff --git a/viscy/unet/networks/fcmae.py b/viscy/unet/networks/fcmae.py index a2e6849e..ad9d9559 100644 --- a/viscy/unet/networks/fcmae.py +++ b/viscy/unet/networks/fcmae.py @@ -9,19 +9,36 @@ from typing import Callable, Literal, Sequence import torch -import torch.nn.functional as F -from timm.layers import ( +from timm.models.convnext import ( + Downsample, DropPath, GlobalResponseNormMlp, LayerNorm2d, create_conv2d, trunc_normal_, ) -from timm.models.convnext import Downsample from torch import BoolTensor, Size, Tensor, nn +from viscy.unet.networks.Unet21D import PixelToVoxelHead, Unet2dDecoder, UnsqueezeHead -def generate_mask(target: Size, stride: int, mask_ratio: float) -> BoolTensor: + +def _init_weights(module: nn.Module) -> None: + """Initialize weights of the given module.""" + if isinstance(module, nn.Conv2d): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + nn.init.zeros_(module.bias) + elif isinstance(module, nn.LayerNorm): + nn.init.ones_(module.weight) + nn.init.zeros_(module.bias) + + +def generate_mask( + target: Size, stride: int, mask_ratio: float, device: str +) -> BoolTensor: """ :param Size target: target shape :param int stride: total stride @@ -32,7 +49,7 @@ def generate_mask(target: Size, stride: int, mask_ratio: float) -> BoolTensor: m_width = target[-1] // stride mask_numel = m_height * m_width masked_elements = int(mask_numel * mask_ratio) - mask = torch.rand(target[0], mask_numel).argsort(1) < masked_elements + mask = torch.rand(target[0], mask_numel, device=device).argsort(1) < masked_elements return mask.reshape(target[0], 1, m_height, m_width) @@ -293,14 +310,16 @@ def __init__( stage_blocks: Sequence[int] = (3, 3, 9, 3), dims: Sequence[int] = (96, 192, 384, 768), drop_path_rate: float = 0.0, + stem_kernel_size: Sequence[int] = (5, 4, 4), + in_stack_depth: int = 5, ) -> None: super().__init__() - stem_kernel_size_2d = 4 - self.stem = nn.Sequential( - MaskedAdaptiveProjection( - in_channels, dims[0], kernel_size_2d=stem_kernel_size_2d, kernel_depth=5 - ), - LayerNorm2d(dims[0]), + self.stem = MaskedAdaptiveProjection( + in_channels, + dims[0], + kernel_size_2d=stem_kernel_size[1:], + kernel_depth=stem_kernel_size[0], + in_stack_depth=in_stack_depth, ) self.stages = nn.ModuleList() chs = [dims[0], *dims] @@ -316,24 +335,82 @@ def __init__( drop_path_rates=[drop_path_rate] * num_blocks, ) ) - self.total_stride = stem_kernel_size_2d * 2 ** (len(self.stages) - 1) + self.total_stride = stem_kernel_size[1] * 2 ** (len(self.stages) - 1) + self.apply(_init_weights) - def forward(self, x: Tensor, mask_ratio: float = 0.0) -> Tensor: + def forward(self, x: Tensor, mask_ratio: float = 0.0) -> list[Tensor]: """ - :param Tensor x: input tensor (BCHW) + :param Tensor x: input tensor (BCDHW) :param float mask_ratio: ratio of the feature maps to mask, defaults to 0.0 (no masking) - :return Tensor: output tensor (BCHW) + :return list[Tensor]: output tensors (list of BCHW) + :return BoolTensor | None: boolean foreground mask, None if no masking """ if mask_ratio > 0.0: - unmasked = ~generate_mask(x.shape, self.total_stride, mask_ratio) + mask = generate_mask( + x.shape, self.total_stride, mask_ratio, device=x.device + ) + b, c, d, h, w = x.shape + unmasked = ~mask + mask = upsample_mask(mask, (b, d, h, w)) else: - unmasked = None + mask = unmasked = None x = self.stem(x) features = [] for stage in self.stages: x = stage(x, unmasked=unmasked) features.append(x) - if unmasked is not None: - return features, unmasked - return features + return features, mask + + +class FullyConvolutionalMAE(nn.Module): + def __init__( + self, + in_channels: int, + encoder_blocks: Sequence[int] = [3, 3, 9, 3], + dims: Sequence[int] = [96, 192, 384, 768], + encoder_drop_path_rate: float = 0.0, + head_expansion_ratio: int = 4, + stem_kernel_size: Sequence[int] = (5, 4, 4), + in_stack_depth: int = 5, + ) -> None: + super().__init__() + self.encoder = MaskedMultiscaleEncoder( + in_channels=in_channels, + stage_blocks=encoder_blocks, + dims=dims, + drop_path_rate=encoder_drop_path_rate, + stem_kernel_size=stem_kernel_size, + in_stack_depth=in_stack_depth, + ) + decoder_channels = list(dims) + decoder_channels.reverse() + decoder_channels[-1] = ( + (in_stack_depth + 2) * in_channels * 2**2 * head_expansion_ratio + ) + self.decoder = Unet2dDecoder( + decoder_channels, + norm_name="instance", + mode="pixelshuffle", + conv_blocks=1, + strides=[2] * (len(dims) - 1) + [stem_kernel_size[-1]], + upsample_pre_conv=None, + ) + if in_stack_depth == 1: + self.head = UnsqueezeHead() + else: + self.head = PixelToVoxelHead( + in_channels=decoder_channels[-1], + out_channels=in_channels, + out_stack_depth=in_stack_depth, + expansion_ratio=head_expansion_ratio, + pool=True, + ) + self.out_stack_depth = in_stack_depth + self.num_blocks = 6 + + def forward(self, x: Tensor, mask_ratio: float = 0.0) -> Tensor: + x, mask = self.encoder(x, mask_ratio=mask_ratio) + x.reverse() + x = self.decoder(x) + return self.head(x), mask From 2fffc9928ae6499d6a4850b59f7cfdd1f6994fe5 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 17 Jan 2024 10:25:08 -0800 Subject: [PATCH 13/92] fix mask for fitting --- tests/unet/test_fcmae.py | 8 ++++++-- viscy/light/engine.py | 24 ++++++++++++------------ viscy/scripts/train_fcmae.py | 23 +++++++++++++++++------ viscy/unet/networks/fcmae.py | 6 +++--- 4 files changed, 38 insertions(+), 23 deletions(-) diff --git a/tests/unet/test_fcmae.py b/tests/unet/test_fcmae.py index 870f1138..36fb673e 100644 --- a/tests/unet/test_fcmae.py +++ b/tests/unet/test_fcmae.py @@ -107,5 +107,9 @@ def test_masked_multiscale_encoder(): def test_fcmae(): x = torch.rand(2, 3, 5, 128, 128) model = FullyConvolutionalMAE(3) - assert model(x).shape == x.shape - assert model(x, mask_ratio=0.6).shape == x.shape + y, m = model(x) + assert y.shape == x.shape + assert m is None + y, m = model(x, mask_ratio=0.6) + assert y.shape == x.shape + assert m.shape == (2, 1, 128, 128) diff --git a/viscy/light/engine.py b/viscy/light/engine.py index 0262cc7d..85254077 100644 --- a/viscy/light/engine.py +++ b/viscy/light/engine.py @@ -372,19 +372,23 @@ def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): class FcmaeUNet(VSUNet): - def __init__(self, train_mask_ratio: float = 0.0, **kwargs): + def __init__(self, fit_mask_ratio: float = 0.0, **kwargs): super().__init__(**kwargs) - self.train_mask_ratio = train_mask_ratio + self.fit_mask_ratio = fit_mask_ratio def forward(self, x, mask_ratio: float = 0.0): return self.model(x, mask_ratio) - def training_step(self, batch: Sample, batch_idx: int): + def forward_fit(self, batch: Sample): source = batch["source"] target = batch["target"] - pred, mask = self.forward(source, mask_ratio=self.train_mask_ratio) + pred, mask = self.forward(source, mask_ratio=self.fit_mask_ratio) loss = F.mse_loss(pred, target, reduction="none") - loss = (loss * mask).sum() / mask.sum() + loss = (loss.mean(2) * mask).sum() / mask.sum() + return source, target, pred, mask, loss + + def training_step(self, batch: Sample, batch_idx: int): + source, target, pred, mask, loss = self.forward_fit(batch) self.log( "loss/train", loss, @@ -396,18 +400,14 @@ def training_step(self, batch: Sample, batch_idx: int): ) if batch_idx < self.log_batches_per_epoch: self.training_step_outputs.extend( - self._detach_sample((source, target, pred)) + self._detach_sample((source, target * mask.unsqueeze(2), pred)) ) return loss def validation_step(self, batch: Sample, batch_idx: int): - source = batch["source"] - target = batch["target"] - pred, mask = self.forward(source, mask_ratio=self.train_mask_ratio) - loss = F.mse_loss(pred, target, reduction="none") - loss = (loss.mean(2) * mask).sum() / mask.sum() + source, target, pred, mask, loss = self.forward_fit(batch) self.log("loss/validate", loss, sync_dist=True) if batch_idx < self.log_batches_per_epoch: self.validation_step_outputs.extend( - self._detach_sample((source, target, pred)) + self._detach_sample((source, target * mask.unsqueeze(2), pred)) ) diff --git a/viscy/scripts/train_fcmae.py b/viscy/scripts/train_fcmae.py index 692bef6d..0c098454 100644 --- a/viscy/scripts/train_fcmae.py +++ b/viscy/scripts/train_fcmae.py @@ -1,4 +1,5 @@ # %% +from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint from lightning.pytorch.loggers import TensorBoardLogger from torch import set_float32_matmul_precision @@ -17,8 +18,11 @@ # %% model = FcmaeUNet( architecture="fcmae", - model_config=dict(in_channels=1), - train_mask_ratio=0.6, + model_config=dict( + in_channels=1, encoder_blocks=[3, 3, 27, 3], dims=[128, 256, 512, 1024] + ), + fit_mask_ratio=0.6, + schedule="WarmupCosine", ) # %% @@ -32,8 +36,10 @@ batch_size=64, num_workers=12, architecture="3D", + yx_patch_size=[384, 384], + normalize_source=True, augmentations=[ - RandWeightedCropd(ch, ch, spatial_size=[-1, 512, 512], num_samples=2), + RandWeightedCropd(ch, ch, spatial_size=[-1, 768, 768], num_samples=2), RandAffined( ch, prob=0.5, @@ -55,11 +61,16 @@ set_float32_matmul_precision("high") trainer = VSTrainer( - fast_dev_run=False, - max_epochs=50, + fast_dev_run=True, + precision="16-mixed", + max_epochs=100, logger=TensorBoardLogger( - save_dir="/hpc/mydata/ziwen.liu/fcmae", version="test_0", log_graph=False + save_dir="/hpc/mydata/ziwen.liu/fcmae", version="test_1", log_graph=False ), + callbacks=[ + LearningRateMonitor(logging_interval="step"), + ModelCheckpoint(monitor="loss/validate", save_top_k=5, every_n_epochs=1), + ], ) trainer.fit(model, data) diff --git a/viscy/unet/networks/fcmae.py b/viscy/unet/networks/fcmae.py index ad9d9559..7f69cf8f 100644 --- a/viscy/unet/networks/fcmae.py +++ b/viscy/unet/networks/fcmae.py @@ -6,7 +6,7 @@ """ -from typing import Callable, Literal, Sequence +from typing import Sequence import torch from timm.models.convnext import ( @@ -43,7 +43,7 @@ def generate_mask( :param Size target: target shape :param int stride: total stride :param float mask_ratio: ratio of the pixels to mask - :return BoolTensor: boolean mask (N, H*W) + :return BoolTensor: boolean mask (B1HW) """ m_height = target[-2] // stride m_width = target[-1] // stride @@ -352,7 +352,7 @@ def forward(self, x: Tensor, mask_ratio: float = 0.0) -> list[Tensor]: ) b, c, d, h, w = x.shape unmasked = ~mask - mask = upsample_mask(mask, (b, d, h, w)) + mask = upsample_mask(mask, (b, 1, h, w)) else: mask = unmasked = None x = self.stem(x) From 2a598b28a38acc14fa185026936b757b7695acc9 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 17 Jan 2024 10:29:58 -0800 Subject: [PATCH 14/92] remove training script --- viscy/scripts/train_fcmae.py | 77 ------------------------------------ 1 file changed, 77 deletions(-) delete mode 100644 viscy/scripts/train_fcmae.py diff --git a/viscy/scripts/train_fcmae.py b/viscy/scripts/train_fcmae.py deleted file mode 100644 index 0c098454..00000000 --- a/viscy/scripts/train_fcmae.py +++ /dev/null @@ -1,77 +0,0 @@ -# %% -from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint -from lightning.pytorch.loggers import TensorBoardLogger -from torch import set_float32_matmul_precision - -from viscy.data.hcs import HCSDataModule -from viscy.light.engine import FcmaeUNet -from viscy.light.trainer import VSTrainer -from viscy.transforms import ( - RandAdjustContrastd, - RandAffined, - RandGaussianNoised, - RandGaussianSmoothd, - RandScaleIntensityd, - RandWeightedCropd, -) - -# %% -model = FcmaeUNet( - architecture="fcmae", - model_config=dict( - in_channels=1, encoder_blocks=[3, 3, 27, 3], dims=[128, 256, 512, 1024] - ), - fit_mask_ratio=0.6, - schedule="WarmupCosine", -) - -# %% -ch = "reconstructed-labelfree" - -data = HCSDataModule( - data_path="/hpc/projects/comp.micro/virtual_staining/datasets/training/raw-and-reconstructed.zarr", - source_channel=ch, - target_channel=ch, - z_window_size=5, - batch_size=64, - num_workers=12, - architecture="3D", - yx_patch_size=[384, 384], - normalize_source=True, - augmentations=[ - RandWeightedCropd(ch, ch, spatial_size=[-1, 768, 768], num_samples=2), - RandAffined( - ch, - prob=0.5, - rotate_range=[3.14, 0.0, 0.0], - shear_range=[0.0, 0.05, 0.05], - scale_range=[0.2, 0.3, 0.3], - ), - RandAdjustContrastd(ch, prob=0.3, gamma=[0.75, 1.5]), - RandScaleIntensityd(ch, prob=0.3, factors=0.5), - RandGaussianNoised(ch, prob=0.5, mean=0.0, std=5.0), - RandGaussianSmoothd( - ch, prob=0.5, sigma_z=[0.25, 1.5], sigma_y=[0.25, 1.5], sigma_x=[0.25, 1.5] - ), - ], -) - - -# %% -set_float32_matmul_precision("high") - -trainer = VSTrainer( - fast_dev_run=True, - precision="16-mixed", - max_epochs=100, - logger=TensorBoardLogger( - save_dir="/hpc/mydata/ziwen.liu/fcmae", version="test_1", log_graph=False - ), - callbacks=[ - LearningRateMonitor(logging_interval="step"), - ModelCheckpoint(monitor="loss/validate", save_top_k=5, every_n_epochs=1), - ], -) -trainer.fit(model, data) - -# %% From b9b188067221c8b156627cf537c7e2496510ec67 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 17 Jan 2024 14:11:54 -0800 Subject: [PATCH 15/92] default architecture --- viscy/light/engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/viscy/light/engine.py b/viscy/light/engine.py index 85254077..e1f699eb 100644 --- a/viscy/light/engine.py +++ b/viscy/light/engine.py @@ -373,7 +373,7 @@ def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): class FcmaeUNet(VSUNet): def __init__(self, fit_mask_ratio: float = 0.0, **kwargs): - super().__init__(**kwargs) + super().__init__(architecture="fcmae", **kwargs) self.fit_mask_ratio = fit_mask_ratio def forward(self, x, mask_ratio: float = 0.0): From fd7700d0ea70339f467c0c431eca4f0c78201f5b Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Mon, 22 Jan 2024 15:04:03 -0800 Subject: [PATCH 16/92] fine-tuning options --- viscy/unet/networks/fcmae.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/viscy/unet/networks/fcmae.py b/viscy/unet/networks/fcmae.py index 7f69cf8f..0799d8fb 100644 --- a/viscy/unet/networks/fcmae.py +++ b/viscy/unet/networks/fcmae.py @@ -367,12 +367,15 @@ class FullyConvolutionalMAE(nn.Module): def __init__( self, in_channels: int, + out_channels: int, encoder_blocks: Sequence[int] = [3, 3, 9, 3], dims: Sequence[int] = [96, 192, 384, 768], encoder_drop_path_rate: float = 0.0, head_expansion_ratio: int = 4, stem_kernel_size: Sequence[int] = (5, 4, 4), in_stack_depth: int = 5, + decoder_conv_blocks: int = 1, + pretraining: bool = True, ) -> None: super().__init__() self.encoder = MaskedMultiscaleEncoder( @@ -392,7 +395,7 @@ def __init__( decoder_channels, norm_name="instance", mode="pixelshuffle", - conv_blocks=1, + conv_blocks=decoder_conv_blocks, strides=[2] * (len(dims) - 1) + [stem_kernel_size[-1]], upsample_pre_conv=None, ) @@ -401,16 +404,20 @@ def __init__( else: self.head = PixelToVoxelHead( in_channels=decoder_channels[-1], - out_channels=in_channels, + out_channels=out_channels, out_stack_depth=in_stack_depth, expansion_ratio=head_expansion_ratio, pool=True, ) self.out_stack_depth = in_stack_depth self.num_blocks = 6 + self.pretraining = pretraining def forward(self, x: Tensor, mask_ratio: float = 0.0) -> Tensor: x, mask = self.encoder(x, mask_ratio=mask_ratio) x.reverse() x = self.decoder(x) - return self.head(x), mask + x = self.head(x) + if self.pretraining: + return x, mask + return x From 054249f14e7dac4e4040edf53d55232831ef3fe6 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 24 Jan 2024 14:12:19 -0800 Subject: [PATCH 17/92] fix cli for finetuning --- viscy/data/hcs.py | 4 ++-- viscy/light/engine.py | 7 ++++++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index 01191db1..f8bb6a22 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -334,7 +334,7 @@ def __init__( split_ratio: float = 0.8, batch_size: int = 16, num_workers: int = 8, - architecture: Literal["2D", "2.1D", "2.2D", "2.5D", "3D"] = "2.5D", + architecture: Literal["2D", "2.1D", "2.2D", "2.5D", "3D", "fcmae"] = "2.5D", yx_patch_size: tuple[int, int] = (256, 256), augmentations: Optional[list[MapTransform]] = None, caching: bool = False, @@ -348,7 +348,7 @@ def __init__( self.target_channel = _ensure_channel_list(target_channel) self.batch_size = batch_size self.num_workers = num_workers - self.target_2d = False if architecture in ["2.2D", "3D"] else True + self.target_2d = False if architecture in ["2.2D", "3D", "fcmae"] else True self.z_window_size = z_window_size self.split_ratio = split_ratio self.yx_patch_size = yx_patch_size diff --git a/viscy/light/engine.py b/viscy/light/engine.py index e1f699eb..e6a2dfa4 100644 --- a/viscy/light/engine.py +++ b/viscy/light/engine.py @@ -118,11 +118,12 @@ class VSUNet(LightningModule): def __init__( self, - architecture: Literal["2D", "2.1D", "2.2D", "2.5D", "3D"], + architecture: Literal["2D", "2.1D", "2.2D", "2.5D", "3D", "fcmae"], model_config: dict = {}, loss_function: Union[nn.Module, MixedLoss] = None, lr: float = 1e-3, schedule: Literal["WarmupCosine", "Constant"] = "Constant", + freeze_encoder: bool = False, log_batches_per_epoch: int = 8, log_samples_per_batch: int = 1, example_input_yx_shape: Sequence[int] = (256, 256), @@ -162,6 +163,7 @@ def __init__( self.test_cellpose_model_path = test_cellpose_model_path self.test_cellpose_diameter = test_cellpose_diameter self.test_evaluate_cellpose = test_evaluate_cellpose + self.freeze_encoder = freeze_encoder def forward(self, x) -> torch.Tensor: return self.model(x) @@ -331,6 +333,9 @@ def on_predict_start(self): self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) def configure_optimizers(self): + if self.freeze_encoder: + self.model: FullyConvolutionalMAE + self.model.encoder.requires_grad_(False) optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr) if self.schedule == "WarmupCosine": scheduler = WarmupCosineSchedule( From d867e101b3e006ed9dc819722280b2bae8ea5560 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 24 Jan 2024 14:56:10 -0800 Subject: [PATCH 18/92] draft combined data module --- viscy/data/combined.py | 62 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 viscy/data/combined.py diff --git a/viscy/data/combined.py b/viscy/data/combined.py new file mode 100644 index 00000000..6b8dd63c --- /dev/null +++ b/viscy/data/combined.py @@ -0,0 +1,62 @@ +from typing import Literal, Sequence + +from lightning.pytorch import LightningDataModule +from lightning.pytorch.utilities import combined_loader + +_MODES = Literal["min_size", "max_size_cycle", "max_size", "sequential"] + + +class CombinedDataModule(LightningDataModule): + """Wrapper for combining multiple data modules. + For supported modes, see ``lightning.pytorch.utilities.combined_loader``. + + :param Sequence[LightningDataModule] data_modules: data modules to combine + :param str train_mode: mode in training stage, defaults to "max_size_cycle" + :param str val_mode: mode in validation stage, defaults to "sequential" + :param str test_mode: mode in testing stage, defaults to "sequential" + :param str predict_mode: mode in prediction stage, defaults to "sequential" + """ + + def __init__( + self, + data_modules: Sequence[LightningDataModule], + train_mode: _MODES = "max_size_cycle", + val_mode: _MODES = "sequential", + test_mode: _MODES = "sequential", + predict_mode: _MODES = "sequential", + ): + super().__init__() + self.data_modules = data_modules + self.train_mode = train_mode + self.val_mode = val_mode + self.test_mode = test_mode + self.predict_mode = predict_mode + + def prepare_data(self): + for dm in self.data_modules: + dm.prepare_data() + + def setup(self, stage: Literal["fit", "validate", "test", "predict"]): + for dm in self.data_modules: + dm.setup(stage) + + def train_dataloader(self): + return combined_loader( + [dm.train_dataloader() for dm in self.data_modules], mode=self.train_mode + ) + + def val_dataloader(self): + return combined_loader( + [dm.val_dataloader() for dm in self.data_modules], mode=self.val_mode + ) + + def test_dataloader(self): + return combined_loader( + [dm.test_dataloader() for dm in self.data_modules], mode=self.test_mode + ) + + def predict_dataloader(self): + return combined_loader( + [dm.predict_dataloader() for dm in self.data_modules], + mode=self.predict_mode, + ) From b06a30077c83402b71390de160f1ce404ca98240 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 25 Jan 2024 15:52:42 -0800 Subject: [PATCH 19/92] fix import --- viscy/data/combined.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/viscy/data/combined.py b/viscy/data/combined.py index 6b8dd63c..5da700dd 100644 --- a/viscy/data/combined.py +++ b/viscy/data/combined.py @@ -1,7 +1,7 @@ from typing import Literal, Sequence from lightning.pytorch import LightningDataModule -from lightning.pytorch.utilities import combined_loader +from lightning.pytorch.utilities.combined_loader import CombinedLoader _MODES = Literal["min_size", "max_size_cycle", "max_size", "sequential"] @@ -41,22 +41,22 @@ def setup(self, stage: Literal["fit", "validate", "test", "predict"]): dm.setup(stage) def train_dataloader(self): - return combined_loader( + return CombinedLoader( [dm.train_dataloader() for dm in self.data_modules], mode=self.train_mode ) def val_dataloader(self): - return combined_loader( + return CombinedLoader( [dm.val_dataloader() for dm in self.data_modules], mode=self.val_mode ) def test_dataloader(self): - return combined_loader( + return CombinedLoader( [dm.test_dataloader() for dm in self.data_modules], mode=self.test_mode ) def predict_dataloader(self): - return combined_loader( + return CombinedLoader( [dm.predict_dataloader() for dm in self.data_modules], mode=self.predict_mode, ) From 39eafab77f97a046a20c5bc4944bf9f24dc11ca1 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Fri, 26 Jan 2024 21:35:49 -0800 Subject: [PATCH 20/92] manual validation loss reduction --- viscy/light/engine.py | 48 ++++++++++++++++++++++++++++--------------- 1 file changed, 31 insertions(+), 17 deletions(-) diff --git a/viscy/light/engine.py b/viscy/light/engine.py index e6a2dfa4..ebd2fd60 100644 --- a/viscy/light/engine.py +++ b/viscy/light/engine.py @@ -10,7 +10,7 @@ from monai.optimizers import WarmupCosineSchedule from monai.transforms import DivisiblePad from skimage.exposure import rescale_intensity -from torch import nn +from torch import Tensor, nn from torch.nn import functional as F from torch.optim.lr_scheduler import ConstantLR from torchmetrics.functional import ( @@ -165,7 +165,7 @@ def __init__( self.test_evaluate_cellpose = test_evaluate_cellpose self.freeze_encoder = freeze_encoder - def forward(self, x) -> torch.Tensor: + def forward(self, x: Tensor) -> Tensor: return self.model(x) def training_step(self, batch: Sample, batch_idx: int): @@ -230,7 +230,7 @@ def test_step(self, batch: Sample, batch_idx: int): else: self._log_segmentation_metrics(None, None) - def _log_regression_metrics(self, pred: torch.Tensor, target: torch.Tensor): + def _log_regression_metrics(self, pred: Tensor, target: Tensor): # paired image translation metrics self.log_dict( { @@ -253,7 +253,7 @@ def _log_regression_metrics(self, pred: torch.Tensor, target: torch.Tensor): on_epoch=True, ) - def _cellpose_predict(self, pred: torch.Tensor, name: str) -> torch.ShortTensor: + def _cellpose_predict(self, pred: Tensor, name: str) -> torch.ShortTensor: pred_labels_np = self.cellpose_model.eval( pred.cpu().numpy(), channels=[0, 0], diameter=self.test_cellpose_diameter )[0].astype(np.int16) @@ -350,7 +350,7 @@ def configure_optimizers(self): ) return [optimizer], [scheduler] - def _detach_sample(self, imgs: Sequence[torch.Tensor]): + def _detach_sample(self, imgs: Sequence[Tensor]): num_samples = min(imgs[0].shape[0], self.log_samples_per_batch) return [ [np.squeeze(img[i].detach().cpu().numpy().max(axis=1)) for img in imgs] @@ -380,11 +380,12 @@ class FcmaeUNet(VSUNet): def __init__(self, fit_mask_ratio: float = 0.0, **kwargs): super().__init__(architecture="fcmae", **kwargs) self.fit_mask_ratio = fit_mask_ratio + self.validation_losses = [] - def forward(self, x, mask_ratio: float = 0.0): + def forward(self, x: Tensor, mask_ratio: float = 0.0): return self.model(x, mask_ratio) - def forward_fit(self, batch: Sample): + def forward_fit(self, batch: Sample) -> tuple[Tensor]: source = batch["source"] target = batch["target"] pred, mask = self.forward(source, mask_ratio=self.fit_mask_ratio) @@ -392,27 +393,40 @@ def forward_fit(self, batch: Sample): loss = (loss.mean(2) * mask).sum() / mask.sum() return source, target, pred, mask, loss - def training_step(self, batch: Sample, batch_idx: int): - source, target, pred, mask, loss = self.forward_fit(batch) + def training_step(self, batch: Sequence[Sample], batch_idx: int): + losses = [] + batch_size = 0 + for b in batch: + source, target, pred, mask, loss = self.forward_fit(b) + losses.append(loss) + batch_size += source.shape[0] + if batch_idx < self.log_batches_per_epoch: + self.training_step_outputs.extend( + self._detach_sample((source, target * mask.unsqueeze(2), pred)) + ) + loss_step = torch.stack(losses).mean() self.log( "loss/train", - loss, + loss_step, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, + batch_size=batch_size, ) - if batch_idx < self.log_batches_per_epoch: - self.training_step_outputs.extend( - self._detach_sample((source, target * mask.unsqueeze(2), pred)) - ) - return loss + return loss_step - def validation_step(self, batch: Sample, batch_idx: int): + def validation_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): source, target, pred, mask, loss = self.forward_fit(batch) - self.log("loss/validate", loss, sync_dist=True) + self.validation_losses.append(loss.detach()) if batch_idx < self.log_batches_per_epoch: self.validation_step_outputs.extend( self._detach_sample((source, target * mask.unsqueeze(2), pred)) ) + + def on_validation_epoch_end(self): + super().on_validation_epoch_end() + self.log( + "loss/validate", torch.stack(self.validation_losses).mean(), sync_dist=True + ) From 9fbf7a551e0613e0173d7de05ba6f9dfd911d709 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Fri, 2 Feb 2024 09:55:29 -0800 Subject: [PATCH 21/92] update linting new black version has different rules --- pyproject.toml | 17 +++++++++++------ viscy/evaluation/evaluation_metrics.py | 1 + viscy/light/engine.py | 26 +++++++++++++------------- viscy/preprocessing/generate_masks.py | 1 + viscy/unet/networks/fcmae.py | 1 - viscy/utils/image_utils.py | 4 +--- viscy/utils/normalize.py | 1 + 7 files changed, 28 insertions(+), 23 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b60cd534..67142b4f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,15 @@ metrics = [ "ptflops>=0.7", ] visual = ["ipykernel", "graphviz", "torchview"] -dev = ["pytest", "pytest-cov", "hypothesis", "profilehooks", "onnxruntime"] +dev = [ + "pytest", + "pytest-cov", + "hypothesis", + "ruff", + "black", + "profilehooks", + "onnxruntime", +] [project.scripts] viscy = "viscy.cli.cli:main" @@ -39,12 +47,9 @@ viscy = "viscy.cli.cli:main" write_to = "viscy/_version.py" [tool.black] -src = ["viscy"] line-length = 88 [tool.ruff] src = ["viscy", "tests"] -extend-select = ["I001"] - -[tool.ruff.isort] -known-first-party = ["viscy"] +lint.extend-select = ["I001"] +lint.isort.known-first-party = ["viscy"] diff --git a/viscy/evaluation/evaluation_metrics.py b/viscy/evaluation/evaluation_metrics.py index 589370bd..fb83c06b 100644 --- a/viscy/evaluation/evaluation_metrics.py +++ b/viscy/evaluation/evaluation_metrics.py @@ -1,4 +1,5 @@ """Metrics for model evaluation""" + from typing import Sequence, Union from warnings import warn diff --git a/viscy/light/engine.py b/viscy/light/engine.py index ebd2fd60..c6197a9a 100644 --- a/viscy/light/engine.py +++ b/viscy/light/engine.py @@ -272,19 +272,19 @@ def _log_segmentation_metrics( self.log_dict( { # semantic segmentation - "test_metrics/accuracy": accuracy( - pred_binary, target_binary, task="binary" - ) - if compute - else -1, - "test_metrics/dice": dice(pred_binary, target_binary) - if compute - else -1, - "test_metrics/jaccard": jaccard_index( - pred_binary, target_binary, task="binary" - ) - if compute - else -1, + "test_metrics/accuracy": ( + accuracy(pred_binary, target_binary, task="binary") + if compute + else -1 + ), + "test_metrics/dice": ( + dice(pred_binary, target_binary) if compute else -1 + ), + "test_metrics/jaccard": ( + jaccard_index(pred_binary, target_binary, task="binary") + if compute + else -1 + ), "test_metrics/mAP": coco_metrics["map"] if compute else -1, "test_metrics/mAP_50": coco_metrics["map_50"] if compute else -1, "test_metrics/mAP_75": coco_metrics["map_75"] if compute else -1, diff --git a/viscy/preprocessing/generate_masks.py b/viscy/preprocessing/generate_masks.py index f88f8fbe..491bc406 100644 --- a/viscy/preprocessing/generate_masks.py +++ b/viscy/preprocessing/generate_masks.py @@ -1,4 +1,5 @@ """Generate masks from sum of flurophore channels""" + import iohub.ngff as ngff import viscy.utils.aux_utils as aux_utils diff --git a/viscy/unet/networks/fcmae.py b/viscy/unet/networks/fcmae.py index 0799d8fb..97771365 100644 --- a/viscy/unet/networks/fcmae.py +++ b/viscy/unet/networks/fcmae.py @@ -5,7 +5,6 @@ and timm's dense implementation of the encoder in ``timm.models.convnext`` """ - from typing import Sequence import torch diff --git a/viscy/utils/image_utils.py b/viscy/utils/image_utils.py index f9020dc9..a9569116 100644 --- a/viscy/utils/image_utils.py +++ b/viscy/utils/image_utils.py @@ -21,9 +21,7 @@ def im_bit_convert(im, bit=16, norm=False, limit=[]): / (limit[1] - limit[0] + sys.float_info.epsilon) * (2**bit - 1) ) - im = np.clip( - im, 0, 2**bit - 1 - ) # clip the values to avoid wrap-around by np.astype + im = np.clip(im, 0, 2**bit - 1) # clip the values to avoid wrap-around by np.astype if bit == 8: im = im.astype(np.uint8, copy=False) # convert to 8 bit else: diff --git a/viscy/utils/normalize.py b/viscy/utils/normalize.py index 93c11713..73753acb 100644 --- a/viscy/utils/normalize.py +++ b/viscy/utils/normalize.py @@ -1,4 +1,5 @@ """Image normalization related functions""" + import sys import numpy as np From e00f5f3bd0415c1c3b8db924bd600cd2354e4cb7 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Fri, 2 Feb 2024 10:01:36 -0800 Subject: [PATCH 22/92] update development guide --- CONTRIBUTING.md | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 3b40b075..44db5bbc 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -10,7 +10,19 @@ then make an editable installation with all the optional dependencies: pip install -e ".[dev,visual,metrics]" ``` -## Testing +## CI requirements + +Lint with Ruff: + +```sh +ruff check viscy +``` + +Format the code with Black: + +```sh +black viscy +``` Run tests with `pytest`: From 9e345b6c3b59a70a3b7c0bcde8bce184e46c3833 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Tue, 13 Feb 2024 15:27:26 -0800 Subject: [PATCH 23/92] update type hints --- viscy/data/hcs.py | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index f8bb6a22..218ea414 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -23,10 +23,11 @@ MultiSampleTrait, RandAffined, ) +from torch import Tensor from torch.utils.data import DataLoader, Dataset -def _ensure_channel_list(str_or_seq: Union[str, Sequence[str]]): +def _ensure_channel_list(str_or_seq: str | Sequence[str]) -> list[str]: """ Ensure channel argument is a list of strings. @@ -67,9 +68,9 @@ class Sample(TypedDict, total=False): index: tuple[str, int, int] # optional - source: Union[torch.Tensor, Sequence[torch.Tensor]] - target: Union[torch.Tensor, Sequence[torch.Tensor]] - labels: Union[torch.Tensor, Sequence[torch.Tensor]] + source: Union[Tensor, Sequence[Tensor]] + target: Union[Tensor, Sequence[Tensor]] + labels: Union[Tensor, Sequence[Tensor]] def _collate_samples(batch: Sequence[Sample]) -> Sample: @@ -83,7 +84,7 @@ def _collate_samples(batch: Sequence[Sample]) -> Sample: elemment = batch[0] collated = {} for key in elemment.keys(): - data: list[list[torch.Tensor]] = [sample[key] for sample in batch] + data: list[list[Tensor]] = [sample[key] for sample in batch] collated[key] = collate_meta_tensor([im for imgs in data for im in imgs]) return collated @@ -108,13 +109,13 @@ def _stat(self, key: str) -> dict: # FIXME: hard-coded key return self.norm_meta[key]["dataset_statistics"] - def __call__(self, data: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + def __call__(self, data: dict[str, Tensor]) -> dict[str, Tensor]: d = dict(data) for key in self.keys: d[key] = (d[key] - self._stat(key)["median"]) / self._stat(key)["iqr"] return d - def inverse(self, data: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + def inverse(self, data: dict[str, Tensor]) -> dict[str, Tensor]: d = dict(data) for key in self.keys: d[key] = (d[key] * self._stat(key)["iqr"]) + self._stat(key)["median"] @@ -128,7 +129,7 @@ class SlidingWindowDataset(Dataset): :param ChannelMap channels: source and target channel names, e.g. ``{'source': 'Phase', 'target': ['Nuclei', 'Membrane']}`` :param int z_window_size: Z window size of the 2.5D U-Net, 1 for 2D - :param Callable[[dict[str, torch.Tensor]], dict[str, torch.Tensor]] transform: + :param Callable[[dict[str, Tensor]], dict[str, Tensor]] transform: a callable that transforms data, defaults to None """ @@ -137,7 +138,7 @@ def __init__( positions: list[Position], channels: ChannelMap, z_window_size: int, - transform: Callable[[dict[str, torch.Tensor]], dict[str, torch.Tensor]] = None, + transform: Callable[[dict[str, Tensor]], dict[str, Tensor]] = None, ) -> None: super().__init__() self.positions = positions @@ -178,14 +179,14 @@ def _find_window(self, index: int) -> tuple[int, int]: def _read_img_window( self, img: ImageArray, ch_idx: list[str], tz: int - ) -> tuple[tuple[torch.Tensor], tuple[str, int, int]]: + ) -> tuple[tuple[Tensor], tuple[str, int, int]]: """Read image window as tensor. :param ImageArray img: NGFF image array :param list[int] channels: list of channel indices to read, output channel ordering will reflect the sequence :param int tz: window index within the FOV, counted Z-first - :return tuple[torch.Tensor], tuple[str, int, int]: + :return tuple[Tensor], tuple[str, int, int]: tuple of (C=1, Z, Y, X) image tensors, tuple of image name, time index, and Z index """ @@ -203,8 +204,8 @@ def __len__(self) -> int: return self._max_window def _stack_channels( - self, sample_images: list[dict[str, torch.Tensor]], key: str - ) -> torch.Tensor: + self, sample_images: list[dict[str, Tensor]], key: str + ) -> Tensor: """Stack single-channel images into a multi-channel tensor.""" if not isinstance(sample_images, list): return torch.stack([sample_images[ch][0] for ch in self.channels[key]]) @@ -258,7 +259,7 @@ class MaskTestDataset(SlidingWindowDataset): :param ChannelMap channels: source and target channel names, e.g. ``{'source': 'Phase', 'target': ['Nuclei', 'Membrane']}`` :param int z_window_size: Z window size of the 2.5D U-Net, 1 for 2D - :param Callable[[dict[str, torch.Tensor]], dict[str, torch.Tensor]] transform: + :param Callable[[dict[str, Tensor]], dict[str, Tensor]] transform: a callable that transforms data, defaults to None """ @@ -267,7 +268,7 @@ def __init__( positions: list[Position], channels: ChannelMap, z_window_size: int, - transform: Callable[[dict[str, torch.Tensor]], dict[str, torch.Tensor]] = None, + transform: Callable[[dict[str, Tensor]], dict[str, Tensor]] = None, ground_truth_masks: str = None, ) -> None: super().__init__(positions, channels, z_window_size, transform) @@ -527,7 +528,7 @@ def on_before_batch_transfer(self, batch: Sample, dataloader_idx: int) -> Sample if self.trainer: if self.trainer.predicting: predicting = True - if predicting or isinstance(batch, torch.Tensor): + if predicting or isinstance(batch, Tensor): # skipping example input array return batch if self.target_2d: From 96deca5f0020fb9a99388f37d0640e656253145e Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Tue, 20 Feb 2024 14:44:02 -0800 Subject: [PATCH 24/92] bump iohub --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 67142b4f..8f6978de 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ requires-python = ">=3.10" license = { file = "LICENSE" } authors = [{ name = "CZ Biohub SF", email = "compmicro@czbiohub.org" }] dependencies = [ - "iohub==0.1.0rc0", + "iohub==0.1.0", "torch>=2.1.2", "timm>=0.9.5", "tensorboard>=2.13.0", From e06aa574634dd504755dd21e40346071ea7a6b00 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Fri, 23 Feb 2024 21:42:29 -0800 Subject: [PATCH 25/92] draft ctmc v1 dataset --- viscy/data/ctmc_v1.py | 67 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 viscy/data/ctmc_v1.py diff --git a/viscy/data/ctmc_v1.py b/viscy/data/ctmc_v1.py new file mode 100644 index 00000000..8c42f85d --- /dev/null +++ b/viscy/data/ctmc_v1.py @@ -0,0 +1,67 @@ +import logging +from pathlib import Path + +import numpy as np +from iohub.ngff import ImageArray, Plate, Position, TransformationMeta, open_ome_zarr +from lightning.pytorch import LightningDataModule +from monai.transforms import Compose, MapTransform +from torch import Tensor +from torch.utils.data import DataLoader + +from viscy.data.hcs import ChannelMap, SlidingWindowDataset + + +class CTMCv1DataModule(LightningDataModule): + """ + Autoregression data module for the CTMCv1 dataset. + Training and validation datasets are stored in separate HCS OME-Zarr stores. + """ + + def __init__( + self, + train_data_path: str | Path, + val_data_path: str | Path, + train_transforms: list[MapTransform], + val_transforms: list[MapTransform], + batch_size: int = 16, + num_workers: int = 8, + channel_name: str = "DIC", + ) -> None: + super().__init__() + self.train_data_path = Path(train_data_path) + self.val_data_path = Path(val_data_path) + self.train_transforms = train_transforms + self.val_transforms = val_transforms + self.channel_map = ChannelMap(source=channel_name, target=channel_name) + self.batch_size = batch_size + self.num_workers = num_workers + + def setup(self, stage: str) -> None: + if stage != "fit": + raise NotImplementedError("Only fit stage is supported") + train_plate = open_ome_zarr(self.train_data_path, mode="r") + val_plate = open_ome_zarr(self.val_data_path, mode="r") + train_positions = [p for _, p in train_plate.positions()] + val_positions = [p for _, p in val_plate.positions()] + self.train_dataset = SlidingWindowDataset( + train_positions, + channels=self.channel_map, + z_window_size=1, + transform=Compose(self.train_transform), + ) + self.val_dataset = SlidingWindowDataset( + val_positions, + channels=self.channel_map, + z_window_size=1, + transform=Compose(self.val_transform), + ) + + def train_dataloader(self) -> DataLoader: + return DataLoader( + self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers + ) + + def val_dataloader(self) -> DataLoader: + return DataLoader( + self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers + ) From 72de113f8c5a678f4383d374da925517d1cace6b Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Fri, 23 Feb 2024 22:33:41 -0800 Subject: [PATCH 26/92] update tests --- tests/light/test_engine.py | 5 +---- tests/unet/test_fcmae.py | 14 +++++++------- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/tests/light/test_engine.py b/tests/light/test_engine.py index c6013365..9ce182f5 100644 --- a/tests/light/test_engine.py +++ b/tests/light/test_engine.py @@ -3,8 +3,5 @@ def test_fcmae_vsunet() -> None: model = FcmaeUNet( - architecture="fcmae", - model_config=dict(in_channels=3), - train_mask_ratio=0.6, + model_config=dict(in_channels=3, out_channels=1), fit_mask_ratio=0.6 ) - diff --git a/tests/unet/test_fcmae.py b/tests/unet/test_fcmae.py index 36fb673e..4ed441b4 100644 --- a/tests/unet/test_fcmae.py +++ b/tests/unet/test_fcmae.py @@ -17,7 +17,7 @@ def test_generate_mask(): w = 64 s = 16 m = 0.75 - mask = generate_mask((2, 3, w, w), stride=s, mask_ratio=m) + mask = generate_mask((2, 3, w, w), stride=s, mask_ratio=m, device="cpu") assert mask.shape == (2, 1, w // s, w // s) assert mask.dtype == torch.bool ratio = mask.sum((2, 3)) / mask.numel() * mask.shape[0] @@ -28,7 +28,7 @@ def test_masked_patchify(): b, c, h, w = 2, 3, 4, 8 x = torch.rand(b, c, h, w) mask_ratio = 0.75 - mask = generate_mask(x.shape, stride=2, mask_ratio=mask_ratio) + mask = generate_mask(x.shape, stride=2, mask_ratio=mask_ratio, device=x.device) mask = upsample_mask(mask, x.shape) feat = masked_patchify(x, ~mask) assert feat.shape == (b, int(h * w * (1 - mask_ratio)), c) @@ -42,7 +42,7 @@ def test_unmasked_patchify_roundtrip(): def test_masked_patchify_roundtrip(): x = torch.rand(2, 3, 4, 8) - mask = generate_mask(x.shape, stride=2, mask_ratio=0.5) + mask = generate_mask(x.shape, stride=2, mask_ratio=0.5, device=x.device) mask = upsample_mask(mask, x.shape) y = masked_unpatchify(masked_patchify(x, ~mask), out_shape=x.shape, unmasked=~mask) assert torch.all((y == 0) ^ (x == y)) @@ -51,7 +51,7 @@ def test_masked_patchify_roundtrip(): def test_masked_convnextv2_block() -> None: x = torch.rand(2, 3, 4, 5) - mask = generate_mask(x.shape, stride=1, mask_ratio=0.5) + mask = generate_mask(x.shape, stride=1, mask_ratio=0.5, device=x.device) block = MaskedConvNeXtV2Block(3, 3 * 2) unmasked_out = block(x) assert len(unmasked_out.unique()) == x.numel() * 2 @@ -65,7 +65,7 @@ def test_masked_convnextv2_block() -> None: def test_masked_convnextv2_stage(): x = torch.rand(2, 3, 16, 16) - mask = generate_mask(x.shape, stride=4, mask_ratio=0.5) + mask = generate_mask(x.shape, stride=4, mask_ratio=0.5, device=x.device) stage = MaskedConvNeXtV2Stage(3, 3, kernel_size=7, stride=2, num_blocks=2) out = stage(x) assert out.shape == (2, 3, 8, 8) @@ -79,7 +79,7 @@ def test_adaptive_projection(): ) assert proj(torch.rand(2, 3, 5, 8, 8)).shape == (2, 12, 2, 2) assert proj(torch.rand(2, 3, 1, 12, 16)).shape == (2, 12, 3, 4) - mask = generate_mask((1, 3, 5, 8, 8), stride=4, mask_ratio=0.6) + mask = generate_mask((1, 3, 5, 8, 8), stride=4, mask_ratio=0.6, device="cpu") masked_out = proj(torch.rand(1, 3, 5, 16, 16), mask) assert masked_out.shape == (1, 12, 4, 4) proj = MaskedAdaptiveProjection( @@ -106,7 +106,7 @@ def test_masked_multiscale_encoder(): def test_fcmae(): x = torch.rand(2, 3, 5, 128, 128) - model = FullyConvolutionalMAE(3) + model = FullyConvolutionalMAE(3, 3) y, m = model(x) assert y.shape == x.shape assert m is None From 13d0aa0574665d0da4f17407033fc964da00e602 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Fri, 23 Feb 2024 23:47:56 -0800 Subject: [PATCH 27/92] move test_data --- tests/data/__init__.py | 0 tests/{light => data}/test_data.py | 0 viscy/data/ctmc_v1.py | 24 ++++++++++++++++-------- 3 files changed, 16 insertions(+), 8 deletions(-) create mode 100644 tests/data/__init__.py rename tests/{light => data}/test_data.py (100%) diff --git a/tests/data/__init__.py b/tests/data/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/light/test_data.py b/tests/data/test_data.py similarity index 100% rename from tests/light/test_data.py rename to tests/data/test_data.py diff --git a/viscy/data/ctmc_v1.py b/viscy/data/ctmc_v1.py index 8c42f85d..df1d3223 100644 --- a/viscy/data/ctmc_v1.py +++ b/viscy/data/ctmc_v1.py @@ -1,11 +1,8 @@ -import logging from pathlib import Path -import numpy as np -from iohub.ngff import ImageArray, Plate, Position, TransformationMeta, open_ome_zarr +from iohub.ngff import open_ome_zarr from lightning.pytorch import LightningDataModule from monai.transforms import Compose, MapTransform -from torch import Tensor from torch.utils.data import DataLoader from viscy.data.hcs import ChannelMap, SlidingWindowDataset @@ -39,8 +36,11 @@ def __init__( def setup(self, stage: str) -> None: if stage != "fit": raise NotImplementedError("Only fit stage is supported") - train_plate = open_ome_zarr(self.train_data_path, mode="r") - val_plate = open_ome_zarr(self.val_data_path, mode="r") + self._setup_fit() + + def _setup_fit(self) -> None: + train_plate = open_ome_zarr(self.train_data_path) + val_plate = open_ome_zarr(self.val_data_path) train_positions = [p for _, p in train_plate.positions()] val_positions = [p for _, p in val_plate.positions()] self.train_dataset = SlidingWindowDataset( @@ -58,10 +58,18 @@ def setup(self, stage: str) -> None: def train_dataloader(self) -> DataLoader: return DataLoader( - self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + persistent_workers=bool(self.num_workers), + shuffle=True, ) def val_dataloader(self) -> DataLoader: return DataLoader( - self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + persistent_workers=bool(self.num_workers), + shuffle=False, ) From 78aed971aa2bea34e89c0db20bde883ebb98a06e Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Fri, 23 Feb 2024 23:53:15 -0800 Subject: [PATCH 28/92] remove path conversion --- viscy/data/ctmc_v1.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/viscy/data/ctmc_v1.py b/viscy/data/ctmc_v1.py index df1d3223..0d65a36a 100644 --- a/viscy/data/ctmc_v1.py +++ b/viscy/data/ctmc_v1.py @@ -25,8 +25,8 @@ def __init__( channel_name: str = "DIC", ) -> None: super().__init__() - self.train_data_path = Path(train_data_path) - self.val_data_path = Path(val_data_path) + self.train_data_path = train_data_path + self.val_data_path = val_data_path self.train_transforms = train_transforms self.val_transforms = val_transforms self.channel_map = ChannelMap(source=channel_name, target=channel_name) From 74e7db3633aed6d04c0995aee6f1db70abb51045 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 26 Feb 2024 09:31:49 -0800 Subject: [PATCH 29/92] configurable normalizations (#68) * inital commit adding the normalization. * adding dataset_statistics to each fov to facilitate the configurable augmentations * fix indentation * ruff * test preprocessing * remove redundant field * cleanup --------- Co-authored-by: Ziwen Liu --- examples/configs/fit_example.yml | 13 +++ tests/conftest.py | 2 + tests/data/test_data.py | 33 ++---- viscy/data/hcs.py | 146 +++++++-------------------- viscy/data/typing.py | 22 ++++ viscy/preprocessing/preprocessing.md | 16 ++- viscy/transforms.py | 39 +++++++ 7 files changed, 139 insertions(+), 132 deletions(-) create mode 100644 viscy/data/typing.py diff --git a/examples/configs/fit_example.yml b/examples/configs/fit_example.yml index 017c57f0..fd17071e 100644 --- a/examples/configs/fit_example.yml +++ b/examples/configs/fit_example.yml @@ -37,6 +37,19 @@ data: batch_size: 32 num_workers: 16 yx_patch_size: [256, 256] + normalizations: + - class_path: viscy.transforms.NormalizeSampled + init_args: + keys: [source] + level: 'fov_statistics', + subtrahend: 'mean' + divisor: 'std' + - class_path: viscy.transforms.NormalizeSampled + init_args: + keys: [target_1] + level: 'fov_statistics', + subtrahend: 'median' + divisor: 'iqr' augmentations: - class_path: viscy.transforms.RandWeightedCropd init_args: diff --git a/tests/conftest.py b/tests/conftest.py index 9ad6630c..198e51ac 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -36,6 +36,8 @@ def preprocessed_hcs_dataset(tmp_path_factory: TempPathFactory) -> Path: norm_meta = {channel: {"dataset_statistics": expected} for channel in channel_names} with open_ome_zarr(dataset_path, mode="r+") as dataset: dataset.zattrs["normalization"] = norm_meta + for _, fov in dataset.positions(): + fov.zattrs["normalization"] = norm_meta return dataset_path diff --git a/tests/data/test_data.py b/tests/data/test_data.py index 153f175f..fb3d8620 100644 --- a/tests/data/test_data.py +++ b/tests/data/test_data.py @@ -18,6 +18,16 @@ def test_preprocess(small_hcs_dataset: Path, default_channels: bool): channel_names = dataset.channel_names trainer = VSTrainer(accelerator="cpu") trainer.preprocess(data_path, channel_names=channel_names, num_workers=2) + with open_ome_zarr(data_path) as dataset: + channel_names = dataset.channel_names + for channel in channel_names: + assert "dataset_statistics" in dataset.zattrs["normalization"][channel] + for _, fov in dataset.positions(): + norm_metadata = fov.zattrs["normalization"] + for channel in channel_names: + assert channel in norm_metadata + assert "dataset_statistics" in norm_metadata[channel] + assert "fov_statistics" in norm_metadata[channel] def test_datamodule_setup_predict(preprocessed_hcs_dataset): @@ -45,26 +55,3 @@ def test_datamodule_setup_predict(preprocessed_hcs_dataset): img.height, img.width, ) - - -def test_datamodule_predict_scales(preprocessed_hcs_dataset): - data_path = preprocessed_hcs_dataset - with open_ome_zarr(data_path) as dataset: - channel_names = dataset.channel_names - - def get_normalized_stack(predict_scale_source): - factor = 1 if predict_scale_source is None else predict_scale_source - dm = HCSDataModule( - data_path=data_path, - source_channel=channel_names[:2], - target_channel=channel_names[2:], - z_window_size=5, - batch_size=2, - num_workers=0, - predict_scale_source=predict_scale_source, - normalize_source=True, - ) - dm.setup(stage="predict") - return dm.predict_dataset[0]["source"] / factor - - assert torch.allclose(get_normalized_stack(None), get_normalized_stack(2)) diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index 218ea414..bb0be09c 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -5,7 +5,7 @@ import tempfile from glob import glob from pathlib import Path -from typing import Callable, Iterable, Literal, Optional, Sequence, TypedDict, Union +from typing import Callable, Literal, Optional, Sequence, Union import numpy as np import torch @@ -18,7 +18,6 @@ from monai.transforms import ( CenterSpatialCropd, Compose, - InvertibleTransform, MapTransform, MultiSampleTrait, RandAffined, @@ -26,6 +25,8 @@ from torch import Tensor from torch.utils.data import DataLoader, Dataset +from viscy.data.typing import ChannelMap, Sample + def _ensure_channel_list(str_or_seq: str | Sequence[str]) -> list[str]: """ @@ -55,24 +56,6 @@ def _search_int_in_str(pattern: str, file_name: str) -> str: raise ValueError(f"Cannot find pattern {pattern} in {file_name}.") -class ChannelMap(TypedDict, total=False): - """Source and target channel names.""" - - source: Union[str, Sequence[str]] - # optional - target: Union[str, Sequence[str]] - - -class Sample(TypedDict, total=False): - """Image sample type for mini-batches.""" - - index: tuple[str, int, int] - # optional - source: Union[Tensor, Sequence[Tensor]] - target: Union[Tensor, Sequence[Tensor]] - labels: Union[Tensor, Sequence[Tensor]] - - def _collate_samples(batch: Sequence[Sample]) -> Sample: """Collate samples into a batch sample. @@ -89,38 +72,6 @@ def _collate_samples(batch: Sequence[Sample]) -> Sample: return collated -class NormalizeSampled(MapTransform, InvertibleTransform): - """Dictionary transform to only normalize target (fluorescence) channel. - - :param Union[str, Iterable[str]] keys: keys to normalize - :param dict[str, dict] norm_meta: Plate normalization metadata - written in preprocessing - """ - - def __init__( - self, keys: Union[str, Iterable[str]], norm_meta: dict[str, dict] - ) -> None: - if set(keys) > set(norm_meta.keys()): - raise KeyError(f"{keys} is not a subset of {norm_meta.keys()}") - super().__init__(keys, allow_missing_keys=False) - self.norm_meta = norm_meta - - def _stat(self, key: str) -> dict: - # FIXME: hard-coded key - return self.norm_meta[key]["dataset_statistics"] - - def __call__(self, data: dict[str, Tensor]) -> dict[str, Tensor]: - d = dict(data) - for key in self.keys: - d[key] = (d[key] - self._stat(key)["median"]) / self._stat(key)["iqr"] - return d - - def inverse(self, data: dict[str, Tensor]) -> dict[str, Tensor]: - d = dict(data) - for key in self.keys: - d[key] = (d[key] * self._stat(key)["iqr"]) + self._stat(key)["median"] - - class SlidingWindowDataset(Dataset): """Torch dataset where each element is a window of (C, Z, Y, X) where C=2 (source and target) and Z is ``z_window_size``. @@ -161,6 +112,7 @@ def _get_windows(self) -> None: w = 0 self.window_keys = [] self.window_arrays = [] + self.window_norm_meta = [] for fov in self.positions: img_arr = fov["0"] ts = img_arr.frames @@ -168,6 +120,7 @@ def _get_windows(self) -> None: w += ts * zs self.window_keys.append(w) self.window_arrays.append(img_arr) + self.window_norm_meta.append(fov.zattrs["normalization"]) self._max_window = w def _find_window(self, index: int) -> tuple[int, int]: @@ -175,7 +128,8 @@ def _find_window(self, index: int) -> tuple[int, int]: window_idx = sorted(self.window_keys + [index + 1]).index(index + 1) w = self.window_keys[window_idx] tz = index - self.window_keys[window_idx - 1] if window_idx > 0 else index - return self.window_arrays[self.window_keys.index(w)], tz + norm_meta = self.window_norm_meta[self.window_keys.index(w)] + return (self.window_arrays[self.window_keys.index(w)], tz, norm_meta) def _read_img_window( self, img: ImageArray, ch_idx: list[str], tz: int @@ -216,7 +170,7 @@ def _stack_channels( ] def __getitem__(self, index: int) -> Sample: - img, tz = self._find_window(index) + img, tz, norm_meta = self._find_window(index) ch_names = self.channels["source"].copy() ch_idx = self.source_ch_idx.copy() if self.target_ch_idx is not None: @@ -229,6 +183,7 @@ def __getitem__(self, index: int) -> Sample: # since adding a reference to a tensor does not copy # maybe write a weight map in preprocessing to use more information? sample_images["weight"] = sample_images[self.channels["target"][0]] + sample_images["norm_meta"] = norm_meta if self.transform: sample_images = self.transform(sample_images) # if isinstance(sample_images, list): @@ -238,6 +193,7 @@ def __getitem__(self, index: int) -> Sample: sample = { "index": sample_index, "source": self._stack_channels(sample_images, "source"), + "norm_meta": norm_meta, } if self.target_ch_idx is not None: sample["target"] = self._stack_channels(sample_images, "target") @@ -312,18 +268,16 @@ class HCSDataModule(LightningDataModule): defaults to "2.5D" :param tuple[int, int] yx_patch_size: patch size in (Y, X), defaults to (256, 256) + :param Optional[list[MapTransform]] normalizations: MONAI dictionary transforms + applied to selected channels, defaults to None (no normalization) :param Optional[list[MapTransform]] augmentations: MONAI dictionary transforms applied to the training set, defaults to None (no augmentation) :param bool caching: whether to decompress all the images and cache the result, will store in ``/tmp/$SLURM_JOB_ID/`` if available, defaults to False - :param bool normalize_source: whether to normalize the source channel, - defaults to False :param Optional[Path] ground_truth_masks: path to the ground truth masks, used in the test stage to compute segmentation metrics, defaults to None - :param Optional[float] predict_scale_source: scale the source channel intensity, - defaults to None (no scaling) """ def __init__( @@ -337,11 +291,10 @@ def __init__( num_workers: int = 8, architecture: Literal["2D", "2.1D", "2.2D", "2.5D", "3D", "fcmae"] = "2.5D", yx_patch_size: tuple[int, int] = (256, 256), + normalizations: Optional[list[MapTransform]] = None, augmentations: Optional[list[MapTransform]] = None, caching: bool = False, - normalize_source: bool = False, ground_truth_masks: Optional[Path] = None, - predict_scale_source: Optional[float] = None, ): super().__init__() self.data_path = Path(data_path) @@ -353,21 +306,11 @@ def __init__( self.z_window_size = z_window_size self.split_ratio = split_ratio self.yx_patch_size = yx_patch_size + self.normalizations = normalizations self.augmentations = augmentations self.caching = caching - self.normalize_source = normalize_source self.ground_truth_masks = ground_truth_masks self.tmp_zarr = None - if predict_scale_source is not None: - if not normalize_source: - raise ValueError( - "Intensity scaling must be applied to normalized source channels." - ) - if predict_scale_source <= 0: - raise ValueError( - f"Intensity scaling {predict_scale_source} should be positive." - ) - self.predict_scale_source = predict_scale_source def prepare_data(self): if not self.caching: @@ -419,31 +362,22 @@ def setup(self, stage: Literal["fit", "validate", "test", "predict"]): else: raise NotImplementedError(f"{stage} stage") - def _setup_eval(self, dataset_settings: dict) -> tuple[Plate, MapTransform]: - """Setup stages where the target is available (evaluating performance).""" - dataset_settings["channels"]["target"] = self.target_channel - data_path = self.tmp_zarr if self.tmp_zarr else self.data_path - plate = open_ome_zarr(data_path, mode="r") - # disable metadata tracking in MONAI for performance - set_track_meta(False) - # define training stage transforms - norm_keys = self.target_channel.copy() - if self.normalize_source: - norm_keys += self.source_channel - normalize_transform = NormalizeSampled( - norm_keys, - plate.zattrs["normalization"], - ) - return plate, normalize_transform - def _setup_fit(self, dataset_settings: dict): """Set up the training and validation datasets.""" - plate, normalize_transform = self._setup_eval(dataset_settings) + # Setup the transformations + # TODO: These have a fixed order for now... (normalization->augmentation->fit_transform) fit_transform = self._fit_transform() train_transform = Compose( - [normalize_transform] + self._train_transform() + fit_transform + self.normalizations + self._train_transform() + fit_transform ) - val_transform = Compose([normalize_transform] + fit_transform) + val_transform = Compose(self.normalizations + fit_transform) + + dataset_settings["channels"]["target"] = self.target_channel + data_path = self.tmp_zarr if self.tmp_zarr else self.data_path + plate = open_ome_zarr(data_path, mode="r") + + # disable metadata tracking in MONAI for performance + set_track_meta(False) # shuffle positions, randomness is handled globally positions = [pos for _, pos in plate.positions()] shuffled_indices = torch.randperm(len(positions)) @@ -465,26 +399,31 @@ def _setup_fit(self, dataset_settings: dict): **train_dataset_settings, ) self.val_dataset = SlidingWindowDataset( - positions[num_train_fovs:], transform=val_transform, **dataset_settings + positions[num_train_fovs:], + transform=val_transform, + **dataset_settings, ) def _setup_test(self, dataset_settings: dict): """Set up the test stage.""" if self.batch_size != 1: logging.warning(f"Ignoring batch size {self.batch_size} in test stage.") - plate, normalize_transform = self._setup_eval(dataset_settings) + + dataset_settings["channels"]["target"] = self.target_channel + data_path = self.tmp_zarr if self.tmp_zarr else self.data_path + plate = open_ome_zarr(data_path, mode="r") if self.ground_truth_masks: self.test_dataset = MaskTestDataset( [p for _, p in plate.positions()], - transform=normalize_transform, + transform=self.normalizations, ground_truth_masks=self.ground_truth_masks, - **dataset_settings, + norm_meta=plate.zattrs["normalization"] ** dataset_settings, ) else: self.test_dataset = SlidingWindowDataset( [p for _, p in plate.positions()], - transform=normalize_transform, - **dataset_settings, + transform=self.normalizations, + norm_meta=plate.zattrs["normalization"] ** dataset_settings, ) def _setup_predict(self, dataset_settings: dict): @@ -506,16 +445,9 @@ def _setup_predict(self, dataset_settings: dict): positions = [plate[fov_name]] elif isinstance(dataset, Plate): positions = [p for _, p in dataset.positions()] - norm_meta = dataset.zattrs["normalization"].copy() - if self.predict_scale_source is not None: - for ch in self.source_channel: - # FIXME: hard-coded key - norm_meta[ch]["dataset_statistics"]["iqr"] /= self.predict_scale_source - predict_transform = ( - NormalizeSampled(self.source_channel, norm_meta) - if self.normalize_source - else None - ) + + predict_transform = self.normalizations + self.predict_dataset = SlidingWindowDataset( positions=positions, transform=predict_transform, diff --git a/viscy/data/typing.py b/viscy/data/typing.py new file mode 100644 index 00000000..c6b7c32f --- /dev/null +++ b/viscy/data/typing.py @@ -0,0 +1,22 @@ +from typing import Sequence, TypedDict, Union + +from torch import Tensor + + +class Sample(TypedDict, total=False): + """Image sample type for mini-batches.""" + + index: tuple[str, int, int] + # optional + source: Union[Tensor, Sequence[Tensor]] + target: Union[Tensor, Sequence[Tensor]] + labels: Union[Tensor, Sequence[Tensor]] + norm_meta: dict[str, dict] + + +class ChannelMap(TypedDict, total=False): + """Source and target channel names.""" + + source: Union[str, Sequence[str]] + # optional + target: Union[str, Sequence[str]] diff --git a/viscy/preprocessing/preprocessing.md b/viscy/preprocessing/preprocessing.md index 76d508c5..809b456f 100644 --- a/viscy/preprocessing/preprocessing.md +++ b/viscy/preprocessing/preprocessing.md @@ -87,11 +87,17 @@ The statistics are added as dictionaries into the .zattrs file. An example of pl } ``` -FOV level statistics added to every position: +FOV level statistics added to every position as well as the dataset_statistics to read dataset statistics: ```json "normalization": { "Deconvolved-Nuc": { + "dataset_statistics": { + "iqr": 149.7620086669922, + "mean": 262.2070617675781, + "median": 65.5246353149414, + "std": 890.0471801757812 + }, "fov_statistics": { "iqr": 450.4745788574219, "mean": 486.3854064941406, @@ -99,7 +105,13 @@ FOV level statistics added to every position: "std": 976.02392578125 } }, - "Phase3D": { + "Phase3D": { + "dataset_statistics": { + "iqr": 0.0011349652777425945, + "mean": -1.9603044165705796e-06, + "median": 3.388232289580628e-05, + "std": 0.005480962339788675 + }, "fov_statistics": { "iqr": 0.006403466919437051, "mean": 0.0010083537781611085, diff --git a/viscy/transforms.py b/viscy/transforms.py index cb3d2622..7ce192af 100644 --- a/viscy/transforms.py +++ b/viscy/transforms.py @@ -3,6 +3,7 @@ from typing import Sequence, Union from monai.transforms import ( + MapTransform, RandAdjustContrastd, RandAffined, RandGaussianNoised, @@ -10,6 +11,9 @@ RandScaleIntensityd, RandWeightedCropd, ) +from typing_extensions import Iterable, Literal + +from viscy.data.typing import Sample class RandWeightedCropd(RandWeightedCropd): @@ -118,3 +122,38 @@ def __init__( sigma_z=sigma_z, **kwargs, ) + + +class NormalizeSampled(MapTransform): + """ + Normalize the sample + :param Union[str, Iterable[str]] keys: keys to normalize + :param str fov: fov path with respect to Plate + :param str subtrahend: subtrahend for normalization, defaults to "mean" + :param str divisor: divisor for normalization, defaults to "std" + """ + + def __init__( + self, + keys: Union[str, Iterable[str]], + level: Literal["fov_statistics", "dataset_statistics"], + subtrahend="mean", + divisor="std", + ) -> None: + super().__init__(keys, allow_missing_keys=False) + self.subtrahend = subtrahend + self.divisor = divisor + self.level = level + + # TODO: need to implement the case where the preprocessing already exists + def __call__(self, sample: Sample) -> Sample: + for key in self.keys: + if key in self.keys: + level_meta = sample["norm_meta"][key][self.level] + subtrahend_val = level_meta[self.subtrahend] + divisor_val = level_meta[self.divisor] + 1e-8 # avoid div by zero + sample[key] = (sample[key] - subtrahend_val) / divisor_val + return sample + + def _normalize(): + NotImplementedError("_normalization() not implemented") From 9b3b032100b480f9340b8aa8b124e8116f232820 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Tue, 27 Feb 2024 17:33:53 -0800 Subject: [PATCH 30/92] fix ctmc dataloading --- viscy/data/ctmc_v1.py | 6 +++--- viscy/data/hcs.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/viscy/data/ctmc_v1.py b/viscy/data/ctmc_v1.py index 0d65a36a..47844d68 100644 --- a/viscy/data/ctmc_v1.py +++ b/viscy/data/ctmc_v1.py @@ -29,7 +29,7 @@ def __init__( self.val_data_path = val_data_path self.train_transforms = train_transforms self.val_transforms = val_transforms - self.channel_map = ChannelMap(source=channel_name, target=channel_name) + self.channel_map = ChannelMap(source=[channel_name], target=[channel_name]) self.batch_size = batch_size self.num_workers = num_workers @@ -47,13 +47,13 @@ def _setup_fit(self) -> None: train_positions, channels=self.channel_map, z_window_size=1, - transform=Compose(self.train_transform), + transform=Compose(self.train_transforms), ) self.val_dataset = SlidingWindowDataset( val_positions, channels=self.channel_map, z_window_size=1, - transform=Compose(self.val_transform), + transform=Compose(self.val_transforms), ) def train_dataloader(self) -> DataLoader: diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index bb0be09c..2c7397c0 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -120,7 +120,7 @@ def _get_windows(self) -> None: w += ts * zs self.window_keys.append(w) self.window_arrays.append(img_arr) - self.window_norm_meta.append(fov.zattrs["normalization"]) + self.window_norm_meta.append(fov.zattrs.get("normalization", 0)) self._max_window = w def _find_window(self, index: int) -> tuple[int, int]: From a3569364ac18897858471c78d4a4c6f3381c6d1c Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Tue, 27 Feb 2024 17:34:30 -0800 Subject: [PATCH 31/92] add example ctmc v1 loading script --- viscy/scripts/load_ctmc_v1.py | 68 +++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 viscy/scripts/load_ctmc_v1.py diff --git a/viscy/scripts/load_ctmc_v1.py b/viscy/scripts/load_ctmc_v1.py new file mode 100644 index 00000000..e5c19094 --- /dev/null +++ b/viscy/scripts/load_ctmc_v1.py @@ -0,0 +1,68 @@ +# %% +from pathlib import Path + +import matplotlib.pyplot as plt +from monai.transforms import ( + CenterSpatialCropd, + NormalizeIntensityd, + RandAffined, + RandScaleIntensityd, +) +from tqdm import tqdm + +from viscy.data.ctmc_v1 import CTMCv1DataModule + +# %% +data_path = Path("") + +normalize_transform = NormalizeIntensityd(keys=["DIC"], channel_wise=True) +crop_transform = CenterSpatialCropd(keys=["DIC"], roi_size=[1, 256, 256]) + +data = CTMCv1DataModule( + train_data_path=data_path / "CTMCV1_test.zarr", + val_data_path=data_path / "CTMCV1_train.zarr", + train_transforms=[ + normalize_transform, + RandAffined( + keys=["DIC"], + rotate_range=[3.14, 0.0, 0.0], + shear_range=[0.0, 0.3, 0.3], + scale_range=[0.0, 0.3, 0.3], + prob=0.8, + ), + RandScaleIntensityd(keys=["DIC"], factors=0.3, prob=0.5), + crop_transform, + ], + val_transforms=[normalize_transform, crop_transform], + batch_size=4, + num_workers=0, + channel_name="DIC", +) + +# %% +data.setup("fit") +dmt = data.train_dataloader() +dmv = data.val_dataloader() + +# %% +for batch in tqdm(dmt): + img = batch["source"] + f, ax = plt.subplots(4, 4, figsize=(12, 12)) + for sample, a in zip(img, ax.flatten()): + a.imshow(sample[0, 0].cpu().numpy(), cmap="gray", vmin=-5, vmax=5) + a.axis("off") + f.tight_layout() + break + +# %% +for batch in tqdm(dmv): + img = batch["source"] + f, ax = plt.subplots(4, 4, figsize=(12, 12)) + for sample, a in zip(img, ax.flatten()): + a.imshow(sample[0, 0].cpu().numpy(), cmap="gray", vmin=-5, vmax=5) + a.axis("off") + f.tight_layout() + break + + +# %% From bac26bedeb1037bf0eec44fb7f1b65fd3da7b653 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 28 Feb 2024 15:52:41 -0800 Subject: [PATCH 32/92] changing the normalization and augmentations default from None to empty list. --- viscy/data/hcs.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index 2c7397c0..af9a03a8 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -269,9 +269,9 @@ class HCSDataModule(LightningDataModule): :param tuple[int, int] yx_patch_size: patch size in (Y, X), defaults to (256, 256) :param Optional[list[MapTransform]] normalizations: MONAI dictionary transforms - applied to selected channels, defaults to None (no normalization) + applied to selected channels, defaults to [] (no normalization) :param Optional[list[MapTransform]] augmentations: MONAI dictionary transforms - applied to the training set, defaults to None (no augmentation) + applied to the training set, defaults to [] (no augmentation) :param bool caching: whether to decompress all the images and cache the result, will store in ``/tmp/$SLURM_JOB_ID/`` if available, defaults to False @@ -291,8 +291,8 @@ def __init__( num_workers: int = 8, architecture: Literal["2D", "2.1D", "2.2D", "2.5D", "3D", "fcmae"] = "2.5D", yx_patch_size: tuple[int, int] = (256, 256), - normalizations: Optional[list[MapTransform]] = None, - augmentations: Optional[list[MapTransform]] = None, + normalizations: Optional[list[MapTransform]] = [], + augmentations: Optional[list[MapTransform]] = [], caching: bool = False, ground_truth_masks: Optional[Path] = None, ): From 0b598c7e1b9fc0bbbc2aa08307964300f66b7f8a Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 29 Feb 2024 11:03:58 -0800 Subject: [PATCH 33/92] invert intensity transform --- viscy/transforms.py | 35 ++++++++++++++++++++++++++++++----- 1 file changed, 30 insertions(+), 5 deletions(-) diff --git a/viscy/transforms.py b/viscy/transforms.py index 7ce192af..88e7f738 100644 --- a/viscy/transforms.py +++ b/viscy/transforms.py @@ -8,9 +8,12 @@ RandAffined, RandGaussianNoised, RandGaussianSmoothd, + RandomizableTransform, RandScaleIntensityd, RandWeightedCropd, ) +from monai.transforms.transform import Randomizable +from numpy.random.mtrand import RandomState as RandomState from typing_extensions import Iterable, Literal from viscy.data.typing import Sample @@ -148,12 +151,34 @@ def __init__( # TODO: need to implement the case where the preprocessing already exists def __call__(self, sample: Sample) -> Sample: for key in self.keys: - if key in self.keys: - level_meta = sample["norm_meta"][key][self.level] - subtrahend_val = level_meta[self.subtrahend] - divisor_val = level_meta[self.divisor] + 1e-8 # avoid div by zero - sample[key] = (sample[key] - subtrahend_val) / divisor_val + level_meta = sample["norm_meta"][key][self.level] + subtrahend_val = level_meta[self.subtrahend] + divisor_val = level_meta[self.divisor] + 1e-8 # avoid div by zero + sample[key] = (sample[key] - subtrahend_val) / divisor_val return sample def _normalize(): NotImplementedError("_normalization() not implemented") + + +class RandInvertIntensityd(MapTransform, RandomizableTransform): + """ + Randomly invert the intensity of the image. + """ + + def __init__(self, keys: Union[str, Iterable[str]], prob: float = 0.1) -> None: + MapTransform.__init__(self, keys) + RandomizableTransform.__init__(self, prob) + + def __call__(self, sample: Sample) -> Sample: + self.randomize(None) + for key in self.keys: + if key in sample: + sample[key] = -sample[key] + return sample + + def set_random_state( + self, seed: int | None = None, state: RandomState | None = None + ) -> Randomizable: + super().set_random_state(seed, state) + return self From ddb30e9d05ebb7378a0ec1ee29acab6a33b32a14 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 29 Feb 2024 11:04:17 -0800 Subject: [PATCH 34/92] concatenated data module --- viscy/data/combined.py | 72 +++++++++++++++++++++++++++++++++++++++--- 1 file changed, 67 insertions(+), 5 deletions(-) diff --git a/viscy/data/combined.py b/viscy/data/combined.py index 5da700dd..45072909 100644 --- a/viscy/data/combined.py +++ b/viscy/data/combined.py @@ -1,9 +1,19 @@ +from enum import Enum from typing import Literal, Sequence from lightning.pytorch import LightningDataModule from lightning.pytorch.utilities.combined_loader import CombinedLoader +from torch import Tensor +from torch.utils.data import ConcatDataset, DataLoader -_MODES = Literal["min_size", "max_size_cycle", "max_size", "sequential"] +from viscy.data.hcs import _collate_samples + + +class CombineMode(Enum): + MIN_SIZE = "min_size" + MAX_SIZE_CYCLE = "max_size_cycle" + MAX_SIZE = "max_size" + SEQUENTIAL = "sequential" class CombinedDataModule(LightningDataModule): @@ -20,10 +30,10 @@ class CombinedDataModule(LightningDataModule): def __init__( self, data_modules: Sequence[LightningDataModule], - train_mode: _MODES = "max_size_cycle", - val_mode: _MODES = "sequential", - test_mode: _MODES = "sequential", - predict_mode: _MODES = "sequential", + train_mode: CombineMode = CombineMode.MAX_SIZE_CYCLE, + val_mode: CombineMode = CombineMode.SEQUENTIAL, + test_mode: CombineMode = CombineMode.SEQUENTIAL, + predict_mode: CombineMode = CombineMode.SEQUENTIAL, ): super().__init__() self.data_modules = data_modules @@ -60,3 +70,55 @@ def predict_dataloader(self): [dm.predict_dataloader() for dm in self.data_modules], mode=self.predict_mode, ) + + +class ConcatDataModule(LightningDataModule): + def __init__(self, data_modules: Sequence[LightningDataModule]): + super().__init__() + self.data_modules = data_modules + self.num_workers = data_modules[0].num_workers + self.batch_size = data_modules[0].batch_size + for dm in data_modules: + if dm.num_workers != self.num_workers: + raise ValueError("Inconsistent number of workers") + if dm.batch_size != self.batch_size: + raise ValueError("Inconsistent batch size") + + def prepare_data(self): + for dm in self.data_modules: + dm.prepare_data() + + def setup(self, stage: Literal["fit", "validate", "test", "predict"]): + self.train_patches_per_stack = 0 + for dm in self.data_modules: + dm.setup(stage) + if patches := getattr(dm, "train_patches_per_stack", 0): + if self.train_patches_per_stack == 0: + self.train_patches_per_stack = patches + elif self.train_patches_per_stack != patches: + raise ValueError("Inconsistent patches per stack") + if stage != "fit": + raise NotImplementedError("Only fit stage is supported") + self.train_dataset = ConcatDataset( + [dm.train_dataset for dm in self.data_modules] + ) + self.val_dataset = ConcatDataset([dm.val_dataset for dm in self.data_modules]) + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.batch_size // self.train_patches_per_stack, + num_workers=self.num_workers, + shuffle=True, + persistent_workers=bool(self.num_workers), + collate_fn=_collate_samples, + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + persistent_workers=bool(self.num_workers), + ) From 950475584f15534638c4c83d6e3fcf21314bb1e7 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 29 Feb 2024 11:04:37 -0800 Subject: [PATCH 35/92] subsample videos --- viscy/data/ctmc_v1.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/viscy/data/ctmc_v1.py b/viscy/data/ctmc_v1.py index 47844d68..d666fdcb 100644 --- a/viscy/data/ctmc_v1.py +++ b/viscy/data/ctmc_v1.py @@ -6,12 +6,33 @@ from torch.utils.data import DataLoader from viscy.data.hcs import ChannelMap, SlidingWindowDataset +from viscy.data.typing import Sample + + +class CTMCv1ValidationDataset(SlidingWindowDataset): + subsample_rate: int = 30 + + def __len__(self) -> int: + # sample every 30th frame in the videos + return super().__len__() // self.subsample_rate + + def __getitem__(self, index: int) -> Sample: + index = index * self.subsample_rate + return super().__getitem__(index) class CTMCv1DataModule(LightningDataModule): """ Autoregression data module for the CTMCv1 dataset. Training and validation datasets are stored in separate HCS OME-Zarr stores. + + :param str | Path train_data_path: Path to the training dataset + :param str | Path val_data_path: Path to the validation dataset + :param list[MapTransform] train_transforms: List of transforms for training + :param list[MapTransform] val_transforms: List of transforms for validation + :param int batch_size: Batch size, defaults to 16 + :param int num_workers: Number of workers, defaults to 8 + :param str channel_name: Name of the DIC channel, defaults to "DIC" """ def __init__( @@ -49,7 +70,7 @@ def _setup_fit(self) -> None: z_window_size=1, transform=Compose(self.train_transforms), ) - self.val_dataset = SlidingWindowDataset( + self.val_dataset = CTMCv1ValidationDataset( val_positions, channels=self.channel_map, z_window_size=1, From 808e39c02763f21c6cb07b79d2c60c2501220021 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 29 Feb 2024 11:04:48 -0800 Subject: [PATCH 36/92] livecell dataset --- viscy/data/livecell.py | 98 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 98 insertions(+) create mode 100644 viscy/data/livecell.py diff --git a/viscy/data/livecell.py b/viscy/data/livecell.py new file mode 100644 index 00000000..5d83f099 --- /dev/null +++ b/viscy/data/livecell.py @@ -0,0 +1,98 @@ +import json +from pathlib import Path + +import torch +from lightning.pytorch import LightningDataModule +from monai.transforms import Compose, Transform +from tifffile import imread +from torch.utils.data import DataLoader, Dataset + +from viscy.data.typing import Sample + + +class LiveCellDataset(Dataset): + """ + LiveCell dataset. + + :param list[Path] images: List of paths to single-page, single-channel TIFF files. + :param Transform | Compose transform: Transform to apply to the dataset + """ + + def __init__(self, images: list[Path], transform: Transform | Compose) -> None: + self.images = images + self.transform = transform + + def __len__(self) -> int: + return len(self.images) + + def __getitem__(self, idx: int) -> Sample: + image = imread(self.images[idx])[None, None] + image = torch.from_numpy(image).to(torch.float32) + image = self.transform(image) + return {"source": image, "target": image} + + +class LiveCellDataModule(LightningDataModule): + def __init__( + self, + train_val_images: Path, + train_annotations: Path, + val_annotations: Path, + train_transforms: list[Transform], + val_transforms: list[Transform], + batch_size: int = 16, + num_workers: int = 8, + ) -> None: + super().__init__() + self.train_val_images = Path(train_val_images) + if not self.train_val_images.is_dir(): + raise NotADirectoryError(str(train_val_images)) + self.train_annotations = Path(train_annotations) + if not self.train_annotations.is_file(): + raise FileNotFoundError(str(train_annotations)) + self.val_annotations = Path(val_annotations) + if not self.val_annotations.is_file(): + raise FileNotFoundError(str(val_annotations)) + self.train_transforms = Compose(train_transforms) + self.val_transforms = Compose(val_transforms) + self.batch_size = batch_size + self.num_workers = num_workers + + def setup(self, stage: str) -> None: + if stage != "fit": + raise NotImplementedError("Only fit stage is supported") + self._setup_fit() + + def _parse_image_names(self, annotations: Path) -> list[Path]: + with open(annotations) as f: + images = [f["file_name"] for f in json.load(f)["images"]] + return sorted(images) + + def _setup_fit(self) -> None: + train_images = self._parse_image_names(self.train_annotations) + val_images = self._parse_image_names(self.val_annotations) + self.train_dataset = LiveCellDataset( + [self.train_val_images / f for f in train_images], + transform=self.train_transforms, + ) + self.val_dataset = LiveCellDataset( + [self.train_val_images / f for f in val_images], + transform=self.val_transforms, + ) + + def train_dataloader(self) -> DataLoader: + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + persistent_workers=bool(self.num_workers), + shuffle=True, + ) + + def val_dataloader(self) -> DataLoader: + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + persistent_workers=bool(self.num_workers), + ) From 43d641db2e448336be64a6ccd17ecd4a8c218b95 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 29 Feb 2024 11:05:04 -0800 Subject: [PATCH 37/92] all sample fields are optional --- viscy/data/typing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/viscy/data/typing.py b/viscy/data/typing.py index c6b7c32f..aef7dea7 100644 --- a/viscy/data/typing.py +++ b/viscy/data/typing.py @@ -6,8 +6,8 @@ class Sample(TypedDict, total=False): """Image sample type for mini-batches.""" + # all optional index: tuple[str, int, int] - # optional source: Union[Tensor, Sequence[Tensor]] target: Union[Tensor, Sequence[Tensor]] labels: Union[Tensor, Sequence[Tensor]] From 42f81cfd2093e020e97db1238cc897e623fedcb1 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 29 Feb 2024 11:05:19 -0800 Subject: [PATCH 38/92] fix multi-dataloader validation --- viscy/light/engine.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/viscy/light/engine.py b/viscy/light/engine.py index 4d18e9c4..6c284954 100644 --- a/viscy/light/engine.py +++ b/viscy/light/engine.py @@ -194,12 +194,12 @@ def training_step(self, batch: Sample, batch_idx: int): ) return loss - def validation_step(self, batch: Sample, batch_idx: int): + def validation_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): source = batch["source"] target = batch["target"] pred = self.forward(source) loss = self.loss_function(pred, target) - self.log("loss/validate", loss, sync_dist=True) + self.log("loss/validate", loss, sync_dist=True, add_dataloader_idx=False) if batch_idx < self.log_batches_per_epoch: self.validation_step_outputs.extend( self._detach_sample((source, target, pred)) @@ -425,7 +425,15 @@ def training_step(self, batch: Sequence[Sample], batch_idx: int): def validation_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): source, target, pred, mask, loss = self.forward_fit(batch) - self.validation_losses.append(loss.detach()) + if dataloader_idx + 1 > len(self.validation_losses): + self.validation_losses.append([]) + self.validation_losses[dataloader_idx].append(loss.detach()) + self.log( + f"loss/val/{dataloader_idx}", + loss, + sync_dist=True, + batch_size=source.shape[0], + ) if batch_idx < self.log_batches_per_epoch: self.validation_step_outputs.extend( self._detach_sample((source, target * mask.unsqueeze(2), pred)) @@ -433,6 +441,6 @@ def validation_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0 def on_validation_epoch_end(self): super().on_validation_epoch_end() - self.log( - "loss/validate", torch.stack(self.validation_losses).mean(), sync_dist=True - ) + # average within each dataloader + loss_means = [torch.tensor(losses).mean() for losses in self.validation_losses] + self.log("loss/validate", torch.tensor(loss_means).mean(), sync_dist=True) From 4546fc77b8ee469b1a93f9689404b3fc47cc622d Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 29 Feb 2024 11:08:26 -0800 Subject: [PATCH 39/92] lint --- viscy/data/combined.py | 1 - 1 file changed, 1 deletion(-) diff --git a/viscy/data/combined.py b/viscy/data/combined.py index 45072909..d70b9333 100644 --- a/viscy/data/combined.py +++ b/viscy/data/combined.py @@ -3,7 +3,6 @@ from lightning.pytorch import LightningDataModule from lightning.pytorch.utilities.combined_loader import CombinedLoader -from torch import Tensor from torch.utils.data import ConcatDataset, DataLoader from viscy.data.hcs import _collate_samples From 306f3efadce651298647a1a8e60dcdd95eccb6d0 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Thu, 29 Feb 2024 13:13:25 -0800 Subject: [PATCH 40/92] fixing preprocessing for varying array shapes (i.e aics dataset) --- viscy/utils/meta_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/viscy/utils/meta_utils.py b/viscy/utils/meta_utils.py index d644dadf..961b6696 100644 --- a/viscy/utils/meta_utils.py +++ b/viscy/utils/meta_utils.py @@ -104,8 +104,9 @@ def generate_normalization_metadata( positions, fov_sample_values = mp_utils.mp_sample_im_pixels( this_channels_args, num_workers ) - dataset_sample_values = np.stack(fov_sample_values, 0) - + dataset_sample_values = np.concatenate( + [arr.flatten() for arr in fov_sample_values] + ) fov_level_statistics = mp_utils.mp_get_val_stats(fov_sample_values, num_workers) dataset_level_statistics = mp_utils.get_val_stats(dataset_sample_values) From 1a0e3ced8711bcdae7c6698c899aff45f7bdc777 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Fri, 1 Mar 2024 20:51:50 -0800 Subject: [PATCH 41/92] update loading scripts --- viscy/scripts/load_ctmc_v1.py | 38 ++++++++++----- viscy/scripts/load_livecell.py | 85 ++++++++++++++++++++++++++++++++++ 2 files changed, 112 insertions(+), 11 deletions(-) create mode 100644 viscy/scripts/load_livecell.py diff --git a/viscy/scripts/load_ctmc_v1.py b/viscy/scripts/load_ctmc_v1.py index e5c19094..41cef698 100644 --- a/viscy/scripts/load_ctmc_v1.py +++ b/viscy/scripts/load_ctmc_v1.py @@ -5,7 +5,11 @@ from monai.transforms import ( CenterSpatialCropd, NormalizeIntensityd, + RandAdjustContrastd, RandAffined, + RandFlipd, + RandGaussianNoised, + RandGaussianSmoothd, RandScaleIntensityd, ) from tqdm import tqdm @@ -13,10 +17,11 @@ from viscy.data.ctmc_v1 import CTMCv1DataModule # %% -data_path = Path("") +channel = "DIC" +data_path = Path("/hpc/reference/imaging/ctmc") -normalize_transform = NormalizeIntensityd(keys=["DIC"], channel_wise=True) -crop_transform = CenterSpatialCropd(keys=["DIC"], roi_size=[1, 256, 256]) +normalize_transform = NormalizeIntensityd(keys=[channel], channel_wise=True) +crop_transform = CenterSpatialCropd(keys=[channel], roi_size=[1, 224, 224]) data = CTMCv1DataModule( train_data_path=data_path / "CTMCV1_test.zarr", @@ -24,19 +29,29 @@ train_transforms=[ normalize_transform, RandAffined( - keys=["DIC"], + keys=[channel], rotate_range=[3.14, 0.0, 0.0], - shear_range=[0.0, 0.3, 0.3], - scale_range=[0.0, 0.3, 0.3], + scale_range=[0.0, [-0.6, 0.1], [-0.6, 0.1]], prob=0.8, + padding_mode="zeros", + ), + RandFlipd(keys=[channel], prob=0.5, spatial_axis=(1,2)), + RandAdjustContrastd(keys=[channel], prob=0.5, gamma=(0.8, 1.2)), + RandScaleIntensityd(keys=[channel], factors=0.3, prob=0.5), + RandGaussianNoised(keys=[channel], prob=0.5, mean=0.0, std=0.2), + RandGaussianSmoothd( + keys=[channel], + sigma_x=(0.05, 0.3), + sigma_y=(0.05, 0.3), + sigma_z=(0.05, 0.0), + prob=0.5, ), - RandScaleIntensityd(keys=["DIC"], factors=0.3, prob=0.5), crop_transform, ], val_transforms=[normalize_transform, crop_transform], - batch_size=4, + batch_size=32, num_workers=0, - channel_name="DIC", + channel_name=channel, ) # %% @@ -47,7 +62,8 @@ # %% for batch in tqdm(dmt): img = batch["source"] - f, ax = plt.subplots(4, 4, figsize=(12, 12)) + img[:, :, :, 32:64, 32:64] = 0 + f, ax = plt.subplots(5, 5, figsize=(15, 15)) for sample, a in zip(img, ax.flatten()): a.imshow(sample[0, 0].cpu().numpy(), cmap="gray", vmin=-5, vmax=5) a.axis("off") @@ -57,7 +73,7 @@ # %% for batch in tqdm(dmv): img = batch["source"] - f, ax = plt.subplots(4, 4, figsize=(12, 12)) + f, ax = plt.subplots(5, 5, figsize=(15, 15)) for sample, a in zip(img, ax.flatten()): a.imshow(sample[0, 0].cpu().numpy(), cmap="gray", vmin=-5, vmax=5) a.axis("off") diff --git a/viscy/scripts/load_livecell.py b/viscy/scripts/load_livecell.py new file mode 100644 index 00000000..cfaf2dfe --- /dev/null +++ b/viscy/scripts/load_livecell.py @@ -0,0 +1,85 @@ +# %% +from pathlib import Path + +import matplotlib.pyplot as plt +from monai.transforms import ( + CenterSpatialCrop, + NormalizeIntensity, + RandAdjustContrast, + RandAffine, + RandFlip, + RandGaussianNoise, + RandGaussianSmooth, + RandScaleIntensity, + RandSpatialCrop, +) +from tqdm import tqdm + +from viscy.data.livecell import LiveCellDataModule + +# %% +data_path = Path("/hpc/reference/imaging/livecell") + +normalize_transform = NormalizeIntensity(channel_wise=True) +crop_transform = CenterSpatialCrop(roi_size=[1, 224, 224]) + +data = LiveCellDataModule( + train_val_images=data_path / "images" / "livecell_train_val_images", + train_annotations=data_path + / "annotations" + / "livecell_coco_train_images_only.json", + val_annotations=data_path / "annotations" / "livecell_coco_val_images_only.json", + train_transforms=[ + normalize_transform, + RandSpatialCrop(roi_size=[1, 384, 384]), + RandAffine( + rotate_range=[3.14, 0.0, 0.0], + scale_range=[0.0, [-0.2, 0.8], [-0.2, 0.8]], + prob=0.8, + padding_mode="zeros", + ), + RandFlip(prob=0.5, spatial_axis=(1, 2)), + RandAdjustContrast(prob=0.5, gamma=(0.8, 1.2)), + RandScaleIntensity(factors=0.3, prob=0.5), + RandGaussianNoise(prob=0.5, mean=0.0, std=0.3), + RandGaussianSmooth( + sigma_x=(0.05, 0.3), + sigma_y=(0.05, 0.3), + sigma_z=(0.05, 0.0), + prob=0.5, + ), + crop_transform, + ], + val_transforms=[normalize_transform, crop_transform], + batch_size=16, + num_workers=0, +) + +# %% +data.setup("fit") +dmt = data.train_dataloader() +dmv = data.val_dataloader() + +# %% +for batch in tqdm(dmt): + img = batch["target"] + img[:, :, :, 32:64, 32:64] = 0 + f, ax = plt.subplots(4, 4, figsize=(15, 15)) + for sample, a in zip(img, ax.flatten()): + a.imshow(sample[0, 0].cpu().numpy(), cmap="gray", vmin=-5, vmax=5) + a.axis("off") + f.tight_layout() + break + +# %% +for batch in tqdm(dmv): + img = batch["source"] + f, ax = plt.subplots(4, 4, figsize=(12, 12)) + for sample, a in zip(img, ax.flatten()): + a.imshow(sample[0, 0].cpu().numpy(), cmap="gray", vmin=-5, vmax=5) + a.axis("off") + f.tight_layout() + break + + +# %% From d3ec94d2c0142bf073b2019ab9ac6eba4312eddd Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Fri, 1 Mar 2024 21:26:12 -0800 Subject: [PATCH 42/92] fix CombineMode --- viscy/data/combined.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/viscy/data/combined.py b/viscy/data/combined.py index d70b9333..13e64e21 100644 --- a/viscy/data/combined.py +++ b/viscy/data/combined.py @@ -36,10 +36,10 @@ def __init__( ): super().__init__() self.data_modules = data_modules - self.train_mode = train_mode - self.val_mode = val_mode - self.test_mode = test_mode - self.predict_mode = predict_mode + self.train_mode = CombineMode(train_mode).value + self.val_mode = CombineMode(val_mode).value + self.test_mode = CombineMode(test_mode).value + self.predict_mode = CombineMode(predict_mode).value def prepare_data(self): for dm in self.data_modules: From dd3471229f54a94b0162d93da667580217b773b5 Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Mon, 4 Mar 2024 11:13:36 -0800 Subject: [PATCH 43/92] added model and annotation code draft --- .../Infection_annotator.py | 55 +++++++ .../Infection_classification_model.py | 142 ++++++++++++++++++ 2 files changed, 197 insertions(+) create mode 100644 examples/infection_phenotyping/Infection_annotator.py create mode 100644 examples/infection_phenotyping/Infection_classification_model.py diff --git a/examples/infection_phenotyping/Infection_annotator.py b/examples/infection_phenotyping/Infection_annotator.py new file mode 100644 index 00000000..e933a773 --- /dev/null +++ b/examples/infection_phenotyping/Infection_annotator.py @@ -0,0 +1,55 @@ + + +#%% use napari to annotate infected cells in segmented data + +import napari +from iohub.ngff import open_ome_zarr +import numpy as np + +file_in_path = '/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/Infection_phenotyping_data/Exp_2023_09_28_DENV_A2.zarr' +zarr_input = open_ome_zarr( + file_in_path, + layout="hcs", + mode="r+", +) +chan_names = zarr_input.channel_names +# zarr_input.append_channel('Inf_mask',resize_arrays=True) + +file_out_path = '/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/Infection_phenotyping_data/Exp_2023_09_28_DENV_A2_infMarked_rev2.zarr' +zarr_output = open_ome_zarr( + file_out_path, + layout="hcs", + mode="w-", + channel_names=['Sensor','Nucl_mask','Inf_mask'], +) + +v = napari.Viewer() + + +#%% Load label image to napari +for well_id, well_data in zarr_input.wells(): + well_name, well_no = well_id.split("/") + + if well_name == 'A' and well_no == '2': + + for pos_name, pos_data in well_data.positions(): + # if int(pos_name) > 1: + v.layers.clear() + data = pos_data.data + + FITC = data[0,0,...] + v.add_image(FITC, name='FITC', colormap='green', blending='additive') + Inf_mask = data[0,1,...].astype(int) + v.add_labels(Inf_mask) + input("Press Enter") + + label_layer = v.layers['Inf_mask'] + label_array = label_layer.data + label_array = np.expand_dims(label_array, axis=(0, 1)) + # zarr_input.create_image('Inf_mask',label_array) + out_data = np.concatenate((data, label_array), axis=1) + position = zarr_output.create_position(well_name, well_no, pos_name) + position["0"] = out_data + + +# %% diff --git a/examples/infection_phenotyping/Infection_classification_model.py b/examples/infection_phenotyping/Infection_classification_model.py new file mode 100644 index 00000000..7f8667fa --- /dev/null +++ b/examples/infection_phenotyping/Infection_classification_model.py @@ -0,0 +1,142 @@ + +# %% +import torch +import sys +from viscy.data.hcs import HCSDataModule +# from lightning.pytorch import CustomDataset +from lightning.pytorch.callbacks import Callback +from lightning.pytorch import LightningDataModule +# import cv2 +import numpy as np +import torch.nn as nn +import torchvision.models as models +import lightning.pytorch as pl +import torch.nn.functional as F +from viscy.light.engine import VSUNet + +# %% Create a dataloader and visualize the batches. +# Set the path to the dataset +dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2023_09_28_DENV_A2_infMarked.zarr" + +# Create an instance of HCSDataModule +data_module = HCSDataModule(dataset_path, source_channel=['Sensor','Nucl_mask'], target_channel=['Inf_mask'],yx_patch_size=[256, 256], split_ratio=0.8, z_window_size=1, architecture = '2D') + +# Prepare the data +data_module.prepare_data() + +# Setup the data +data_module.setup(stage = "fit") + +# Create a dataloader +dataloader = data_module.train_dataloader() + +# Visualize the dataset and the batch using napari +import napari +from pytorch_lightning.loggers import TensorBoardLogger +# import os + +# Set the display +# os.environ['DISPLAY'] = ':1' + +# Create a napari viewer +viewer = napari.Viewer() + +# Add the dataset to the viewer +for batch in dataloader: + if isinstance(batch, dict): + for k, v in batch.items(): + if isinstance(v, torch.Tensor): + viewer.add_image(v.cpu().numpy().astype(np.float32)) + +# Start the napari event loop +napari.run() + +# %% use 2D Unet from viscy with a softmax layer at end for 4 label classification +# use for image translation from instance segmentation to annotated image + +# use diceloss function from here: https://gist.github.com/weiliu620/52d140b22685cf9552da4899e2160183 +def dice_loss(pred, target): + """This definition generalize to real valued pred and target vector. This should be differentiable. + pred: tensor with first dimension as batch + target: tensor with first dimension as batch + """ + + smooth = 1. + + # have to use contiguous since they may from a torch.view op + iflat = pred.contiguous().view(-1) + tflat = target.contiguous().view(-1) + intersection = (iflat * tflat).sum() + + A_sum = torch.sum(tflat * iflat) + B_sum = torch.sum(tflat * tflat) + + return 1 - ((2. * intersection + smooth) / (A_sum + B_sum + smooth) ) + +unet_model = VSUNet( + architecture='2D', + loss_function=dice_loss, + lr=1e-3, + example_input_xy_shape=(64,64) + ) + +# Define the optimizer +optimizer = torch.optim.Adam(unet_model.parameters(), lr=1e-3) + +# Iterate over the batches +for batch in dataloader: + # Extract the input and target from the batch + input_data, target = batch['source'], batch['target'] + + # Forward pass through the model + output = unet_model(input_data) + + # Apply softmax activation to the output + output = F.softmax(output, dim=1) + + # Calculate the loss + loss = dice_loss(output, target) + + # Perform backpropagation and update the model's parameters + loss.backward() + optimizer.step() + optimizer.zero_grad() + +#%% use the batch for training the unet model using the lightning module + +# Train the model +# Create a TensorBoard logger +logger = TensorBoardLogger("logs", name="infection_classification_model") + +# Pass the logger to the Trainer +trainer = pl.Trainer(gpus=1, logger=logger) + +# Fit the model +trainer.fit(unet_model, data_module) + +# %% test the model on the test set +# Load the test dataset +test_dataloader = data_module.test_dataloader() + +# Set the model to evaluation mode +unet_model.eval() + +# Create a list to store the predictions +predictions = [] + +# Iterate over the test batches +for batch in test_dataloader: + # Extract the input from the batch + input_data = batch['source'] + + # Forward pass through the model + output = unet_model(input_data) + + # Append the predictions to the list + predictions.append(output.detach().cpu().numpy()) + +# Convert the predictions to a numpy array +predictions = np.stack(predictions) + +# Save the predictions as added channel in zarr format +zarr.save('predictions.zarr', predictions) From 5fc9da2756285dd7bcc35f576910e873e1e9bdbf Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Mon, 4 Mar 2024 19:35:06 -0800 Subject: [PATCH 44/92] chnaged to simple unet model --- .../Infection_classification_model.py | 104 ++++++++++++++---- 1 file changed, 84 insertions(+), 20 deletions(-) diff --git a/examples/infection_phenotyping/Infection_classification_model.py b/examples/infection_phenotyping/Infection_classification_model.py index 7f8667fa..f038cc50 100644 --- a/examples/infection_phenotyping/Infection_classification_model.py +++ b/examples/infection_phenotyping/Infection_classification_model.py @@ -1,25 +1,24 @@ # %% import torch -import sys from viscy.data.hcs import HCSDataModule -# from lightning.pytorch import CustomDataset -from lightning.pytorch.callbacks import Callback -from lightning.pytorch import LightningDataModule -# import cv2 + import numpy as np import torch.nn as nn -import torchvision.models as models import lightning.pytorch as pl import torch.nn.functional as F -from viscy.light.engine import VSUNet + +import napari +from pytorch_lightning.loggers import TensorBoardLogger +from monai.transforms import Zoom +from monai.transforms import Compose, RandRotate, Resize, Zoom, Flip, RandFlip, RandZoom, RandRotate90, RandRotate, RandAffine, Rand2DElastic, Rand3DElastic, RandGaussianNoise, RandGaussianNoised # %% Create a dataloader and visualize the batches. # Set the path to the dataset dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2023_09_28_DENV_A2_infMarked.zarr" # Create an instance of HCSDataModule -data_module = HCSDataModule(dataset_path, source_channel=['Sensor','Nucl_mask'], target_channel=['Inf_mask'],yx_patch_size=[256, 256], split_ratio=0.8, z_window_size=1, architecture = '2D') +data_module = HCSDataModule(dataset_path, source_channel=['Sensor','Nucl_mask'], target_channel=['Inf_mask'],yx_patch_size=[128,128], split_ratio=0.8, z_window_size=1, architecture = '2D') # Prepare the data data_module.prepare_data() @@ -31,10 +30,6 @@ dataloader = data_module.train_dataloader() # Visualize the dataset and the batch using napari -import napari -from pytorch_lightning.loggers import TensorBoardLogger -# import os - # Set the display # os.environ['DISPLAY'] = ':1' @@ -73,23 +68,86 @@ def dice_loss(pred, target): return 1 - ((2. * intersection + smooth) / (A_sum + B_sum + smooth) ) -unet_model = VSUNet( - architecture='2D', - loss_function=dice_loss, - lr=1e-3, - example_input_xy_shape=(64,64) - ) +# load 2D UNet from viscy +# unet_model = VSUNet( +# architecture='2D', +# model_config={"in_channels": 2, "out_channels": 1}, +# loss_function=dice_loss, +# lr=1e-3, +# example_input_xy_shape=(128,128), +# ) + +# Define the data augmentations +# Define the augmentations +# transforms = Compose([ +# RandRotate(range_x=15, prob=0.5), +# Resize(spatial_size=[64, 64],mode='linear'), +# Zoom([0.5,2], mode='bilinear'), +# Flip(spatial_axis=[0,1]), +# RandFlip(spatial_axis=[0,1], prob=0.5), +# RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5), +# RandRotate90(spatial_axes=(0,1), prob=0.2, max_k=3), +# RandGaussianNoise(prob=0.5), +# ]) + +transforms = Compose([ + Flip(spatial_axis=[0,1]), +]) + +# create a small unet for image translation which accepts two input images (a label image and a microscopy image) and outputs one label image +class UNet(nn.Module): + def __init__(self, in_channels, out_channels): + super(UNet, self).__init__() + + self.encoder = nn.Sequential( + nn.Conv2d(in_channels, 64, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.Conv2d(64, 128, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.Conv2d(128, 256, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=2, stride=2), + nn.Conv2d(256, 512, kernel_size=3, padding=1), + nn.ReLU(inplace=True) + ) + + # Define the decoder part of the U-Net architecture + self.decoder = nn.Sequential( + nn.Conv2d(512, 256, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 128, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(128, 64, kernel_size=3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(64, out_channels, kernel_size=1) + ) + + def forward(self, x): + # Apply the encoder to the input + x = self.encoder(x) + + # Apply the decoder to the output of the encoder + x = self.decoder(x) + + return x + +unet_model = UNet(in_channels=2, out_channels=1) # Define the optimizer optimizer = torch.optim.Adam(unet_model.parameters(), lr=1e-3) -# Iterate over the batches +#%% Iterate over the batches for batch in dataloader: # Extract the input and target from the batch input_data, target = batch['source'], batch['target'] + # Apply the augmentations to your data + augmented_input = transforms(input_data) + # Forward pass through the model - output = unet_model(input_data) + output = unet_model(augmented_input,target) # Apply softmax activation to the output output = F.softmax(output, dim=1) @@ -102,6 +160,11 @@ def dice_loss(pred, target): optimizer.step() optimizer.zero_grad() +# Visualize sample of the augmented data using napari +for i in range(augmented_data.shape[0]): + viewer.add_image(augmented_data[i].cpu().numpy().astype(np.float32)) + + #%% use the batch for training the unet model using the lightning module # Train the model @@ -139,4 +202,5 @@ def dice_loss(pred, target): predictions = np.stack(predictions) # Save the predictions as added channel in zarr format +# use iohub or viscy to save the predictions!!! zarr.save('predictions.zarr', predictions) From e6274887bdda8f4a6fd718f510fd783d05a0a1f3 Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Wed, 6 Mar 2024 10:21:25 -0800 Subject: [PATCH 45/92] start with lesser augmentations --- .../Infection_classification_model.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/examples/infection_phenotyping/Infection_classification_model.py b/examples/infection_phenotyping/Infection_classification_model.py index f038cc50..91a7078e 100644 --- a/examples/infection_phenotyping/Infection_classification_model.py +++ b/examples/infection_phenotyping/Infection_classification_model.py @@ -10,7 +10,6 @@ import napari from pytorch_lightning.loggers import TensorBoardLogger -from monai.transforms import Zoom from monai.transforms import Compose, RandRotate, Resize, Zoom, Flip, RandFlip, RandZoom, RandRotate90, RandRotate, RandAffine, Rand2DElastic, Rand3DElastic, RandGaussianNoise, RandGaussianNoised # %% Create a dataloader and visualize the batches. @@ -91,7 +90,11 @@ def dice_loss(pred, target): # ]) transforms = Compose([ + RandRotate(range_x=15, prob=0.5), Flip(spatial_axis=[0,1]), + RandFlip(spatial_axis=[0,1], prob=0.5), + RandRotate90(spatial_axes=(0,1), prob=0.2, max_k=3), + RandGaussianNoise(prob=0.5), ]) # create a small unet for image translation which accepts two input images (a label image and a microscopy image) and outputs one label image @@ -121,7 +124,12 @@ def __init__(self, in_channels, out_channels): nn.ReLU(inplace=True), nn.Conv2d(128, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True), - nn.Conv2d(64, out_channels, kernel_size=1) + nn.Conv2d(128, out_channels, kernel_size=1), + nn.Softmax(dim=1), + nn.Conv2d(out_channels, out_channels, kernel_size=1), + nn.Softmax(dim=1), + nn.Conv2d(out_channels, out_channels, kernel_size=1), + nn.Softmax(dim=1) ) def forward(self, x): @@ -145,9 +153,10 @@ def forward(self, x): # Apply the augmentations to your data augmented_input = transforms(input_data) + viewer.add_image(augmented_input.cpu().numpy().astype(np.float32)) # Forward pass through the model - output = unet_model(augmented_input,target) + output = unet_model(augmented_input) # Apply softmax activation to the output output = F.softmax(output, dim=1) @@ -161,8 +170,8 @@ def forward(self, x): optimizer.zero_grad() # Visualize sample of the augmented data using napari -for i in range(augmented_data.shape[0]): - viewer.add_image(augmented_data[i].cpu().numpy().astype(np.float32)) +# for i in range(augmented_input.shape[0]): +# viewer.add_image(augmented_input[i].cpu().numpy().astype(np.float32)) #%% use the batch for training the unet model using the lightning module From 310ba7091444683a4c1b3e52995746f62568848c Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Wed, 6 Mar 2024 10:22:59 -0800 Subject: [PATCH 46/92] added readme file --- examples/infection_phenotyping/readme.md | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 examples/infection_phenotyping/readme.md diff --git a/examples/infection_phenotyping/readme.md b/examples/infection_phenotyping/readme.md new file mode 100644 index 00000000..74dbc500 --- /dev/null +++ b/examples/infection_phenotyping/readme.md @@ -0,0 +1,7 @@ +# Infection Classification Model + +This repository contains the code for the infection classification model (`infection_classification_model.py`) used in the infection phenotyping project. + +## Overview + +The `infection_classification_model.py` file implements a machine learning model for classifying infections based on various features. The model is trained on a labeled dataset, either fluorescence or label-free images, and can be used to predict the infection type for new samples. \ No newline at end of file From 34b81b952f34c8a00e190e354aa544b51fee5f82 Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Wed, 6 Mar 2024 15:56:07 -0800 Subject: [PATCH 47/92] added tensorboard logging --- .../Infection_classification_model.py | 174 ++++++++++-------- 1 file changed, 96 insertions(+), 78 deletions(-) diff --git a/examples/infection_phenotyping/Infection_classification_model.py b/examples/infection_phenotyping/Infection_classification_model.py index 91a7078e..aab76139 100644 --- a/examples/infection_phenotyping/Infection_classification_model.py +++ b/examples/infection_phenotyping/Infection_classification_model.py @@ -10,14 +10,26 @@ import napari from pytorch_lightning.loggers import TensorBoardLogger -from monai.transforms import Compose, RandRotate, Resize, Zoom, Flip, RandFlip, RandZoom, RandRotate90, RandRotate, RandAffine, Rand2DElastic, Rand3DElastic, RandGaussianNoise, RandGaussianNoised +from monai.transforms import RandRotate, Resize, Zoom, Flip, RandFlip, RandZoom, RandRotate90, RandRotate, RandAffine, Rand2DElastic, Rand3DElastic, RandGaussianNoise, RandGaussianNoised +from pytorch_lightning.callbacks import ModelCheckpoint # %% Create a dataloader and visualize the batches. # Set the path to the dataset dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2023_09_28_DENV_A2_infMarked.zarr" # Create an instance of HCSDataModule -data_module = HCSDataModule(dataset_path, source_channel=['Sensor','Nucl_mask'], target_channel=['Inf_mask'],yx_patch_size=[128,128], split_ratio=0.8, z_window_size=1, architecture = '2D') +data_module = HCSDataModule( + dataset_path, + source_channel=['Sensor','Nucl_mask'], + target_channel=['Inf_mask'], + yx_patch_size=[128,128], + split_ratio=0.8, + z_window_size=1, + architecture = '2D', + num_workers=1, + batch_size=12, + augmentations=[], +) # Prepare the data data_module.prepare_data() @@ -32,41 +44,41 @@ # Set the display # os.environ['DISPLAY'] = ':1' -# Create a napari viewer -viewer = napari.Viewer() +# # Create a napari viewer +# viewer = napari.Viewer() -# Add the dataset to the viewer -for batch in dataloader: - if isinstance(batch, dict): - for k, v in batch.items(): - if isinstance(v, torch.Tensor): - viewer.add_image(v.cpu().numpy().astype(np.float32)) +# # Add the dataset to the viewer +# for batch in dataloader: +# if isinstance(batch, dict): +# for k, v in batch.items(): +# if isinstance(v, torch.Tensor): +# viewer.add_image(v.cpu().numpy().astype(np.float32)) -# Start the napari event loop -napari.run() +# # Start the napari event loop +# napari.run() # %% use 2D Unet from viscy with a softmax layer at end for 4 label classification # use for image translation from instance segmentation to annotated image -# use diceloss function from here: https://gist.github.com/weiliu620/52d140b22685cf9552da4899e2160183 -def dice_loss(pred, target): - """This definition generalize to real valued pred and target vector. This should be differentiable. - pred: tensor with first dimension as batch - target: tensor with first dimension as batch - """ - - smooth = 1. - - # have to use contiguous since they may from a torch.view op - iflat = pred.contiguous().view(-1) - tflat = target.contiguous().view(-1) - intersection = (iflat * tflat).sum() - - A_sum = torch.sum(tflat * iflat) - B_sum = torch.sum(tflat * tflat) +# use diceloss function from here: https://www.kaggle.com/code/bigironsphere/loss-function-library-keras-pytorch +class DiceLoss(nn.Module): + def __init__(self, weight=None, size_average=True): + super(DiceLoss, self).__init__() + + def forward(self, inputs, targets, smooth=1): + + #comment out if your model contains a sigmoid or equivalent activation layer + inputs = F.sigmoid(inputs) + + #flatten label and prediction tensors + inputs = inputs.view(-1) + targets = targets.view(-1) + + intersection = (inputs * targets).sum() + dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth) + + return 1 - dice - return 1 - ((2. * intersection + smooth) / (A_sum + B_sum + smooth) ) - # load 2D UNet from viscy # unet_model = VSUNet( # architecture='2D', @@ -76,67 +88,43 @@ def dice_loss(pred, target): # example_input_xy_shape=(128,128), # ) -# Define the data augmentations -# Define the augmentations -# transforms = Compose([ -# RandRotate(range_x=15, prob=0.5), -# Resize(spatial_size=[64, 64],mode='linear'), -# Zoom([0.5,2], mode='bilinear'), -# Flip(spatial_axis=[0,1]), -# RandFlip(spatial_axis=[0,1], prob=0.5), -# RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5), -# RandRotate90(spatial_axes=(0,1), prob=0.2, max_k=3), -# RandGaussianNoise(prob=0.5), -# ]) - -transforms = Compose([ - RandRotate(range_x=15, prob=0.5), - Flip(spatial_axis=[0,1]), - RandFlip(spatial_axis=[0,1], prob=0.5), - RandRotate90(spatial_axes=(0,1), prob=0.2, max_k=3), - RandGaussianNoise(prob=0.5), -]) - # create a small unet for image translation which accepts two input images (a label image and a microscopy image) and outputs one label image class UNet(nn.Module): def __init__(self, in_channels, out_channels): super(UNet, self).__init__() self.encoder = nn.Sequential( - nn.Conv2d(in_channels, 64, kernel_size=3, padding=1), + nn.Conv3d(in_channels, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True), - nn.MaxPool2d(kernel_size=2, stride=2), - nn.Conv2d(64, 128, kernel_size=3, padding=1), + nn.MaxPool3d(kernel_size=1, stride=1), + nn.Conv3d(64, 128, kernel_size=3, padding=1), nn.ReLU(inplace=True), - nn.MaxPool2d(kernel_size=2, stride=2), - nn.Conv2d(128, 256, kernel_size=3, padding=1), + nn.MaxPool3d(kernel_size=1, stride=1), + nn.Conv3d(128, 256, kernel_size=3, padding=1), nn.ReLU(inplace=True), - nn.MaxPool2d(kernel_size=2, stride=2), - nn.Conv2d(256, 512, kernel_size=3, padding=1), + nn.MaxPool3d(kernel_size=1, stride=1), + nn.Conv3d(256, 512, kernel_size=3, padding=1), nn.ReLU(inplace=True) ) - # Define the decoder part of the U-Net architecture self.decoder = nn.Sequential( - nn.Conv2d(512, 256, kernel_size=3, padding=1), + nn.Conv3d(512, 256, kernel_size=3, padding=1), nn.ReLU(inplace=True), - nn.Conv2d(256, 128, kernel_size=3, padding=1), + nn.Conv3d(256, 128, kernel_size=3, padding=1), nn.ReLU(inplace=True), - nn.Conv2d(128, 64, kernel_size=3, padding=1), + nn.Conv3d(128, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True), - nn.Conv2d(128, out_channels, kernel_size=1), + nn.Conv3d(64, out_channels, kernel_size=1), nn.Softmax(dim=1), - nn.Conv2d(out_channels, out_channels, kernel_size=1), + nn.Conv3d(out_channels, out_channels, kernel_size=1), nn.Softmax(dim=1), - nn.Conv2d(out_channels, out_channels, kernel_size=1), + nn.Conv3d(out_channels, out_channels, kernel_size=1), nn.Softmax(dim=1) ) def forward(self, x): - # Apply the encoder to the input x = self.encoder(x) - # Apply the decoder to the output of the encoder x = self.decoder(x) return x @@ -150,19 +138,13 @@ def forward(self, x): for batch in dataloader: # Extract the input and target from the batch input_data, target = batch['source'], batch['target'] - - # Apply the augmentations to your data - augmented_input = transforms(input_data) - viewer.add_image(augmented_input.cpu().numpy().astype(np.float32)) + # viewer.add_image(input_data.cpu().numpy().astype(np.float32)) # Forward pass through the model - output = unet_model(augmented_input) - - # Apply softmax activation to the output - output = F.softmax(output, dim=1) + output = unet_model(input_data) # Calculate the loss - loss = dice_loss(output, target) + loss = DiceLoss()(output, target) # Perform backpropagation and update the model's parameters loss.backward() @@ -178,10 +160,46 @@ def forward(self, x): # Train the model # Create a TensorBoard logger -logger = TensorBoardLogger("logs", name="infection_classification_model") +class LightningUNet(pl.LightningModule): + def __init__(self, in_channels, out_channels): + super(LightningUNet, self).__init__() + self.unet_model = UNet(in_channels, out_channels) + + def forward(self, x): + return self.unet_model(x) + + def training_step(self, batch, batch_idx): + input_data, target = batch['source'], batch['target'] + output = self(input_data) + loss = DiceLoss()(output, target) + self.log('train_loss', loss) + return loss + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) + return optimizer + +# Create an instance of the LightningUNet class +unet_model = LightningUNet(in_channels=2, out_channels=1) + +# Define the logger +logger = TensorBoardLogger("/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/Infection_phenotyping_data/logs", name="infection_classification_model") # Pass the logger to the Trainer -trainer = pl.Trainer(gpus=1, logger=logger) +trainer = pl.Trainer(logger=logger, max_epochs=10, default_root_dir="/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/Infection_phenotyping_data/logs", log_every_n_steps=1) + +# Define the checkpoint callback +checkpoint_callback = ModelCheckpoint( + dirpath='/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/Infection_phenotyping_data/checkpoints', + filename='checkpoint_{epoch:02d}', + save_top_k=-1, + verbose=True, + monitor='val_loss', + mode='min' +) + +# Add the checkpoint callback to the trainer +trainer.callbacks.append(checkpoint_callback) # Fit the model trainer.fit(unet_model, data_module) From a4e2f0d683553c7971d536c11ef1ee5486c6ad71 Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Thu, 7 Mar 2024 14:05:57 -0800 Subject: [PATCH 48/92] added validation step --- .../Infection_classification_model.py | 35 ++++++++++++++++--- 1 file changed, 30 insertions(+), 5 deletions(-) diff --git a/examples/infection_phenotyping/Infection_classification_model.py b/examples/infection_phenotyping/Infection_classification_model.py index aab76139..cab56b5c 100644 --- a/examples/infection_phenotyping/Infection_classification_model.py +++ b/examples/infection_phenotyping/Infection_classification_model.py @@ -38,7 +38,9 @@ data_module.setup(stage = "fit") # Create a dataloader -dataloader = data_module.train_dataloader() +train_dm = data_module.train_dataloader() + +val_dm = data_module.val_dataloader() # Visualize the dataset and the batch using napari # Set the display @@ -135,7 +137,7 @@ def forward(self, x): optimizer = torch.optim.Adam(unet_model.parameters(), lr=1e-3) #%% Iterate over the batches -for batch in dataloader: +for batch in train_dm: # Extract the input and target from the batch input_data, target = batch['source'], batch['target'] # viewer.add_image(input_data.cpu().numpy().astype(np.float32)) @@ -151,6 +153,17 @@ def forward(self, x): optimizer.step() optimizer.zero_grad() +for batch in val_dm: + # Extract the input and target from the batch + input_data, target = batch['source'], batch['target'] + + # Forward pass through the model + output = unet_model(input_data) + + # Calculate the loss + loss = DiceLoss()(output, target) + + # Visualize sample of the augmented data using napari # for i in range(augmented_input.shape[0]): # viewer.add_image(augmented_input[i].cpu().numpy().astype(np.float32)) @@ -175,6 +188,12 @@ def training_step(self, batch, batch_idx): self.log('train_loss', loss) return loss + def validation_step(self, batch, batch_idx): + input_data, target = batch['source'], batch['target'] + output = self(input_data) + loss = DiceLoss()(output, target) + self.log('val_loss', loss) + def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) return optimizer @@ -186,7 +205,7 @@ def configure_optimizers(self): logger = TensorBoardLogger("/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/Infection_phenotyping_data/logs", name="infection_classification_model") # Pass the logger to the Trainer -trainer = pl.Trainer(logger=logger, max_epochs=10, default_root_dir="/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/Infection_phenotyping_data/logs", log_every_n_steps=1) +trainer = pl.Trainer(logger=logger, max_epochs=30, default_root_dir="/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/Infection_phenotyping_data/logs", log_every_n_steps=1) # Define the checkpoint callback checkpoint_callback = ModelCheckpoint( @@ -205,8 +224,14 @@ def configure_optimizers(self): trainer.fit(unet_model, data_module) # %% test the model on the test set -# Load the test dataset -test_dataloader = data_module.test_dataloader() +test_datapath = '/hpc/projects/intracellular_dashboard/viral-sensor/2023_12_08-BJ5a-calibration/5_classify/2023_12_08_BJ5a_pAL040_72HPI_Calibration_1.zarr' + +test_dm = HCSDataModule( + test_datapath, + source_channel=['Sensor','Nuclei_mask'], +) +# Load the predict dataset +test_dataloader = test_dm.test_dataloader() # Set the model to evaluation mode unet_model.eval() From 0ebb5df40e6597d9af2b797e6ba89cb25bb8491a Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Mon, 11 Mar 2024 08:14:18 -0700 Subject: [PATCH 49/92] chnaged to viscy 2d unet --- .../Infection_classification_model.py | 76 ++----------------- 1 file changed, 7 insertions(+), 69 deletions(-) diff --git a/examples/infection_phenotyping/Infection_classification_model.py b/examples/infection_phenotyping/Infection_classification_model.py index cab56b5c..f97e654c 100644 --- a/examples/infection_phenotyping/Infection_classification_model.py +++ b/examples/infection_phenotyping/Infection_classification_model.py @@ -12,6 +12,8 @@ from pytorch_lightning.loggers import TensorBoardLogger from monai.transforms import RandRotate, Resize, Zoom, Flip, RandFlip, RandZoom, RandRotate90, RandRotate, RandAffine, Rand2DElastic, Rand3DElastic, RandGaussianNoise, RandGaussianNoised from pytorch_lightning.callbacks import ModelCheckpoint +from monai.losses import DiceLoss +from viscy.light.engine import VSUNet # %% Create a dataloader and visualize the batches. # Set the path to the dataset @@ -61,77 +63,13 @@ # %% use 2D Unet from viscy with a softmax layer at end for 4 label classification # use for image translation from instance segmentation to annotated image - -# use diceloss function from here: https://www.kaggle.com/code/bigironsphere/loss-function-library-keras-pytorch -class DiceLoss(nn.Module): - def __init__(self, weight=None, size_average=True): - super(DiceLoss, self).__init__() - - def forward(self, inputs, targets, smooth=1): - - #comment out if your model contains a sigmoid or equivalent activation layer - inputs = F.sigmoid(inputs) - - #flatten label and prediction tensors - inputs = inputs.view(-1) - targets = targets.view(-1) - - intersection = (inputs * targets).sum() - dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth) - - return 1 - dice # load 2D UNet from viscy -# unet_model = VSUNet( -# architecture='2D', -# model_config={"in_channels": 2, "out_channels": 1}, -# loss_function=dice_loss, -# lr=1e-3, -# example_input_xy_shape=(128,128), -# ) - -# create a small unet for image translation which accepts two input images (a label image and a microscopy image) and outputs one label image -class UNet(nn.Module): - def __init__(self, in_channels, out_channels): - super(UNet, self).__init__() - - self.encoder = nn.Sequential( - nn.Conv3d(in_channels, 64, kernel_size=3, padding=1), - nn.ReLU(inplace=True), - nn.MaxPool3d(kernel_size=1, stride=1), - nn.Conv3d(64, 128, kernel_size=3, padding=1), - nn.ReLU(inplace=True), - nn.MaxPool3d(kernel_size=1, stride=1), - nn.Conv3d(128, 256, kernel_size=3, padding=1), - nn.ReLU(inplace=True), - nn.MaxPool3d(kernel_size=1, stride=1), - nn.Conv3d(256, 512, kernel_size=3, padding=1), - nn.ReLU(inplace=True) - ) - - self.decoder = nn.Sequential( - nn.Conv3d(512, 256, kernel_size=3, padding=1), - nn.ReLU(inplace=True), - nn.Conv3d(256, 128, kernel_size=3, padding=1), - nn.ReLU(inplace=True), - nn.Conv3d(128, 64, kernel_size=3, padding=1), - nn.ReLU(inplace=True), - nn.Conv3d(64, out_channels, kernel_size=1), - nn.Softmax(dim=1), - nn.Conv3d(out_channels, out_channels, kernel_size=1), - nn.Softmax(dim=1), - nn.Conv3d(out_channels, out_channels, kernel_size=1), - nn.Softmax(dim=1) - ) - - def forward(self, x): - x = self.encoder(x) - - x = self.decoder(x) - - return x - -unet_model = UNet(in_channels=2, out_channels=1) +unet_model = VSUNet( + architecture='2D', + model_config={"in_channels": 2, "out_channels": 4, "task": "reg"}, + lr=1e-3, +) # Define the optimizer optimizer = torch.optim.Adam(unet_model.parameters(), lr=1e-3) From a0e426a8edb9b09f3dc839dd029c40849f1c462f Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Tue, 12 Mar 2024 12:22:36 -0700 Subject: [PATCH 50/92] used crossentropyloss with one-hot encoding --- .../Infection_classification_model.py | 211 ++++++++++-------- 1 file changed, 124 insertions(+), 87 deletions(-) diff --git a/examples/infection_phenotyping/Infection_classification_model.py b/examples/infection_phenotyping/Infection_classification_model.py index f97e654c..b2230d9b 100644 --- a/examples/infection_phenotyping/Infection_classification_model.py +++ b/examples/infection_phenotyping/Infection_classification_model.py @@ -1,4 +1,3 @@ - # %% import torch from viscy.data.hcs import HCSDataModule @@ -7,13 +6,31 @@ import torch.nn as nn import lightning.pytorch as pl import torch.nn.functional as F +import torchview +from typing import Literal, Union -import napari +# import napari from pytorch_lightning.loggers import TensorBoardLogger -from monai.transforms import RandRotate, Resize, Zoom, Flip, RandFlip, RandZoom, RandRotate90, RandRotate, RandAffine, Rand2DElastic, Rand3DElastic, RandGaussianNoise, RandGaussianNoised +from monai.transforms import ( + RandRotate, + Resize, + Zoom, + Flip, + RandFlip, + RandZoom, + RandRotate90, + RandRotate, + RandAffine, + Rand2DElastic, + Rand3DElastic, + RandGaussianNoise, + RandGaussianNoised, +) from pytorch_lightning.callbacks import ModelCheckpoint from monai.losses import DiceLoss from viscy.light.engine import VSUNet +from viscy.unet.networks.Unet2D import Unet2d +from viscy.data.hcs import Sample # %% Create a dataloader and visualize the batches. # Set the path to the dataset @@ -21,13 +38,13 @@ # Create an instance of HCSDataModule data_module = HCSDataModule( - dataset_path, - source_channel=['Sensor','Nucl_mask'], - target_channel=['Inf_mask'], - yx_patch_size=[128,128], - split_ratio=0.8, - z_window_size=1, - architecture = '2D', + dataset_path, + source_channel=["Sensor"], + target_channel=["Inf_mask"], + yx_patch_size=[128, 128], + split_ratio=0.8, + z_window_size=1, + architecture="2D", num_workers=1, batch_size=12, augmentations=[], @@ -37,7 +54,7 @@ data_module.prepare_data() # Setup the data -data_module.setup(stage = "fit") +data_module.setup(stage="fit") # Create a dataloader train_dm = data_module.train_dataloader() @@ -61,112 +78,132 @@ # # Start the napari event loop # napari.run() -# %% use 2D Unet from viscy with a softmax layer at end for 4 label classification -# use for image translation from instance segmentation to annotated image - -# load 2D UNet from viscy -unet_model = VSUNet( - architecture='2D', - model_config={"in_channels": 2, "out_channels": 4, "task": "reg"}, - lr=1e-3, -) - -# Define the optimizer -optimizer = torch.optim.Adam(unet_model.parameters(), lr=1e-3) - -#%% Iterate over the batches -for batch in train_dm: - # Extract the input and target from the batch - input_data, target = batch['source'], batch['target'] - # viewer.add_image(input_data.cpu().numpy().astype(np.float32)) - - # Forward pass through the model - output = unet_model(input_data) - - # Calculate the loss - loss = DiceLoss()(output, target) - - # Perform backpropagation and update the model's parameters - loss.backward() - optimizer.step() - optimizer.zero_grad() - -for batch in val_dm: - # Extract the input and target from the batch - input_data, target = batch['source'], batch['target'] - - # Forward pass through the model - output = unet_model(input_data) - - # Calculate the loss - loss = DiceLoss()(output, target) - - -# Visualize sample of the augmented data using napari -# for i in range(augmented_input.shape[0]): -# viewer.add_image(augmented_input[i].cpu().numpy().astype(np.float32)) +# %% use 2D Unet and Lightning module -#%% use the batch for training the unet model using the lightning module - # Train the model # Create a TensorBoard logger class LightningUNet(pl.LightningModule): - def __init__(self, in_channels, out_channels): + def __init__( + self, + in_channels, + out_channels, + lr: float = 1e-3, + loss_function: nn.CrossEntropyLoss = None, + schedule: Literal["WarmupCosine", "Constant"] = "Constant", + log_batches_per_epoch: int = 2, + log_samples_per_batch: int = 1, + ): super(LightningUNet, self).__init__() - self.unet_model = UNet(in_channels, out_channels) + self.unet_model = Unet2d(in_channels=in_channels, out_channels=out_channels) + self.lr = lr + self.loss_function = loss_function if loss_function else nn.CrossEntropyLoss() + self.schedule = schedule + self.log_batches_per_epoch = log_batches_per_epoch + self.log_samples_per_batch = log_samples_per_batch + self.training_step_outputs = [] + self.validation_step_outputs = [] def forward(self, x): return self.unet_model(x) - def training_step(self, batch, batch_idx): - input_data, target = batch['source'], batch['target'] - output = self(input_data) - loss = DiceLoss()(output, target) - self.log('train_loss', loss) - return loss - - def validation_step(self, batch, batch_idx): - input_data, target = batch['source'], batch['target'] - output = self(input_data) - loss = DiceLoss()(output, target) - self.log('val_loss', loss) - def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) return optimizer -# Create an instance of the LightningUNet class -unet_model = LightningUNet(in_channels=2, out_channels=1) + def training_step(self, batch: Sample, batch_idx: int): + + # Extract the input and target from the batch + source = batch["source"] + target = batch["target"] + pred = self.forward(source) + + # Convert the target image to one-hot encoding + target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=4).permute( + 0, 4, 1, 2, 3 + ) + target_one_hot = target_one_hot.float() # Convert target to float type + # Calculate the loss + train_loss = self.loss_function(pred, target_one_hot) + # if batch_idx < self.log_batches_per_epoch: + # self.training_step_outputs.extend( + # self._detach_sample((source, target_one_hot, pred)) + # ) + self.log( + "loss/train", + train_loss, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, + ) + return train_loss + + def validation_step(self, batch: Sample, batch_idx: int): + + # Extract the input and target from the batch + source = batch["source"] + target = batch["target"] + pred = self.forward(source) + + # Convert the target image to one-hot encoding + target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=4).permute( + 0, 4, 1, 2, 3 + ) + target_one_hot = target_one_hot.float() # Convert target to float type + # Calculate the loss + loss = self.loss_function(pred, target_one_hot) + self.log("loss/validate", loss, sync_dist=True, add_dataloader_idx=False) + # if batch_idx < self.log_batches_per_epoch: + # self.validation_step_outputs.extend( + # self._detach_sample((source, target, pred)) + # ) + return loss + -# Define the logger -logger = TensorBoardLogger("/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/Infection_phenotyping_data/logs", name="infection_classification_model") +# %% Define the logger +logger = TensorBoardLogger( + "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/Infection_phenotyping_data/logs", + name="infection_classification_model", +) # Pass the logger to the Trainer -trainer = pl.Trainer(logger=logger, max_epochs=30, default_root_dir="/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/Infection_phenotyping_data/logs", log_every_n_steps=1) +trainer = pl.Trainer( + logger=logger, + max_epochs=30, + default_root_dir="/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/Infection_phenotyping_data/logs", + log_every_n_steps=1, +) # Define the checkpoint callback checkpoint_callback = ModelCheckpoint( - dirpath='/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/Infection_phenotyping_data/checkpoints', - filename='checkpoint_{epoch:02d}', + dirpath="/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/Infection_phenotyping_data/checkpoints", + filename="checkpoint_{epoch:02d}", save_top_k=-1, verbose=True, - monitor='val_loss', - mode='min' + monitor="loss/validate", + mode="min", ) # Add the checkpoint callback to the trainer trainer.callbacks.append(checkpoint_callback) # Fit the model -trainer.fit(unet_model, data_module) +model = LightningUNet( + in_channels=1, + out_channels=4, + loss_function=nn.CrossEntropyLoss(), +) +trainer.fit(model, data_module) + # %% test the model on the test set -test_datapath = '/hpc/projects/intracellular_dashboard/viral-sensor/2023_12_08-BJ5a-calibration/5_classify/2023_12_08_BJ5a_pAL040_72HPI_Calibration_1.zarr' +test_datapath = "/hpc/projects/intracellular_dashboard/viral-sensor/2023_12_08-BJ5a-calibration/5_classify/2023_12_08_BJ5a_pAL040_72HPI_Calibration_1.zarr" test_dm = HCSDataModule( - test_datapath, - source_channel=['Sensor','Nuclei_mask'], + test_datapath, + source_channel=["Sensor", "Nuclei_mask"], ) # Load the predict dataset test_dataloader = test_dm.test_dataloader() @@ -180,7 +217,7 @@ def configure_optimizers(self): # Iterate over the test batches for batch in test_dataloader: # Extract the input from the batch - input_data = batch['source'] + input_data = batch["source"] # Forward pass through the model output = unet_model(input_data) @@ -193,4 +230,4 @@ def configure_optimizers(self): # Save the predictions as added channel in zarr format # use iohub or viscy to save the predictions!!! -zarr.save('predictions.zarr', predictions) +zarr.save("predictions.zarr", predictions) From 5ecbde022611e06314190d26fb3bfe2fef2e6a37 Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Tue, 12 Mar 2024 16:17:53 -0700 Subject: [PATCH 51/92] added sample image logging --- .../Infection_classification_model.py | 69 +++++++++++++++---- 1 file changed, 57 insertions(+), 12 deletions(-) diff --git a/examples/infection_phenotyping/Infection_classification_model.py b/examples/infection_phenotyping/Infection_classification_model.py index b2230d9b..30af733f 100644 --- a/examples/infection_phenotyping/Infection_classification_model.py +++ b/examples/infection_phenotyping/Infection_classification_model.py @@ -7,10 +7,13 @@ import lightning.pytorch as pl import torch.nn.functional as F import torchview -from typing import Literal, Union +from typing import Literal, Sequence +from skimage.exposure import rescale_intensity +from matplotlib.cm import get_cmap # import napari from pytorch_lightning.loggers import TensorBoardLogger +from torch import Tensor from monai.transforms import ( RandRotate, Resize, @@ -125,10 +128,10 @@ def training_step(self, batch: Sample, batch_idx: int): target_one_hot = target_one_hot.float() # Convert target to float type # Calculate the loss train_loss = self.loss_function(pred, target_one_hot) - # if batch_idx < self.log_batches_per_epoch: - # self.training_step_outputs.extend( - # self._detach_sample((source, target_one_hot, pred)) - # ) + if batch_idx < self.log_batches_per_epoch: + self.training_step_outputs.extend( + self._detach_sample((source, target_one_hot, pred)) + ) self.log( "loss/train", train_loss, @@ -154,13 +157,55 @@ def validation_step(self, batch: Sample, batch_idx: int): target_one_hot = target_one_hot.float() # Convert target to float type # Calculate the loss loss = self.loss_function(pred, target_one_hot) - self.log("loss/validate", loss, sync_dist=True, add_dataloader_idx=False) - # if batch_idx < self.log_batches_per_epoch: - # self.validation_step_outputs.extend( - # self._detach_sample((source, target, pred)) - # ) + if batch_idx < self.log_batches_per_epoch: + self.validation_step_outputs.extend( + self._detach_sample((source, target, pred)) + ) + self.log( + "loss/validate", + loss, + sync_dist=True, + add_dataloader_idx=False, + ) return loss + def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): + source = self._predict_pad(batch["source"]) + return self._predict_pad.inverse(self.forward(source)) + + def on_train_epoch_end(self): + self._log_samples("train_samples", self.training_step_outputs) + self.training_step_outputs = [] + + def on_validation_epoch_end(self): + self._log_samples("val_samples", self.validation_step_outputs) + self.validation_step_outputs = [] + + def _detach_sample(self, imgs: Sequence[Tensor]): + num_samples = 2 # min(imgs[0].shape[0], self.log_samples_per_batch) + return [ + [np.squeeze(img[i].detach().cpu().numpy().max(axis=1)) for img in imgs] + for i in range(num_samples) + ] + + def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): + images_grid = [] + for sample_images in imgs: + images_row = [] + for i, image in enumerate(sample_images): + cm_name = "gray" if i == 0 else "inferno" + if image.ndim == 2: + image = image[np.newaxis] + for channel in image: + channel = rescale_intensity(channel, out_range=(0, 1)) + render = get_cmap(cm_name)(channel, bytes=True)[..., :3] + images_row.append(render) + images_grid.append(np.concatenate(images_row, axis=1)) + grid = np.concatenate(images_grid, axis=0) + self.logger.experiment.add_image( + key, grid, self.current_epoch, dataformats="HWC" + ) + # %% Define the logger logger = TensorBoardLogger( @@ -171,7 +216,7 @@ def validation_step(self, batch: Sample, batch_idx: int): # Pass the logger to the Trainer trainer = pl.Trainer( logger=logger, - max_epochs=30, + max_epochs=50, default_root_dir="/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/Infection_phenotyping_data/logs", log_every_n_steps=1, ) @@ -193,7 +238,7 @@ def validation_step(self, batch: Sample, batch_idx: int): model = LightningUNet( in_channels=1, out_channels=4, - loss_function=nn.CrossEntropyLoss(), + loss_function=nn.CrossEntropyLoss(weight=torch.tensor([0.1, 0.4, 0.4, 0.1])), ) trainer.fit(model, data_module) From 58b7fa56fae4efec6975fea03a4e19d710cfbe7c Mon Sep 17 00:00:00 2001 From: Shalin Mehta Date: Tue, 12 Mar 2024 19:38:01 -0700 Subject: [PATCH 52/92] attempt to build magicgui annotation --- .../Infection_annotator.py | 86 +++++++++++++++---- 1 file changed, 71 insertions(+), 15 deletions(-) diff --git a/examples/infection_phenotyping/Infection_annotator.py b/examples/infection_phenotyping/Infection_annotator.py index e933a773..5117ce91 100644 --- a/examples/infection_phenotyping/Infection_annotator.py +++ b/examples/infection_phenotyping/Infection_annotator.py @@ -1,55 +1,111 @@ +# %% Run this to display napari on the remote server while running the script in local IDE +import os - -#%% use napari to annotate infected cells in segmented data +os.environ["DISPLAY"] = ":1" +# %% use napari to annotate infected cells in segmented data import napari from iohub.ngff import open_ome_zarr import numpy as np +from pathlib import Path + +dataset_folder = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets" +) + +input_file = dataset_folder / "Exp_2023_09_28_DENV_A2.zarr" +output_file = ( + dataset_folder / "Exp_2023_09_28_DENV_A2_infMarked_test_annotation_pipeline.zarr" +) -file_in_path = '/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/Infection_phenotyping_data/Exp_2023_09_28_DENV_A2.zarr' zarr_input = open_ome_zarr( - file_in_path, + input_file, layout="hcs", mode="r+", ) chan_names = zarr_input.channel_names # zarr_input.append_channel('Inf_mask',resize_arrays=True) -file_out_path = '/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/Infection_phenotyping_data/Exp_2023_09_28_DENV_A2_infMarked_rev2.zarr' zarr_output = open_ome_zarr( - file_out_path, + output_file, layout="hcs", - mode="w-", - channel_names=['Sensor','Nucl_mask','Inf_mask'], + mode="w", + channel_names=["Sensor", "Nucl_mask", "Inf_mask"], ) v = napari.Viewer() -#%% Load label image to napari +# %% Load label image to napari for well_id, well_data in zarr_input.wells(): well_name, well_no = well_id.split("/") - if well_name == 'A' and well_no == '2': + if well_name == "A" and well_no == "2": for pos_name, pos_data in well_data.positions(): # if int(pos_name) > 1: v.layers.clear() data = pos_data.data - FITC = data[0,0,...] - v.add_image(FITC, name='FITC', colormap='green', blending='additive') - Inf_mask = data[0,1,...].astype(int) + FITC = data[0, 0, ...] + v.add_image(FITC, name="FITC", colormap="green", blending="additive") + Inf_mask = data[0, 1, ...].astype(int) v.add_labels(Inf_mask) input("Press Enter") - label_layer = v.layers['Inf_mask'] + label_layer = v.layers["Inf_mask"] label_array = label_layer.data label_array = np.expand_dims(label_array, axis=(0, 1)) # zarr_input.create_image('Inf_mask',label_array) out_data = np.concatenate((data, label_array), axis=1) position = zarr_output.create_position(well_name, well_no, pos_name) position["0"] = out_data - + +# %% Template for magicgui based annotation workflow. +from magicgui import magicgui +from napari.types import ImageData + + +# Create an enumeration of all wells +wells = list(w[0] for w in zarr_input.wells()) +well_id, well_data = next(zarr_input.wells()) +positions = list(p[0] for p in well_data.positions()) +channel_names = zarr_input.channel_names + + +@magicgui( + call_button="load data", + wells={"choices", ["A/1", "A/2", "A/3", "A/4", "A/5"]}, + positions={"choices", ["0", "1", "2", "3", "4"]}, +) # defines the widget. +def load_well(well: str, position: str): # defines the callback. + # Load all data from specified well and position + for well_id, well_data in zarr_input.wells(): + if well_id == well: + for pos_name, pos_data in well_data.positions(): + if pos_name == position: + for i, ch in enumerate(channel_names): + data = pos_data.data + v.add_image( + data[0, i, ...], + name=ch, + colormap="gray", + blending="additive", + ) + break + break + + +@magicgui(call_button="save annotations") # defines the widget. +def save_annotations( + annotation_layer: ImageData, output_path: Path +): # defines the callback. + # Save the output to the specified path + print("save") + + +# Add both widgets to napari +v.window.add_dock_widget(load_well(wells, "0")) +v.window.add_dock_widget(save_annotations) # %% From 35ead0cc7aef5d001ad64f3ab06a004adee2df00 Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Wed, 13 Mar 2024 16:03:51 -0700 Subject: [PATCH 53/92] renamed infection annotation tool --- .../{Infection_annotator.py => Infection_annotation_refiner.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename examples/infection_phenotyping/{Infection_annotator.py => Infection_annotation_refiner.py} (100%) diff --git a/examples/infection_phenotyping/Infection_annotator.py b/examples/infection_phenotyping/Infection_annotation_refiner.py similarity index 100% rename from examples/infection_phenotyping/Infection_annotator.py rename to examples/infection_phenotyping/Infection_annotation_refiner.py From 802ebc33ed705c3942ec41475beeaa9e52e0eaa9 Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Sat, 23 Mar 2024 10:06:22 -0700 Subject: [PATCH 54/92] added normalization and augmentations --- .../Infection_classification_model.py | 101 +++++++----------- 1 file changed, 37 insertions(+), 64 deletions(-) diff --git a/examples/infection_phenotyping/Infection_classification_model.py b/examples/infection_phenotyping/Infection_classification_model.py index 30af733f..d9a045cc 100644 --- a/examples/infection_phenotyping/Infection_classification_model.py +++ b/examples/infection_phenotyping/Infection_classification_model.py @@ -6,7 +6,8 @@ import torch.nn as nn import lightning.pytorch as pl import torch.nn.functional as F -import torchview + +# import torchview from typing import Literal, Sequence from skimage.exposure import rescale_intensity from matplotlib.cm import get_cmap @@ -14,43 +15,47 @@ # import napari from pytorch_lightning.loggers import TensorBoardLogger from torch import Tensor -from monai.transforms import ( - RandRotate, - Resize, - Zoom, - Flip, - RandFlip, - RandZoom, - RandRotate90, - RandRotate, - RandAffine, - Rand2DElastic, - Rand3DElastic, - RandGaussianNoise, - RandGaussianNoised, -) from pytorch_lightning.callbacks import ModelCheckpoint -from monai.losses import DiceLoss -from viscy.light.engine import VSUNet + +# from monai.losses import DiceLoss +# from viscy.light.engine import VSUNet from viscy.unet.networks.Unet2D import Unet2d from viscy.data.hcs import Sample +from viscy.transforms import RandWeightedCropd, RandGaussianNoised +from viscy.transforms import NormalizeSampled # %% Create a dataloader and visualize the batches. # Set the path to the dataset -dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2023_09_28_DENV_A2_infMarked.zarr" +dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2023_09_27_DENV_A2_infMarked_refined.zarr" # Create an instance of HCSDataModule data_module = HCSDataModule( dataset_path, - source_channel=["Sensor"], + source_channel=["Sensor", "Phase"], target_channel=["Inf_mask"], yx_patch_size=[128, 128], split_ratio=0.8, z_window_size=1, architecture="2D", num_workers=1, - batch_size=12, - augmentations=[], + batch_size=64, + normalizations=[ + NormalizeSampled( + keys=["Sensor", "Phase"], + level="fov_statistics", + subtrahend="median", + divisor="iqr", + ) + ], + augmentations=[ + RandWeightedCropd( + num_samples=8, + spatial_size=[-1, 128, 128], + keys=["Sensor", "Phase", "Inf_mask"], + w_key="Inf_mask", + ), + RandGaussianNoised(keys=["Sensor", "Phase"], mean=0.0, std=1.0, prob=0.5), + ], ) # Prepare the data @@ -159,13 +164,14 @@ def validation_step(self, batch: Sample, batch_idx: int): loss = self.loss_function(pred, target_one_hot) if batch_idx < self.log_batches_per_epoch: self.validation_step_outputs.extend( - self._detach_sample((source, target, pred)) + self._detach_sample((source, target_one_hot, pred)) ) self.log( "loss/validate", loss, sync_dist=True, add_dataloader_idx=False, + logger=True, ) return loss @@ -209,21 +215,21 @@ def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): # %% Define the logger logger = TensorBoardLogger( - "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/Infection_phenotyping_data/logs", - name="infection_classification_model", + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/", + name="logs_wPhase", ) # Pass the logger to the Trainer trainer = pl.Trainer( logger=logger, - max_epochs=50, - default_root_dir="/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/Infection_phenotyping_data/logs", + max_epochs=100, + default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase", log_every_n_steps=1, ) # Define the checkpoint callback checkpoint_callback = ModelCheckpoint( - dirpath="/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/Infection_phenotyping_data/checkpoints", + dirpath="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase/", filename="checkpoint_{epoch:02d}", save_top_k=-1, verbose=True, @@ -236,43 +242,10 @@ def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): # Fit the model model = LightningUNet( - in_channels=1, + in_channels=2, out_channels=4, - loss_function=nn.CrossEntropyLoss(weight=torch.tensor([0.1, 0.4, 0.4, 0.1])), + loss_function=nn.CrossEntropyLoss(weight=torch.tensor([0.1, 0.3, 0.3, 0.3])), ) trainer.fit(model, data_module) - -# %% test the model on the test set -test_datapath = "/hpc/projects/intracellular_dashboard/viral-sensor/2023_12_08-BJ5a-calibration/5_classify/2023_12_08_BJ5a_pAL040_72HPI_Calibration_1.zarr" - -test_dm = HCSDataModule( - test_datapath, - source_channel=["Sensor", "Nuclei_mask"], -) -# Load the predict dataset -test_dataloader = test_dm.test_dataloader() - -# Set the model to evaluation mode -unet_model.eval() - -# Create a list to store the predictions -predictions = [] - -# Iterate over the test batches -for batch in test_dataloader: - # Extract the input from the batch - input_data = batch["source"] - - # Forward pass through the model - output = unet_model(input_data) - - # Append the predictions to the list - predictions.append(output.detach().cpu().numpy()) - -# Convert the predictions to a numpy array -predictions = np.stack(predictions) - -# Save the predictions as added channel in zarr format -# use iohub or viscy to save the predictions!!! -zarr.save("predictions.zarr", predictions) +# %% From 908039a82a53a7c5e4100d309a5c30f58ccbda7f Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Mon, 25 Mar 2024 16:09:05 -0700 Subject: [PATCH 55/92] added model testing code --- .../test_infection_classifier.py | 96 +++++++++++++++++++ 1 file changed, 96 insertions(+) create mode 100644 examples/infection_phenotyping/test_infection_classifier.py diff --git a/examples/infection_phenotyping/test_infection_classifier.py b/examples/infection_phenotyping/test_infection_classifier.py new file mode 100644 index 00000000..3fed6319 --- /dev/null +++ b/examples/infection_phenotyping/test_infection_classifier.py @@ -0,0 +1,96 @@ +# %% +import numpy as np +from viscy.data.hcs import HCSDataModule +from viscy.transforms import NormalizeSampled +from viscy.unet.networks.Unet2D import Unet2d +from viscy.data.hcs import Sample +import lightning.pytorch as pl +import torch + +from viscy.light.predict_writer import HCSPredictionWriter +from monai.transforms import DivisiblePad + +# %% test the model on the test set +test_datapath = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2024_02_13_DENV_3infMarked_test.zarr" + +data_module = HCSDataModule( + test_datapath, + source_channel=["Sensor", "Phase"], + target_channel=[], + split_ratio=0.8, + z_window_size=1, + architecture="2D", + num_workers=1, + batch_size=1, + normalizations=[ + NormalizeSampled( + keys=["Sensor", "Phase"], + level="fov_statistics", + subtrahend="median", + divisor="iqr", + ) + ], +) + +# Prepare the data +data_module.prepare_data() + +data_module.setup(stage="predict") +test_dm = data_module.test_dataloader() +sample = next(iter(test_dm)) + +# %% +class LightningUNet(pl.LightningModule): + def __init__( + self, + in_channels, + out_channels, + ckpt_path, + ): + super(LightningUNet, self).__init__() + self.unet_model = Unet2d(in_channels=in_channels, out_channels=out_channels) + if ckpt_path is not None: + state_dict = torch.load(ckpt_path, map_location=torch.device("cpu"))[ + "state_dict" + ] + state_dict.pop("loss_function.weight", None) # Remove the unexpected key + self.load_state_dict(state_dict) # loading only weights + + def forward(self, x): + return self.unet_model(x) + + def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): + source = self._predict_pad(batch["source"]) + pred_class = self.forward(source) + pred_int = torch.argmax(pred_class, dim=4, keepdim=True) + return self._predict_pad.inverse(pred_int) + + def on_predict_start(self): + """Pad the input shape to be divisible by the downsampling factor. + The inverse of this transform crops the prediction to original shape. + """ + down_factor = 2**self.unet_model.num_blocks + self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) + + +# %% create trainer and input + +output_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/pred/Exp_2024_02_13_DENV_3infMarked_pred.zarr" + +trainer = pl.Trainer( + default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase", + callbacks=[HCSPredictionWriter(output_path, write_input=True)], +) +model = LightningUNet( + in_channels=2, + out_channels=3, + ckpt_path="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase/version_34/checkpoints/epoch=99-step=300.ckpt", +) + +trainer.predict( + model=model, + datamodule=data_module, + return_predictions=True, +) + +# %% test the model on the test set and write to zarr store \ No newline at end of file From 88615d5f5f3006e9956e86cdb452600990c8f151 Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Mon, 25 Mar 2024 16:28:07 -0700 Subject: [PATCH 56/92] removed annotation refiner --- .../Infection_annotation_refiner.py | 111 ------------------ 1 file changed, 111 deletions(-) delete mode 100644 examples/infection_phenotyping/Infection_annotation_refiner.py diff --git a/examples/infection_phenotyping/Infection_annotation_refiner.py b/examples/infection_phenotyping/Infection_annotation_refiner.py deleted file mode 100644 index 5117ce91..00000000 --- a/examples/infection_phenotyping/Infection_annotation_refiner.py +++ /dev/null @@ -1,111 +0,0 @@ -# %% Run this to display napari on the remote server while running the script in local IDE -import os - -os.environ["DISPLAY"] = ":1" -# %% use napari to annotate infected cells in segmented data - -import napari -from iohub.ngff import open_ome_zarr -import numpy as np -from pathlib import Path - -dataset_folder = Path( - "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets" -) - -input_file = dataset_folder / "Exp_2023_09_28_DENV_A2.zarr" -output_file = ( - dataset_folder / "Exp_2023_09_28_DENV_A2_infMarked_test_annotation_pipeline.zarr" -) - -zarr_input = open_ome_zarr( - input_file, - layout="hcs", - mode="r+", -) -chan_names = zarr_input.channel_names -# zarr_input.append_channel('Inf_mask',resize_arrays=True) - -zarr_output = open_ome_zarr( - output_file, - layout="hcs", - mode="w", - channel_names=["Sensor", "Nucl_mask", "Inf_mask"], -) - -v = napari.Viewer() - - -# %% Load label image to napari -for well_id, well_data in zarr_input.wells(): - well_name, well_no = well_id.split("/") - - if well_name == "A" and well_no == "2": - - for pos_name, pos_data in well_data.positions(): - # if int(pos_name) > 1: - v.layers.clear() - data = pos_data.data - - FITC = data[0, 0, ...] - v.add_image(FITC, name="FITC", colormap="green", blending="additive") - Inf_mask = data[0, 1, ...].astype(int) - v.add_labels(Inf_mask) - input("Press Enter") - - label_layer = v.layers["Inf_mask"] - label_array = label_layer.data - label_array = np.expand_dims(label_array, axis=(0, 1)) - # zarr_input.create_image('Inf_mask',label_array) - out_data = np.concatenate((data, label_array), axis=1) - position = zarr_output.create_position(well_name, well_no, pos_name) - position["0"] = out_data - - -# %% Template for magicgui based annotation workflow. -from magicgui import magicgui -from napari.types import ImageData - - -# Create an enumeration of all wells -wells = list(w[0] for w in zarr_input.wells()) -well_id, well_data = next(zarr_input.wells()) -positions = list(p[0] for p in well_data.positions()) -channel_names = zarr_input.channel_names - - -@magicgui( - call_button="load data", - wells={"choices", ["A/1", "A/2", "A/3", "A/4", "A/5"]}, - positions={"choices", ["0", "1", "2", "3", "4"]}, -) # defines the widget. -def load_well(well: str, position: str): # defines the callback. - # Load all data from specified well and position - for well_id, well_data in zarr_input.wells(): - if well_id == well: - for pos_name, pos_data in well_data.positions(): - if pos_name == position: - for i, ch in enumerate(channel_names): - data = pos_data.data - v.add_image( - data[0, i, ...], - name=ch, - colormap="gray", - blending="additive", - ) - break - break - - -@magicgui(call_button="save annotations") # defines the widget. -def save_annotations( - annotation_layer: ImageData, output_path: Path -): # defines the callback. - # Save the output to the specified path - print("save") - - -# Add both widgets to napari -v.window.add_dock_widget(load_well(wells, "0")) -v.window.add_dock_widget(save_annotations) -# %% From 82428ed4b988087cc386d25673af44c56534179f Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Mon, 25 Mar 2024 17:47:17 -0700 Subject: [PATCH 57/92] corrected conversion of class to int --- .../Infection_classification_model.py | 31 ++++++++++++++----- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/examples/infection_phenotyping/Infection_classification_model.py b/examples/infection_phenotyping/Infection_classification_model.py index d9a045cc..b6ef2b1f 100644 --- a/examples/infection_phenotyping/Infection_classification_model.py +++ b/examples/infection_phenotyping/Infection_classification_model.py @@ -6,10 +6,12 @@ import torch.nn as nn import lightning.pytorch as pl import torch.nn.functional as F +import torchmetrics # import torchview from typing import Literal, Sequence from skimage.exposure import rescale_intensity +from sklearn.metrics import ConfusionMatrixDisplay from matplotlib.cm import get_cmap # import napari @@ -26,7 +28,7 @@ # %% Create a dataloader and visualize the batches. # Set the path to the dataset -dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2023_09_27_DENV_A2_infMarked_refined.zarr" +dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2024_02_13_DENV_3infMarked_trainVal.zarr" # Create an instance of HCSDataModule data_module = HCSDataModule( @@ -38,7 +40,7 @@ z_window_size=1, architecture="2D", num_workers=1, - batch_size=64, + batch_size=128, normalizations=[ NormalizeSampled( keys=["Sensor", "Phase"], @@ -53,8 +55,7 @@ spatial_size=[-1, 128, 128], keys=["Sensor", "Phase", "Inf_mask"], w_key="Inf_mask", - ), - RandGaussianNoised(keys=["Sensor", "Phase"], mean=0.0, std=1.0, prob=0.5), + ) ], ) @@ -111,6 +112,9 @@ def __init__( self.log_samples_per_batch = log_samples_per_batch self.training_step_outputs = [] self.validation_step_outputs = [] + self.val_cm = torchmetrics.classification.ConfusionMatrix( + task="multiclass", num_classes=self.n_classes + ) def forward(self, x): return self.unet_model(x) @@ -127,7 +131,7 @@ def training_step(self, batch: Sample, batch_idx: int): pred = self.forward(source) # Convert the target image to one-hot encoding - target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=4).permute( + target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute( 0, 4, 1, 2, 3 ) target_one_hot = target_one_hot.float() # Convert target to float type @@ -156,12 +160,13 @@ def validation_step(self, batch: Sample, batch_idx: int): pred = self.forward(source) # Convert the target image to one-hot encoding - target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=4).permute( + target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute( 0, 4, 1, 2, 3 ) target_one_hot = target_one_hot.float() # Convert target to float type # Calculate the loss loss = self.loss_function(pred, target_one_hot) + self.val_cm(target_one_hot, pred) if batch_idx < self.log_batches_per_epoch: self.validation_step_outputs.extend( self._detach_sample((source, target_one_hot, pred)) @@ -187,6 +192,16 @@ def on_validation_epoch_end(self): self._log_samples("val_samples", self.validation_step_outputs) self.validation_step_outputs = [] + # Log the confusion matrix at the end of the epoch + confusion_matrix = self.val_cm.compute().cpu().numpy() + + self.logger.experiment.add_figure( + "Validation Confusion Matrix", + ConfusionMatrixDisplay(confusion_matrix, self.index_to_label), + self.current_epoch, + ) + self.val_cm.reset() + def _detach_sample(self, imgs: Sequence[Tensor]): num_samples = 2 # min(imgs[0].shape[0], self.log_samples_per_batch) return [ @@ -243,8 +258,8 @@ def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): # Fit the model model = LightningUNet( in_channels=2, - out_channels=4, - loss_function=nn.CrossEntropyLoss(weight=torch.tensor([0.1, 0.3, 0.3, 0.3])), + out_channels=3, + loss_function=nn.CrossEntropyLoss(weight=torch.tensor([0.05, 0.25, 0.7])), ) trainer.fit(model, data_module) From b470ed1fd87850145e66d8b9832b492a561c8b14 Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Mon, 25 Mar 2024 17:55:36 -0700 Subject: [PATCH 58/92] corrected prediction module --- .../test_infection_classifier.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/examples/infection_phenotyping/test_infection_classifier.py b/examples/infection_phenotyping/test_infection_classifier.py index 3fed6319..d8918f82 100644 --- a/examples/infection_phenotyping/test_infection_classifier.py +++ b/examples/infection_phenotyping/test_infection_classifier.py @@ -6,7 +6,7 @@ from viscy.data.hcs import Sample import lightning.pytorch as pl import torch - +import torchmetrics from viscy.light.predict_writer import HCSPredictionWriter from monai.transforms import DivisiblePad @@ -16,7 +16,7 @@ data_module = HCSDataModule( test_datapath, source_channel=["Sensor", "Phase"], - target_channel=[], + target_channel=["inf_mask"], split_ratio=0.8, z_window_size=1, architecture="2D", @@ -36,8 +36,6 @@ data_module.prepare_data() data_module.setup(stage="predict") -test_dm = data_module.test_dataloader() -sample = next(iter(test_dm)) # %% class LightningUNet(pl.LightningModule): @@ -49,6 +47,9 @@ def __init__( ): super(LightningUNet, self).__init__() self.unet_model = Unet2d(in_channels=in_channels, out_channels=out_channels) + # self.pred_cm = torchmetrics.classification.ConfusionMatrix( + # task="multiclass", num_classes=self.n_classes + # ) if ckpt_path is not None: state_dict = torch.load(ckpt_path, map_location=torch.device("cpu"))[ "state_dict" @@ -62,8 +63,8 @@ def forward(self, x): def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): source = self._predict_pad(batch["source"]) pred_class = self.forward(source) - pred_int = torch.argmax(pred_class, dim=4, keepdim=True) - return self._predict_pad.inverse(pred_int) + pred_int = torch.argmax(pred_class, dim=1, keepdim=True) + return pred_int def on_predict_start(self): """Pad the input shape to be divisible by the downsampling factor. @@ -79,7 +80,7 @@ def on_predict_start(self): trainer = pl.Trainer( default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase", - callbacks=[HCSPredictionWriter(output_path, write_input=True)], + callbacks=[HCSPredictionWriter(output_path, write_input=False)], ) model = LightningUNet( in_channels=2, From f3746f89a271e6c907292bf7fdbdd47d2d85ea6a Mon Sep 17 00:00:00 2001 From: Shalin Mehta Date: Tue, 26 Mar 2024 06:35:49 -0700 Subject: [PATCH 59/92] cleaned up the code and comments for the LightningUNet --- .../Infection_classification_model.py | 196 ++++++++---------- 1 file changed, 85 insertions(+), 111 deletions(-) diff --git a/examples/infection_phenotyping/Infection_classification_model.py b/examples/infection_phenotyping/Infection_classification_model.py index b6ef2b1f..f03595d7 100644 --- a/examples/infection_phenotyping/Infection_classification_model.py +++ b/examples/infection_phenotyping/Infection_classification_model.py @@ -25,7 +25,7 @@ from viscy.data.hcs import Sample from viscy.transforms import RandWeightedCropd, RandGaussianNoised from viscy.transforms import NormalizeSampled - +] # %% Create a dataloader and visualize the batches. # Set the path to the dataset dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2024_02_13_DENV_3infMarked_trainVal.zarr" @@ -93,140 +93,114 @@ # Train the model # Create a TensorBoard logger class LightningUNet(pl.LightningModule): + # Initialize the class def __init__( self, - in_channels, - out_channels, - lr: float = 1e-3, - loss_function: nn.CrossEntropyLoss = None, - schedule: Literal["WarmupCosine", "Constant"] = "Constant", - log_batches_per_epoch: int = 2, - log_samples_per_batch: int = 1, + in_channels: int, # Number of input channels + out_channels: int, # Number of output channels + lr: float = 1e-3, # Learning rate + loss_function: nn.CrossEntropyLoss = None, # Loss function + schedule: Literal["WarmupCosine", "Constant"] = "Constant", # Learning rate schedule + log_batches_per_epoch: int = 2, # Number of batches to log per epoch + log_samples_per_batch: int = 2, # Number of samples to log per batch ): - super(LightningUNet, self).__init__() + super(LightningUNet, self).__init__() # Call the superclass initializer + # Initialize the UNet model self.unet_model = Unet2d(in_channels=in_channels, out_channels=out_channels) - self.lr = lr + self.lr = lr # Set the learning rate + # Set the loss function to CrossEntropyLoss if none is provided self.loss_function = loss_function if loss_function else nn.CrossEntropyLoss() - self.schedule = schedule - self.log_batches_per_epoch = log_batches_per_epoch - self.log_samples_per_batch = log_samples_per_batch - self.training_step_outputs = [] - self.validation_step_outputs = [] - self.val_cm = torchmetrics.classification.ConfusionMatrix( - task="multiclass", num_classes=self.n_classes - ) - + self.schedule = schedule # Set the learning rate schedule + self.log_batches_per_epoch = log_batches_per_epoch # Set the number of batches to log per epoch + self.log_samples_per_batch = log_samples_per_batch # Set the number of samples to log per batch + self.training_step_outputs = [] # Initialize the list of training step outputs + self.validation_step_outputs = [] # Initialize the list of validation step outputs + # Initialize the confusion matrix for validation + self.val_cm = torchmetrics.classification.ConfusionMatrix(task="multiclass", num_classes=out_channels) + + # Define the forward pass def forward(self, x): - return self.unet_model(x) + return self.unet_model(x) # Pass the input through the UNet model + # Define the optimizer def configure_optimizers(self): - optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) + optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) # Use the Adam optimizer return optimizer + # Define the training step def training_step(self, batch: Sample, batch_idx: int): - - # Extract the input and target from the batch - source = batch["source"] - target = batch["target"] - pred = self.forward(source) - - # Convert the target image to one-hot encoding - target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute( - 0, 4, 1, 2, 3 - ) - target_one_hot = target_one_hot.float() # Convert target to float type - # Calculate the loss - train_loss = self.loss_function(pred, target_one_hot) + source = batch["source"] # Extract the source from the batch + target = batch["target"] # Extract the target from the batch + pred = self.forward(source) # Make a prediction using the source + # Convert the target to one-hot encoding + target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute(0, 4, 1, 2, 3) + target_one_hot = target_one_hot.float() # Convert the target to float type + train_loss = self.loss_function(pred, target_one_hot) # Calculate the loss + # Log the training step outputs if the batch index is less than the number of batches to log per epoch if batch_idx < self.log_batches_per_epoch: - self.training_step_outputs.extend( - self._detach_sample((source, target_one_hot, pred)) - ) - self.log( - "loss/train", - train_loss, - on_step=True, - on_epoch=True, - prog_bar=True, - logger=True, - sync_dist=True, - ) - return train_loss + self.training_step_outputs.extend(self._detach_sample((source, target_one_hot, pred))) + # Log the training loss + self.log("loss/train", train_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) + return train_loss # Return the training loss def validation_step(self, batch: Sample, batch_idx: int): - - # Extract the input and target from the batch - source = batch["source"] - target = batch["target"] - pred = self.forward(source) - - # Convert the target image to one-hot encoding - target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute( - 0, 4, 1, 2, 3 - ) - target_one_hot = target_one_hot.float() # Convert target to float type - # Calculate the loss - loss = self.loss_function(pred, target_one_hot) - self.val_cm(target_one_hot, pred) + source = batch["source"] # Extract the source from the batch + target = batch["target"] # Extract the target from the batch + pred = self.forward(source) # Make a prediction using the source + # Convert the target to one-hot encoding + target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute(0, 4, 1, 2, 3) + target_one_hot = target_one_hot.float() # Convert the target to float type + loss = self.loss_function(pred, target_one_hot) # Calculate the loss + self.val_cm(target_one_hot, pred) # Update the confusion matrix + # Log the validation step outputs if the batch index is less than the number of batches to log per epoch if batch_idx < self.log_batches_per_epoch: - self.validation_step_outputs.extend( - self._detach_sample((source, target_one_hot, pred)) - ) - self.log( - "loss/validate", - loss, - sync_dist=True, - add_dataloader_idx=False, - logger=True, - ) - return loss + self.validation_step_outputs.extend(self._detach_sample((source, target_one_hot, pred))) + # Log the validation loss + self.log("loss/validate", loss, sync_dist=True, add_dataloader_idx=False, logger=True) + return loss # Return the validation loss + # Define the prediction step def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): - source = self._predict_pad(batch["source"]) - return self._predict_pad.inverse(self.forward(source)) + source = self._predict_pad(batch["source"]) # Pad the source + return self._predict_pad.inverse(self.forward(source)) # Make a prediction and remove the padding + # Define what happens at the end of a training epoch def on_train_epoch_end(self): - self._log_samples("train_samples", self.training_step_outputs) - self.training_step_outputs = [] + self._log_samples("train_samples", self.training_step_outputs) # Log the training samples + self.training_step_outputs = [] # Reset the list of training step outputs + # Define what happens at the end of a validation epoch def on_validation_epoch_end(self): - self._log_samples("val_samples", self.validation_step_outputs) - self.validation_step_outputs = [] - - # Log the confusion matrix at the end of the epoch - confusion_matrix = self.val_cm.compute().cpu().numpy() - - self.logger.experiment.add_figure( - "Validation Confusion Matrix", - ConfusionMatrixDisplay(confusion_matrix, self.index_to_label), - self.current_epoch, - ) - self.val_cm.reset() - + self._log_samples("val_samples", self.validation_step_outputs) # Log the validation samples + self.validation_step_outputs = [] # Reset the list of validation step outputs + # Compute the confusion matrix + confusion_matrix = self.val_cm.compute().cpu().numpy() + # Log the confusion matrix + self.logger.experiment.add_figure("Validation Confusion Matrix", ConfusionMatrixDisplay(confusion_matrix, self.index_to_label), self.current_epoch) + self.val_cm.reset() # Reset the confusion matrix + + # Define a method to detach a sample def _detach_sample(self, imgs: Sequence[Tensor]): - num_samples = 2 # min(imgs[0].shape[0], self.log_samples_per_batch) - return [ - [np.squeeze(img[i].detach().cpu().numpy().max(axis=1)) for img in imgs] - for i in range(num_samples) - ] + # Detach the images and convert them to numpy arrays + return [[np.squeeze(img[i].detach().cpu().numpy().max(axis=1)) for img in imgs] for i in range(self.log_samples_per_batch)] + # Define a method to log samples def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): - images_grid = [] - for sample_images in imgs: - images_row = [] - for i, image in enumerate(sample_images): - cm_name = "gray" if i == 0 else "inferno" - if image.ndim == 2: - image = image[np.newaxis] - for channel in image: - channel = rescale_intensity(channel, out_range=(0, 1)) - render = get_cmap(cm_name)(channel, bytes=True)[..., :3] - images_row.append(render) - images_grid.append(np.concatenate(images_row, axis=1)) - grid = np.concatenate(images_grid, axis=0) - self.logger.experiment.add_image( - key, grid, self.current_epoch, dataformats="HWC" - ) - + images_grid = [] # Initialize the list of image grids + for sample_images in imgs: # For each sample image + images_row = [] # Initialize the list of image rows + for i, image in enumerate(sample_images): # For each image in the sample images + cm_name = "gray" if i == 0 else "inferno" # Set the colormap name + if image.ndim == 2: # If the image is 2D + image = image[np.newaxis] # Add a new axis + for channel in image: # For each channel in the image + channel = rescale_intensity(channel, out_range=(0, 1)) # Rescale the intensity of the channel + render = get_cmap(cm_name)(channel, bytes=True)[..., :3] # Render the channel + images_row.append(render) # Append the render to the list of image rows + images_grid.append(np.concatenate(images_row, axis=1)) # Append the concatenated image rows to the list of image grids + grid = np.concatenate(images_grid, axis=0) # Concatenate the image grids + # Log the image grid + self.logger.experiment.add_image(key, grid, self.current_epoch, dataformats="HWC") # %% Define the logger logger = TensorBoardLogger( From 20655d6411a569be1f2f5e4bab3063d1895c5af9 Mon Sep 17 00:00:00 2001 From: Shalin Mehta Date: Tue, 26 Mar 2024 07:20:22 -0700 Subject: [PATCH 60/92] removed confusion matrix code, finding runtime error with model --- .../Infection_classification_model.py | 35 +++++++++---------- 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/examples/infection_phenotyping/Infection_classification_model.py b/examples/infection_phenotyping/Infection_classification_model.py index f03595d7..5b541256 100644 --- a/examples/infection_phenotyping/Infection_classification_model.py +++ b/examples/infection_phenotyping/Infection_classification_model.py @@ -25,8 +25,9 @@ from viscy.data.hcs import Sample from viscy.transforms import RandWeightedCropd, RandGaussianNoised from viscy.transforms import NormalizeSampled -] + # %% Create a dataloader and visualize the batches. + # Set the path to the dataset dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2024_02_13_DENV_3infMarked_trainVal.zarr" @@ -87,24 +88,24 @@ # # Start the napari event loop # napari.run() -# %% use 2D Unet and Lightning module +# %% +# Define a 2D UNet model for semantic segmentation as a lightning module. -# Train the model -# Create a TensorBoard logger -class LightningUNet(pl.LightningModule): - # Initialize the class + +class SemanticSegUNet2D(pl.LightningModule): + # Model for semantic segmentation. def __init__( self, in_channels: int, # Number of input channels out_channels: int, # Number of output channels lr: float = 1e-3, # Learning rate - loss_function: nn.CrossEntropyLoss = None, # Loss function + loss_function: nn.Module = nn.CrossEntropyLoss(), # Loss function schedule: Literal["WarmupCosine", "Constant"] = "Constant", # Learning rate schedule log_batches_per_epoch: int = 2, # Number of batches to log per epoch log_samples_per_batch: int = 2, # Number of samples to log per batch ): - super(LightningUNet, self).__init__() # Call the superclass initializer + super(SemanticSegUNet2D, self).__init__() # Call the superclass initializer # Initialize the UNet model self.unet_model = Unet2d(in_channels=in_channels, out_channels=out_channels) self.lr = lr # Set the learning rate @@ -115,8 +116,7 @@ def __init__( self.log_samples_per_batch = log_samples_per_batch # Set the number of samples to log per batch self.training_step_outputs = [] # Initialize the list of training step outputs self.validation_step_outputs = [] # Initialize the list of validation step outputs - # Initialize the confusion matrix for validation - self.val_cm = torchmetrics.classification.ConfusionMatrix(task="multiclass", num_classes=out_channels) + # Define the forward pass def forward(self, x): @@ -151,7 +151,6 @@ def validation_step(self, batch: Sample, batch_idx: int): target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute(0, 4, 1, 2, 3) target_one_hot = target_one_hot.float() # Convert the target to float type loss = self.loss_function(pred, target_one_hot) # Calculate the loss - self.val_cm(target_one_hot, pred) # Update the confusion matrix # Log the validation step outputs if the batch index is less than the number of batches to log per epoch if batch_idx < self.log_batches_per_epoch: self.validation_step_outputs.extend(self._detach_sample((source, target_one_hot, pred))) @@ -173,11 +172,7 @@ def on_train_epoch_end(self): def on_validation_epoch_end(self): self._log_samples("val_samples", self.validation_step_outputs) # Log the validation samples self.validation_step_outputs = [] # Reset the list of validation step outputs - # Compute the confusion matrix - confusion_matrix = self.val_cm.compute().cpu().numpy() - # Log the confusion matrix - self.logger.experiment.add_figure("Validation Confusion Matrix", ConfusionMatrixDisplay(confusion_matrix, self.index_to_label), self.current_epoch) - self.val_cm.reset() # Reset the confusion matrix + # TODO: Log the confusion matrix # Define a method to detach a sample def _detach_sample(self, imgs: Sequence[Tensor]): @@ -230,11 +225,15 @@ def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): trainer.callbacks.append(checkpoint_callback) # Fit the model -model = LightningUNet( +model = SemanticSegUNet2D( in_channels=2, out_channels=3, loss_function=nn.CrossEntropyLoss(weight=torch.tensor([0.05, 0.25, 0.7])), ) -trainer.fit(model, data_module) + +print(model) +# %% +# Run training. +# trainer.fit(model, data_module) # %% From d022dae9b61f9683964960fb2e72f4f5dff7f707 Mon Sep 17 00:00:00 2001 From: Shalin Mehta Date: Tue, 26 Mar 2024 08:05:39 -0700 Subject: [PATCH 61/92] moved scripts to viscy.scripts.infection_phenotyping module to enable imports across scripts --- .../infection_phenotyping/Infection_classification_model.py | 0 {examples => viscy/scripts}/infection_phenotyping/readme.md | 0 .../scripts}/infection_phenotyping/test_infection_classifier.py | 0 3 files changed, 0 insertions(+), 0 deletions(-) rename {examples => viscy/scripts}/infection_phenotyping/Infection_classification_model.py (100%) rename {examples => viscy/scripts}/infection_phenotyping/readme.md (100%) rename {examples => viscy/scripts}/infection_phenotyping/test_infection_classifier.py (100%) diff --git a/examples/infection_phenotyping/Infection_classification_model.py b/viscy/scripts/infection_phenotyping/Infection_classification_model.py similarity index 100% rename from examples/infection_phenotyping/Infection_classification_model.py rename to viscy/scripts/infection_phenotyping/Infection_classification_model.py diff --git a/examples/infection_phenotyping/readme.md b/viscy/scripts/infection_phenotyping/readme.md similarity index 100% rename from examples/infection_phenotyping/readme.md rename to viscy/scripts/infection_phenotyping/readme.md diff --git a/examples/infection_phenotyping/test_infection_classifier.py b/viscy/scripts/infection_phenotyping/test_infection_classifier.py similarity index 100% rename from examples/infection_phenotyping/test_infection_classifier.py rename to viscy/scripts/infection_phenotyping/test_infection_classifier.py From 901fd70c846542e155a33eaa24102287954d43be Mon Sep 17 00:00:00 2001 From: Shalin Mehta Date: Tue, 26 Mar 2024 10:03:26 -0700 Subject: [PATCH 62/92] combine the lightning modules for training and prediction, fix the DDP exception --- .../Infection_classification_model.py | 121 ++++++++++++++---- .../test_infection_classifier.py | 49 ++----- 2 files changed, 104 insertions(+), 66 deletions(-) diff --git a/viscy/scripts/infection_phenotyping/Infection_classification_model.py b/viscy/scripts/infection_phenotyping/Infection_classification_model.py index 5b541256..4e442e0f 100644 --- a/viscy/scripts/infection_phenotyping/Infection_classification_model.py +++ b/viscy/scripts/infection_phenotyping/Infection_classification_model.py @@ -20,6 +20,8 @@ from pytorch_lightning.callbacks import ModelCheckpoint # from monai.losses import DiceLoss +from monai.transforms import DivisiblePad + # from viscy.light.engine import VSUNet from viscy.unet.networks.Unet2D import Unet2d from viscy.data.hcs import Sample @@ -88,7 +90,7 @@ # # Start the napari event loop # napari.run() -# %% +# %% # Define a 2D UNet model for semantic segmentation as a lightning module. @@ -101,9 +103,12 @@ def __init__( out_channels: int, # Number of output channels lr: float = 1e-3, # Learning rate loss_function: nn.Module = nn.CrossEntropyLoss(), # Loss function - schedule: Literal["WarmupCosine", "Constant"] = "Constant", # Learning rate schedule + schedule: Literal[ + "WarmupCosine", "Constant" + ] = "Constant", # Learning rate schedule log_batches_per_epoch: int = 2, # Number of batches to log per epoch log_samples_per_batch: int = 2, # Number of samples to log per batch + checkpoint_path: str = None, # Path to the checkpoint ): super(SemanticSegUNet2D, self).__init__() # Call the superclass initializer # Initialize the UNet model @@ -112,11 +117,23 @@ def __init__( # Set the loss function to CrossEntropyLoss if none is provided self.loss_function = loss_function if loss_function else nn.CrossEntropyLoss() self.schedule = schedule # Set the learning rate schedule - self.log_batches_per_epoch = log_batches_per_epoch # Set the number of batches to log per epoch - self.log_samples_per_batch = log_samples_per_batch # Set the number of samples to log per batch + self.log_batches_per_epoch = ( + log_batches_per_epoch # Set the number of batches to log per epoch + ) + self.log_samples_per_batch = ( + log_samples_per_batch # Set the number of samples to log per batch + ) self.training_step_outputs = [] # Initialize the list of training step outputs - self.validation_step_outputs = [] # Initialize the list of validation step outputs - + self.validation_step_outputs = ( + [] + ) # Initialize the list of validation step outputs + + if checkpoint_path is not None: + state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu"))[ + "state_dict" + ] + state_dict.pop("loss_function.weight", None) # Remove the unexpected key + self.load_state_dict(state_dict) # loading only weights # Define the forward pass def forward(self, x): @@ -124,7 +141,9 @@ def forward(self, x): # Define the optimizer def configure_optimizers(self): - optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) # Use the Adam optimizer + optimizer = torch.optim.Adam( + self.parameters(), lr=self.lr + ) # Use the Adam optimizer return optimizer # Define the training step @@ -133,14 +152,26 @@ def training_step(self, batch: Sample, batch_idx: int): target = batch["target"] # Extract the target from the batch pred = self.forward(source) # Make a prediction using the source # Convert the target to one-hot encoding - target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute(0, 4, 1, 2, 3) + target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute( + 0, 4, 1, 2, 3 + ) target_one_hot = target_one_hot.float() # Convert the target to float type train_loss = self.loss_function(pred, target_one_hot) # Calculate the loss # Log the training step outputs if the batch index is less than the number of batches to log per epoch if batch_idx < self.log_batches_per_epoch: - self.training_step_outputs.extend(self._detach_sample((source, target_one_hot, pred))) + self.training_step_outputs.extend( + self._detach_sample((source, target_one_hot, pred)) + ) # Log the training loss - self.log("loss/train", train_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) + self.log( + "loss/train", + train_loss, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, + ) return train_loss # Return the training loss def validation_step(self, batch: Sample, batch_idx: int): @@ -148,54 +179,92 @@ def validation_step(self, batch: Sample, batch_idx: int): target = batch["target"] # Extract the target from the batch pred = self.forward(source) # Make a prediction using the source # Convert the target to one-hot encoding - target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute(0, 4, 1, 2, 3) + target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute( + 0, 4, 1, 2, 3 + ) target_one_hot = target_one_hot.float() # Convert the target to float type loss = self.loss_function(pred, target_one_hot) # Calculate the loss # Log the validation step outputs if the batch index is less than the number of batches to log per epoch if batch_idx < self.log_batches_per_epoch: - self.validation_step_outputs.extend(self._detach_sample((source, target_one_hot, pred))) + self.validation_step_outputs.extend( + self._detach_sample((source, target_one_hot, pred)) + ) # Log the validation loss - self.log("loss/validate", loss, sync_dist=True, add_dataloader_idx=False, logger=True) + self.log( + "loss/validate", loss, sync_dist=True, add_dataloader_idx=False, logger=True + ) return loss # Return the validation loss + def on_predict_start(self): + """Pad the input shape to be divisible by the downsampling factor. + The inverse of this transform crops the prediction to original shape. + """ + down_factor = 2**self.unet_model.num_blocks + self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) + # Define the prediction step def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): source = self._predict_pad(batch["source"]) # Pad the source - return self._predict_pad.inverse(self.forward(source)) # Make a prediction and remove the padding + logits = self._predict_pad.inverse( + self.forward(source) + ) # Predict and remove padding. + prob_map = F.softmax(logits, dim=1) # Calculate the probabilities + return prob_map # return the probabilities for computing metrics. # Define what happens at the end of a training epoch def on_train_epoch_end(self): - self._log_samples("train_samples", self.training_step_outputs) # Log the training samples + self._log_samples( + "train_samples", self.training_step_outputs + ) # Log the training samples self.training_step_outputs = [] # Reset the list of training step outputs # Define what happens at the end of a validation epoch def on_validation_epoch_end(self): - self._log_samples("val_samples", self.validation_step_outputs) # Log the validation samples + self._log_samples( + "val_samples", self.validation_step_outputs + ) # Log the validation samples self.validation_step_outputs = [] # Reset the list of validation step outputs # TODO: Log the confusion matrix # Define a method to detach a sample def _detach_sample(self, imgs: Sequence[Tensor]): # Detach the images and convert them to numpy arrays - return [[np.squeeze(img[i].detach().cpu().numpy().max(axis=1)) for img in imgs] for i in range(self.log_samples_per_batch)] + num_samples = 3 + return [ + [np.squeeze(img[i].detach().cpu().numpy().max(axis=1)) for img in imgs] + for i in range(num_samples) + ] # Define a method to log samples def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): images_grid = [] # Initialize the list of image grids for sample_images in imgs: # For each sample image images_row = [] # Initialize the list of image rows - for i, image in enumerate(sample_images): # For each image in the sample images + for i, image in enumerate( + sample_images + ): # For each image in the sample images cm_name = "gray" if i == 0 else "inferno" # Set the colormap name if image.ndim == 2: # If the image is 2D image = image[np.newaxis] # Add a new axis for channel in image: # For each channel in the image - channel = rescale_intensity(channel, out_range=(0, 1)) # Rescale the intensity of the channel - render = get_cmap(cm_name)(channel, bytes=True)[..., :3] # Render the channel - images_row.append(render) # Append the render to the list of image rows - images_grid.append(np.concatenate(images_row, axis=1)) # Append the concatenated image rows to the list of image grids + channel = rescale_intensity( + channel, out_range=(0, 1) + ) # Rescale the intensity of the channel + render = get_cmap(cm_name)(channel, bytes=True)[ + ..., :3 + ] # Render the channel + images_row.append( + render + ) # Append the render to the list of image rows + images_grid.append( + np.concatenate(images_row, axis=1) + ) # Append the concatenated image rows to the list of image grids grid = np.concatenate(images_grid, axis=0) # Concatenate the image grids # Log the image grid - self.logger.experiment.add_image(key, grid, self.current_epoch, dataformats="HWC") + self.logger.experiment.add_image( + key, grid, self.current_epoch, dataformats="HWC" + ) + # %% Define the logger logger = TensorBoardLogger( @@ -209,6 +278,7 @@ def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): max_epochs=100, default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase", log_every_n_steps=1, + devices=1, # Set the number of GPUs to use. This avoids run-time exception from distributed training when the node has multiple GPUs ) # Define the checkpoint callback @@ -232,8 +302,9 @@ def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): ) print(model) -# %% +# %% # Run training. -# trainer.fit(model, data_module) + +trainer.fit(model, data_module) # %% diff --git a/viscy/scripts/infection_phenotyping/test_infection_classifier.py b/viscy/scripts/infection_phenotyping/test_infection_classifier.py index d8918f82..0780b33e 100644 --- a/viscy/scripts/infection_phenotyping/test_infection_classifier.py +++ b/viscy/scripts/infection_phenotyping/test_infection_classifier.py @@ -8,7 +8,9 @@ import torch import torchmetrics from viscy.light.predict_writer import HCSPredictionWriter -from monai.transforms import DivisiblePad +from viscy.scripts.infection_phenotyping.Infection_classification_model import ( + SemanticSegUNet2D, +) # %% test the model on the test set test_datapath = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2024_02_13_DENV_3infMarked_test.zarr" @@ -37,52 +39,17 @@ data_module.setup(stage="predict") -# %% -class LightningUNet(pl.LightningModule): - def __init__( - self, - in_channels, - out_channels, - ckpt_path, - ): - super(LightningUNet, self).__init__() - self.unet_model = Unet2d(in_channels=in_channels, out_channels=out_channels) - # self.pred_cm = torchmetrics.classification.ConfusionMatrix( - # task="multiclass", num_classes=self.n_classes - # ) - if ckpt_path is not None: - state_dict = torch.load(ckpt_path, map_location=torch.device("cpu"))[ - "state_dict" - ] - state_dict.pop("loss_function.weight", None) # Remove the unexpected key - self.load_state_dict(state_dict) # loading only weights - - def forward(self, x): - return self.unet_model(x) - - def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): - source = self._predict_pad(batch["source"]) - pred_class = self.forward(source) - pred_int = torch.argmax(pred_class, dim=1, keepdim=True) - return pred_int - - def on_predict_start(self): - """Pad the input shape to be divisible by the downsampling factor. - The inverse of this transform crops the prediction to original shape. - """ - down_factor = 2**self.unet_model.num_blocks - self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) - - # %% create trainer and input -output_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/pred/Exp_2024_02_13_DENV_3infMarked_pred.zarr" +output_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/pred/Exp_2024_02_13_DENV_3infMarked_pred_SM.zarr" trainer = pl.Trainer( default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase", callbacks=[HCSPredictionWriter(output_path, write_input=False)], + devices=1, # Set the number of GPUs to use. This avoids run-time exception from distributed training when the node has multiple GPUs ) -model = LightningUNet( + +model = SemanticSegUNet2D( in_channels=2, out_channels=3, ckpt_path="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase/version_34/checkpoints/epoch=99-step=300.ckpt", @@ -94,4 +61,4 @@ def on_predict_start(self): return_predictions=True, ) -# %% test the model on the test set and write to zarr store \ No newline at end of file +# %% test the model on the test set and write to zarr store From 708a67ab3990dc31607b412cded0598dc47a4683 Mon Sep 17 00:00:00 2001 From: Shalin Mehta Date: Tue, 26 Mar 2024 12:34:43 -0700 Subject: [PATCH 63/92] all the stubs for computing and logging confusion matrix per cell --- .../Infection_classification_model.py | 128 +++++++++++++++++- .../test_infection_classifier.py | 2 - 2 files changed, 125 insertions(+), 5 deletions(-) diff --git a/viscy/scripts/infection_phenotyping/Infection_classification_model.py b/viscy/scripts/infection_phenotyping/Infection_classification_model.py index 4e442e0f..ee60fd05 100644 --- a/viscy/scripts/infection_phenotyping/Infection_classification_model.py +++ b/viscy/scripts/infection_phenotyping/Infection_classification_model.py @@ -21,6 +21,7 @@ # from monai.losses import DiceLoss from monai.transforms import DivisiblePad +from skimage.measure import regionprops # from viscy.light.engine import VSUNet from viscy.unet.networks.Unet2D import Unet2d @@ -128,6 +129,9 @@ def __init__( [] ) # Initialize the list of validation step outputs + self.pred_cm = None # Initialize the confusion matrix + self.index_to_label_dict = ["Background", "Infected", "Uninfected"] + if checkpoint_path is not None: state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu"))[ "state_dict" @@ -205,11 +209,29 @@ def on_predict_start(self): # Define the prediction step def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): source = self._predict_pad(batch["source"]) # Pad the source + target = batch["target"] # Extract the target from the batch logits = self._predict_pad.inverse( self.forward(source) ) # Predict and remove padding. - prob_map = F.softmax(logits, dim=1) # Calculate the probabilities - return prob_map # return the probabilities for computing metrics. + prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities + # Go from probabilities/one-hot encoded data to class labels. + labels_pred = torch.argmax(prob_pred, dim=1) # Calculate the predicted labels + labels_target = torch.argmax(target, dim=1) # Calculate the target labels + # FIXME: Check if compliant with lightning API + self.pred_cm = confusion_matrix_per_cell( + labels_target, labels_pred, num_classes=3 + ) + + return prob_pred # log the probabilities instead of logits. + + # Accumulate the confusion matrix at the end of prediction epoch and log. + def on_predict_epoch_end(self): + confusion_matrix = self.pred_cm.compute().cpu().numpy() + self.logger.experiment.add_figure( + "Confusion Matrix per Cell", + plot_confusion_matrix(confusion_matrix, self.index_to_label_dict), + self.current_epoch, + ) # Define what happens at the end of a training epoch def on_train_epoch_end(self): @@ -307,4 +329,104 @@ def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): trainer.fit(model, data_module) -# %% +# %% Methods to compute confusion matrix per cell using torchmetrics + + +# The confusion matrix at the single-cell resolution. +def confusion_matrix_per_cell( + y_true: torch.Tensor, y_pred: torch.Tensor, num_classes: int +): + """Compute confusion matrix per cell. + + Args: + y_true (torch.Tensor): Ground truth label image (BXHXW). + y_pred (torch.Tensor): Predicted label image (BXHXW). + num_classes (int): Number of classes. + + Returns: + torch.Tensor: Confusion matrix per cell (BXCXC). + """ + # Convert the image class to the nuclei class + nuclei_true, nuclei_pred = image_class_to_nuclei_class(y_true, y_pred, num_classes) + # Compute the confusion matrix per cell + confusion_matrix_per_cell = torchmetrics.functional.confusion_matrix( + nuclei_true(nuclei_true > 0), # indexing just non-background pixels. + nuclei_pred(nuclei_true > 0), + num_classes=num_classes, + task="multi_class", + ) + return confusion_matrix_per_cell + + +# These images can be logged with prediction. +def image_class_to_nuclei_class( + y_true: torch.Tonser, y_pred: torch.Tensor, num_classes: int +): + """Convert the class of the image to the class of the nuclei. + + Args: + label_image (torch.Tensor): Label image (BXHXW). Values of tensor are integers that represent semantic segmentation. + num_classes (int): Number of classes. + + Returns: + torch.Tensor: Label images with a consensus class at the centroid of nuclei. + """ + nuclei_true = torch.zeros_like(y_true) + nuclie_pred = torch.zeros_like(y_pred) + batch_size = y_true.size(0) + # find centroids of nuclei from y_true + for i in range(batch_size): + regions = regionprops(y_true[i].cpu().numpy()) + # Find centroids, pixel coordinates from the ground truth. + for region in regions: + centroid = region.centroid + pixel_ids = region.coords + # Find the class of the nuclei in the ground truth and prediction. + pix_labels_true = y_true[i, pixel_ids[:, 0], pixel_ids[:, 1]] + consensus_class_true = np.mode(pix_labels_true[:]) + + pix_labels_pred = y_pred[i, pixel_ids[:, 0], pixel_ids[:, 1]] + consensus_class_pred = np.mode(pix_labels_pred[:]) + nuclei_true[i, centroid[0], centroid[1]] = consensus_class_true + nuclei_pred[i, centroid[0], centroid[1]] = consensus_class_pred + + # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. + # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. + + return nuclei_true, nuclei_pred + + +def plot_confusion_matrix(confusion_matrix, index_to_label_dict): + # Create a figure and axis to plot the confusion matrix + fig, ax = plt.subplots() + + # Create a color heatmap for the confusion matrix + cax = ax.matshow(confusion_matrix, cmap="viridis") + + # Create a colorbar and set the label + fig.colorbar(cax, label="Frequency") + + # Set labels for the classes + + ax.set_xticks(np.arange(len(index_to_label_dict))) + ax.set_yticks(np.arange(len(index_to_label_dict))) + ax.set_xticklabels(index_to_label_dict.values(), rotation=45) + ax.set_yticklabels(index_to_label_dict.values()) + + # Set labels for the axes + ax.set_xlabel("Predicted") + ax.set_ylabel("True") + + # Add text annotations to the confusion matrix + for i in range(len(index_to_label_dict)): + for j in range(len(index_to_label_dict)): + ax.text( + j, + i, + str(int(confusion_matrix[i, j])), + ha="center", + va="center", + color="white", + ) + + return fig diff --git a/viscy/scripts/infection_phenotyping/test_infection_classifier.py b/viscy/scripts/infection_phenotyping/test_infection_classifier.py index 0780b33e..14bff35c 100644 --- a/viscy/scripts/infection_phenotyping/test_infection_classifier.py +++ b/viscy/scripts/infection_phenotyping/test_infection_classifier.py @@ -60,5 +60,3 @@ datamodule=data_module, return_predictions=True, ) - -# %% test the model on the test set and write to zarr store From 6bb9ca38fe981ba4fe4036ec4fdac7bf80ebf282 Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Mon, 1 Apr 2024 15:02:05 -0700 Subject: [PATCH 64/92] separated training and test scripts --- .../Infection_classification_model.py | 324 +----------------- .../test_infection_classifier.py | 52 ++- 2 files changed, 29 insertions(+), 347 deletions(-) diff --git a/viscy/scripts/infection_phenotyping/Infection_classification_model.py b/viscy/scripts/infection_phenotyping/Infection_classification_model.py index ee60fd05..d8056a04 100644 --- a/viscy/scripts/infection_phenotyping/Infection_classification_model.py +++ b/viscy/scripts/infection_phenotyping/Infection_classification_model.py @@ -1,33 +1,15 @@ # %% import torch -from viscy.data.hcs import HCSDataModule - -import numpy as np -import torch.nn as nn import lightning.pytorch as pl -import torch.nn.functional as F -import torchmetrics - -# import torchview -from typing import Literal, Sequence -from skimage.exposure import rescale_intensity -from sklearn.metrics import ConfusionMatrixDisplay -from matplotlib.cm import get_cmap +import torch.nn as nn -# import napari from pytorch_lightning.loggers import TensorBoardLogger -from torch import Tensor from pytorch_lightning.callbacks import ModelCheckpoint -# from monai.losses import DiceLoss -from monai.transforms import DivisiblePad -from skimage.measure import regionprops - -# from viscy.light.engine import VSUNet -from viscy.unet.networks.Unet2D import Unet2d -from viscy.data.hcs import Sample -from viscy.transforms import RandWeightedCropd, RandGaussianNoised +from viscy.transforms import RandWeightedCropd from viscy.transforms import NormalizeSampled +from viscy.data.hcs import HCSDataModule +from viscy.scripts.infection_phenotyping.classify_infection import SemanticSegUNet2D # %% Create a dataloader and visualize the batches. @@ -91,202 +73,6 @@ # # Start the napari event loop # napari.run() -# %% - -# Define a 2D UNet model for semantic segmentation as a lightning module. - - -class SemanticSegUNet2D(pl.LightningModule): - # Model for semantic segmentation. - def __init__( - self, - in_channels: int, # Number of input channels - out_channels: int, # Number of output channels - lr: float = 1e-3, # Learning rate - loss_function: nn.Module = nn.CrossEntropyLoss(), # Loss function - schedule: Literal[ - "WarmupCosine", "Constant" - ] = "Constant", # Learning rate schedule - log_batches_per_epoch: int = 2, # Number of batches to log per epoch - log_samples_per_batch: int = 2, # Number of samples to log per batch - checkpoint_path: str = None, # Path to the checkpoint - ): - super(SemanticSegUNet2D, self).__init__() # Call the superclass initializer - # Initialize the UNet model - self.unet_model = Unet2d(in_channels=in_channels, out_channels=out_channels) - self.lr = lr # Set the learning rate - # Set the loss function to CrossEntropyLoss if none is provided - self.loss_function = loss_function if loss_function else nn.CrossEntropyLoss() - self.schedule = schedule # Set the learning rate schedule - self.log_batches_per_epoch = ( - log_batches_per_epoch # Set the number of batches to log per epoch - ) - self.log_samples_per_batch = ( - log_samples_per_batch # Set the number of samples to log per batch - ) - self.training_step_outputs = [] # Initialize the list of training step outputs - self.validation_step_outputs = ( - [] - ) # Initialize the list of validation step outputs - - self.pred_cm = None # Initialize the confusion matrix - self.index_to_label_dict = ["Background", "Infected", "Uninfected"] - - if checkpoint_path is not None: - state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu"))[ - "state_dict" - ] - state_dict.pop("loss_function.weight", None) # Remove the unexpected key - self.load_state_dict(state_dict) # loading only weights - - # Define the forward pass - def forward(self, x): - return self.unet_model(x) # Pass the input through the UNet model - - # Define the optimizer - def configure_optimizers(self): - optimizer = torch.optim.Adam( - self.parameters(), lr=self.lr - ) # Use the Adam optimizer - return optimizer - - # Define the training step - def training_step(self, batch: Sample, batch_idx: int): - source = batch["source"] # Extract the source from the batch - target = batch["target"] # Extract the target from the batch - pred = self.forward(source) # Make a prediction using the source - # Convert the target to one-hot encoding - target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute( - 0, 4, 1, 2, 3 - ) - target_one_hot = target_one_hot.float() # Convert the target to float type - train_loss = self.loss_function(pred, target_one_hot) # Calculate the loss - # Log the training step outputs if the batch index is less than the number of batches to log per epoch - if batch_idx < self.log_batches_per_epoch: - self.training_step_outputs.extend( - self._detach_sample((source, target_one_hot, pred)) - ) - # Log the training loss - self.log( - "loss/train", - train_loss, - on_step=True, - on_epoch=True, - prog_bar=True, - logger=True, - sync_dist=True, - ) - return train_loss # Return the training loss - - def validation_step(self, batch: Sample, batch_idx: int): - source = batch["source"] # Extract the source from the batch - target = batch["target"] # Extract the target from the batch - pred = self.forward(source) # Make a prediction using the source - # Convert the target to one-hot encoding - target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute( - 0, 4, 1, 2, 3 - ) - target_one_hot = target_one_hot.float() # Convert the target to float type - loss = self.loss_function(pred, target_one_hot) # Calculate the loss - # Log the validation step outputs if the batch index is less than the number of batches to log per epoch - if batch_idx < self.log_batches_per_epoch: - self.validation_step_outputs.extend( - self._detach_sample((source, target_one_hot, pred)) - ) - # Log the validation loss - self.log( - "loss/validate", loss, sync_dist=True, add_dataloader_idx=False, logger=True - ) - return loss # Return the validation loss - - def on_predict_start(self): - """Pad the input shape to be divisible by the downsampling factor. - The inverse of this transform crops the prediction to original shape. - """ - down_factor = 2**self.unet_model.num_blocks - self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) - - # Define the prediction step - def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): - source = self._predict_pad(batch["source"]) # Pad the source - target = batch["target"] # Extract the target from the batch - logits = self._predict_pad.inverse( - self.forward(source) - ) # Predict and remove padding. - prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities - # Go from probabilities/one-hot encoded data to class labels. - labels_pred = torch.argmax(prob_pred, dim=1) # Calculate the predicted labels - labels_target = torch.argmax(target, dim=1) # Calculate the target labels - # FIXME: Check if compliant with lightning API - self.pred_cm = confusion_matrix_per_cell( - labels_target, labels_pred, num_classes=3 - ) - - return prob_pred # log the probabilities instead of logits. - - # Accumulate the confusion matrix at the end of prediction epoch and log. - def on_predict_epoch_end(self): - confusion_matrix = self.pred_cm.compute().cpu().numpy() - self.logger.experiment.add_figure( - "Confusion Matrix per Cell", - plot_confusion_matrix(confusion_matrix, self.index_to_label_dict), - self.current_epoch, - ) - - # Define what happens at the end of a training epoch - def on_train_epoch_end(self): - self._log_samples( - "train_samples", self.training_step_outputs - ) # Log the training samples - self.training_step_outputs = [] # Reset the list of training step outputs - - # Define what happens at the end of a validation epoch - def on_validation_epoch_end(self): - self._log_samples( - "val_samples", self.validation_step_outputs - ) # Log the validation samples - self.validation_step_outputs = [] # Reset the list of validation step outputs - # TODO: Log the confusion matrix - - # Define a method to detach a sample - def _detach_sample(self, imgs: Sequence[Tensor]): - # Detach the images and convert them to numpy arrays - num_samples = 3 - return [ - [np.squeeze(img[i].detach().cpu().numpy().max(axis=1)) for img in imgs] - for i in range(num_samples) - ] - - # Define a method to log samples - def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): - images_grid = [] # Initialize the list of image grids - for sample_images in imgs: # For each sample image - images_row = [] # Initialize the list of image rows - for i, image in enumerate( - sample_images - ): # For each image in the sample images - cm_name = "gray" if i == 0 else "inferno" # Set the colormap name - if image.ndim == 2: # If the image is 2D - image = image[np.newaxis] # Add a new axis - for channel in image: # For each channel in the image - channel = rescale_intensity( - channel, out_range=(0, 1) - ) # Rescale the intensity of the channel - render = get_cmap(cm_name)(channel, bytes=True)[ - ..., :3 - ] # Render the channel - images_row.append( - render - ) # Append the render to the list of image rows - images_grid.append( - np.concatenate(images_row, axis=1) - ) # Append the concatenated image rows to the list of image grids - grid = np.concatenate(images_grid, axis=0) # Concatenate the image grids - # Log the image grid - self.logger.experiment.add_image( - key, grid, self.current_epoch, dataformats="HWC" - ) - # %% Define the logger logger = TensorBoardLogger( @@ -328,105 +114,3 @@ def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): # Run training. trainer.fit(model, data_module) - -# %% Methods to compute confusion matrix per cell using torchmetrics - - -# The confusion matrix at the single-cell resolution. -def confusion_matrix_per_cell( - y_true: torch.Tensor, y_pred: torch.Tensor, num_classes: int -): - """Compute confusion matrix per cell. - - Args: - y_true (torch.Tensor): Ground truth label image (BXHXW). - y_pred (torch.Tensor): Predicted label image (BXHXW). - num_classes (int): Number of classes. - - Returns: - torch.Tensor: Confusion matrix per cell (BXCXC). - """ - # Convert the image class to the nuclei class - nuclei_true, nuclei_pred = image_class_to_nuclei_class(y_true, y_pred, num_classes) - # Compute the confusion matrix per cell - confusion_matrix_per_cell = torchmetrics.functional.confusion_matrix( - nuclei_true(nuclei_true > 0), # indexing just non-background pixels. - nuclei_pred(nuclei_true > 0), - num_classes=num_classes, - task="multi_class", - ) - return confusion_matrix_per_cell - - -# These images can be logged with prediction. -def image_class_to_nuclei_class( - y_true: torch.Tonser, y_pred: torch.Tensor, num_classes: int -): - """Convert the class of the image to the class of the nuclei. - - Args: - label_image (torch.Tensor): Label image (BXHXW). Values of tensor are integers that represent semantic segmentation. - num_classes (int): Number of classes. - - Returns: - torch.Tensor: Label images with a consensus class at the centroid of nuclei. - """ - nuclei_true = torch.zeros_like(y_true) - nuclie_pred = torch.zeros_like(y_pred) - batch_size = y_true.size(0) - # find centroids of nuclei from y_true - for i in range(batch_size): - regions = regionprops(y_true[i].cpu().numpy()) - # Find centroids, pixel coordinates from the ground truth. - for region in regions: - centroid = region.centroid - pixel_ids = region.coords - # Find the class of the nuclei in the ground truth and prediction. - pix_labels_true = y_true[i, pixel_ids[:, 0], pixel_ids[:, 1]] - consensus_class_true = np.mode(pix_labels_true[:]) - - pix_labels_pred = y_pred[i, pixel_ids[:, 0], pixel_ids[:, 1]] - consensus_class_pred = np.mode(pix_labels_pred[:]) - nuclei_true[i, centroid[0], centroid[1]] = consensus_class_true - nuclei_pred[i, centroid[0], centroid[1]] = consensus_class_pred - - # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. - # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. - - return nuclei_true, nuclei_pred - - -def plot_confusion_matrix(confusion_matrix, index_to_label_dict): - # Create a figure and axis to plot the confusion matrix - fig, ax = plt.subplots() - - # Create a color heatmap for the confusion matrix - cax = ax.matshow(confusion_matrix, cmap="viridis") - - # Create a colorbar and set the label - fig.colorbar(cax, label="Frequency") - - # Set labels for the classes - - ax.set_xticks(np.arange(len(index_to_label_dict))) - ax.set_yticks(np.arange(len(index_to_label_dict))) - ax.set_xticklabels(index_to_label_dict.values(), rotation=45) - ax.set_yticklabels(index_to_label_dict.values()) - - # Set labels for the axes - ax.set_xlabel("Predicted") - ax.set_ylabel("True") - - # Add text annotations to the confusion matrix - for i in range(len(index_to_label_dict)): - for j in range(len(index_to_label_dict)): - ax.text( - j, - i, - str(int(confusion_matrix[i, j])), - ha="center", - va="center", - color="white", - ) - - return fig diff --git a/viscy/scripts/infection_phenotyping/test_infection_classifier.py b/viscy/scripts/infection_phenotyping/test_infection_classifier.py index 14bff35c..c12552ca 100644 --- a/viscy/scripts/infection_phenotyping/test_infection_classifier.py +++ b/viscy/scripts/infection_phenotyping/test_infection_classifier.py @@ -1,51 +1,37 @@ # %% -import numpy as np from viscy.data.hcs import HCSDataModule -from viscy.transforms import NormalizeSampled -from viscy.unet.networks.Unet2D import Unet2d -from viscy.data.hcs import Sample import lightning.pytorch as pl -import torch -import torchmetrics from viscy.light.predict_writer import HCSPredictionWriter -from viscy.scripts.infection_phenotyping.Infection_classification_model import ( - SemanticSegUNet2D, -) +from viscy.scripts.infection_phenotyping.classify_infection import SemanticSegUNet2D +from pytorch_lightning.loggers import TensorBoardLogger # %% test the model on the test set test_datapath = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2024_02_13_DENV_3infMarked_test.zarr" data_module = HCSDataModule( - test_datapath, - source_channel=["Sensor", "Phase"], - target_channel=["inf_mask"], + data_path=test_datapath, + source_channel=['Sensor','Phase'], + target_channel=['Inf_mask'], split_ratio=0.8, z_window_size=1, architecture="2D", - num_workers=1, + num_workers=0, batch_size=1, - normalizations=[ - NormalizeSampled( - keys=["Sensor", "Phase"], - level="fov_statistics", - subtrahend="median", - divisor="iqr", - ) - ], ) -# Prepare the data -data_module.prepare_data() - -data_module.setup(stage="predict") +data_module.setup(stage="test") # %% create trainer and input -output_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/pred/Exp_2024_02_13_DENV_3infMarked_pred_SM.zarr" +logger = TensorBoardLogger( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/", + name="logs_wPhase", +) trainer = pl.Trainer( + logger=logger, default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase", - callbacks=[HCSPredictionWriter(output_path, write_input=False)], + log_every_n_steps=1, devices=1, # Set the number of GPUs to use. This avoids run-time exception from distributed training when the node has multiple GPUs ) @@ -55,6 +41,18 @@ ckpt_path="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase/version_34/checkpoints/epoch=99-step=300.ckpt", ) +trainer.test(model=model, datamodule=data_module) + +# %% predict the test set + +output_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/pred/Exp_2024_02_13_DENV_3infMarked_pred_SP.zarr" + +trainer = pl.Trainer( + default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase", + callbacks=[HCSPredictionWriter(output_path, write_input=False)], + devices=1, # Set the number of GPUs to use. This avoids run-time exception from distributed training when the node has multiple GPUs +) + trainer.predict( model=model, datamodule=data_module, From 99a387645882e7b66d439d649a35316736856c6c Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Mon, 1 Apr 2024 15:14:16 -0700 Subject: [PATCH 65/92] lightning module --- .../classify_infection.py | 332 ++++++++++++++++++ 1 file changed, 332 insertions(+) create mode 100644 viscy/scripts/infection_phenotyping/classify_infection.py diff --git a/viscy/scripts/infection_phenotyping/classify_infection.py b/viscy/scripts/infection_phenotyping/classify_infection.py new file mode 100644 index 00000000..10f26340 --- /dev/null +++ b/viscy/scripts/infection_phenotyping/classify_infection.py @@ -0,0 +1,332 @@ + +import torch +import torch.nn as nn +import lightning.pytorch as pl +import torch.nn.functional as F +from torch import Tensor +import torchmetrics +from statistics import mode + +# import torchview +from typing import Literal, Sequence +from skimage.exposure import rescale_intensity +from matplotlib.cm import get_cmap +from skimage.measure import regionprops +import numpy as np +import matplotlib.pyplot as plt + +from monai.transforms import DivisiblePad +from viscy.unet.networks.Unet2D import Unet2d +from viscy.data.hcs import Sample + + +# %% Methods to compute confusion matrix per cell using torchmetrics + +# The confusion matrix at the single-cell resolution. +def confusion_matrix_per_cell( + y_true: torch.Tensor, y_pred: torch.Tensor, num_classes: int +): + """Compute confusion matrix per cell. + + Args: + y_true (torch.Tensor): Ground truth label image (BXHXW). + y_pred (torch.Tensor): Predicted label image (BXHXW). + num_classes (int): Number of classes. + + Returns: + torch.Tensor: Confusion matrix per cell (BXCXC). + """ + # Convert the image class to the nuclei class + nuclei_true, nuclei_pred = image_class_to_nuclei_class(y_true, y_pred, num_classes) + # Compute the confusion matrix per cell + confusion_matrix_per_cell = torchmetrics.functional.confusion_matrix( + nuclei_true[nuclei_true > 0], # indexing just non-background pixels. + nuclei_pred[nuclei_true > 0], + num_classes=num_classes, + task="multiclass", # Fix: Change "multi_class" to "multiclass" + ) + return confusion_matrix_per_cell + + +# These images can be logged with prediction. +def image_class_to_nuclei_class( + y_true: torch.Tensor, y_pred: torch.Tensor, num_classes: int +): + """Convert the class of the image to the class of the nuclei. + + Args: + label_image (torch.Tensor): Label image (BXHXW). Values of tensor are integers that represent semantic segmentation. + num_classes (int): Number of classes. + + Returns: + torch.Tensor: Label images with a consensus class at the centroid of nuclei. + """ + nuclei_true = torch.zeros_like(y_true[:, 0, 0, :, :]) + nuclei_pred = torch.zeros_like(y_pred[:, 0, : , :]) + batch_size = y_true.size(0) + # find centroids of nuclei from y_true + for i in range(batch_size): + y_true_cpu = y_true[i].cpu().numpy() + y_true_reshaped = y_true_cpu.reshape(y_true_cpu.shape[-2:]) + print(y_true_reshaped.shape) + regions = regionprops(y_true_reshaped.astype(int)) + # Find centroids, pixel coordinates from the ground truth. + for region in regions: + centroid = region.centroid + pixel_ids = region.coords + # Find the class of the nuclei in the ground truth and prediction. + pix_labels_true = y_true[i, 0, 0, pixel_ids[:, 0], pixel_ids[:, 1]] + consensus_class_true = mode(pix_labels_true[:]) + + pix_labels_pred = y_pred[i, 0, pixel_ids[:, 0], pixel_ids[:, 1]] + consensus_class_pred = mode(pix_labels_pred[:]) + nuclei_true[i, pixel_ids[0], pixel_ids[1]] = torch.FloatTensor([consensus_class_true]).to(y_true.dtype) + nuclei_pred[i, pixel_ids[0], pixel_ids[1]] = torch.FloatTensor([consensus_class_pred]).to(y_pred.dtype) + + # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. + # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. + + return nuclei_true, nuclei_pred + +# Define a 2d unet model for infection classification as a lightning module. + +class SemanticSegUNet2D(pl.LightningModule): + # Model for semantic segmentation. + def __init__( + self, + in_channels: int, # Number of input channels + out_channels: int, # Number of output channels + lr: float = 1e-3, # Learning rate + loss_function: nn.Module = nn.CrossEntropyLoss(), # Loss function + schedule: Literal[ + "WarmupCosine", "Constant" + ] = "Constant", # Learning rate schedule + log_batches_per_epoch: int = 2, # Number of batches to log per epoch + log_samples_per_batch: int = 2, # Number of samples to log per batch + ckpt_path: str = None, # Path to the checkpoint + ): + super(SemanticSegUNet2D, self).__init__() # Call the superclass initializer + # Initialize the UNet model + self.unet_model = Unet2d(in_channels=in_channels, out_channels=out_channels) + self.lr = lr # Set the learning rate + # Set the loss function to CrossEntropyLoss if none is provided + self.loss_function = loss_function if loss_function else nn.CrossEntropyLoss() + self.schedule = schedule # Set the learning rate schedule + self.log_batches_per_epoch = ( + log_batches_per_epoch # Set the number of batches to log per epoch + ) + self.log_samples_per_batch = ( + log_samples_per_batch # Set the number of samples to log per batch + ) + self.training_step_outputs = [] # Initialize the list of training step outputs + self.validation_step_outputs = ( + [] + ) # Initialize the list of validation step outputs + + self.pred_cm = None # Initialize the confusion matrix + self.index_to_label_dict = ["Background", "Infected", "Uninfected"] + + if ckpt_path is not None: + state_dict = torch.load(ckpt_path, map_location=torch.device("cpu"))[ + "state_dict" + ] + state_dict.pop("loss_function.weight", None) # Remove the unexpected key + self.load_state_dict(state_dict) # loading only weights + + # Define the forward pass + def forward(self, x): + return self.unet_model(x) # Pass the input through the UNet model + + # Define the optimizer + def configure_optimizers(self): + optimizer = torch.optim.Adam( + self.parameters(), lr=self.lr + ) # Use the Adam optimizer + return optimizer + + # Define the training step + def training_step(self, batch: Sample, batch_idx: int): + source = batch["source"] # Extract the source from the batch + target = batch["target"] # Extract the target from the batch + pred = self.forward(source) # Make a prediction using the source + # Convert the target to one-hot encoding + target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute( + 0, 4, 1, 2, 3 + ) + target_one_hot = target_one_hot.float() # Convert the target to float type + train_loss = self.loss_function(pred, target_one_hot) # Calculate the loss + # Log the training step outputs if the batch index is less than the number of batches to log per epoch + if batch_idx < self.log_batches_per_epoch: + self.training_step_outputs.extend( + self._detach_sample((source, target_one_hot, pred)) + ) + # Log the training loss + self.log( + "loss/train", + train_loss, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, + ) + return train_loss # Return the training loss + + def validation_step(self, batch: Sample, batch_idx: int): + source = batch["source"] # Extract the source from the batch + target = batch["target"] # Extract the target from the batch + pred = self.forward(source) # Make a prediction using the source + # Convert the target to one-hot encoding + target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute( + 0, 4, 1, 2, 3 + ) + target_one_hot = target_one_hot.float() # Convert the target to float type + loss = self.loss_function(pred, target_one_hot) # Calculate the loss + # Log the validation step outputs if the batch index is less than the number of batches to log per epoch + if batch_idx < self.log_batches_per_epoch: + self.validation_step_outputs.extend( + self._detach_sample((source, target_one_hot, pred)) + ) + # Log the validation loss + self.log( + "loss/validate", loss, sync_dist=True, add_dataloader_idx=False, logger=True + ) + return loss # Return the validation loss + + def test_step(self, batch: Sample, batch_idx: int): + source = batch["source"] # Extract the source from the batch + target = batch["target"] # Extract the target from the batch + down_factor = 2**self.unet_model.num_blocks + self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) + source = self._predict_pad(batch["source"]) # Pad the source + logits = self._predict_pad.inverse( + self.forward(source) + ) # Predict and remove padding. + prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities + labels_pred = torch.argmax(prob_pred, dim=1) # Calculate the predicted labels + self.pred_cm = confusion_matrix_per_cell( + target, labels_pred, num_classes=3 + ) + + return self.pred_cm + + def on_predict_start(self): + """Pad the input shape to be divisible by the downsampling factor. + The inverse of this transform crops the prediction to original shape. + """ + down_factor = 2**self.unet_model.num_blocks + self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) + + # Define the prediction step + def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): + down_factor = 2**self.unet_model.num_blocks + self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) + source = self._predict_pad(batch["source"]) # Pad the source + logits = self._predict_pad.inverse( + self.forward(source) + ) # Predict and remove padding. + prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities + # Go from probabilities/one-hot encoded data to class labels. + labels_pred = torch.argmax(prob_pred, dim=1) # Calculate the predicted labels + + return labels_pred # log the class predicted image + + # Accumulate the confusion matrix at the end of prediction epoch and log. + def on_test_epoch_end(self): + confusion_matrix = self.pred_cm.compute().cpu().numpy() + + self.logger.experiment.add_figure( + "Confusion Matrix per Cell", + plot_confusion_matrix(confusion_matrix, self.index_to_label_dict), + self.current_epoch, + ) + + # Define what happens at the end of a training epoch + def on_train_epoch_end(self): + self._log_samples( + "train_samples", self.training_step_outputs + ) # Log the training samples + self.training_step_outputs = [] # Reset the list of training step outputs + + # Define what happens at the end of a validation epoch + def on_validation_epoch_end(self): + self._log_samples( + "val_samples", self.validation_step_outputs + ) # Log the validation samples + self.validation_step_outputs = [] # Reset the list of validation step outputs + # TODO: Log the confusion matrix + + # Define a method to detach a sample + def _detach_sample(self, imgs: Sequence[Tensor]): + # Detach the images and convert them to numpy arrays + num_samples = 3 + return [ + [np.squeeze(img[i].detach().cpu().numpy().max(axis=1)) for img in imgs] + for i in range(num_samples) + ] + + # Define a method to log samples + def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): + images_grid = [] # Initialize the list of image grids + for sample_images in imgs: # For each sample image + images_row = [] # Initialize the list of image rows + for i, image in enumerate( + sample_images + ): # For each image in the sample images + cm_name = "gray" if i == 0 else "inferno" # Set the colormap name + if image.ndim == 2: # If the image is 2D + image = image[np.newaxis] # Add a new axis + for channel in image: # For each channel in the image + channel = rescale_intensity( + channel, out_range=(0, 1) + ) # Rescale the intensity of the channel + render = get_cmap(cm_name)(channel, bytes=True)[ + ..., :3 + ] # Render the channel + images_row.append( + render + ) # Append the render to the list of image rows + images_grid.append( + np.concatenate(images_row, axis=1) + ) # Append the concatenated image rows to the list of image grids + grid = np.concatenate(images_grid, axis=0) # Concatenate the image grids + # Log the image grid + self.logger.experiment.add_image( + key, grid, self.current_epoch, dataformats="HWC" + ) + + def plot_confusion_matrix(confusion_matrix, index_to_label_dict): + # Create a figure and axis to plot the confusion matrix + fig, ax = plt.subplots() + + # Create a color heatmap for the confusion matrix + cax = ax.matshow(confusion_matrix, cmap="viridis") + + # Create a colorbar and set the label + fig.colorbar(cax, label="Frequency") + + # Set labels for the classes + + ax.set_xticks(np.arange(len(index_to_label_dict))) + ax.set_yticks(np.arange(len(index_to_label_dict))) + ax.set_xticklabels(index_to_label_dict.values(), rotation=45) + ax.set_yticklabels(index_to_label_dict.values()) + + # Set labels for the axes + ax.set_xlabel("Predicted") + ax.set_ylabel("True") + + # Add text annotations to the confusion matrix + for i in range(len(index_to_label_dict)): + for j in range(len(index_to_label_dict)): + ax.text( + j, + i, + str(int(confusion_matrix[i, j])), + ha="center", + va="center", + color="white", + ) + + plt.show(fig) # Show the figure + return fig \ No newline at end of file From 000a966bec2df456f2eeb42bc87ec017ec39669d Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Tue, 2 Apr 2024 16:38:55 -0700 Subject: [PATCH 66/92] corrected test cm compute --- .../Infection_classification_model.py | 2 + .../classify_infection.py | 163 ++++++++++-------- .../test_infection_classifier.py | 11 +- 3 files changed, 105 insertions(+), 71 deletions(-) diff --git a/viscy/scripts/infection_phenotyping/Infection_classification_model.py b/viscy/scripts/infection_phenotyping/Infection_classification_model.py index d8056a04..37f606cb 100644 --- a/viscy/scripts/infection_phenotyping/Infection_classification_model.py +++ b/viscy/scripts/infection_phenotyping/Infection_classification_model.py @@ -114,3 +114,5 @@ # Run training. trainer.fit(model, data_module) + +# %% diff --git a/viscy/scripts/infection_phenotyping/classify_infection.py b/viscy/scripts/infection_phenotyping/classify_infection.py index 10f26340..51cac985 100644 --- a/viscy/scripts/infection_phenotyping/classify_infection.py +++ b/viscy/scripts/infection_phenotyping/classify_infection.py @@ -6,6 +6,7 @@ from torch import Tensor import torchmetrics from statistics import mode +# import napari # import torchview from typing import Literal, Sequence @@ -43,7 +44,7 @@ def confusion_matrix_per_cell( nuclei_true[nuclei_true > 0], # indexing just non-background pixels. nuclei_pred[nuclei_true > 0], num_classes=num_classes, - task="multiclass", # Fix: Change "multi_class" to "multiclass" + task="multiclass", ) return confusion_matrix_per_cell @@ -62,7 +63,8 @@ def image_class_to_nuclei_class( torch.Tensor: Label images with a consensus class at the centroid of nuclei. """ nuclei_true = torch.zeros_like(y_true[:, 0, 0, :, :]) - nuclei_pred = torch.zeros_like(y_pred[:, 0, : , :]) + nuclei_pred = torch.zeros_like(y_pred[:, 0, 0, :, :]) + batch_size = y_true.size(0) # find centroids of nuclei from y_true for i in range(batch_size): @@ -78,16 +80,51 @@ def image_class_to_nuclei_class( pix_labels_true = y_true[i, 0, 0, pixel_ids[:, 0], pixel_ids[:, 1]] consensus_class_true = mode(pix_labels_true[:]) - pix_labels_pred = y_pred[i, 0, pixel_ids[:, 0], pixel_ids[:, 1]] + pix_labels_pred = y_pred[i, 0, 0, pixel_ids[:, 0], pixel_ids[:, 1]] consensus_class_pred = mode(pix_labels_pred[:]) - nuclei_true[i, pixel_ids[0], pixel_ids[1]] = torch.FloatTensor([consensus_class_true]).to(y_true.dtype) - nuclei_pred[i, pixel_ids[0], pixel_ids[1]] = torch.FloatTensor([consensus_class_pred]).to(y_pred.dtype) + nuclei_true[i, int(centroid[0]), int(centroid[1])] = torch.FloatTensor([consensus_class_true]).to(y_true.dtype) + nuclei_pred[i, int(centroid[0]), int(centroid[1])] = torch.FloatTensor([consensus_class_pred]).to(y_pred.dtype) # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. return nuclei_true, nuclei_pred +def plot_confusion_matrix(confusion_matrix, index_to_label_dict): + # Create a figure and axis to plot the confusion matrix + fig, ax = plt.subplots() + + # Create a color heatmap for the confusion matrix + cax = ax.matshow(confusion_matrix, cmap="viridis") + + # Create a colorbar and set the label + index_to_label_dict = dict(enumerate(index_to_label_dict)) # Convert list to dictionary + fig.colorbar(cax, label="Frequency") + + # Set labels for the classes + ax.set_xticks(np.arange(len(index_to_label_dict))) + ax.set_yticks(np.arange(len(index_to_label_dict))) + ax.set_xticklabels(index_to_label_dict.values(), rotation=45) + ax.set_yticklabels(index_to_label_dict.values()) + + # Set labels for the axes + ax.set_xlabel("Predicted") + ax.set_ylabel("True") + + # Add text annotations to the confusion matrix + for i in range(len(index_to_label_dict)): + for j in range(len(index_to_label_dict)): + ax.text( + j, + i, + str(int(confusion_matrix[i, j])), + ha="center", + va="center", + color="white", + ) + + # plt.show(fig) # Show the figure + return fig # Define a 2d unet model for infection classification as a lightning module. class SemanticSegUNet2D(pl.LightningModule): @@ -108,6 +145,12 @@ def __init__( super(SemanticSegUNet2D, self).__init__() # Call the superclass initializer # Initialize the UNet model self.unet_model = Unet2d(in_channels=in_channels, out_channels=out_channels) + if ckpt_path is not None: + state_dict = torch.load(ckpt_path, map_location=torch.device("cpu"))[ + "state_dict" + ] + state_dict.pop("loss_function.weight", None) # Remove the unexpected key + self.load_state_dict(state_dict) # loading only weights self.lr = lr # Set the learning rate # Set the loss function to CrossEntropyLoss if none is provided self.loss_function = loss_function if loss_function else nn.CrossEntropyLoss() @@ -126,12 +169,7 @@ def __init__( self.pred_cm = None # Initialize the confusion matrix self.index_to_label_dict = ["Background", "Infected", "Uninfected"] - if ckpt_path is not None: - state_dict = torch.load(ckpt_path, map_location=torch.device("cpu"))[ - "state_dict" - ] - state_dict.pop("loss_function.weight", None) # Remove the unexpected key - self.load_state_dict(state_dict) # loading only weights + # Define the forward pass def forward(self, x): @@ -193,23 +231,6 @@ def validation_step(self, batch: Sample, batch_idx: int): ) return loss # Return the validation loss - def test_step(self, batch: Sample, batch_idx: int): - source = batch["source"] # Extract the source from the batch - target = batch["target"] # Extract the target from the batch - down_factor = 2**self.unet_model.num_blocks - self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) - source = self._predict_pad(batch["source"]) # Pad the source - logits = self._predict_pad.inverse( - self.forward(source) - ) # Predict and remove padding. - prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities - labels_pred = torch.argmax(prob_pred, dim=1) # Calculate the predicted labels - self.pred_cm = confusion_matrix_per_cell( - target, labels_pred, num_classes=3 - ) - - return self.pred_cm - def on_predict_start(self): """Pad the input shape to be divisible by the downsampling factor. The inverse of this transform crops the prediction to original shape. @@ -230,16 +251,55 @@ def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): labels_pred = torch.argmax(prob_pred, dim=1) # Calculate the predicted labels return labels_pred # log the class predicted image + + def on_test_start(self): + self.pred_cm = torch.zeros((3, 3)) + down_factor = 2**self.unet_model.num_blocks + self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) + + def test_step(self, batch: Sample): + source = self._predict_pad(batch["source"]) # Pad the source + logits = self.forward(source) + # prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities + labels_pred = torch.argmax(logits, dim=1, keepdim=True) # Calculate the predicted labels + + # pred_img = logits.detach().cpu().numpy() + # v = napari.Viewer() + # v.add_image(pred_img) + # napari.run() - # Accumulate the confusion matrix at the end of prediction epoch and log. - def on_test_epoch_end(self): - confusion_matrix = self.pred_cm.compute().cpu().numpy() + target = self._predict_pad(batch["target"]) # Extract the target from the batch + pred_cm = confusion_matrix_per_cell( + target, labels_pred, num_classes=3 + ) # Calculate the confusion matrix per cell + + self.pred_cm += pred_cm # Append the confusion matrix to pred_cm self.logger.experiment.add_figure( "Confusion Matrix per Cell", - plot_confusion_matrix(confusion_matrix, self.index_to_label_dict), + plot_confusion_matrix(pred_cm, self.index_to_label_dict), + self.current_epoch, + ) + + # Accumulate the confusion matrix at the end of test epoch and log. + def on_test_end(self): + confusion_matrix_sum = self.pred_cm + self.logger.experiment.add_figure( + "Confusion Matrix", + plot_confusion_matrix(confusion_matrix_sum, self.index_to_label_dict), self.current_epoch, ) + # def on_test_batch_end(self): + # # confusion_matrix_sum = torch.zeros((3, 3)) # Initialize the sum of confusion matrices + # # for pred_cm in self.pred_cm: # For each confusion matrix + # # confusion_matrix_sum += pred_cm # Accumulate the sum + # # confusion_matrix_sum = confusion_matrix_sum.cpu().numpy() # Convert to numpy array + # confusion_matrix_sum = torch.sum(torch.stack([tensor.cpu() for tensor in self.pred_cm], dim=0), dim=0) + # self.logger.experiment.add_figure( + # "Confusion Matrix batch-wise", + # plot_confusion_matrix(confusion_matrix_sum, self.index_to_label_dict), + # self.current_epoch, + # ) # Define what happens at the end of a training epoch def on_train_epoch_end(self): @@ -254,7 +314,6 @@ def on_validation_epoch_end(self): "val_samples", self.validation_step_outputs ) # Log the validation samples self.validation_step_outputs = [] # Reset the list of validation step outputs - # TODO: Log the confusion matrix # Define a method to detach a sample def _detach_sample(self, imgs: Sequence[Tensor]): @@ -293,40 +352,4 @@ def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): # Log the image grid self.logger.experiment.add_image( key, grid, self.current_epoch, dataformats="HWC" - ) - - def plot_confusion_matrix(confusion_matrix, index_to_label_dict): - # Create a figure and axis to plot the confusion matrix - fig, ax = plt.subplots() - - # Create a color heatmap for the confusion matrix - cax = ax.matshow(confusion_matrix, cmap="viridis") - - # Create a colorbar and set the label - fig.colorbar(cax, label="Frequency") - - # Set labels for the classes - - ax.set_xticks(np.arange(len(index_to_label_dict))) - ax.set_yticks(np.arange(len(index_to_label_dict))) - ax.set_xticklabels(index_to_label_dict.values(), rotation=45) - ax.set_yticklabels(index_to_label_dict.values()) - - # Set labels for the axes - ax.set_xlabel("Predicted") - ax.set_ylabel("True") - - # Add text annotations to the confusion matrix - for i in range(len(index_to_label_dict)): - for j in range(len(index_to_label_dict)): - ax.text( - j, - i, - str(int(confusion_matrix[i, j])), - ha="center", - va="center", - color="white", - ) - - plt.show(fig) # Show the figure - return fig \ No newline at end of file + ) \ No newline at end of file diff --git a/viscy/scripts/infection_phenotyping/test_infection_classifier.py b/viscy/scripts/infection_phenotyping/test_infection_classifier.py index c12552ca..fadd3e75 100644 --- a/viscy/scripts/infection_phenotyping/test_infection_classifier.py +++ b/viscy/scripts/infection_phenotyping/test_infection_classifier.py @@ -4,6 +4,7 @@ from viscy.light.predict_writer import HCSPredictionWriter from viscy.scripts.infection_phenotyping.classify_infection import SemanticSegUNet2D from pytorch_lightning.loggers import TensorBoardLogger +from viscy.transforms import NormalizeSampled # %% test the model on the test set test_datapath = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2024_02_13_DENV_3infMarked_test.zarr" @@ -17,6 +18,14 @@ architecture="2D", num_workers=0, batch_size=1, + normalizations=[ + NormalizeSampled( + keys=["Sensor", "Phase"], + level="fov_statistics", + subtrahend="median", + divisor="iqr", + ) + ], ) data_module.setup(stage="test") @@ -38,7 +47,7 @@ model = SemanticSegUNet2D( in_channels=2, out_channels=3, - ckpt_path="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase/version_34/checkpoints/epoch=99-step=300.ckpt", + ckpt_path="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase/version_74/checkpoints/epoch=99-step=300.ckpt", ) trainer.test(model=model, datamodule=data_module) From 688336e0d80953a1598cc7b1cb738a93b9817efd Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Wed, 3 Apr 2024 11:16:31 -0700 Subject: [PATCH 67/92] corrected test module --- .../classify_infection.py | 35 ++++++------------- 1 file changed, 11 insertions(+), 24 deletions(-) diff --git a/viscy/scripts/infection_phenotyping/classify_infection.py b/viscy/scripts/infection_phenotyping/classify_infection.py index 51cac985..497efe96 100644 --- a/viscy/scripts/infection_phenotyping/classify_infection.py +++ b/viscy/scripts/infection_phenotyping/classify_infection.py @@ -4,9 +4,10 @@ import lightning.pytorch as pl import torch.nn.functional as F from torch import Tensor -import torchmetrics +# from torchmetrics.functional import confusion_matrix from statistics import mode # import napari +from sklearn.metrics import ConfusionMatrixDisplay # import torchview from typing import Literal, Sequence @@ -39,13 +40,17 @@ def confusion_matrix_per_cell( """ # Convert the image class to the nuclei class nuclei_true, nuclei_pred = image_class_to_nuclei_class(y_true, y_pred, num_classes) + + nuclei_true_np = nuclei_true.cpu().numpy() + nuclei_pred_np = nuclei_pred.cpu().numpy() + # Compute the confusion matrix per cell - confusion_matrix_per_cell = torchmetrics.functional.confusion_matrix( - nuclei_true[nuclei_true > 0], # indexing just non-background pixels. - nuclei_pred[nuclei_true > 0], - num_classes=num_classes, - task="multiclass", + confusion_matrix_per_cell = ConfusionMatrixDisplay.from_predictions( + nuclei_true_np[nuclei_true_np > 0], # indexing just non-background pixels. + nuclei_pred_np[nuclei_true_np > 0], + labels=range(num_classes), ) + confusion_matrix_per_cell = torch.tensor(confusion_matrix_per_cell) return confusion_matrix_per_cell @@ -70,7 +75,6 @@ def image_class_to_nuclei_class( for i in range(batch_size): y_true_cpu = y_true[i].cpu().numpy() y_true_reshaped = y_true_cpu.reshape(y_true_cpu.shape[-2:]) - print(y_true_reshaped.shape) regions = regionprops(y_true_reshaped.astype(int)) # Find centroids, pixel coordinates from the ground truth. for region in regions: @@ -263,16 +267,10 @@ def test_step(self, batch: Sample): # prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities labels_pred = torch.argmax(logits, dim=1, keepdim=True) # Calculate the predicted labels - # pred_img = logits.detach().cpu().numpy() - # v = napari.Viewer() - # v.add_image(pred_img) - # napari.run() - target = self._predict_pad(batch["target"]) # Extract the target from the batch pred_cm = confusion_matrix_per_cell( target, labels_pred, num_classes=3 ) # Calculate the confusion matrix per cell - self.pred_cm += pred_cm # Append the confusion matrix to pred_cm self.logger.experiment.add_figure( @@ -289,17 +287,6 @@ def on_test_end(self): plot_confusion_matrix(confusion_matrix_sum, self.index_to_label_dict), self.current_epoch, ) - # def on_test_batch_end(self): - # # confusion_matrix_sum = torch.zeros((3, 3)) # Initialize the sum of confusion matrices - # # for pred_cm in self.pred_cm: # For each confusion matrix - # # confusion_matrix_sum += pred_cm # Accumulate the sum - # # confusion_matrix_sum = confusion_matrix_sum.cpu().numpy() # Convert to numpy array - # confusion_matrix_sum = torch.sum(torch.stack([tensor.cpu() for tensor in self.pred_cm], dim=0), dim=0) - # self.logger.experiment.add_figure( - # "Confusion Matrix batch-wise", - # plot_confusion_matrix(confusion_matrix_sum, self.index_to_label_dict), - # self.current_epoch, - # ) # Define what happens at the end of a training epoch def on_train_epoch_end(self): From 6b58f34dd951ad7ddccde50b8e3a6ba0b4bb31e9 Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Wed, 3 Apr 2024 11:18:09 -0700 Subject: [PATCH 68/92] separated test and prediction scripts --- .../predict_infection_classifier.py | 54 +++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 viscy/scripts/infection_phenotyping/predict_infection_classifier.py diff --git a/viscy/scripts/infection_phenotyping/predict_infection_classifier.py b/viscy/scripts/infection_phenotyping/predict_infection_classifier.py new file mode 100644 index 00000000..14600325 --- /dev/null +++ b/viscy/scripts/infection_phenotyping/predict_infection_classifier.py @@ -0,0 +1,54 @@ + + +from viscy.light.predict_writer import HCSPredictionWriter +from viscy.data.hcs import HCSDataModule +import lightning.pytorch as pl +from viscy.scripts.infection_phenotyping.classify_infection import SemanticSegUNet2D +from viscy.transforms import NormalizeSampled + +# %% # %% write the predictions to a zarr file + +pred_datapath = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2024_02_13_DENV_3infMarked_test.zarr" + +data_module = HCSDataModule( + data_path=pred_datapath, + source_channel=['Sensor','Phase'], + target_channel=['Inf_mask'], + split_ratio=0.8, + z_window_size=1, + architecture="2D", + num_workers=0, + batch_size=1, + normalizations=[ + NormalizeSampled( + keys=["Sensor", "Phase"], + level="fov_statistics", + subtrahend="median", + divisor="iqr", + ) + ], +) + +data_module.setup(stage="predict") + +model = SemanticSegUNet2D( + in_channels=2, + out_channels=3, + ckpt_path="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase/version_74/checkpoints/epoch=99-step=300.ckpt", +) + +# %% perform prediction + +output_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/pred/Exp_2024_02_13_DENV_3infMarked_pred_SP.zarr" + +trainer = pl.Trainer( + default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase", + callbacks=[HCSPredictionWriter(output_path, write_input=False)], + devices=1, # Set the number of GPUs to use. This avoids run-time exception from distributed training when the node has multiple GPUs +) + +trainer.predict( + model=model, + datamodule=data_module, + return_predictions=True, +) From b6ad254bf5bcb7e6fe78b6cd4d4c38d075c4a0b2 Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Thu, 4 Apr 2024 21:56:44 -0700 Subject: [PATCH 69/92] changed confusion matrix compute --- .../classify_infection.py | 88 ++++++++--------- .../predict_infection_classifier.py | 12 ++- .../test_infection_classifier.py | 96 ++++++++++++++++--- 3 files changed, 129 insertions(+), 67 deletions(-) diff --git a/viscy/scripts/infection_phenotyping/classify_infection.py b/viscy/scripts/infection_phenotyping/classify_infection.py index 497efe96..ee904434 100644 --- a/viscy/scripts/infection_phenotyping/classify_infection.py +++ b/viscy/scripts/infection_phenotyping/classify_infection.py @@ -4,16 +4,13 @@ import lightning.pytorch as pl import torch.nn.functional as F from torch import Tensor -# from torchmetrics.functional import confusion_matrix -from statistics import mode -# import napari -from sklearn.metrics import ConfusionMatrixDisplay +import cv2 # import torchview from typing import Literal, Sequence from skimage.exposure import rescale_intensity from matplotlib.cm import get_cmap -from skimage.measure import regionprops +from skimage.measure import regionprops, label import numpy as np import matplotlib.pyplot as plt @@ -21,7 +18,7 @@ from viscy.unet.networks.Unet2D import Unet2d from viscy.data.hcs import Sample - +# # %% Methods to compute confusion matrix per cell using torchmetrics # The confusion matrix at the single-cell resolution. @@ -39,23 +36,13 @@ def confusion_matrix_per_cell( torch.Tensor: Confusion matrix per cell (BXCXC). """ # Convert the image class to the nuclei class - nuclei_true, nuclei_pred = image_class_to_nuclei_class(y_true, y_pred, num_classes) - - nuclei_true_np = nuclei_true.cpu().numpy() - nuclei_pred_np = nuclei_pred.cpu().numpy() - - # Compute the confusion matrix per cell - confusion_matrix_per_cell = ConfusionMatrixDisplay.from_predictions( - nuclei_true_np[nuclei_true_np > 0], # indexing just non-background pixels. - nuclei_pred_np[nuclei_true_np > 0], - labels=range(num_classes), - ) + confusion_matrix_per_cell = compute_confusion_matrix(y_true, y_pred, num_classes) confusion_matrix_per_cell = torch.tensor(confusion_matrix_per_cell) return confusion_matrix_per_cell # These images can be logged with prediction. -def image_class_to_nuclei_class( +def compute_confusion_matrix( y_true: torch.Tensor, y_pred: torch.Tensor, num_classes: int ): """Convert the class of the image to the class of the nuclei. @@ -67,32 +54,40 @@ def image_class_to_nuclei_class( Returns: torch.Tensor: Label images with a consensus class at the centroid of nuclei. """ - nuclei_true = torch.zeros_like(y_true[:, 0, 0, :, :]) - nuclei_pred = torch.zeros_like(y_pred[:, 0, 0, :, :]) batch_size = y_true.size(0) # find centroids of nuclei from y_true + conf_mat = np.zeros((num_classes, num_classes)) for i in range(batch_size): y_true_cpu = y_true[i].cpu().numpy() + y_pred_cpu = y_pred[i].cpu().numpy() y_true_reshaped = y_true_cpu.reshape(y_true_cpu.shape[-2:]) - regions = regionprops(y_true_reshaped.astype(int)) - # Find centroids, pixel coordinates from the ground truth. - for region in regions: - centroid = region.centroid - pixel_ids = region.coords - # Find the class of the nuclei in the ground truth and prediction. - pix_labels_true = y_true[i, 0, 0, pixel_ids[:, 0], pixel_ids[:, 1]] - consensus_class_true = mode(pix_labels_true[:]) + y_pred_reshaped = y_pred_cpu.reshape(y_pred_cpu.shape[-2:]) + y_pred_resized = cv2.resize(y_pred_reshaped, dsize=y_true_reshaped.shape[::-1], interpolation=cv2.INTER_NEAREST) + y_pred_resized = np.where(y_true_reshaped > 0, y_pred_resized, 0) - pix_labels_pred = y_pred[i, 0, 0, pixel_ids[:, 0], pixel_ids[:, 1]] - consensus_class_pred = mode(pix_labels_pred[:]) - nuclei_true[i, int(centroid[0]), int(centroid[1])] = torch.FloatTensor([consensus_class_true]).to(y_true.dtype) - nuclei_pred[i, int(centroid[0]), int(centroid[1])] = torch.FloatTensor([consensus_class_pred]).to(y_pred.dtype) + # find objects in every image + label_img = label(y_true_reshaped) + regions = regionprops(label_img) + # Find centroids, pixel coordinates from the ground truth. + for region in regions: + if region.area > 0: + row, col = region.centroid + pred_id = y_pred_resized[int(row), int(col)] + test_id = y_true_reshaped[int(row), int(col)] + + if pred_id == 1 and test_id == 1: + conf_mat[1,1] += 1 + if pred_id == 1 and test_id == 2: + conf_mat[0,1] += 1 + if pred_id == 2 and test_id == 1: + conf_mat[1,0] += 1 + if pred_id == 2 and test_id == 2: + conf_mat[0,0] += 1 # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. - - return nuclei_true, nuclei_pred + return conf_mat def plot_confusion_matrix(confusion_matrix, index_to_label_dict): # Create a figure and axis to plot the confusion matrix @@ -171,7 +166,7 @@ def __init__( ) # Initialize the list of validation step outputs self.pred_cm = None # Initialize the confusion matrix - self.index_to_label_dict = ["Background", "Infected", "Uninfected"] + self.index_to_label_dict = ["Infected", "Uninfected"] @@ -244,32 +239,28 @@ def on_predict_start(self): # Define the prediction step def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): - down_factor = 2**self.unet_model.num_blocks - self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) source = self._predict_pad(batch["source"]) # Pad the source - logits = self._predict_pad.inverse( - self.forward(source) - ) # Predict and remove padding. + logits = self._predict_pad.inverse(self.forward(source)) # Predict and remove padding. prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities # Go from probabilities/one-hot encoded data to class labels. - labels_pred = torch.argmax(prob_pred, dim=1) # Calculate the predicted labels - + labels_pred = torch.argmax(prob_pred, dim=1, keepdim=True) # Calculate the predicted labels + return labels_pred # log the class predicted image def on_test_start(self): - self.pred_cm = torch.zeros((3, 3)) + self.pred_cm = torch.zeros((2,2)) down_factor = 2**self.unet_model.num_blocks self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) def test_step(self, batch: Sample): source = self._predict_pad(batch["source"]) # Pad the source - logits = self.forward(source) - # prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities - labels_pred = torch.argmax(logits, dim=1, keepdim=True) # Calculate the predicted labels + logits = self._predict_pad.inverse(self.forward(source)) + prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities + labels_pred = torch.argmax(prob_pred, dim=1, keepdim=True) # Calculate the predicted labels target = self._predict_pad(batch["target"]) # Extract the target from the batch pred_cm = confusion_matrix_per_cell( - target, labels_pred, num_classes=3 + target, labels_pred, num_classes=2 ) # Calculate the confusion matrix per cell self.pred_cm += pred_cm # Append the confusion matrix to pred_cm @@ -339,4 +330,5 @@ def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): # Log the image grid self.logger.experiment.add_image( key, grid, self.current_epoch, dataformats="HWC" - ) \ No newline at end of file + ) +# %% diff --git a/viscy/scripts/infection_phenotyping/predict_infection_classifier.py b/viscy/scripts/infection_phenotyping/predict_infection_classifier.py index 14600325..783c1334 100644 --- a/viscy/scripts/infection_phenotyping/predict_infection_classifier.py +++ b/viscy/scripts/infection_phenotyping/predict_infection_classifier.py @@ -1,4 +1,4 @@ - +# %% from viscy.light.predict_writer import HCSPredictionWriter from viscy.data.hcs import HCSDataModule @@ -17,24 +17,24 @@ split_ratio=0.8, z_window_size=1, architecture="2D", - num_workers=0, + num_workers=1, batch_size=1, normalizations=[ NormalizeSampled( - keys=["Sensor", "Phase"], + keys=["Phase", "Sensor"], level="fov_statistics", subtrahend="median", divisor="iqr", ) ], ) - +data_module.prepare_data() data_module.setup(stage="predict") model = SemanticSegUNet2D( in_channels=2, out_channels=3, - ckpt_path="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase/version_74/checkpoints/epoch=99-step=300.ckpt", + ckpt_path="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase/version_34/checkpoints/epoch=99-step=300.ckpt", ) # %% perform prediction @@ -52,3 +52,5 @@ datamodule=data_module, return_predictions=True, ) + +# %% diff --git a/viscy/scripts/infection_phenotyping/test_infection_classifier.py b/viscy/scripts/infection_phenotyping/test_infection_classifier.py index fadd3e75..5ed14094 100644 --- a/viscy/scripts/infection_phenotyping/test_infection_classifier.py +++ b/viscy/scripts/infection_phenotyping/test_infection_classifier.py @@ -1,7 +1,6 @@ # %% from viscy.data.hcs import HCSDataModule import lightning.pytorch as pl -from viscy.light.predict_writer import HCSPredictionWriter from viscy.scripts.infection_phenotyping.classify_infection import SemanticSegUNet2D from pytorch_lightning.loggers import TensorBoardLogger from viscy.transforms import NormalizeSampled @@ -47,23 +46,92 @@ model = SemanticSegUNet2D( in_channels=2, out_channels=3, - ckpt_path="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase/version_74/checkpoints/epoch=99-step=300.ckpt", + ckpt_path="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase/version_34/checkpoints/epoch=99-step=300.ckpt", ) trainer.test(model=model, datamodule=data_module) -# %% predict the test set -output_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/pred/Exp_2024_02_13_DENV_3infMarked_pred_SP.zarr" -trainer = pl.Trainer( - default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase", - callbacks=[HCSPredictionWriter(output_path, write_input=False)], - devices=1, # Set the number of GPUs to use. This avoids run-time exception from distributed training when the node has multiple GPUs -) -trainer.predict( - model=model, - datamodule=data_module, - return_predictions=True, -) +# # %% script to develop confusion matrix for infected cell classifier + +# from iohub.ngff import open_ome_zarr +# import numpy as np +# from skimage.measure import regionprops, label +# import cv2 +# import seaborn as sns +# import matplotlib.pyplot as plt + +# # %% load the predicted zarr and the human-in-loop annotations zarr + +# pred_datapath = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/pred/Exp_2024_02_13_DENV_3infMarked_pred.zarr" +# test_datapath = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2024_02_13_DENV_3infMarked_test.zarr" + +# pred_dataset = open_ome_zarr( +# pred_datapath, +# layout="hcs", +# mode="r+", +# ) +# chan_pred = pred_dataset.channel_names + +# test_dataset = open_ome_zarr( +# test_datapath, +# layout="hcs", +# mode="r+", +# ) +# chan_test = test_dataset.channel_names + +# # %% compute confusion matrix for one image +# for well_id, well_data in pred_dataset.wells(): +# well_name, well_no = well_id.split("/") + +# for pos_name, pos_data in well_data.positions(): + +# pred_data = pos_data.data +# pred_pos_data = pred_data.numpy() +# T,C,Z,X,Y = pred_pos_data.shape + +# test_data = test_dataset[well_id + "/" + pos_name + "/0"] +# test_pos_data = test_data.numpy() + +# # compute confusion matrix for each time point and add to total +# conf_mat = np.zeros((2, 2)) +# for time in range(T): +# pred_img = pred_pos_data[time, chan_pred.index("Inf_mask_prediction"), 0, : , :] +# test_img = test_pos_data[time, chan_test.index("Inf_mask"), 0, : , :] + +# test_img_rz = cv2.resize(test_img, dsize=(2016,2048), interpolation=cv2.INTER_NEAREST)# pred_img = +# pred_img = np.where(test_img_rz > 0, pred_img, 0) + +# # find objects in every image +# label_img = label(test_img_rz) +# regions_label = regionprops(label_img) + +# # store pixel id for every label in pred and test imgs +# for region in regions_label: +# if region.area > 0: +# row, col = region.centroid +# pred_id = pred_img[int(row), int(col)] +# test_id = test_img_rz[int(row), int(col)] +# if pred_id == 1 and test_id == 1: +# conf_mat[1,1] += 1 +# if pred_id == 1 and test_id == 2: +# conf_mat[1,0] += 1 +# if pred_id == 2 and test_id == 1: +# conf_mat[0,1] += 1 +# if pred_id == 2 and test_id == 2: +# conf_mat[0,0] += 1 + +# # display the confusion matrix +# ax= plt.subplot() +# sns.heatmap(conf_mat, annot=True, fmt='g', ax=ax); #annot=True to annotate cells, ftm='g' to disable scientific notation + +# # labels, title and ticks +# ax.set_xlabel('annotated labels');ax.set_ylabel('predicted labels'); +# ax.set_title('Confusion Matrix'); +# ax.xaxis.set_ticklabels(['infected', 'uninfected']); ax.yaxis.set_ticklabels(['infected', 'uninfected']); + + +# # %% +# %% From 9c9ce41b27f0ada2e3cc50f255e702f60e955654 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Fri, 12 Apr 2024 16:12:12 -0700 Subject: [PATCH 70/92] fix merge error --- viscy/data/hcs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index 27992bfc..f33b6121 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -427,7 +427,7 @@ def _setup_test(self, dataset_settings: dict): [p for _, p in plate.positions()], transform=test_transform, ground_truth_masks=self.ground_truth_masks, - norm_meta=plate.zattrs["normalization"] ** dataset_settings, + **dataset_settings, ) else: self.test_dataset = SlidingWindowDataset( From 6b0a42d525c18f740054f7e82897f32a2cf70210 Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Thu, 23 May 2024 14:01:11 -0700 Subject: [PATCH 71/92] split 2D and 2.5D model scripts --- .../Infection_classification_25DModel.py | 154 ++++++++ ...py => Infection_classification_2Dmodel.py} | 2 +- .../classify_infection_25D.py | 335 ++++++++++++++++++ ..._infection.py => classify_infection_2D.py} | 5 +- 4 files changed, 494 insertions(+), 2 deletions(-) create mode 100644 viscy/scripts/infection_phenotyping/Infection_classification_25DModel.py rename viscy/scripts/infection_phenotyping/{Infection_classification_model.py => Infection_classification_2Dmodel.py} (97%) create mode 100644 viscy/scripts/infection_phenotyping/classify_infection_25D.py rename viscy/scripts/infection_phenotyping/{classify_infection.py => classify_infection_2D.py} (98%) diff --git a/viscy/scripts/infection_phenotyping/Infection_classification_25DModel.py b/viscy/scripts/infection_phenotyping/Infection_classification_25DModel.py new file mode 100644 index 00000000..91702497 --- /dev/null +++ b/viscy/scripts/infection_phenotyping/Infection_classification_25DModel.py @@ -0,0 +1,154 @@ +# %% +import torch +import lightning.pytorch as pl +import torch.nn as nn + +from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.callbacks import ModelCheckpoint + +from viscy.transforms import RandWeightedCropd +from viscy.transforms import NormalizeSampled +from viscy.data.hcs import HCSDataModule +from viscy.scripts.infection_phenotyping.classify_infection_25D import SemanticSegUNet25D + +from iohub.ngff import open_ome_zarr + +# %% Create a dataloader and visualize the batches. + +# Set the path to the dataset +dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2023_11_08_Opencell_infection/OC43_infection_timelapse_trainVal.zarr" + +# find ratio of background, uninfected and infected pixels +zarr_input = open_ome_zarr( + dataset_path, + layout="hcs", + mode="r+", +) +in_chan_names = zarr_input.channel_names + +num_pixels_bkg = 0 +num_pixels_uninf = 0 +num_pixels_inf = 0 +num_pixels = 0 +for well_id, well_data in zarr_input.wells(): + well_name, well_no = well_id.split("/") + + for pos_name, pos_data in well_data.positions(): + data = pos_data.data + T,C,Z,Y,X = data.shape + out_data = data.numpy() + for time in range(T): + Inf_mask = out_data[time,in_chan_names.index("Inf_mask"),...] + # Calculate the number of pixels valued 0, 1, and 2 in 'Inf_mask' + num_pixels_bkg = num_pixels_bkg + (Inf_mask == 0).sum() + num_pixels_uninf = num_pixels_uninf + (Inf_mask == 1).sum() + num_pixels_inf = num_pixels_inf + (Inf_mask == 2).sum() + num_pixels = num_pixels + Z*X*Y + +pixel_ratio_1 = [num_pixels/num_pixels_bkg, num_pixels/num_pixels_uninf, num_pixels/num_pixels_inf] +pixel_ratio_sum = sum(pixel_ratio_1) +pixel_ratio = [ratio / pixel_ratio_sum for ratio in pixel_ratio_1] + +# %% craete data module + +# Create an instance of HCSDataModule +data_module = HCSDataModule( + dataset_path, + source_channel=["Phase", "HSP90"], + target_channel=["Inf_mask"], + yx_patch_size=[512, 512], + split_ratio=0.8, + z_window_size=5, + architecture="2.5D", + num_workers=3, + batch_size=32, + normalizations=[ + NormalizeSampled( + keys=["Phase","HSP90"], + level="fov_statistics", + subtrahend="median", + divisor="iqr", + ) + ], + augmentations=[ + RandWeightedCropd( + num_samples=4, + spatial_size=[-1, 512, 512], + keys=["Phase","HSP90"], + w_key="Inf_mask", + ) + ], +) + +# Prepare the data +data_module.prepare_data() + +# Setup the data +data_module.setup(stage="fit") + +# Create a dataloader +train_dm = data_module.train_dataloader() + +val_dm = data_module.val_dataloader() + +# Visualize the dataset and the batch using napari +# Set the display +# os.environ['DISPLAY'] = ':1' + +# # Create a napari viewer +# viewer = napari.Viewer() + +# # Add the dataset to the viewer +# for batch in dataloader: +# if isinstance(batch, dict): +# for k, v in batch.items(): +# if isinstance(v, torch.Tensor): +# viewer.add_image(v.cpu().numpy().astype(np.float32)) + +# # Start the napari event loop +# napari.run() + + +# %% Define the logger +logger = TensorBoardLogger( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/", + name="logs", +) + +# Pass the logger to the Trainer +trainer = pl.Trainer( + logger=logger, + max_epochs=200, + default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/logs/", + log_every_n_steps=1, + devices=1, # Set the number of GPUs to use. This avoids run-time exception from distributed training when the node has multiple GPUs +) + +# Define the checkpoint callback +checkpoint_callback = ModelCheckpoint( + dirpath="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/logs/", + filename="checkpoint_{epoch:02d}", + save_top_k=-1, + verbose=True, + monitor="loss/validate", + mode="min", +) + +# Add the checkpoint callback to the trainer +trainer.callbacks.append(checkpoint_callback) + +# Fit the model +model = SemanticSegUNet25D( + in_channels=2, + out_channels=3, + loss_function=nn.CrossEntropyLoss(weight=torch.tensor(pixel_ratio)), +) + +print(model) + +# %% +# Run training. + +trainer.fit(model, data_module) + +# %% diff --git a/viscy/scripts/infection_phenotyping/Infection_classification_model.py b/viscy/scripts/infection_phenotyping/Infection_classification_2Dmodel.py similarity index 97% rename from viscy/scripts/infection_phenotyping/Infection_classification_model.py rename to viscy/scripts/infection_phenotyping/Infection_classification_2Dmodel.py index 37f606cb..52af4673 100644 --- a/viscy/scripts/infection_phenotyping/Infection_classification_model.py +++ b/viscy/scripts/infection_phenotyping/Infection_classification_2Dmodel.py @@ -9,7 +9,7 @@ from viscy.transforms import RandWeightedCropd from viscy.transforms import NormalizeSampled from viscy.data.hcs import HCSDataModule -from viscy.scripts.infection_phenotyping.classify_infection import SemanticSegUNet2D +from viscy.scripts.infection_phenotyping.classify_infection_2D import SemanticSegUNet2D # %% Create a dataloader and visualize the batches. diff --git a/viscy/scripts/infection_phenotyping/classify_infection_25D.py b/viscy/scripts/infection_phenotyping/classify_infection_25D.py new file mode 100644 index 00000000..c78a7e8f --- /dev/null +++ b/viscy/scripts/infection_phenotyping/classify_infection_25D.py @@ -0,0 +1,335 @@ + +import torch +import torch.nn as nn +import lightning.pytorch as pl +import torch.nn.functional as F +from torch import Tensor +import cv2 + +# import torchview +from typing import Literal, Sequence +from skimage.exposure import rescale_intensity +from matplotlib.cm import get_cmap +from skimage.measure import regionprops, label +import numpy as np +import matplotlib.pyplot as plt + +from monai.transforms import DivisiblePad +from viscy.unet.networks.Unet25D import Unet25d +from viscy.data.hcs import Sample + +# +# %% Methods to compute confusion matrix per cell using torchmetrics + +# The confusion matrix at the single-cell resolution. +def confusion_matrix_per_cell( + y_true: torch.Tensor, y_pred: torch.Tensor, num_classes: int +): + """Compute confusion matrix per cell. + + Args: + y_true (torch.Tensor): Ground truth label image (BXHXW). + y_pred (torch.Tensor): Predicted label image (BXHXW). + num_classes (int): Number of classes. + + Returns: + torch.Tensor: Confusion matrix per cell (BXCXC). + """ + # Convert the image class to the nuclei class + confusion_matrix_per_cell = compute_confusion_matrix(y_true, y_pred, num_classes) + confusion_matrix_per_cell = torch.tensor(confusion_matrix_per_cell) + return confusion_matrix_per_cell + + +# These images can be logged with prediction. +def compute_confusion_matrix( + y_true: torch.Tensor, y_pred: torch.Tensor, num_classes: int +): + """Convert the class of the image to the class of the nuclei. + + Args: + label_image (torch.Tensor): Label image (BXHXW). Values of tensor are integers that represent semantic segmentation. + num_classes (int): Number of classes. + + Returns: + torch.Tensor: Label images with a consensus class at the centroid of nuclei. + """ + + batch_size = y_true.size(0) + # find centroids of nuclei from y_true + conf_mat = np.zeros((num_classes, num_classes)) + for i in range(batch_size): + y_true_cpu = y_true[i].cpu().numpy() + y_pred_cpu = y_pred[i].cpu().numpy() + y_true_reshaped = y_true_cpu.reshape(y_true_cpu.shape[-2:]) + y_pred_reshaped = y_pred_cpu.reshape(y_pred_cpu.shape[-2:]) + y_pred_resized = cv2.resize(y_pred_reshaped, dsize=y_true_reshaped.shape[::-1], interpolation=cv2.INTER_NEAREST) + y_pred_resized = np.where(y_true_reshaped > 0, y_pred_resized, 0) + + # find objects in every image + label_img = label(y_true_reshaped) + regions = regionprops(label_img) + + # Find centroids, pixel coordinates from the ground truth. + for region in regions: + if region.area > 0: + row, col = region.centroid + pred_id = y_pred_resized[int(row), int(col)] + test_id = y_true_reshaped[int(row), int(col)] + + if pred_id == 1 and test_id == 1: + conf_mat[1,1] += 1 + if pred_id == 1 and test_id == 2: + conf_mat[0,1] += 1 + if pred_id == 2 and test_id == 1: + conf_mat[1,0] += 1 + if pred_id == 2 and test_id == 2: + conf_mat[0,0] += 1 + # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. + # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. + return conf_mat + +def plot_confusion_matrix(confusion_matrix, index_to_label_dict): + # Create a figure and axis to plot the confusion matrix + fig, ax = plt.subplots() + + # Create a color heatmap for the confusion matrix + cax = ax.matshow(confusion_matrix, cmap="viridis") + + # Create a colorbar and set the label + index_to_label_dict = dict(enumerate(index_to_label_dict)) # Convert list to dictionary + fig.colorbar(cax, label="Frequency") + + # Set labels for the classes + ax.set_xticks(np.arange(len(index_to_label_dict))) + ax.set_yticks(np.arange(len(index_to_label_dict))) + ax.set_xticklabels(index_to_label_dict.values(), rotation=45) + ax.set_yticklabels(index_to_label_dict.values()) + + # Set labels for the axes + ax.set_xlabel("Predicted") + ax.set_ylabel("True") + + # Add text annotations to the confusion matrix + for i in range(len(index_to_label_dict)): + for j in range(len(index_to_label_dict)): + ax.text( + j, + i, + str(int(confusion_matrix[i, j])), + ha="center", + va="center", + color="white", + ) + + # plt.show(fig) # Show the figure + return fig +# Define a 25d unet model for infection classification as a lightning module. + +class SemanticSegUNet25D(pl.LightningModule): + # Model for semantic segmentation. + def __init__( + self, + in_channels: int, # Number of input channels + out_channels: int, # Number of output channels + lr: float = 1e-3, # Learning rate + loss_function: nn.Module = nn.CrossEntropyLoss(), # Loss function + schedule: Literal[ + "WarmupCosine", "Constant" + ] = "Constant", # Learning rate schedule + log_batches_per_epoch: int = 2, # Number of batches to log per epoch + log_samples_per_batch: int = 2, # Number of samples to log per batch + ckpt_path: str = None, # Path to the checkpoint + ): + super(SemanticSegUNet25D, self).__init__() # Call the superclass initializer + # Initialize the UNet model + self.unet_model = Unet25d(in_channels=in_channels, out_channels=out_channels, num_blocks=4, num_block_layers=4) + if ckpt_path is not None: + state_dict = torch.load(ckpt_path, map_location=torch.device("cpu"))[ + "state_dict" + ] + state_dict.pop("loss_function.weight", None) # Remove the unexpected key + self.load_state_dict(state_dict) # loading only weights + self.lr = lr # Set the learning rate + # Set the loss function to CrossEntropyLoss if none is provided + self.loss_function = loss_function if loss_function else nn.CrossEntropyLoss() + self.schedule = schedule # Set the learning rate schedule + self.log_batches_per_epoch = ( + log_batches_per_epoch # Set the number of batches to log per epoch + ) + self.log_samples_per_batch = ( + log_samples_per_batch # Set the number of samples to log per batch + ) + self.training_step_outputs = [] # Initialize the list of training step outputs + self.validation_step_outputs = ( + [] + ) # Initialize the list of validation step outputs + + self.pred_cm = None # Initialize the confusion matrix + self.index_to_label_dict = ["Infected", "Uninfected"] + + + # Define the forward pass + def forward(self, x): + return self.unet_model(x) # Pass the input through the UNet model + + # Define the optimizer + def configure_optimizers(self): + optimizer = torch.optim.Adam( + self.parameters(), lr=self.lr + ) # Use the Adam optimizer + return optimizer + + # Define the training step + def training_step(self, batch: Sample, batch_idx: int): + source = batch["source"] # Extract the source from the batch + target = batch["target"] # Extract the target from the batch + pred = self.forward(source) # Make a prediction using the source + # Convert the target to one-hot encoding + target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute( + 0, 4, 1, 2, 3 + ) + target_one_hot = target_one_hot.float() # Convert the target to float type + train_loss = self.loss_function(pred, target_one_hot) # Calculate the loss + # Log the training step outputs if the batch index is less than the number of batches to log per epoch + if batch_idx < self.log_batches_per_epoch: + self.training_step_outputs.extend( + self._detach_sample((source, target_one_hot, pred)) + ) + # Log the training loss + self.log( + "loss/train", + train_loss, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, + ) + return train_loss # Return the training loss + + def validation_step(self, batch: Sample, batch_idx: int): + source = batch["source"] # Extract the source from the batch + target = batch["target"] # Extract the target from the batch + pred = self.forward(source) # Make a prediction using the source + # Convert the target to one-hot encoding + target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute( + 0, 4, 1, 2, 3 + ) + target_one_hot = target_one_hot.float() # Convert the target to float type + loss = self.loss_function(pred, target_one_hot) # Calculate the loss + # Log the validation step outputs if the batch index is less than the number of batches to log per epoch + if batch_idx < self.log_batches_per_epoch: + self.validation_step_outputs.extend( + self._detach_sample((source, target_one_hot, pred)) + ) + # Log the validation loss + self.log( + "loss/validate", loss, sync_dist=True, add_dataloader_idx=False, logger=True + ) + return loss # Return the validation loss + + def on_predict_start(self): + """Pad the input shape to be divisible by the downsampling factor. + The inverse of this transform crops the prediction to original shape. + """ + down_factor = 2**self.unet_model.num_blocks + self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) + + # Define the prediction step + def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): + source = self._predict_pad(batch["source"]) # Pad the source + logits = self._predict_pad.inverse(self.forward(source)) # Predict and remove padding. + prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities + # Go from probabilities/one-hot encoded data to class labels. + labels_pred = torch.argmax(prob_pred, dim=1, keepdim=True) # Calculate the predicted labels + # prob_chan = prob_pred[:, 2, :, :] + # prob_chan = prob_chan.unsqueeze(1) + return labels_pred # log the class predicted image + # return prob_chan # log the probability predicted image + + def on_test_start(self): + self.pred_cm = torch.zeros((2,2)) + down_factor = 2**self.unet_model.num_blocks + self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) + + def test_step(self, batch: Sample): + source = self._predict_pad(batch["source"]) # Pad the source + logits = self._predict_pad.inverse(self.forward(source)) + prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities + labels_pred = torch.argmax(prob_pred, dim=1, keepdim=True) # Calculate the predicted labels + + target = self._predict_pad(batch["target"]) # Extract the target from the batch + pred_cm = confusion_matrix_per_cell( + target, labels_pred, num_classes=2 + ) # Calculate the confusion matrix per cell + self.pred_cm += pred_cm # Append the confusion matrix to pred_cm + + self.logger.experiment.add_figure( + "Confusion Matrix per Cell", + plot_confusion_matrix(pred_cm, self.index_to_label_dict), + self.current_epoch, + ) + + # Accumulate the confusion matrix at the end of test epoch and log. + def on_test_end(self): + confusion_matrix_sum = self.pred_cm + self.logger.experiment.add_figure( + "Confusion Matrix", + plot_confusion_matrix(confusion_matrix_sum, self.index_to_label_dict), + self.current_epoch, + ) + + # Define what happens at the end of a training epoch + def on_train_epoch_end(self): + self._log_samples( + "train_samples", self.training_step_outputs + ) # Log the training samples + self.training_step_outputs = [] # Reset the list of training step outputs + + # Define what happens at the end of a validation epoch + def on_validation_epoch_end(self): + self._log_samples( + "val_samples", self.validation_step_outputs + ) # Log the validation samples + self.validation_step_outputs = [] # Reset the list of validation step outputs + + # Define a method to detach a sample + def _detach_sample(self, imgs: Sequence[Tensor]): + # Detach the images and convert them to numpy arrays + num_samples = 3 + return [ + [np.squeeze(img[i].detach().cpu().numpy().max(axis=1)) for img in imgs] + for i in range(num_samples) + ] + + # Define a method to log samples + def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): + images_grid = [] # Initialize the list of image grids + for sample_images in imgs: # For each sample image + images_row = [] # Initialize the list of image rows + for i, image in enumerate( + sample_images + ): # For each image in the sample images + cm_name = "gray" if i == 0 else "inferno" # Set the colormap name + if image.ndim == 2: # If the image is 2D + image = image[np.newaxis] # Add a new axis + for channel in image: # For each channel in the image + channel = rescale_intensity( + channel, out_range=(0, 1) + ) # Rescale the intensity of the channel + render = get_cmap(cm_name)(channel, bytes=True)[ + ..., :3 + ] # Render the channel + images_row.append( + render + ) # Append the render to the list of image rows + images_grid.append( + np.concatenate(images_row, axis=1) + ) # Append the concatenated image rows to the list of image grids + grid = np.concatenate(images_grid, axis=0) # Concatenate the image grids + # Log the image grid + self.logger.experiment.add_image( + key, grid, self.current_epoch, dataformats="HWC" + ) +# %% diff --git a/viscy/scripts/infection_phenotyping/classify_infection.py b/viscy/scripts/infection_phenotyping/classify_infection_2D.py similarity index 98% rename from viscy/scripts/infection_phenotyping/classify_infection.py rename to viscy/scripts/infection_phenotyping/classify_infection_2D.py index ee904434..b4269c74 100644 --- a/viscy/scripts/infection_phenotyping/classify_infection.py +++ b/viscy/scripts/infection_phenotyping/classify_infection_2D.py @@ -16,6 +16,7 @@ from monai.transforms import DivisiblePad from viscy.unet.networks.Unet2D import Unet2d +# from viscy.unet.networks.Unet25D import Unet25d from viscy.data.hcs import Sample # @@ -244,8 +245,10 @@ def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities # Go from probabilities/one-hot encoded data to class labels. labels_pred = torch.argmax(prob_pred, dim=1, keepdim=True) # Calculate the predicted labels - + # prob_chan = prob_pred[:, 2, :, :] + # prob_chan = prob_chan.unsqueeze(1) return labels_pred # log the class predicted image + # return prob_chan # log the probability predicted image def on_test_start(self): self.pred_cm = torch.zeros((2,2)) From 2ea889227c4370e9a9419e9375986728d2903db6 Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Mon, 27 May 2024 09:15:09 -0700 Subject: [PATCH 72/92] added covnext script --- .../Infection_classification_covnextModel.py | 154 ++++++++ .../classify_infection_covnext.py | 347 ++++++++++++++++++ 2 files changed, 501 insertions(+) create mode 100644 viscy/scripts/infection_phenotyping/Infection_classification_covnextModel.py create mode 100644 viscy/scripts/infection_phenotyping/classify_infection_covnext.py diff --git a/viscy/scripts/infection_phenotyping/Infection_classification_covnextModel.py b/viscy/scripts/infection_phenotyping/Infection_classification_covnextModel.py new file mode 100644 index 00000000..2b8a1634 --- /dev/null +++ b/viscy/scripts/infection_phenotyping/Infection_classification_covnextModel.py @@ -0,0 +1,154 @@ +# %% +import torch +import lightning.pytorch as pl +import torch.nn as nn + +from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.callbacks import ModelCheckpoint + +from viscy.transforms import RandWeightedCropd +from viscy.transforms import NormalizeSampled +from viscy.data.hcs import HCSDataModule +from viscy.scripts.infection_phenotyping.classify_infection_covnext import SemanticSegUNet25D + +from iohub.ngff import open_ome_zarr + +# %% Create a dataloader and visualize the batches. + +# Set the path to the dataset +dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2023_11_08_Opencell_infection/OC43_infection_timelapse_trainVal.zarr" + +# find ratio of background, uninfected and infected pixels +zarr_input = open_ome_zarr( + dataset_path, + layout="hcs", + mode="r+", +) +in_chan_names = zarr_input.channel_names + +num_pixels_bkg = 0 +num_pixels_uninf = 0 +num_pixels_inf = 0 +num_pixels = 0 +for well_id, well_data in zarr_input.wells(): + well_name, well_no = well_id.split("/") + + for pos_name, pos_data in well_data.positions(): + data = pos_data.data + T,C,Z,Y,X = data.shape + out_data = data.numpy() + for time in range(T): + Inf_mask = out_data[time,in_chan_names.index("Inf_mask"),...] + # Calculate the number of pixels valued 0, 1, and 2 in 'Inf_mask' + num_pixels_bkg = num_pixels_bkg + (Inf_mask == 0).sum() + num_pixels_uninf = num_pixels_uninf + (Inf_mask == 1).sum() + num_pixels_inf = num_pixels_inf + (Inf_mask == 2).sum() + num_pixels = num_pixels + Z*X*Y + +pixel_ratio_1 = [num_pixels/num_pixels_bkg, num_pixels/num_pixels_uninf, num_pixels/num_pixels_inf] +pixel_ratio_sum = sum(pixel_ratio_1) +pixel_ratio = [ratio / pixel_ratio_sum for ratio in pixel_ratio_1] + +# %% craete data module + +# Create an instance of HCSDataModule +data_module = HCSDataModule( + dataset_path, + source_channel=["Phase", "HSP90"], + target_channel=["Inf_mask"], + yx_patch_size=[256, 256], + split_ratio=0.8, + z_window_size=5, + architecture="2.2D", + num_workers=3, + batch_size=16, + normalizations=[ + NormalizeSampled( + keys=["Phase","HSP90"], + level="fov_statistics", + subtrahend="median", + divisor="iqr", + ) + ], + augmentations=[ + RandWeightedCropd( + num_samples=4, + spatial_size=[-1, 256, 256], + keys=["Phase","HSP90"], + w_key="Inf_mask", + ) + ], +) + +# Prepare the data +data_module.prepare_data() + +# Setup the data +data_module.setup(stage="fit") + +# Create a dataloader +train_dm = data_module.train_dataloader() + +val_dm = data_module.val_dataloader() + +# Visualize the dataset and the batch using napari +# Set the display +# os.environ['DISPLAY'] = ':1' + +# # Create a napari viewer +# viewer = napari.Viewer() + +# # Add the dataset to the viewer +# for batch in dataloader: +# if isinstance(batch, dict): +# for k, v in batch.items(): +# if isinstance(v, torch.Tensor): +# viewer.add_image(v.cpu().numpy().astype(np.float32)) + +# # Start the napari event loop +# napari.run() + + +# %% Define the logger +logger = TensorBoardLogger( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/", + name="logs", +) + +# Pass the logger to the Trainer +trainer = pl.Trainer( + logger=logger, + max_epochs=200, + default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/logs/", + log_every_n_steps=1, + devices=1, # Set the number of GPUs to use. This avoids run-time exception from distributed training when the node has multiple GPUs +) + +# Define the checkpoint callback +checkpoint_callback = ModelCheckpoint( + dirpath="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/mantis_phase_hsp90/logs/", + filename="checkpoint_{epoch:02d}", + save_top_k=-1, + verbose=True, + monitor="loss/validate", + mode="min", +) + +# Add the checkpoint callback to the trainer +trainer.callbacks.append(checkpoint_callback) + +# Fit the model +model = SemanticSegUNet25D( + in_channels=2, + out_channels=3, + loss_function=nn.CrossEntropyLoss(weight=torch.tensor(pixel_ratio)), +) + +print(model) + +# %% +# Run training. + +trainer.fit(model, data_module) + +# %% diff --git a/viscy/scripts/infection_phenotyping/classify_infection_covnext.py b/viscy/scripts/infection_phenotyping/classify_infection_covnext.py new file mode 100644 index 00000000..edf2feb4 --- /dev/null +++ b/viscy/scripts/infection_phenotyping/classify_infection_covnext.py @@ -0,0 +1,347 @@ + +import torch +import torch.nn as nn +import lightning.pytorch as pl +import torch.nn.functional as F +from torch import Tensor +import cv2 + +# import torchview +from typing import Literal, Sequence +from skimage.exposure import rescale_intensity +from matplotlib.cm import get_cmap +from skimage.measure import regionprops, label +import numpy as np +import matplotlib.pyplot as plt + +from monai.transforms import DivisiblePad +from viscy.unet.networks.Unet25D import Unet25d +from viscy.data.hcs import Sample +from viscy.light.engine import VSUNet + +# +# %% Methods to compute confusion matrix per cell using torchmetrics + +# The confusion matrix at the single-cell resolution. +def confusion_matrix_per_cell( + y_true: torch.Tensor, y_pred: torch.Tensor, num_classes: int +): + """Compute confusion matrix per cell. + + Args: + y_true (torch.Tensor): Ground truth label image (BXHXW). + y_pred (torch.Tensor): Predicted label image (BXHXW). + num_classes (int): Number of classes. + + Returns: + torch.Tensor: Confusion matrix per cell (BXCXC). + """ + # Convert the image class to the nuclei class + confusion_matrix_per_cell = compute_confusion_matrix(y_true, y_pred, num_classes) + confusion_matrix_per_cell = torch.tensor(confusion_matrix_per_cell) + return confusion_matrix_per_cell + + +# These images can be logged with prediction. +def compute_confusion_matrix( + y_true: torch.Tensor, y_pred: torch.Tensor, num_classes: int +): + """Convert the class of the image to the class of the nuclei. + + Args: + label_image (torch.Tensor): Label image (BXHXW). Values of tensor are integers that represent semantic segmentation. + num_classes (int): Number of classes. + + Returns: + torch.Tensor: Label images with a consensus class at the centroid of nuclei. + """ + + batch_size = y_true.size(0) + # find centroids of nuclei from y_true + conf_mat = np.zeros((num_classes, num_classes)) + for i in range(batch_size): + y_true_cpu = y_true[i].cpu().numpy() + y_pred_cpu = y_pred[i].cpu().numpy() + y_true_reshaped = y_true_cpu.reshape(y_true_cpu.shape[-2:]) + y_pred_reshaped = y_pred_cpu.reshape(y_pred_cpu.shape[-2:]) + y_pred_resized = cv2.resize(y_pred_reshaped, dsize=y_true_reshaped.shape[::-1], interpolation=cv2.INTER_NEAREST) + y_pred_resized = np.where(y_true_reshaped > 0, y_pred_resized, 0) + + # find objects in every image + label_img = label(y_true_reshaped) + regions = regionprops(label_img) + + # Find centroids, pixel coordinates from the ground truth. + for region in regions: + if region.area > 0: + row, col = region.centroid + pred_id = y_pred_resized[int(row), int(col)] + test_id = y_true_reshaped[int(row), int(col)] + + if pred_id == 1 and test_id == 1: + conf_mat[1,1] += 1 + if pred_id == 1 and test_id == 2: + conf_mat[0,1] += 1 + if pred_id == 2 and test_id == 1: + conf_mat[1,0] += 1 + if pred_id == 2 and test_id == 2: + conf_mat[0,0] += 1 + # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. + # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. + return conf_mat + +def plot_confusion_matrix(confusion_matrix, index_to_label_dict): + # Create a figure and axis to plot the confusion matrix + fig, ax = plt.subplots() + + # Create a color heatmap for the confusion matrix + cax = ax.matshow(confusion_matrix, cmap="viridis") + + # Create a colorbar and set the label + index_to_label_dict = dict(enumerate(index_to_label_dict)) # Convert list to dictionary + fig.colorbar(cax, label="Frequency") + + # Set labels for the classes + ax.set_xticks(np.arange(len(index_to_label_dict))) + ax.set_yticks(np.arange(len(index_to_label_dict))) + ax.set_xticklabels(index_to_label_dict.values(), rotation=45) + ax.set_yticklabels(index_to_label_dict.values()) + + # Set labels for the axes + ax.set_xlabel("Predicted") + ax.set_ylabel("True") + + # Add text annotations to the confusion matrix + for i in range(len(index_to_label_dict)): + for j in range(len(index_to_label_dict)): + ax.text( + j, + i, + str(int(confusion_matrix[i, j])), + ha="center", + va="center", + color="white", + ) + + # plt.show(fig) # Show the figure + return fig +# Define a 25d unet model for infection classification as a lightning module. + +class SemanticSegUNet25D(pl.LightningModule): + # Model for semantic segmentation. + def __init__( + self, + in_channels: int, # Number of input channels + out_channels: int, # Number of output channels + lr: float = 1e-3, # Learning rate + loss_function: nn.Module = nn.CrossEntropyLoss(), # Loss function + schedule: Literal[ + "WarmupCosine", "Constant" + ] = "Constant", # Learning rate schedule + log_batches_per_epoch: int = 2, # Number of batches to log per epoch + log_samples_per_batch: int = 2, # Number of samples to log per batch + ckpt_path: str = None, # Path to the checkpoint + ): + super(SemanticSegUNet25D, self).__init__() # Call the superclass initializer + # Initialize the UNet model + self.unet_model = VSUNet( + architecture="2.2D", + model_config={ + "in_channels": 1, + "out_channels": 3, + "in_stack_depth": 5, + "backbone": "convnextv2_tiny", + "stem_kernel_size": (5, 4, 4), + "decoder_mode": "pixelshuffle", + "head_expansion_ratio": 4, + }, + ) + if ckpt_path is not None: + state_dict = torch.load(ckpt_path, map_location=torch.device("cpu"))[ + "state_dict" + ] + state_dict.pop("loss_function.weight", None) # Remove the unexpected key + self.load_state_dict(state_dict) # loading only weights + self.lr = lr # Set the learning rate + # Set the loss function to CrossEntropyLoss if none is provided + self.loss_function = loss_function if loss_function else nn.CrossEntropyLoss() + self.schedule = schedule # Set the learning rate schedule + self.log_batches_per_epoch = ( + log_batches_per_epoch # Set the number of batches to log per epoch + ) + self.log_samples_per_batch = ( + log_samples_per_batch # Set the number of samples to log per batch + ) + self.training_step_outputs = [] # Initialize the list of training step outputs + self.validation_step_outputs = ( + [] + ) # Initialize the list of validation step outputs + + self.pred_cm = None # Initialize the confusion matrix + self.index_to_label_dict = ["Infected", "Uninfected"] + + + # Define the forward pass + def forward(self, x): + return self.unet_model(x) # Pass the input through the UNet model + + # Define the optimizer + def configure_optimizers(self): + optimizer = torch.optim.Adam( + self.parameters(), lr=self.lr + ) # Use the Adam optimizer + return optimizer + + # Define the training step + def training_step(self, batch: Sample, batch_idx: int): + source = batch["source"] # Extract the source from the batch + target = batch["target"] # Extract the target from the batch + pred = self.forward(source) # Make a prediction using the source + # Convert the target to one-hot encoding + target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute( + 0, 4, 1, 2, 3 + ) + target_one_hot = target_one_hot.float() # Convert the target to float type + train_loss = self.loss_function(pred, target_one_hot) # Calculate the loss + # Log the training step outputs if the batch index is less than the number of batches to log per epoch + if batch_idx < self.log_batches_per_epoch: + self.training_step_outputs.extend( + self._detach_sample((source, target_one_hot, pred)) + ) + # Log the training loss + self.log( + "loss/train", + train_loss, + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=True, + ) + return train_loss # Return the training loss + + def validation_step(self, batch: Sample, batch_idx: int): + source = batch["source"] # Extract the source from the batch + target = batch["target"] # Extract the target from the batch + pred = self.forward(source) # Make a prediction using the source + # Convert the target to one-hot encoding + target_one_hot = F.one_hot(target.squeeze(1).long(), num_classes=3).permute( + 0, 4, 1, 2, 3 + ) + target_one_hot = target_one_hot.float() # Convert the target to float type + loss = self.loss_function(pred, target_one_hot) # Calculate the loss + # Log the validation step outputs if the batch index is less than the number of batches to log per epoch + if batch_idx < self.log_batches_per_epoch: + self.validation_step_outputs.extend( + self._detach_sample((source, target_one_hot, pred)) + ) + # Log the validation loss + self.log( + "loss/validate", loss, sync_dist=True, add_dataloader_idx=False, logger=True + ) + return loss # Return the validation loss + + def on_predict_start(self): + """Pad the input shape to be divisible by the downsampling factor. + The inverse of this transform crops the prediction to original shape. + """ + down_factor = 2**self.unet_model.num_blocks + self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) + + # Define the prediction step + def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): + source = self._predict_pad(batch["source"]) # Pad the source + logits = self._predict_pad.inverse(self.forward(source)) # Predict and remove padding. + prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities + # Go from probabilities/one-hot encoded data to class labels. + labels_pred = torch.argmax(prob_pred, dim=1, keepdim=True) # Calculate the predicted labels + # prob_chan = prob_pred[:, 2, :, :] + # prob_chan = prob_chan.unsqueeze(1) + return labels_pred # log the class predicted image + # return prob_chan # log the probability predicted image + + def on_test_start(self): + self.pred_cm = torch.zeros((2,2)) + down_factor = 2**self.unet_model.num_blocks + self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) + + def test_step(self, batch: Sample): + source = self._predict_pad(batch["source"]) # Pad the source + logits = self._predict_pad.inverse(self.forward(source)) + prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities + labels_pred = torch.argmax(prob_pred, dim=1, keepdim=True) # Calculate the predicted labels + + target = self._predict_pad(batch["target"]) # Extract the target from the batch + pred_cm = confusion_matrix_per_cell( + target, labels_pred, num_classes=2 + ) # Calculate the confusion matrix per cell + self.pred_cm += pred_cm # Append the confusion matrix to pred_cm + + self.logger.experiment.add_figure( + "Confusion Matrix per Cell", + plot_confusion_matrix(pred_cm, self.index_to_label_dict), + self.current_epoch, + ) + + # Accumulate the confusion matrix at the end of test epoch and log. + def on_test_end(self): + confusion_matrix_sum = self.pred_cm + self.logger.experiment.add_figure( + "Confusion Matrix", + plot_confusion_matrix(confusion_matrix_sum, self.index_to_label_dict), + self.current_epoch, + ) + + # Define what happens at the end of a training epoch + def on_train_epoch_end(self): + self._log_samples( + "train_samples", self.training_step_outputs + ) # Log the training samples + self.training_step_outputs = [] # Reset the list of training step outputs + + # Define what happens at the end of a validation epoch + def on_validation_epoch_end(self): + self._log_samples( + "val_samples", self.validation_step_outputs + ) # Log the validation samples + self.validation_step_outputs = [] # Reset the list of validation step outputs + + # Define a method to detach a sample + def _detach_sample(self, imgs: Sequence[Tensor]): + # Detach the images and convert them to numpy arrays + num_samples = 3 + return [ + [np.squeeze(img[i].detach().cpu().numpy().max(axis=1)) for img in imgs] + for i in range(num_samples) + ] + + # Define a method to log samples + def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): + images_grid = [] # Initialize the list of image grids + for sample_images in imgs: # For each sample image + images_row = [] # Initialize the list of image rows + for i, image in enumerate( + sample_images + ): # For each image in the sample images + cm_name = "gray" if i == 0 else "inferno" # Set the colormap name + if image.ndim == 2: # If the image is 2D + image = image[np.newaxis] # Add a new axis + for channel in image: # For each channel in the image + channel = rescale_intensity( + channel, out_range=(0, 1) + ) # Rescale the intensity of the channel + render = get_cmap(cm_name)(channel, bytes=True)[ + ..., :3 + ] # Render the channel + images_row.append( + render + ) # Append the render to the list of image rows + images_grid.append( + np.concatenate(images_row, axis=1) + ) # Append the concatenated image rows to the list of image grids + grid = np.concatenate(images_grid, axis=0) # Concatenate the image grids + # Log the image grid + self.logger.experiment.add_image( + key, grid, self.current_epoch, dataformats="HWC" + ) +# %% From 220eba131dd8dd20bd536823ea6f238f93a4bc37 Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Tue, 28 May 2024 13:15:33 -0700 Subject: [PATCH 73/92] fix model input parameter --- .../Infection_classification_covnextModel.py | 4 +++- .../infection_phenotyping/classify_infection_covnext.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/viscy/scripts/infection_phenotyping/Infection_classification_covnextModel.py b/viscy/scripts/infection_phenotyping/Infection_classification_covnextModel.py index 2b8a1634..60ce0b5a 100644 --- a/viscy/scripts/infection_phenotyping/Infection_classification_covnextModel.py +++ b/viscy/scripts/infection_phenotyping/Infection_classification_covnextModel.py @@ -1,4 +1,6 @@ # %% +# import sys +# sys.path.append("/hpc/mydata/soorya.pradeep/viscy_infection_phenotyping/Viscy/") import torch import lightning.pytorch as pl import torch.nn as nn @@ -134,7 +136,7 @@ mode="min", ) -# Add the checkpoint callback to the trainer +# Add the checkpoint callback to the trainer`` trainer.callbacks.append(checkpoint_callback) # Fit the model diff --git a/viscy/scripts/infection_phenotyping/classify_infection_covnext.py b/viscy/scripts/infection_phenotyping/classify_infection_covnext.py index edf2feb4..2ba698ee 100644 --- a/viscy/scripts/infection_phenotyping/classify_infection_covnext.py +++ b/viscy/scripts/infection_phenotyping/classify_infection_covnext.py @@ -147,8 +147,8 @@ def __init__( self.unet_model = VSUNet( architecture="2.2D", model_config={ - "in_channels": 1, - "out_channels": 3, + "in_channels": in_channels, + "out_channels": out_channels, "in_stack_depth": 5, "backbone": "convnextv2_tiny", "stem_kernel_size": (5, 4, 4), From c4839da3248bee2bde61e58c733f86ecbe486b1c Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Wed, 29 May 2024 09:33:10 -0700 Subject: [PATCH 74/92] update input file --- .../Infection_classification_covnextModel.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/viscy/scripts/infection_phenotyping/Infection_classification_covnextModel.py b/viscy/scripts/infection_phenotyping/Infection_classification_covnextModel.py index 60ce0b5a..0ecd6bdd 100644 --- a/viscy/scripts/infection_phenotyping/Infection_classification_covnextModel.py +++ b/viscy/scripts/infection_phenotyping/Infection_classification_covnextModel.py @@ -18,7 +18,7 @@ # %% Create a dataloader and visualize the batches. # Set the path to the dataset -dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2023_11_08_Opencell_infection/OC43_infection_timelapse_trainVal.zarr" +dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2023_11_08_Opencell_infection/OC43_infection_timelapse_all_curated_train.zarr" # find ratio of background, uninfected and infected pixels zarr_input = open_ome_zarr( @@ -56,7 +56,7 @@ # Create an instance of HCSDataModule data_module = HCSDataModule( dataset_path, - source_channel=["Phase", "HSP90"], + source_channel=["Phase", "HSP90", "phase_nucl_iqr","hsp90_skew"], target_channel=["Inf_mask"], yx_patch_size=[256, 256], split_ratio=0.8, @@ -66,7 +66,7 @@ batch_size=16, normalizations=[ NormalizeSampled( - keys=["Phase","HSP90"], + keys=["Phase","HSP90", "phase_nucl_iqr","hsp90_skew"], level="fov_statistics", subtrahend="median", divisor="iqr", @@ -76,7 +76,7 @@ RandWeightedCropd( num_samples=4, spatial_size=[-1, 256, 256], - keys=["Phase","HSP90"], + keys=["Phase","HSP90", "phase_nucl_iqr","hsp90_skew"], w_key="Inf_mask", ) ], @@ -141,7 +141,7 @@ # Fit the model model = SemanticSegUNet25D( - in_channels=2, + in_channels=4, out_channels=3, loss_function=nn.CrossEntropyLoss(weight=torch.tensor(pixel_ratio)), ) From c04b4ace519e725ffab3a02b4ba513080fcde28f Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Tue, 2 Jul 2024 07:02:08 -0700 Subject: [PATCH 75/92] add augmentations --- .../Infection_classification_2Dmodel.py | 73 +++++++++++++++---- .../classify_infection_2D.py | 2 +- 2 files changed, 60 insertions(+), 15 deletions(-) diff --git a/viscy/scripts/infection_phenotyping/Infection_classification_2Dmodel.py b/viscy/scripts/infection_phenotyping/Infection_classification_2Dmodel.py index 52af4673..aab7c86b 100644 --- a/viscy/scripts/infection_phenotyping/Infection_classification_2Dmodel.py +++ b/viscy/scripts/infection_phenotyping/Infection_classification_2Dmodel.py @@ -6,30 +6,63 @@ from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.callbacks import ModelCheckpoint -from viscy.transforms import RandWeightedCropd -from viscy.transforms import NormalizeSampled +from viscy.transforms import RandWeightedCropd, NormalizeSampled, RandScaleIntensityd, RandGaussianSmoothd from viscy.data.hcs import HCSDataModule from viscy.scripts.infection_phenotyping.classify_infection_2D import SemanticSegUNet2D +from iohub.ngff import open_ome_zarr + # %% Create a dataloader and visualize the batches. # Set the path to the dataset -dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2024_02_13_DENV_3infMarked_trainVal.zarr" +dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_05_03_DENV_eFT226_Timecourse/4-infection-classification/0_training_test_data/2024_05_03_DENV_eFT226_Timecourse_trainVal_2D.zarr" + +# find ratio of background, uninfected and infected pixels +zarr_input = open_ome_zarr( + dataset_path, + layout="hcs", + mode="r+", +) +in_chan_names = zarr_input.channel_names + +num_pixels_bkg = 0 +num_pixels_uninf = 0 +num_pixels_inf = 0 +num_pixels = 0 +for well_id, well_data in zarr_input.wells(): + well_name, well_no = well_id.split("/") + + for pos_name, pos_data in well_data.positions(): + data = pos_data.data + T,C,Z,Y,X = data.shape + out_data = data.numpy() + for time in range(T): + Inf_mask = out_data[time,in_chan_names.index("Inf_mask"),...] + # Calculate the number of pixels valued 0, 1, and 2 in 'Inf_mask' + num_pixels_bkg = num_pixels_bkg + (Inf_mask == 0).sum() + num_pixels_uninf = num_pixels_uninf + (Inf_mask == 1).sum() + num_pixels_inf = num_pixels_inf + (Inf_mask == 2).sum() + num_pixels = num_pixels + Z*X*Y + +pixel_ratio_1 = [num_pixels/num_pixels_bkg, num_pixels/num_pixels_uninf, num_pixels/num_pixels_inf] +pixel_ratio_sum = sum(pixel_ratio_1) +pixel_ratio = [ratio / pixel_ratio_sum for ratio in pixel_ratio_1] +# %% # Create an instance of HCSDataModule data_module = HCSDataModule( dataset_path, - source_channel=["Sensor", "Phase"], + source_channel=["mCherry", "Phase3D"], target_channel=["Inf_mask"], yx_patch_size=[128, 128], - split_ratio=0.8, + split_ratio=0.5, z_window_size=1, architecture="2D", num_workers=1, batch_size=128, normalizations=[ NormalizeSampled( - keys=["Sensor", "Phase"], + keys=["Phase3D", "mCherry"], level="fov_statistics", subtrahend="median", divisor="iqr", @@ -39,9 +72,21 @@ RandWeightedCropd( num_samples=8, spatial_size=[-1, 128, 128], - keys=["Sensor", "Phase", "Inf_mask"], + keys=["mCherry", "Phase3D", "Inf_mask"], w_key="Inf_mask", - ) + ), + RandScaleIntensityd( + keys=["mCherry", "Phase3D"], + factors=[0.1, 0.5], + prob=0.5, + ), + RandGaussianSmoothd( + keys=["mCherry", "Phase3D"], + prob=0.5, + sigma_x=[0.5, 1.0], + sigma_y=[0.5, 1.0], + sigma_z=[0.5, 1.0], + ), ], ) @@ -76,22 +121,22 @@ # %% Define the logger logger = TensorBoardLogger( - "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/", - name="logs_wPhase", + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_05_03_DENV_eFT226_Timecourse/4-infection-classification/1_model_training/", + name="logs", ) # Pass the logger to the Trainer trainer = pl.Trainer( logger=logger, - max_epochs=100, - default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase", + max_epochs=500, + default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/2024_05_03_DENV_eFT226_Timecourse/4-infection-classification/1_model_training/logs/", log_every_n_steps=1, devices=1, # Set the number of GPUs to use. This avoids run-time exception from distributed training when the node has multiple GPUs ) # Define the checkpoint callback checkpoint_callback = ModelCheckpoint( - dirpath="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase/", + dirpath="/hpc/projects/intracellular_dashboard/viral-sensor/2024_05_03_DENV_eFT226_Timecourse/4-infection-classification/1_model_training/logs/", filename="checkpoint_{epoch:02d}", save_top_k=-1, verbose=True, @@ -106,7 +151,7 @@ model = SemanticSegUNet2D( in_channels=2, out_channels=3, - loss_function=nn.CrossEntropyLoss(weight=torch.tensor([0.05, 0.25, 0.7])), + loss_function=nn.CrossEntropyLoss(weight=torch.tensor(pixel_ratio)), ) print(model) diff --git a/viscy/scripts/infection_phenotyping/classify_infection_2D.py b/viscy/scripts/infection_phenotyping/classify_infection_2D.py index b4269c74..261252cd 100644 --- a/viscy/scripts/infection_phenotyping/classify_infection_2D.py +++ b/viscy/scripts/infection_phenotyping/classify_infection_2D.py @@ -133,7 +133,7 @@ def __init__( self, in_channels: int, # Number of input channels out_channels: int, # Number of output channels - lr: float = 1e-3, # Learning rate + lr: float = 1e-5, # Learning rate loss_function: nn.Module = nn.CrossEntropyLoss(), # Loss function schedule: Literal[ "WarmupCosine", "Constant" From 418d6d9bd8d0a901068341f57126e9486c543c84 Mon Sep 17 00:00:00 2001 From: Shalin Mehta Date: Wed, 3 Jul 2024 09:00:08 -0700 Subject: [PATCH 76/92] refactor infection_classification code to viscy/applications --- .../Infection_classification_25DModel.py | 0 .../infection_classification}/Infection_classification_2Dmodel.py | 0 .../Infection_classification_covnextModel.py | 0 .../infection_classification}/classify_infection_25D.py | 0 .../infection_classification}/classify_infection_2D.py | 0 .../infection_classification}/classify_infection_covnext.py | 0 .../infection_classification}/predict_infection_classifier.py | 0 .../infection_classification}/readme.md | 0 .../infection_classification}/test_infection_classifier.py | 0 9 files changed, 0 insertions(+), 0 deletions(-) rename {viscy/scripts/infection_phenotyping => applications/infection_classification}/Infection_classification_25DModel.py (100%) rename {viscy/scripts/infection_phenotyping => applications/infection_classification}/Infection_classification_2Dmodel.py (100%) rename {viscy/scripts/infection_phenotyping => applications/infection_classification}/Infection_classification_covnextModel.py (100%) rename {viscy/scripts/infection_phenotyping => applications/infection_classification}/classify_infection_25D.py (100%) rename {viscy/scripts/infection_phenotyping => applications/infection_classification}/classify_infection_2D.py (100%) rename {viscy/scripts/infection_phenotyping => applications/infection_classification}/classify_infection_covnext.py (100%) rename {viscy/scripts/infection_phenotyping => applications/infection_classification}/predict_infection_classifier.py (100%) rename {viscy/scripts/infection_phenotyping => applications/infection_classification}/readme.md (100%) rename {viscy/scripts/infection_phenotyping => applications/infection_classification}/test_infection_classifier.py (100%) diff --git a/viscy/scripts/infection_phenotyping/Infection_classification_25DModel.py b/applications/infection_classification/Infection_classification_25DModel.py similarity index 100% rename from viscy/scripts/infection_phenotyping/Infection_classification_25DModel.py rename to applications/infection_classification/Infection_classification_25DModel.py diff --git a/viscy/scripts/infection_phenotyping/Infection_classification_2Dmodel.py b/applications/infection_classification/Infection_classification_2Dmodel.py similarity index 100% rename from viscy/scripts/infection_phenotyping/Infection_classification_2Dmodel.py rename to applications/infection_classification/Infection_classification_2Dmodel.py diff --git a/viscy/scripts/infection_phenotyping/Infection_classification_covnextModel.py b/applications/infection_classification/Infection_classification_covnextModel.py similarity index 100% rename from viscy/scripts/infection_phenotyping/Infection_classification_covnextModel.py rename to applications/infection_classification/Infection_classification_covnextModel.py diff --git a/viscy/scripts/infection_phenotyping/classify_infection_25D.py b/applications/infection_classification/classify_infection_25D.py similarity index 100% rename from viscy/scripts/infection_phenotyping/classify_infection_25D.py rename to applications/infection_classification/classify_infection_25D.py diff --git a/viscy/scripts/infection_phenotyping/classify_infection_2D.py b/applications/infection_classification/classify_infection_2D.py similarity index 100% rename from viscy/scripts/infection_phenotyping/classify_infection_2D.py rename to applications/infection_classification/classify_infection_2D.py diff --git a/viscy/scripts/infection_phenotyping/classify_infection_covnext.py b/applications/infection_classification/classify_infection_covnext.py similarity index 100% rename from viscy/scripts/infection_phenotyping/classify_infection_covnext.py rename to applications/infection_classification/classify_infection_covnext.py diff --git a/viscy/scripts/infection_phenotyping/predict_infection_classifier.py b/applications/infection_classification/predict_infection_classifier.py similarity index 100% rename from viscy/scripts/infection_phenotyping/predict_infection_classifier.py rename to applications/infection_classification/predict_infection_classifier.py diff --git a/viscy/scripts/infection_phenotyping/readme.md b/applications/infection_classification/readme.md similarity index 100% rename from viscy/scripts/infection_phenotyping/readme.md rename to applications/infection_classification/readme.md diff --git a/viscy/scripts/infection_phenotyping/test_infection_classifier.py b/applications/infection_classification/test_infection_classifier.py similarity index 100% rename from viscy/scripts/infection_phenotyping/test_infection_classifier.py rename to applications/infection_classification/test_infection_classifier.py From 67b330c556aa73a34efe1707c19e78a3cc88b243 Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Tue, 9 Jul 2024 10:14:30 -0700 Subject: [PATCH 77/92] changes made for BJ5 classification --- .../Infection_classification_2Dmodel.py | 26 +++++++-------- .../classify_infection_2D.py | 23 ++++++++++++- .../predict_infection_classifier.py | 33 +++++++++++++------ .../test_infection_classifier.py | 20 +++++------ 4 files changed, 68 insertions(+), 34 deletions(-) diff --git a/applications/infection_classification/Infection_classification_2Dmodel.py b/applications/infection_classification/Infection_classification_2Dmodel.py index aab7c86b..1983dbf8 100644 --- a/applications/infection_classification/Infection_classification_2Dmodel.py +++ b/applications/infection_classification/Infection_classification_2Dmodel.py @@ -15,7 +15,7 @@ # %% Create a dataloader and visualize the batches. # Set the path to the dataset -dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_05_03_DENV_eFT226_Timecourse/4-infection-classification/0_training_test_data/2024_05_03_DENV_eFT226_Timecourse_trainVal_2D.zarr" +dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_04_25_BJ5a_DENV_TimeCourse/4-human_annotation/train_data.zarr" # find ratio of background, uninfected and infected pixels zarr_input = open_ome_zarr( @@ -52,17 +52,17 @@ # Create an instance of HCSDataModule data_module = HCSDataModule( dataset_path, - source_channel=["mCherry", "Phase3D"], + source_channel=["TXR_Density3D", "Phase3D"], target_channel=["Inf_mask"], yx_patch_size=[128, 128], - split_ratio=0.5, + split_ratio=0.7, z_window_size=1, architecture="2D", num_workers=1, - batch_size=128, + batch_size=256, normalizations=[ NormalizeSampled( - keys=["Phase3D", "mCherry"], + keys=["Phase3D", "TXR_Density3D"], level="fov_statistics", subtrahend="median", divisor="iqr", @@ -70,18 +70,18 @@ ], augmentations=[ RandWeightedCropd( - num_samples=8, + num_samples=16, spatial_size=[-1, 128, 128], - keys=["mCherry", "Phase3D", "Inf_mask"], + keys=["TXR_Density3D", "Phase3D", "Inf_mask"], w_key="Inf_mask", ), RandScaleIntensityd( - keys=["mCherry", "Phase3D"], - factors=[0.1, 0.5], + keys=["TXR_Density3D", "Phase3D"], + factors=[0.5, 0.5], prob=0.5, ), RandGaussianSmoothd( - keys=["mCherry", "Phase3D"], + keys=["TXR_Density3D", "Phase3D"], prob=0.5, sigma_x=[0.5, 1.0], sigma_y=[0.5, 1.0], @@ -121,7 +121,7 @@ # %% Define the logger logger = TensorBoardLogger( - "/hpc/projects/intracellular_dashboard/viral-sensor/2024_05_03_DENV_eFT226_Timecourse/4-infection-classification/1_model_training/", + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_04_25_BJ5a_DENV_TimeCourse/5-infection_classifier/0-model_training/", name="logs", ) @@ -129,14 +129,14 @@ trainer = pl.Trainer( logger=logger, max_epochs=500, - default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/2024_05_03_DENV_eFT226_Timecourse/4-infection-classification/1_model_training/logs/", + default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/2024_04_25_BJ5a_DENV_TimeCourse/5-infection_classifier/0-model_training/logs/", log_every_n_steps=1, devices=1, # Set the number of GPUs to use. This avoids run-time exception from distributed training when the node has multiple GPUs ) # Define the checkpoint callback checkpoint_callback = ModelCheckpoint( - dirpath="/hpc/projects/intracellular_dashboard/viral-sensor/2024_05_03_DENV_eFT226_Timecourse/4-infection-classification/1_model_training/logs/", + dirpath="/hpc/projects/intracellular_dashboard/viral-sensor/2024_04_25_BJ5a_DENV_TimeCourse/5-infection_classifier/0-model_training/logs/", filename="checkpoint_{epoch:02d}", save_top_k=-1, verbose=True, diff --git a/applications/infection_classification/classify_infection_2D.py b/applications/infection_classification/classify_infection_2D.py index 261252cd..8d5d83af 100644 --- a/applications/infection_classification/classify_infection_2D.py +++ b/applications/infection_classification/classify_infection_2D.py @@ -18,6 +18,7 @@ from viscy.unet.networks.Unet2D import Unet2d # from viscy.unet.networks.Unet25D import Unet25d from viscy.data.hcs import Sample +# from skimage.io import imsave # # %% Methods to compute confusion matrix per cell using torchmetrics @@ -64,6 +65,10 @@ def compute_confusion_matrix( y_pred_cpu = y_pred[i].cpu().numpy() y_true_reshaped = y_true_cpu.reshape(y_true_cpu.shape[-2:]) y_pred_reshaped = y_pred_cpu.reshape(y_pred_cpu.shape[-2:]) + + # imsave(f"y_pred_{i}.png", y_pred_reshaped.astype(np.uint8)) + # imsave(f"y_true_{i}.png", y_true_reshaped.astype(np.uint8)) + y_pred_resized = cv2.resize(y_pred_reshaped, dsize=y_true_reshaped.shape[::-1], interpolation=cv2.INTER_NEAREST) y_pred_resized = np.where(y_true_reshaped > 0, y_pred_resized, 0) @@ -127,13 +132,21 @@ def plot_confusion_matrix(confusion_matrix, index_to_label_dict): return fig # Define a 2d unet model for infection classification as a lightning module. +# write a prediction writre to save the predictions as png files +# def predict_writer(label_pred, file_name): +# output_path = f"/hpc/projects/intracellular_dashboard/viral-sensor/2024_04_25_BJ5a_DENV_TimeCourse/5-infection_classifier/1-predict_infection/pred_debug_2024_07_08/{file_name}" +# label_pred_cpu = label_pred.cpu().numpy() +# write_npy = label_pred_cpu[0,0,0,:,:] +# # label_pred_reshaped = label_pred_cpu.reshape(label_pred_cpu.shape[-2:]) +# imsave(output_path, write_npy) + class SemanticSegUNet2D(pl.LightningModule): # Model for semantic segmentation. def __init__( self, in_channels: int, # Number of input channels out_channels: int, # Number of output channels - lr: float = 1e-5, # Learning rate + lr: float = 1e-4, # Learning rate loss_function: nn.Module = nn.CrossEntropyLoss(), # Loss function schedule: Literal[ "WarmupCosine", "Constant" @@ -241,10 +254,12 @@ def on_predict_start(self): # Define the prediction step def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): source = self._predict_pad(batch["source"]) # Pad the source + # predict_writer(batch["source"], f"pred_source.npy") logits = self._predict_pad.inverse(self.forward(source)) # Predict and remove padding. prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities # Go from probabilities/one-hot encoded data to class labels. labels_pred = torch.argmax(prob_pred, dim=1, keepdim=True) # Calculate the predicted labels + # prob_chan = prob_pred[:, 2, :, :] # prob_chan = prob_chan.unsqueeze(1) return labels_pred # log the class predicted image @@ -254,13 +269,19 @@ def on_test_start(self): self.pred_cm = torch.zeros((2,2)) down_factor = 2**self.unet_model.num_blocks self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) + # self.i_num = 0 def test_step(self, batch: Sample): source = self._predict_pad(batch["source"]) # Pad the source + # predict_writer(batch["source"], f"test_source_{self.i_num}.npy") logits = self._predict_pad.inverse(self.forward(source)) prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities labels_pred = torch.argmax(prob_pred, dim=1, keepdim=True) # Calculate the predicted labels + # self.i_num += 1 + # Save the prediction as a png file + # predict_writer(labels_pred, f"predict_{self.i_num}.png") + target = self._predict_pad(batch["target"]) # Extract the target from the batch pred_cm = confusion_matrix_per_cell( target, labels_pred, num_classes=2 diff --git a/applications/infection_classification/predict_infection_classifier.py b/applications/infection_classification/predict_infection_classifier.py index 783c1334..e809f4ff 100644 --- a/applications/infection_classification/predict_infection_classifier.py +++ b/applications/infection_classification/predict_infection_classifier.py @@ -3,46 +3,59 @@ from viscy.light.predict_writer import HCSPredictionWriter from viscy.data.hcs import HCSDataModule import lightning.pytorch as pl -from viscy.scripts.infection_phenotyping.classify_infection import SemanticSegUNet2D +from viscy.scripts.infection_phenotyping.classify_infection_2D import SemanticSegUNet2D from viscy.transforms import NormalizeSampled # %% # %% write the predictions to a zarr file -pred_datapath = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2024_02_13_DENV_3infMarked_test.zarr" +# pred_datapath = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_04_25_BJ5a_DENV_TimeCourse/5-infection_classifier/1-predict_infection/2024_04_25_BJ5a_DENV_TimeCourse_2D.zarr" +pred_datapath = '/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/A549_63X/2024_02_04_A549_DENV_ZIKV_timelapse/0-train_test_data/2024_02_04_A549_DENV_ZIKV_timelapse_test_2D.zarr' data_module = HCSDataModule( data_path=pred_datapath, - source_channel=['Sensor','Phase'], + source_channel=['RFP', 'Phase3D'], target_channel=['Inf_mask'], - split_ratio=0.8, + split_ratio=0.7, z_window_size=1, architecture="2D", num_workers=1, batch_size=1, normalizations=[ NormalizeSampled( - keys=["Phase", "Sensor"], + keys=["RFP", "Phase3D"], level="fov_statistics", subtrahend="median", divisor="iqr", ) ], ) -data_module.prepare_data() + data_module.setup(stage="predict") +# model = SemanticSegUNet2D( +# in_channels=2, +# out_channels=3, +# ckpt_path="/hpc/projects/intracellular_dashboard/viral-sensor/2024_04_25_BJ5a_DENV_TimeCourse/5-infection_classifier/0-model_training/logs/checkpoint_epoch=206.ckpt", +# ) model = SemanticSegUNet2D( in_channels=2, out_channels=3, - ckpt_path="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase/version_34/checkpoints/epoch=99-step=300.ckpt", + ckpt_path="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/A549_63X/2024_02_04_A549_DENV_ZIKV_timelapse/1-model_train/logs/version_0/checkpoints/epoch=199-step=800.ckpt", ) # %% perform prediction -output_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/pred/Exp_2024_02_13_DENV_3infMarked_pred_SP.zarr" +# output_path = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_04_25_BJ5a_DENV_TimeCourse/5-infection_classifier/1-predict_infection/2024_04_25_BJ5a_DENV_TimeCourse_2D_pred.zarr" +output_path = '/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/A549_63X/2024_02_04_A549_DENV_ZIKV_timelapse/2-predict_infection/2024_02_04_A549_DENV_ZIKV_timelapse_pred_2D_new.zarr' + +# trainer = pl.Trainer( +# default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/2024_04_25_BJ5a_DENV_TimeCourse/5-infection_classifier/0-model_training/logs", +# callbacks=[HCSPredictionWriter(output_path, write_input=False)], +# devices=1, # Set the number of GPUs to use. This avoids run-time exception from distributed training when the node has multiple GPUs +# ) trainer = pl.Trainer( - default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase", + default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/A549_63X/2024_02_04_A549_DENV_ZIKV_timelapse/1-model_train/logs", callbacks=[HCSPredictionWriter(output_path, write_input=False)], devices=1, # Set the number of GPUs to use. This avoids run-time exception from distributed training when the node has multiple GPUs ) @@ -53,4 +66,4 @@ return_predictions=True, ) -# %% +# %% \ No newline at end of file diff --git a/applications/infection_classification/test_infection_classifier.py b/applications/infection_classification/test_infection_classifier.py index 5ed14094..1d15a18c 100644 --- a/applications/infection_classification/test_infection_classifier.py +++ b/applications/infection_classification/test_infection_classifier.py @@ -1,25 +1,25 @@ # %% from viscy.data.hcs import HCSDataModule import lightning.pytorch as pl -from viscy.scripts.infection_phenotyping.classify_infection import SemanticSegUNet2D +from viscy.scripts.infection_phenotyping.classify_infection_2D import SemanticSegUNet2D from pytorch_lightning.loggers import TensorBoardLogger from viscy.transforms import NormalizeSampled # %% test the model on the test set -test_datapath = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2024_02_13_DENV_3infMarked_test.zarr" +test_datapath = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_04_25_BJ5a_DENV_TimeCourse/4-human_annotation/test_data.zarr" data_module = HCSDataModule( data_path=test_datapath, - source_channel=['Sensor','Phase'], + source_channel=['TXR_Density3D','Phase3D'], target_channel=['Inf_mask'], - split_ratio=0.8, + split_ratio=0.7, z_window_size=1, architecture="2D", - num_workers=0, + num_workers=1, batch_size=1, normalizations=[ NormalizeSampled( - keys=["Sensor", "Phase"], + keys=["TXR_Density3D", "Phase3D"], level="fov_statistics", subtrahend="median", divisor="iqr", @@ -32,13 +32,13 @@ # %% create trainer and input logger = TensorBoardLogger( - "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/", - name="logs_wPhase", + "/hpc/projects/intracellular_dashboard/viral-sensor/2024_04_25_BJ5a_DENV_TimeCourse/5-infection_classifier/0-model_training/", + name="logs", ) trainer = pl.Trainer( logger=logger, - default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase", + default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/2024_04_25_BJ5a_DENV_TimeCourse/5-infection_classifier/0-model_training/logs", log_every_n_steps=1, devices=1, # Set the number of GPUs to use. This avoids run-time exception from distributed training when the node has multiple GPUs ) @@ -46,7 +46,7 @@ model = SemanticSegUNet2D( in_channels=2, out_channels=3, - ckpt_path="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/sensorInf_phenotyping/logs_wPhase/version_34/checkpoints/epoch=99-step=300.ckpt", + ckpt_path="/hpc/projects/intracellular_dashboard/viral-sensor/2024_04_25_BJ5a_DENV_TimeCourse/5-infection_classifier/0-model_training/logs/checkpoint_epoch=206.ckpt", ) trainer.test(model=model, datamodule=data_module) From d420e80e769fab4b1d86bda34e3c79e0bd7137c0 Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Tue, 9 Jul 2024 10:30:57 -0700 Subject: [PATCH 78/92] format code --- .../Infection_classification_25DModel.py | 40 +++----- .../Infection_classification_2Dmodel.py | 55 +++++------ .../Infection_classification_covnextModel.py | 44 ++++----- .../classify_infection_25D.py | 60 ++++++++---- .../classify_infection_2D.py | 89 ++++++++++-------- .../classify_infection_covnext.py | 57 ++++++++---- .../predict_infection_classifier.py | 27 ++---- .../test_infection_classifier.py | 93 ++----------------- 8 files changed, 196 insertions(+), 269 deletions(-) diff --git a/applications/infection_classification/Infection_classification_25DModel.py b/applications/infection_classification/Infection_classification_25DModel.py index 91702497..55d208f0 100644 --- a/applications/infection_classification/Infection_classification_25DModel.py +++ b/applications/infection_classification/Infection_classification_25DModel.py @@ -9,7 +9,9 @@ from viscy.transforms import RandWeightedCropd from viscy.transforms import NormalizeSampled from viscy.data.hcs import HCSDataModule -from viscy.scripts.infection_phenotyping.classify_infection_25D import SemanticSegUNet25D +from applications.infection_classification.classify_infection_25D import ( + SemanticSegUNet25D, +) from iohub.ngff import open_ome_zarr @@ -35,17 +37,21 @@ for pos_name, pos_data in well_data.positions(): data = pos_data.data - T,C,Z,Y,X = data.shape + T, C, Z, Y, X = data.shape out_data = data.numpy() for time in range(T): - Inf_mask = out_data[time,in_chan_names.index("Inf_mask"),...] + Inf_mask = out_data[time, in_chan_names.index("Inf_mask"), ...] # Calculate the number of pixels valued 0, 1, and 2 in 'Inf_mask' num_pixels_bkg = num_pixels_bkg + (Inf_mask == 0).sum() num_pixels_uninf = num_pixels_uninf + (Inf_mask == 1).sum() num_pixels_inf = num_pixels_inf + (Inf_mask == 2).sum() - num_pixels = num_pixels + Z*X*Y + num_pixels = num_pixels + Z * X * Y -pixel_ratio_1 = [num_pixels/num_pixels_bkg, num_pixels/num_pixels_uninf, num_pixels/num_pixels_inf] +pixel_ratio_1 = [ + num_pixels / num_pixels_bkg, + num_pixels / num_pixels_uninf, + num_pixels / num_pixels_inf, +] pixel_ratio_sum = sum(pixel_ratio_1) pixel_ratio = [ratio / pixel_ratio_sum for ratio in pixel_ratio_1] @@ -64,7 +70,7 @@ batch_size=32, normalizations=[ NormalizeSampled( - keys=["Phase","HSP90"], + keys=["Phase", "HSP90"], level="fov_statistics", subtrahend="median", divisor="iqr", @@ -74,7 +80,7 @@ RandWeightedCropd( num_samples=4, spatial_size=[-1, 512, 512], - keys=["Phase","HSP90"], + keys=["Phase", "HSP90"], w_key="Inf_mask", ) ], @@ -91,23 +97,6 @@ val_dm = data_module.val_dataloader() -# Visualize the dataset and the batch using napari -# Set the display -# os.environ['DISPLAY'] = ':1' - -# # Create a napari viewer -# viewer = napari.Viewer() - -# # Add the dataset to the viewer -# for batch in dataloader: -# if isinstance(batch, dict): -# for k, v in batch.items(): -# if isinstance(v, torch.Tensor): -# viewer.add_image(v.cpu().numpy().astype(np.float32)) - -# # Start the napari event loop -# napari.run() - # %% Define the logger logger = TensorBoardLogger( @@ -146,8 +135,7 @@ print(model) -# %% -# Run training. +# %% Run training. trainer.fit(model, data_module) diff --git a/applications/infection_classification/Infection_classification_2Dmodel.py b/applications/infection_classification/Infection_classification_2Dmodel.py index 1983dbf8..f0978765 100644 --- a/applications/infection_classification/Infection_classification_2Dmodel.py +++ b/applications/infection_classification/Infection_classification_2Dmodel.py @@ -6,13 +6,19 @@ from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.callbacks import ModelCheckpoint -from viscy.transforms import RandWeightedCropd, NormalizeSampled, RandScaleIntensityd, RandGaussianSmoothd +from viscy.transforms import ( + RandWeightedCropd, + NormalizeSampled, + RandScaleIntensityd, + RandGaussianSmoothd, +) from viscy.data.hcs import HCSDataModule -from viscy.scripts.infection_phenotyping.classify_infection_2D import SemanticSegUNet2D - +from applications.infection_classification.classify_infection_2D import ( + SemanticSegUNet2D, +) from iohub.ngff import open_ome_zarr -# %% Create a dataloader and visualize the batches. +# %% calculate the ratio of background, uninfected and infected pixels in the input dataset # Set the path to the dataset dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_04_25_BJ5a_DENV_TimeCourse/4-human_annotation/train_data.zarr" @@ -34,22 +40,26 @@ for pos_name, pos_data in well_data.positions(): data = pos_data.data - T,C,Z,Y,X = data.shape + T, C, Z, Y, X = data.shape out_data = data.numpy() for time in range(T): - Inf_mask = out_data[time,in_chan_names.index("Inf_mask"),...] + Inf_mask = out_data[time, in_chan_names.index("Inf_mask"), ...] # Calculate the number of pixels valued 0, 1, and 2 in 'Inf_mask' num_pixels_bkg = num_pixels_bkg + (Inf_mask == 0).sum() num_pixels_uninf = num_pixels_uninf + (Inf_mask == 1).sum() num_pixels_inf = num_pixels_inf + (Inf_mask == 2).sum() - num_pixels = num_pixels + Z*X*Y + num_pixels = num_pixels + Z * X * Y -pixel_ratio_1 = [num_pixels/num_pixels_bkg, num_pixels/num_pixels_uninf, num_pixels/num_pixels_inf] +pixel_ratio_1 = [ + num_pixels / num_pixels_bkg, + num_pixels / num_pixels_uninf, + num_pixels / num_pixels_inf, +] pixel_ratio_sum = sum(pixel_ratio_1) pixel_ratio = [ratio / pixel_ratio_sum for ratio in pixel_ratio_1] -# %% -# Create an instance of HCSDataModule +# %% Create an instance of HCSDataModule + data_module = HCSDataModule( dataset_path, source_channel=["TXR_Density3D", "Phase3D"], @@ -101,25 +111,9 @@ val_dm = data_module.val_dataloader() -# Visualize the dataset and the batch using napari -# Set the display -# os.environ['DISPLAY'] = ':1' - -# # Create a napari viewer -# viewer = napari.Viewer() +# %% Set up for training -# # Add the dataset to the viewer -# for batch in dataloader: -# if isinstance(batch, dict): -# for k, v in batch.items(): -# if isinstance(v, torch.Tensor): -# viewer.add_image(v.cpu().numpy().astype(np.float32)) - -# # Start the napari event loop -# napari.run() - - -# %% Define the logger +# define the logger logger = TensorBoardLogger( "/hpc/projects/intracellular_dashboard/viral-sensor/2024_04_25_BJ5a_DENV_TimeCourse/5-infection_classifier/0-model_training/", name="logs", @@ -154,9 +148,10 @@ loss_function=nn.CrossEntropyLoss(weight=torch.tensor(pixel_ratio)), ) +# visualize the model print(model) -# %% -# Run training. + +# %% Run training. trainer.fit(model, data_module) diff --git a/applications/infection_classification/Infection_classification_covnextModel.py b/applications/infection_classification/Infection_classification_covnextModel.py index 0ecd6bdd..5af56ec8 100644 --- a/applications/infection_classification/Infection_classification_covnextModel.py +++ b/applications/infection_classification/Infection_classification_covnextModel.py @@ -11,7 +11,9 @@ from viscy.transforms import RandWeightedCropd from viscy.transforms import NormalizeSampled from viscy.data.hcs import HCSDataModule -from viscy.scripts.infection_phenotyping.classify_infection_covnext import SemanticSegUNet25D +from applications.infection_classification.classify_infection_covnext import ( + SemanticSegUNet22D, +) from iohub.ngff import open_ome_zarr @@ -37,17 +39,21 @@ for pos_name, pos_data in well_data.positions(): data = pos_data.data - T,C,Z,Y,X = data.shape + T, C, Z, Y, X = data.shape out_data = data.numpy() for time in range(T): - Inf_mask = out_data[time,in_chan_names.index("Inf_mask"),...] + Inf_mask = out_data[time, in_chan_names.index("Inf_mask"), ...] # Calculate the number of pixels valued 0, 1, and 2 in 'Inf_mask' num_pixels_bkg = num_pixels_bkg + (Inf_mask == 0).sum() num_pixels_uninf = num_pixels_uninf + (Inf_mask == 1).sum() num_pixels_inf = num_pixels_inf + (Inf_mask == 2).sum() - num_pixels = num_pixels + Z*X*Y + num_pixels = num_pixels + Z * X * Y -pixel_ratio_1 = [num_pixels/num_pixels_bkg, num_pixels/num_pixels_uninf, num_pixels/num_pixels_inf] +pixel_ratio_1 = [ + num_pixels / num_pixels_bkg, + num_pixels / num_pixels_uninf, + num_pixels / num_pixels_inf, +] pixel_ratio_sum = sum(pixel_ratio_1) pixel_ratio = [ratio / pixel_ratio_sum for ratio in pixel_ratio_1] @@ -56,7 +62,7 @@ # Create an instance of HCSDataModule data_module = HCSDataModule( dataset_path, - source_channel=["Phase", "HSP90", "phase_nucl_iqr","hsp90_skew"], + source_channel=["Phase", "HSP90", "phase_nucl_iqr", "hsp90_skew"], target_channel=["Inf_mask"], yx_patch_size=[256, 256], split_ratio=0.8, @@ -66,7 +72,7 @@ batch_size=16, normalizations=[ NormalizeSampled( - keys=["Phase","HSP90", "phase_nucl_iqr","hsp90_skew"], + keys=["Phase", "HSP90", "phase_nucl_iqr", "hsp90_skew"], level="fov_statistics", subtrahend="median", divisor="iqr", @@ -76,7 +82,7 @@ RandWeightedCropd( num_samples=4, spatial_size=[-1, 256, 256], - keys=["Phase","HSP90", "phase_nucl_iqr","hsp90_skew"], + keys=["Phase", "HSP90", "phase_nucl_iqr", "hsp90_skew"], w_key="Inf_mask", ) ], @@ -93,23 +99,6 @@ val_dm = data_module.val_dataloader() -# Visualize the dataset and the batch using napari -# Set the display -# os.environ['DISPLAY'] = ':1' - -# # Create a napari viewer -# viewer = napari.Viewer() - -# # Add the dataset to the viewer -# for batch in dataloader: -# if isinstance(batch, dict): -# for k, v in batch.items(): -# if isinstance(v, torch.Tensor): -# viewer.add_image(v.cpu().numpy().astype(np.float32)) - -# # Start the napari event loop -# napari.run() - # %% Define the logger logger = TensorBoardLogger( @@ -140,7 +129,7 @@ trainer.callbacks.append(checkpoint_callback) # Fit the model -model = SemanticSegUNet25D( +model = SemanticSegUNet22D( in_channels=4, out_channels=3, loss_function=nn.CrossEntropyLoss(weight=torch.tensor(pixel_ratio)), @@ -148,8 +137,7 @@ print(model) -# %% -# Run training. +# %% Run training. trainer.fit(model, data_module) diff --git a/applications/infection_classification/classify_infection_25D.py b/applications/infection_classification/classify_infection_25D.py index c78a7e8f..464cecc1 100644 --- a/applications/infection_classification/classify_infection_25D.py +++ b/applications/infection_classification/classify_infection_25D.py @@ -1,4 +1,3 @@ - import torch import torch.nn as nn import lightning.pytorch as pl @@ -18,9 +17,10 @@ from viscy.unet.networks.Unet25D import Unet25d from viscy.data.hcs import Sample -# +# # %% Methods to compute confusion matrix per cell using torchmetrics + # The confusion matrix at the single-cell resolution. def confusion_matrix_per_cell( y_true: torch.Tensor, y_pred: torch.Tensor, num_classes: int @@ -63,7 +63,11 @@ def compute_confusion_matrix( y_pred_cpu = y_pred[i].cpu().numpy() y_true_reshaped = y_true_cpu.reshape(y_true_cpu.shape[-2:]) y_pred_reshaped = y_pred_cpu.reshape(y_pred_cpu.shape[-2:]) - y_pred_resized = cv2.resize(y_pred_reshaped, dsize=y_true_reshaped.shape[::-1], interpolation=cv2.INTER_NEAREST) + y_pred_resized = cv2.resize( + y_pred_reshaped, + dsize=y_true_reshaped.shape[::-1], + interpolation=cv2.INTER_NEAREST, + ) y_pred_resized = np.where(y_true_reshaped > 0, y_pred_resized, 0) # find objects in every image @@ -78,17 +82,18 @@ def compute_confusion_matrix( test_id = y_true_reshaped[int(row), int(col)] if pred_id == 1 and test_id == 1: - conf_mat[1,1] += 1 + conf_mat[1, 1] += 1 if pred_id == 1 and test_id == 2: - conf_mat[0,1] += 1 + conf_mat[0, 1] += 1 if pred_id == 2 and test_id == 1: - conf_mat[1,0] += 1 + conf_mat[1, 0] += 1 if pred_id == 2 and test_id == 2: - conf_mat[0,0] += 1 + conf_mat[0, 0] += 1 # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. return conf_mat + def plot_confusion_matrix(confusion_matrix, index_to_label_dict): # Create a figure and axis to plot the confusion matrix fig, ax = plt.subplots() @@ -97,7 +102,9 @@ def plot_confusion_matrix(confusion_matrix, index_to_label_dict): cax = ax.matshow(confusion_matrix, cmap="viridis") # Create a colorbar and set the label - index_to_label_dict = dict(enumerate(index_to_label_dict)) # Convert list to dictionary + index_to_label_dict = dict( + enumerate(index_to_label_dict) + ) # Convert list to dictionary fig.colorbar(cax, label="Frequency") # Set labels for the classes @@ -124,8 +131,11 @@ def plot_confusion_matrix(confusion_matrix, index_to_label_dict): # plt.show(fig) # Show the figure return fig + + # Define a 25d unet model for infection classification as a lightning module. + class SemanticSegUNet25D(pl.LightningModule): # Model for semantic segmentation. def __init__( @@ -143,7 +153,12 @@ def __init__( ): super(SemanticSegUNet25D, self).__init__() # Call the superclass initializer # Initialize the UNet model - self.unet_model = Unet25d(in_channels=in_channels, out_channels=out_channels, num_blocks=4, num_block_layers=4) + self.unet_model = Unet25d( + in_channels=in_channels, + out_channels=out_channels, + num_blocks=4, + num_block_layers=4, + ) if ckpt_path is not None: state_dict = torch.load(ckpt_path, map_location=torch.device("cpu"))[ "state_dict" @@ -168,7 +183,6 @@ def __init__( self.pred_cm = None # Initialize the confusion matrix self.index_to_label_dict = ["Infected", "Uninfected"] - # Define the forward pass def forward(self, x): return self.unet_model(x) # Pass the input through the UNet model @@ -239,39 +253,45 @@ def on_predict_start(self): # Define the prediction step def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): source = self._predict_pad(batch["source"]) # Pad the source - logits = self._predict_pad.inverse(self.forward(source)) # Predict and remove padding. + logits = self._predict_pad.inverse( + self.forward(source) + ) # Predict and remove padding. prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities # Go from probabilities/one-hot encoded data to class labels. - labels_pred = torch.argmax(prob_pred, dim=1, keepdim=True) # Calculate the predicted labels + labels_pred = torch.argmax( + prob_pred, dim=1, keepdim=True + ) # Calculate the predicted labels # prob_chan = prob_pred[:, 2, :, :] # prob_chan = prob_chan.unsqueeze(1) return labels_pred # log the class predicted image # return prob_chan # log the probability predicted image - + def on_test_start(self): - self.pred_cm = torch.zeros((2,2)) + self.pred_cm = torch.zeros((2, 2)) down_factor = 2**self.unet_model.num_blocks self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) - + def test_step(self, batch: Sample): source = self._predict_pad(batch["source"]) # Pad the source logits = self._predict_pad.inverse(self.forward(source)) prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities - labels_pred = torch.argmax(prob_pred, dim=1, keepdim=True) # Calculate the predicted labels - + labels_pred = torch.argmax( + prob_pred, dim=1, keepdim=True + ) # Calculate the predicted labels + target = self._predict_pad(batch["target"]) # Extract the target from the batch pred_cm = confusion_matrix_per_cell( target, labels_pred, num_classes=2 ) # Calculate the confusion matrix per cell self.pred_cm += pred_cm # Append the confusion matrix to pred_cm - + self.logger.experiment.add_figure( "Confusion Matrix per Cell", plot_confusion_matrix(pred_cm, self.index_to_label_dict), self.current_epoch, ) - # Accumulate the confusion matrix at the end of test epoch and log. + # Accumulate the confusion matrix at the end of test epoch and log. def on_test_end(self): confusion_matrix_sum = self.pred_cm self.logger.experiment.add_figure( @@ -332,4 +352,6 @@ def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): self.logger.experiment.add_image( key, grid, self.current_epoch, dataformats="HWC" ) + + # %% diff --git a/applications/infection_classification/classify_infection_2D.py b/applications/infection_classification/classify_infection_2D.py index 8d5d83af..1227cf8a 100644 --- a/applications/infection_classification/classify_infection_2D.py +++ b/applications/infection_classification/classify_infection_2D.py @@ -1,3 +1,4 @@ +# %% lightning moules for infection classification using the viscy library import torch import torch.nn as nn @@ -16,13 +17,16 @@ from monai.transforms import DivisiblePad from viscy.unet.networks.Unet2D import Unet2d + # from viscy.unet.networks.Unet25D import Unet25d from viscy.data.hcs import Sample + # from skimage.io import imsave -# +# # %% Methods to compute confusion matrix per cell using torchmetrics + # The confusion matrix at the single-cell resolution. def confusion_matrix_per_cell( y_true: torch.Tensor, y_pred: torch.Tensor, num_classes: int @@ -43,7 +47,7 @@ def confusion_matrix_per_cell( return confusion_matrix_per_cell -# These images can be logged with prediction. +# confusion matrix computation def compute_confusion_matrix( y_true: torch.Tensor, y_pred: torch.Tensor, num_classes: int ): @@ -66,10 +70,11 @@ def compute_confusion_matrix( y_true_reshaped = y_true_cpu.reshape(y_true_cpu.shape[-2:]) y_pred_reshaped = y_pred_cpu.reshape(y_pred_cpu.shape[-2:]) - # imsave(f"y_pred_{i}.png", y_pred_reshaped.astype(np.uint8)) - # imsave(f"y_true_{i}.png", y_true_reshaped.astype(np.uint8)) - - y_pred_resized = cv2.resize(y_pred_reshaped, dsize=y_true_reshaped.shape[::-1], interpolation=cv2.INTER_NEAREST) + y_pred_resized = cv2.resize( + y_pred_reshaped, + dsize=y_true_reshaped.shape[::-1], + interpolation=cv2.INTER_NEAREST, + ) y_pred_resized = np.where(y_true_reshaped > 0, y_pred_resized, 0) # find objects in every image @@ -84,17 +89,19 @@ def compute_confusion_matrix( test_id = y_true_reshaped[int(row), int(col)] if pred_id == 1 and test_id == 1: - conf_mat[1,1] += 1 + conf_mat[1, 1] += 1 if pred_id == 1 and test_id == 2: - conf_mat[0,1] += 1 + conf_mat[0, 1] += 1 if pred_id == 2 and test_id == 1: - conf_mat[1,0] += 1 + conf_mat[1, 0] += 1 if pred_id == 2 and test_id == 2: - conf_mat[0,0] += 1 + conf_mat[0, 0] += 1 # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. return conf_mat + +# plot the computed confusion matrix def plot_confusion_matrix(confusion_matrix, index_to_label_dict): # Create a figure and axis to plot the confusion matrix fig, ax = plt.subplots() @@ -103,7 +110,9 @@ def plot_confusion_matrix(confusion_matrix, index_to_label_dict): cax = ax.matshow(confusion_matrix, cmap="viridis") # Create a colorbar and set the label - index_to_label_dict = dict(enumerate(index_to_label_dict)) # Convert list to dictionary + index_to_label_dict = dict( + enumerate(index_to_label_dict) + ) # Convert list to dictionary fig.colorbar(cax, label="Frequency") # Set labels for the classes @@ -128,20 +137,13 @@ def plot_confusion_matrix(confusion_matrix, index_to_label_dict): color="white", ) - # plt.show(fig) # Show the figure return fig -# Define a 2d unet model for infection classification as a lightning module. -# write a prediction writre to save the predictions as png files -# def predict_writer(label_pred, file_name): -# output_path = f"/hpc/projects/intracellular_dashboard/viral-sensor/2024_04_25_BJ5a_DENV_TimeCourse/5-infection_classifier/1-predict_infection/pred_debug_2024_07_08/{file_name}" -# label_pred_cpu = label_pred.cpu().numpy() -# write_npy = label_pred_cpu[0,0,0,:,:] -# # label_pred_reshaped = label_pred_cpu.reshape(label_pred_cpu.shape[-2:]) -# imsave(output_path, write_npy) class SemanticSegUNet2D(pl.LightningModule): + # Model for semantic segmentation. + def __init__( self, in_channels: int, # Number of input channels @@ -182,8 +184,6 @@ def __init__( self.pred_cm = None # Initialize the confusion matrix self.index_to_label_dict = ["Infected", "Uninfected"] - - # Define the forward pass def forward(self, x): return self.unet_model(x) # Pass the input through the UNet model @@ -197,6 +197,10 @@ def configure_optimizers(self): # Define the training step def training_step(self, batch: Sample, batch_idx: int): + """ + The training step for the model. + This method is called for every batch during the training process. + """ source = batch["source"] # Extract the source from the batch target = batch["target"] # Extract the target from the batch pred = self.forward(source) # Make a prediction using the source @@ -224,6 +228,9 @@ def training_step(self, batch: Sample, batch_idx: int): return train_loss # Return the training loss def validation_step(self, batch: Sample, batch_idx: int): + """ + The validation step for the model. + """ source = batch["source"] # Extract the source from the batch target = batch["target"] # Extract the target from the batch pred = self.forward(source) # Make a prediction using the source @@ -253,48 +260,46 @@ def on_predict_start(self): # Define the prediction step def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): + source = self._predict_pad(batch["source"]) # Pad the source - # predict_writer(batch["source"], f"pred_source.npy") - logits = self._predict_pad.inverse(self.forward(source)) # Predict and remove padding. + logits = self._predict_pad.inverse( + self.forward(source) + ) # Predict and remove padding. prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities # Go from probabilities/one-hot encoded data to class labels. - labels_pred = torch.argmax(prob_pred, dim=1, keepdim=True) # Calculate the predicted labels - - # prob_chan = prob_pred[:, 2, :, :] - # prob_chan = prob_chan.unsqueeze(1) + labels_pred = torch.argmax( + prob_pred, dim=1, keepdim=True + ) # Calculate the predicted labels + return labels_pred # log the class predicted image - # return prob_chan # log the probability predicted image - + def on_test_start(self): - self.pred_cm = torch.zeros((2,2)) + self.pred_cm = torch.zeros((2, 2)) down_factor = 2**self.unet_model.num_blocks self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) - # self.i_num = 0 - + def test_step(self, batch: Sample): source = self._predict_pad(batch["source"]) # Pad the source # predict_writer(batch["source"], f"test_source_{self.i_num}.npy") logits = self._predict_pad.inverse(self.forward(source)) prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities - labels_pred = torch.argmax(prob_pred, dim=1, keepdim=True) # Calculate the predicted labels - - # self.i_num += 1 - # Save the prediction as a png file - # predict_writer(labels_pred, f"predict_{self.i_num}.png") - + labels_pred = torch.argmax( + prob_pred, dim=1, keepdim=True + ) # Calculate the predicted labels + target = self._predict_pad(batch["target"]) # Extract the target from the batch pred_cm = confusion_matrix_per_cell( target, labels_pred, num_classes=2 ) # Calculate the confusion matrix per cell self.pred_cm += pred_cm # Append the confusion matrix to pred_cm - + self.logger.experiment.add_figure( "Confusion Matrix per Cell", plot_confusion_matrix(pred_cm, self.index_to_label_dict), self.current_epoch, ) - # Accumulate the confusion matrix at the end of test epoch and log. + # Accumulate the confusion matrix at the end of test epoch and log. def on_test_end(self): confusion_matrix_sum = self.pred_cm self.logger.experiment.add_figure( @@ -355,4 +360,6 @@ def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): self.logger.experiment.add_image( key, grid, self.current_epoch, dataformats="HWC" ) + + # %% diff --git a/applications/infection_classification/classify_infection_covnext.py b/applications/infection_classification/classify_infection_covnext.py index 2ba698ee..385da925 100644 --- a/applications/infection_classification/classify_infection_covnext.py +++ b/applications/infection_classification/classify_infection_covnext.py @@ -1,4 +1,3 @@ - import torch import torch.nn as nn import lightning.pytorch as pl @@ -19,9 +18,10 @@ from viscy.data.hcs import Sample from viscy.light.engine import VSUNet -# +# # %% Methods to compute confusion matrix per cell using torchmetrics + # The confusion matrix at the single-cell resolution. def confusion_matrix_per_cell( y_true: torch.Tensor, y_pred: torch.Tensor, num_classes: int @@ -64,7 +64,11 @@ def compute_confusion_matrix( y_pred_cpu = y_pred[i].cpu().numpy() y_true_reshaped = y_true_cpu.reshape(y_true_cpu.shape[-2:]) y_pred_reshaped = y_pred_cpu.reshape(y_pred_cpu.shape[-2:]) - y_pred_resized = cv2.resize(y_pred_reshaped, dsize=y_true_reshaped.shape[::-1], interpolation=cv2.INTER_NEAREST) + y_pred_resized = cv2.resize( + y_pred_reshaped, + dsize=y_true_reshaped.shape[::-1], + interpolation=cv2.INTER_NEAREST, + ) y_pred_resized = np.where(y_true_reshaped > 0, y_pred_resized, 0) # find objects in every image @@ -79,17 +83,18 @@ def compute_confusion_matrix( test_id = y_true_reshaped[int(row), int(col)] if pred_id == 1 and test_id == 1: - conf_mat[1,1] += 1 + conf_mat[1, 1] += 1 if pred_id == 1 and test_id == 2: - conf_mat[0,1] += 1 + conf_mat[0, 1] += 1 if pred_id == 2 and test_id == 1: - conf_mat[1,0] += 1 + conf_mat[1, 0] += 1 if pred_id == 2 and test_id == 2: - conf_mat[0,0] += 1 + conf_mat[0, 0] += 1 # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. # Find all instances of nuclei in ground truth and compute the class of the nuclei in both ground truth and prediction. return conf_mat + def plot_confusion_matrix(confusion_matrix, index_to_label_dict): # Create a figure and axis to plot the confusion matrix fig, ax = plt.subplots() @@ -98,7 +103,9 @@ def plot_confusion_matrix(confusion_matrix, index_to_label_dict): cax = ax.matshow(confusion_matrix, cmap="viridis") # Create a colorbar and set the label - index_to_label_dict = dict(enumerate(index_to_label_dict)) # Convert list to dictionary + index_to_label_dict = dict( + enumerate(index_to_label_dict) + ) # Convert list to dictionary fig.colorbar(cax, label="Frequency") # Set labels for the classes @@ -125,9 +132,12 @@ def plot_confusion_matrix(confusion_matrix, index_to_label_dict): # plt.show(fig) # Show the figure return fig + + # Define a 25d unet model for infection classification as a lightning module. -class SemanticSegUNet25D(pl.LightningModule): + +class SemanticSegUNet22D(pl.LightningModule): # Model for semantic segmentation. def __init__( self, @@ -142,7 +152,7 @@ def __init__( log_samples_per_batch: int = 2, # Number of samples to log per batch ckpt_path: str = None, # Path to the checkpoint ): - super(SemanticSegUNet25D, self).__init__() # Call the superclass initializer + super(SemanticSegUNet22D, self).__init__() # Call the superclass initializer # Initialize the UNet model self.unet_model = VSUNet( architecture="2.2D", @@ -180,7 +190,6 @@ def __init__( self.pred_cm = None # Initialize the confusion matrix self.index_to_label_dict = ["Infected", "Uninfected"] - # Define the forward pass def forward(self, x): return self.unet_model(x) # Pass the input through the UNet model @@ -251,39 +260,45 @@ def on_predict_start(self): # Define the prediction step def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0): source = self._predict_pad(batch["source"]) # Pad the source - logits = self._predict_pad.inverse(self.forward(source)) # Predict and remove padding. + logits = self._predict_pad.inverse( + self.forward(source) + ) # Predict and remove padding. prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities # Go from probabilities/one-hot encoded data to class labels. - labels_pred = torch.argmax(prob_pred, dim=1, keepdim=True) # Calculate the predicted labels + labels_pred = torch.argmax( + prob_pred, dim=1, keepdim=True + ) # Calculate the predicted labels # prob_chan = prob_pred[:, 2, :, :] # prob_chan = prob_chan.unsqueeze(1) return labels_pred # log the class predicted image # return prob_chan # log the probability predicted image - + def on_test_start(self): - self.pred_cm = torch.zeros((2,2)) + self.pred_cm = torch.zeros((2, 2)) down_factor = 2**self.unet_model.num_blocks self._predict_pad = DivisiblePad((0, 0, down_factor, down_factor)) - + def test_step(self, batch: Sample): source = self._predict_pad(batch["source"]) # Pad the source logits = self._predict_pad.inverse(self.forward(source)) prob_pred = F.softmax(logits, dim=1) # Calculate the probabilities - labels_pred = torch.argmax(prob_pred, dim=1, keepdim=True) # Calculate the predicted labels - + labels_pred = torch.argmax( + prob_pred, dim=1, keepdim=True + ) # Calculate the predicted labels + target = self._predict_pad(batch["target"]) # Extract the target from the batch pred_cm = confusion_matrix_per_cell( target, labels_pred, num_classes=2 ) # Calculate the confusion matrix per cell self.pred_cm += pred_cm # Append the confusion matrix to pred_cm - + self.logger.experiment.add_figure( "Confusion Matrix per Cell", plot_confusion_matrix(pred_cm, self.index_to_label_dict), self.current_epoch, ) - # Accumulate the confusion matrix at the end of test epoch and log. + # Accumulate the confusion matrix at the end of test epoch and log. def on_test_end(self): confusion_matrix_sum = self.pred_cm self.logger.experiment.add_figure( @@ -344,4 +359,6 @@ def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): self.logger.experiment.add_image( key, grid, self.current_epoch, dataformats="HWC" ) + + # %% diff --git a/applications/infection_classification/predict_infection_classifier.py b/applications/infection_classification/predict_infection_classifier.py index e809f4ff..9eedad32 100644 --- a/applications/infection_classification/predict_infection_classifier.py +++ b/applications/infection_classification/predict_infection_classifier.py @@ -3,18 +3,19 @@ from viscy.light.predict_writer import HCSPredictionWriter from viscy.data.hcs import HCSDataModule import lightning.pytorch as pl -from viscy.scripts.infection_phenotyping.classify_infection_2D import SemanticSegUNet2D +from applications.infection_classification.classify_infection_2D import ( + SemanticSegUNet2D, +) from viscy.transforms import NormalizeSampled # %% # %% write the predictions to a zarr file -# pred_datapath = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_04_25_BJ5a_DENV_TimeCourse/5-infection_classifier/1-predict_infection/2024_04_25_BJ5a_DENV_TimeCourse_2D.zarr" -pred_datapath = '/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/A549_63X/2024_02_04_A549_DENV_ZIKV_timelapse/0-train_test_data/2024_02_04_A549_DENV_ZIKV_timelapse_test_2D.zarr' +pred_datapath = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/A549_63X/2024_02_04_A549_DENV_ZIKV_timelapse/0-train_test_data/2024_02_04_A549_DENV_ZIKV_timelapse_test_2D.zarr" data_module = HCSDataModule( data_path=pred_datapath, - source_channel=['RFP', 'Phase3D'], - target_channel=['Inf_mask'], + source_channel=["RFP", "Phase3D"], + target_channel=["Inf_mask"], split_ratio=0.7, z_window_size=1, architecture="2D", @@ -32,11 +33,6 @@ data_module.setup(stage="predict") -# model = SemanticSegUNet2D( -# in_channels=2, -# out_channels=3, -# ckpt_path="/hpc/projects/intracellular_dashboard/viral-sensor/2024_04_25_BJ5a_DENV_TimeCourse/5-infection_classifier/0-model_training/logs/checkpoint_epoch=206.ckpt", -# ) model = SemanticSegUNet2D( in_channels=2, out_channels=3, @@ -45,14 +41,7 @@ # %% perform prediction -# output_path = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_04_25_BJ5a_DENV_TimeCourse/5-infection_classifier/1-predict_infection/2024_04_25_BJ5a_DENV_TimeCourse_2D_pred.zarr" -output_path = '/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/A549_63X/2024_02_04_A549_DENV_ZIKV_timelapse/2-predict_infection/2024_02_04_A549_DENV_ZIKV_timelapse_pred_2D_new.zarr' - -# trainer = pl.Trainer( -# default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/2024_04_25_BJ5a_DENV_TimeCourse/5-infection_classifier/0-model_training/logs", -# callbacks=[HCSPredictionWriter(output_path, write_input=False)], -# devices=1, # Set the number of GPUs to use. This avoids run-time exception from distributed training when the node has multiple GPUs -# ) +output_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/A549_63X/2024_02_04_A549_DENV_ZIKV_timelapse/2-predict_infection/2024_02_04_A549_DENV_ZIKV_timelapse_pred_2D_new.zarr" trainer = pl.Trainer( default_root_dir="/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/A549_63X/2024_02_04_A549_DENV_ZIKV_timelapse/1-model_train/logs", @@ -66,4 +55,4 @@ return_predictions=True, ) -# %% \ No newline at end of file +# %% diff --git a/applications/infection_classification/test_infection_classifier.py b/applications/infection_classification/test_infection_classifier.py index 1d15a18c..c4c5d761 100644 --- a/applications/infection_classification/test_infection_classifier.py +++ b/applications/infection_classification/test_infection_classifier.py @@ -1,7 +1,9 @@ # %% from viscy.data.hcs import HCSDataModule import lightning.pytorch as pl -from viscy.scripts.infection_phenotyping.classify_infection_2D import SemanticSegUNet2D +from applications.infection_classification.classify_infection_2D import ( + SemanticSegUNet2D, +) from pytorch_lightning.loggers import TensorBoardLogger from viscy.transforms import NormalizeSampled @@ -10,8 +12,8 @@ data_module = HCSDataModule( data_path=test_datapath, - source_channel=['TXR_Density3D','Phase3D'], - target_channel=['Inf_mask'], + source_channel=["TXR_Density3D", "Phase3D"], + target_channel=["Inf_mask"], split_ratio=0.7, z_window_size=1, architecture="2D", @@ -49,89 +51,8 @@ ckpt_path="/hpc/projects/intracellular_dashboard/viral-sensor/2024_04_25_BJ5a_DENV_TimeCourse/5-infection_classifier/0-model_training/logs/checkpoint_epoch=206.ckpt", ) -trainer.test(model=model, datamodule=data_module) - - - - -# # %% script to develop confusion matrix for infected cell classifier - -# from iohub.ngff import open_ome_zarr -# import numpy as np -# from skimage.measure import regionprops, label -# import cv2 -# import seaborn as sns -# import matplotlib.pyplot as plt - -# # %% load the predicted zarr and the human-in-loop annotations zarr - -# pred_datapath = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/pred/Exp_2024_02_13_DENV_3infMarked_pred.zarr" -# test_datapath = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2024_02_13_DENV_3infMarked_test.zarr" - -# pred_dataset = open_ome_zarr( -# pred_datapath, -# layout="hcs", -# mode="r+", -# ) -# chan_pred = pred_dataset.channel_names - -# test_dataset = open_ome_zarr( -# test_datapath, -# layout="hcs", -# mode="r+", -# ) -# chan_test = test_dataset.channel_names - -# # %% compute confusion matrix for one image -# for well_id, well_data in pred_dataset.wells(): -# well_name, well_no = well_id.split("/") - -# for pos_name, pos_data in well_data.positions(): - -# pred_data = pos_data.data -# pred_pos_data = pred_data.numpy() -# T,C,Z,X,Y = pred_pos_data.shape - -# test_data = test_dataset[well_id + "/" + pos_name + "/0"] -# test_pos_data = test_data.numpy() - -# # compute confusion matrix for each time point and add to total -# conf_mat = np.zeros((2, 2)) -# for time in range(T): -# pred_img = pred_pos_data[time, chan_pred.index("Inf_mask_prediction"), 0, : , :] -# test_img = test_pos_data[time, chan_test.index("Inf_mask"), 0, : , :] - -# test_img_rz = cv2.resize(test_img, dsize=(2016,2048), interpolation=cv2.INTER_NEAREST)# pred_img = -# pred_img = np.where(test_img_rz > 0, pred_img, 0) - -# # find objects in every image -# label_img = label(test_img_rz) -# regions_label = regionprops(label_img) - -# # store pixel id for every label in pred and test imgs -# for region in regions_label: -# if region.area > 0: -# row, col = region.centroid -# pred_id = pred_img[int(row), int(col)] -# test_id = test_img_rz[int(row), int(col)] -# if pred_id == 1 and test_id == 1: -# conf_mat[1,1] += 1 -# if pred_id == 1 and test_id == 2: -# conf_mat[1,0] += 1 -# if pred_id == 2 and test_id == 1: -# conf_mat[0,1] += 1 -# if pred_id == 2 and test_id == 2: -# conf_mat[0,0] += 1 - -# # display the confusion matrix -# ax= plt.subplot() -# sns.heatmap(conf_mat, annot=True, fmt='g', ax=ax); #annot=True to annotate cells, ftm='g' to disable scientific notation - -# # labels, title and ticks -# ax.set_xlabel('annotated labels');ax.set_ylabel('predicted labels'); -# ax.set_title('Confusion Matrix'); -# ax.xaxis.set_ticklabels(['infected', 'uninfected']); ax.yaxis.set_ticklabels(['infected', 'uninfected']); +# %% test the model +trainer.test(model=model, datamodule=data_module) -# # %% # %% From bd23f3b654e7d17f64b4b6c2d86a6243b2fa53f9 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Tue, 9 Jul 2024 11:23:48 -0700 Subject: [PATCH 79/92] add explicit packaging list --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 8f6978de..18cefec8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,9 @@ dev = [ [project.scripts] viscy = "viscy.cli.cli:main" +[tool.setuptools] +packages = ["viscy"] + [tool.setuptools_scm] write_to = "viscy/_version.py" From 701ea7732286d3a05592dc165acdd5a1837eff07 Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Tue, 9 Jul 2024 11:36:29 -0700 Subject: [PATCH 80/92] rename testing script --- ...st_infection_classifier.py => infection_classifier_testing.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename applications/infection_classification/{test_infection_classifier.py => infection_classifier_testing.py} (100%) diff --git a/applications/infection_classification/test_infection_classifier.py b/applications/infection_classification/infection_classifier_testing.py similarity index 100% rename from applications/infection_classification/test_infection_classifier.py rename to applications/infection_classification/infection_classifier_testing.py From 8ddb58e035c50304d5eecfe66f0b66ac09dca494 Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Wed, 10 Jul 2024 12:03:29 -0700 Subject: [PATCH 81/92] update readme --- applications/infection_classification/readme.md | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/applications/infection_classification/readme.md b/applications/infection_classification/readme.md index 74dbc500..2013d302 100644 --- a/applications/infection_classification/readme.md +++ b/applications/infection_classification/readme.md @@ -1,7 +1,13 @@ # Infection Classification Model -This repository contains the code for the infection classification model (`infection_classification_model.py`) used in the infection phenotyping project. +This repository contains the code for developing the infection classification model used in the infection phenotyping project. Infection classification models can be trained on human annotated ground truth with fluorescence sensor channel and phase channel to predict the state of infection of single cells. The pixels are predicted to be background (class 0), uninfected (class 1) or infected (class 2) by the model. ## Overview -The `infection_classification_model.py` file implements a machine learning model for classifying infections based on various features. The model is trained on a labeled dataset, either fluorescence or label-free images, and can be used to predict the infection type for new samples. \ No newline at end of file +The following scripts are available: + +Training: `infection_classification_*model.py` file implements a machine learning model for classifying infections based on various features. The model is trained on a labeled dataset, with fluorescence and label-free images. + +Testing: `infection_classifier_testing.py` file tests the 2D infection classification model trained on a 2D dataset. + +Prediction: `predict_classifier_testing.py` is an example script to perform prediction using 2D data and 2D model. It can be used to predict the infection type for new samples. \ No newline at end of file From a49cfba80246b15dac531b631463bee35fb87d6d Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Wed, 10 Jul 2024 12:15:08 -0700 Subject: [PATCH 82/92] move function to preprocessing --- .../Infection_classification_25DModel.py | 41 ++-------------- .../Infection_classification_2Dmodel.py | 38 +------------- .../Infection_classification_covnextModel.py | 39 +-------------- .../classify_infection_25D.py | 1 - .../classify_infection_2D.py | 2 - .../predict_infection_classifier.py | 2 +- viscy/preprocessing/compute_pixel_ratio.py | 49 +++++++++++++++++++ 7 files changed, 58 insertions(+), 114 deletions(-) create mode 100644 viscy/preprocessing/compute_pixel_ratio.py diff --git a/applications/infection_classification/Infection_classification_25DModel.py b/applications/infection_classification/Infection_classification_25DModel.py index 55d208f0..7fdf2b63 100644 --- a/applications/infection_classification/Infection_classification_25DModel.py +++ b/applications/infection_classification/Infection_classification_25DModel.py @@ -12,7 +12,7 @@ from applications.infection_classification.classify_infection_25D import ( SemanticSegUNet25D, ) - +from viscy.preprocessing import calculate_pixel_ratio from iohub.ngff import open_ome_zarr # %% Create a dataloader and visualize the batches. @@ -20,42 +20,7 @@ # Set the path to the dataset dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2023_11_08_Opencell_infection/OC43_infection_timelapse_trainVal.zarr" -# find ratio of background, uninfected and infected pixels -zarr_input = open_ome_zarr( - dataset_path, - layout="hcs", - mode="r+", -) -in_chan_names = zarr_input.channel_names - -num_pixels_bkg = 0 -num_pixels_uninf = 0 -num_pixels_inf = 0 -num_pixels = 0 -for well_id, well_data in zarr_input.wells(): - well_name, well_no = well_id.split("/") - - for pos_name, pos_data in well_data.positions(): - data = pos_data.data - T, C, Z, Y, X = data.shape - out_data = data.numpy() - for time in range(T): - Inf_mask = out_data[time, in_chan_names.index("Inf_mask"), ...] - # Calculate the number of pixels valued 0, 1, and 2 in 'Inf_mask' - num_pixels_bkg = num_pixels_bkg + (Inf_mask == 0).sum() - num_pixels_uninf = num_pixels_uninf + (Inf_mask == 1).sum() - num_pixels_inf = num_pixels_inf + (Inf_mask == 2).sum() - num_pixels = num_pixels + Z * X * Y - -pixel_ratio_1 = [ - num_pixels / num_pixels_bkg, - num_pixels / num_pixels_uninf, - num_pixels / num_pixels_inf, -] -pixel_ratio_sum = sum(pixel_ratio_1) -pixel_ratio = [ratio / pixel_ratio_sum for ratio in pixel_ratio_1] - -# %% craete data module +# %% create data module # Create an instance of HCSDataModule data_module = HCSDataModule( @@ -86,6 +51,8 @@ ], ) +pixel_ratio = calculate_pixel_ratio(dataset_path,target_channel="Inf_mask") + # Prepare the data data_module.prepare_data() diff --git a/applications/infection_classification/Infection_classification_2Dmodel.py b/applications/infection_classification/Infection_classification_2Dmodel.py index f0978765..76ed41f8 100644 --- a/applications/infection_classification/Infection_classification_2Dmodel.py +++ b/applications/infection_classification/Infection_classification_2Dmodel.py @@ -16,48 +16,13 @@ from applications.infection_classification.classify_infection_2D import ( SemanticSegUNet2D, ) -from iohub.ngff import open_ome_zarr +from viscy.preprocessing import calculate_pixel_ratio # %% calculate the ratio of background, uninfected and infected pixels in the input dataset # Set the path to the dataset dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/2024_04_25_BJ5a_DENV_TimeCourse/4-human_annotation/train_data.zarr" -# find ratio of background, uninfected and infected pixels -zarr_input = open_ome_zarr( - dataset_path, - layout="hcs", - mode="r+", -) -in_chan_names = zarr_input.channel_names - -num_pixels_bkg = 0 -num_pixels_uninf = 0 -num_pixels_inf = 0 -num_pixels = 0 -for well_id, well_data in zarr_input.wells(): - well_name, well_no = well_id.split("/") - - for pos_name, pos_data in well_data.positions(): - data = pos_data.data - T, C, Z, Y, X = data.shape - out_data = data.numpy() - for time in range(T): - Inf_mask = out_data[time, in_chan_names.index("Inf_mask"), ...] - # Calculate the number of pixels valued 0, 1, and 2 in 'Inf_mask' - num_pixels_bkg = num_pixels_bkg + (Inf_mask == 0).sum() - num_pixels_uninf = num_pixels_uninf + (Inf_mask == 1).sum() - num_pixels_inf = num_pixels_inf + (Inf_mask == 2).sum() - num_pixels = num_pixels + Z * X * Y - -pixel_ratio_1 = [ - num_pixels / num_pixels_bkg, - num_pixels / num_pixels_uninf, - num_pixels / num_pixels_inf, -] -pixel_ratio_sum = sum(pixel_ratio_1) -pixel_ratio = [ratio / pixel_ratio_sum for ratio in pixel_ratio_1] - # %% Create an instance of HCSDataModule data_module = HCSDataModule( @@ -99,6 +64,7 @@ ), ], ) +pixel_ratio = calculate_pixel_ratio(dataset_path, target_channel="Inf_mask") # Prepare the data data_module.prepare_data() diff --git a/applications/infection_classification/Infection_classification_covnextModel.py b/applications/infection_classification/Infection_classification_covnextModel.py index 5af56ec8..1f453c60 100644 --- a/applications/infection_classification/Infection_classification_covnextModel.py +++ b/applications/infection_classification/Infection_classification_covnextModel.py @@ -14,49 +14,13 @@ from applications.infection_classification.classify_infection_covnext import ( SemanticSegUNet22D, ) - -from iohub.ngff import open_ome_zarr +from viscy.preprocessing import calculate_pixel_ratio # %% Create a dataloader and visualize the batches. # Set the path to the dataset dataset_path = "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/datasets/Exp_2023_11_08_Opencell_infection/OC43_infection_timelapse_all_curated_train.zarr" -# find ratio of background, uninfected and infected pixels -zarr_input = open_ome_zarr( - dataset_path, - layout="hcs", - mode="r+", -) -in_chan_names = zarr_input.channel_names - -num_pixels_bkg = 0 -num_pixels_uninf = 0 -num_pixels_inf = 0 -num_pixels = 0 -for well_id, well_data in zarr_input.wells(): - well_name, well_no = well_id.split("/") - - for pos_name, pos_data in well_data.positions(): - data = pos_data.data - T, C, Z, Y, X = data.shape - out_data = data.numpy() - for time in range(T): - Inf_mask = out_data[time, in_chan_names.index("Inf_mask"), ...] - # Calculate the number of pixels valued 0, 1, and 2 in 'Inf_mask' - num_pixels_bkg = num_pixels_bkg + (Inf_mask == 0).sum() - num_pixels_uninf = num_pixels_uninf + (Inf_mask == 1).sum() - num_pixels_inf = num_pixels_inf + (Inf_mask == 2).sum() - num_pixels = num_pixels + Z * X * Y - -pixel_ratio_1 = [ - num_pixels / num_pixels_bkg, - num_pixels / num_pixels_uninf, - num_pixels / num_pixels_inf, -] -pixel_ratio_sum = sum(pixel_ratio_1) -pixel_ratio = [ratio / pixel_ratio_sum for ratio in pixel_ratio_1] - # %% craete data module # Create an instance of HCSDataModule @@ -87,6 +51,7 @@ ) ], ) +pixel_ratio = calculate_pixel_ratio(dataset_path, target_channel="Inf_mask") # Prepare the data data_module.prepare_data() diff --git a/applications/infection_classification/classify_infection_25D.py b/applications/infection_classification/classify_infection_25D.py index 464cecc1..58bbd0d6 100644 --- a/applications/infection_classification/classify_infection_25D.py +++ b/applications/infection_classification/classify_infection_25D.py @@ -17,7 +17,6 @@ from viscy.unet.networks.Unet25D import Unet25d from viscy.data.hcs import Sample -# # %% Methods to compute confusion matrix per cell using torchmetrics diff --git a/applications/infection_classification/classify_infection_2D.py b/applications/infection_classification/classify_infection_2D.py index 1227cf8a..555a652b 100644 --- a/applications/infection_classification/classify_infection_2D.py +++ b/applications/infection_classification/classify_infection_2D.py @@ -21,8 +21,6 @@ # from viscy.unet.networks.Unet25D import Unet25d from viscy.data.hcs import Sample -# from skimage.io import imsave - # # %% Methods to compute confusion matrix per cell using torchmetrics diff --git a/applications/infection_classification/predict_infection_classifier.py b/applications/infection_classification/predict_infection_classifier.py index 9eedad32..6e5397f7 100644 --- a/applications/infection_classification/predict_infection_classifier.py +++ b/applications/infection_classification/predict_infection_classifier.py @@ -52,7 +52,7 @@ trainer.predict( model=model, datamodule=data_module, - return_predictions=True, + return_predictions=False, ) # %% diff --git a/viscy/preprocessing/compute_pixel_ratio.py b/viscy/preprocessing/compute_pixel_ratio.py new file mode 100644 index 00000000..3f0dc3f5 --- /dev/null +++ b/viscy/preprocessing/compute_pixel_ratio.py @@ -0,0 +1,49 @@ + +''' compute the pixel ratio of background (0), uninfected (1) and infected (2) pixels in the zarr dataset''' + +from iohub.ngff import open_ome_zarr + +def calculate_pixel_ratio(dataset_path: str, target_channel: str): + """ + find ratio of background, uninfected and infected pixels + in the input dataset + Args: + dataset_path (str): Path to the dataset + Returns: + pixel_ratio (list): List of ratios of background, uninfected and infected pixels + """ + zarr_input = open_ome_zarr( + dataset_path, + layout="hcs", + mode="r+", + ) + in_chan_names = zarr_input.channel_names + + num_pixels_bkg = 0 + num_pixels_uninf = 0 + num_pixels_inf = 0 + num_pixels = 0 + for well_id, well_data in zarr_input.wells(): + well_name, well_no = well_id.split("/") + + for pos_name, pos_data in well_data.positions(): + data = pos_data.data + T, C, Z, Y, X = data.shape + out_data = data.numpy() + for time in range(T): + Inf_mask = out_data[time, in_chan_names.index(target_channel), ...] + # Calculate the number of pixels valued 0, 1, and 2 in 'Inf_mask' + num_pixels_bkg = num_pixels_bkg + (Inf_mask == 0).sum() + num_pixels_uninf = num_pixels_uninf + (Inf_mask == 1).sum() + num_pixels_inf = num_pixels_inf + (Inf_mask == 2).sum() + num_pixels = num_pixels + Z * X * Y + + pixel_ratio_1 = [ + num_pixels / num_pixels_bkg, + num_pixels / num_pixels_uninf, + num_pixels / num_pixels_inf, + ] + pixel_ratio_sum = sum(pixel_ratio_1) + pixel_ratio = [ratio / pixel_ratio_sum for ratio in pixel_ratio_1] + + return pixel_ratio \ No newline at end of file From 00baf9df31591098eb1016dbf0c5caa1a625757a Mon Sep 17 00:00:00 2001 From: Soorya Pradeep Date: Wed, 10 Jul 2024 12:18:14 -0700 Subject: [PATCH 83/92] format code --- viscy/preprocessing/compute_pixel_ratio.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/viscy/preprocessing/compute_pixel_ratio.py b/viscy/preprocessing/compute_pixel_ratio.py index 3f0dc3f5..6b30571a 100644 --- a/viscy/preprocessing/compute_pixel_ratio.py +++ b/viscy/preprocessing/compute_pixel_ratio.py @@ -1,14 +1,15 @@ - -''' compute the pixel ratio of background (0), uninfected (1) and infected (2) pixels in the zarr dataset''' +""" compute the pixel ratio of background (0), uninfected (1) and infected (2) pixels in the zarr dataset""" from iohub.ngff import open_ome_zarr + def calculate_pixel_ratio(dataset_path: str, target_channel: str): """ find ratio of background, uninfected and infected pixels in the input dataset Args: dataset_path (str): Path to the dataset + target_channel (str): Name of the manual annotation channel Returns: pixel_ratio (list): List of ratios of background, uninfected and infected pixels """ @@ -46,4 +47,4 @@ def calculate_pixel_ratio(dataset_path: str, target_channel: str): pixel_ratio_sum = sum(pixel_ratio_1) pixel_ratio = [ratio / pixel_ratio_sum for ratio in pixel_ratio_1] - return pixel_ratio \ No newline at end of file + return pixel_ratio From cd35d2253e452aa95d77befadb1c4a859f61ba82 Mon Sep 17 00:00:00 2001 From: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> Date: Wed, 10 Jul 2024 13:43:29 -0700 Subject: [PATCH 84/92] formatting --- .../Infection_classification_25DModel.py | 17 ++++++------- .../Infection_classification_2Dmodel.py | 19 +++++++-------- .../Infection_classification_covnextModel.py | 14 +++++------ .../classify_infection_25D.py | 24 +++++++++---------- .../classify_infection_2D.py | 23 +++++++++--------- .../classify_infection_covnext.py | 23 +++++++++--------- .../infection_classifier_testing.py | 3 ++- .../predict_infection_classifier.py | 5 ++-- .../infection_classification/readme.md | 2 +- 9 files changed, 62 insertions(+), 68 deletions(-) diff --git a/applications/infection_classification/Infection_classification_25DModel.py b/applications/infection_classification/Infection_classification_25DModel.py index 7fdf2b63..015cd106 100644 --- a/applications/infection_classification/Infection_classification_25DModel.py +++ b/applications/infection_classification/Infection_classification_25DModel.py @@ -1,19 +1,16 @@ # %% -import torch import lightning.pytorch as pl +import torch import torch.nn as nn - -from pytorch_lightning.loggers import TensorBoardLogger -from pytorch_lightning.callbacks import ModelCheckpoint - -from viscy.transforms import RandWeightedCropd -from viscy.transforms import NormalizeSampled -from viscy.data.hcs import HCSDataModule from applications.infection_classification.classify_infection_25D import ( SemanticSegUNet25D, ) +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.loggers import TensorBoardLogger + +from viscy.data.hcs import HCSDataModule from viscy.preprocessing import calculate_pixel_ratio -from iohub.ngff import open_ome_zarr +from viscy.transforms import NormalizeSampled, RandWeightedCropd # %% Create a dataloader and visualize the batches. @@ -51,7 +48,7 @@ ], ) -pixel_ratio = calculate_pixel_ratio(dataset_path,target_channel="Inf_mask") +pixel_ratio = calculate_pixel_ratio(dataset_path, target_channel="Inf_mask") # Prepare the data data_module.prepare_data() diff --git a/applications/infection_classification/Infection_classification_2Dmodel.py b/applications/infection_classification/Infection_classification_2Dmodel.py index 76ed41f8..055c232d 100644 --- a/applications/infection_classification/Infection_classification_2Dmodel.py +++ b/applications/infection_classification/Infection_classification_2Dmodel.py @@ -1,22 +1,21 @@ # %% -import torch import lightning.pytorch as pl +import torch import torch.nn as nn - -from pytorch_lightning.loggers import TensorBoardLogger +from applications.infection_classification.classify_infection_2D import ( + SemanticSegUNet2D, +) from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.loggers import TensorBoardLogger +from viscy.data.hcs import HCSDataModule +from viscy.preprocessing import calculate_pixel_ratio from viscy.transforms import ( - RandWeightedCropd, NormalizeSampled, - RandScaleIntensityd, RandGaussianSmoothd, + RandScaleIntensityd, + RandWeightedCropd, ) -from viscy.data.hcs import HCSDataModule -from applications.infection_classification.classify_infection_2D import ( - SemanticSegUNet2D, -) -from viscy.preprocessing import calculate_pixel_ratio # %% calculate the ratio of background, uninfected and infected pixels in the input dataset diff --git a/applications/infection_classification/Infection_classification_covnextModel.py b/applications/infection_classification/Infection_classification_covnextModel.py index 1f453c60..7fea6e76 100644 --- a/applications/infection_classification/Infection_classification_covnextModel.py +++ b/applications/infection_classification/Infection_classification_covnextModel.py @@ -1,20 +1,18 @@ # %% # import sys # sys.path.append("/hpc/mydata/soorya.pradeep/viscy_infection_phenotyping/Viscy/") -import torch import lightning.pytorch as pl +import torch import torch.nn as nn - -from pytorch_lightning.loggers import TensorBoardLogger -from pytorch_lightning.callbacks import ModelCheckpoint - -from viscy.transforms import RandWeightedCropd -from viscy.transforms import NormalizeSampled -from viscy.data.hcs import HCSDataModule from applications.infection_classification.classify_infection_covnext import ( SemanticSegUNet22D, ) +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.loggers import TensorBoardLogger + +from viscy.data.hcs import HCSDataModule from viscy.preprocessing import calculate_pixel_ratio +from viscy.transforms import NormalizeSampled, RandWeightedCropd # %% Create a dataloader and visualize the batches. diff --git a/applications/infection_classification/classify_infection_25D.py b/applications/infection_classification/classify_infection_25D.py index 58bbd0d6..e16f56f4 100644 --- a/applications/infection_classification/classify_infection_25D.py +++ b/applications/infection_classification/classify_infection_25D.py @@ -1,21 +1,21 @@ +# import torchview +from typing import Literal, Sequence + +import cv2 +import lightning.pytorch as pl +import matplotlib.pyplot as plt +import numpy as np import torch import torch.nn as nn -import lightning.pytorch as pl import torch.nn.functional as F -from torch import Tensor -import cv2 - -# import torchview -from typing import Literal, Sequence -from skimage.exposure import rescale_intensity from matplotlib.cm import get_cmap -from skimage.measure import regionprops, label -import numpy as np -import matplotlib.pyplot as plt - from monai.transforms import DivisiblePad -from viscy.unet.networks.Unet25D import Unet25d +from skimage.exposure import rescale_intensity +from skimage.measure import label, regionprops +from torch import Tensor + from viscy.data.hcs import Sample +from viscy.unet.networks.Unet25D import Unet25d # %% Methods to compute confusion matrix per cell using torchmetrics diff --git a/applications/infection_classification/classify_infection_2D.py b/applications/infection_classification/classify_infection_2D.py index 555a652b..afd97ab7 100644 --- a/applications/infection_classification/classify_infection_2D.py +++ b/applications/infection_classification/classify_infection_2D.py @@ -1,25 +1,24 @@ # %% lightning moules for infection classification using the viscy library +# import torchview +from typing import Literal, Sequence + +import cv2 +import lightning.pytorch as pl +import matplotlib.pyplot as plt +import numpy as np import torch import torch.nn as nn -import lightning.pytorch as pl import torch.nn.functional as F -from torch import Tensor -import cv2 - -# import torchview -from typing import Literal, Sequence -from skimage.exposure import rescale_intensity from matplotlib.cm import get_cmap -from skimage.measure import regionprops, label -import numpy as np -import matplotlib.pyplot as plt - from monai.transforms import DivisiblePad -from viscy.unet.networks.Unet2D import Unet2d +from skimage.exposure import rescale_intensity +from skimage.measure import label, regionprops +from torch import Tensor # from viscy.unet.networks.Unet25D import Unet25d from viscy.data.hcs import Sample +from viscy.unet.networks.Unet2D import Unet2d # # %% Methods to compute confusion matrix per cell using torchmetrics diff --git a/applications/infection_classification/classify_infection_covnext.py b/applications/infection_classification/classify_infection_covnext.py index 385da925..5eddb236 100644 --- a/applications/infection_classification/classify_infection_covnext.py +++ b/applications/infection_classification/classify_infection_covnext.py @@ -1,20 +1,19 @@ +# import torchview +from typing import Literal, Sequence + +import cv2 +import lightning.pytorch as pl +import matplotlib.pyplot as plt +import numpy as np import torch import torch.nn as nn -import lightning.pytorch as pl import torch.nn.functional as F -from torch import Tensor -import cv2 - -# import torchview -from typing import Literal, Sequence -from skimage.exposure import rescale_intensity from matplotlib.cm import get_cmap -from skimage.measure import regionprops, label -import numpy as np -import matplotlib.pyplot as plt - from monai.transforms import DivisiblePad -from viscy.unet.networks.Unet25D import Unet25d +from skimage.exposure import rescale_intensity +from skimage.measure import label, regionprops +from torch import Tensor + from viscy.data.hcs import Sample from viscy.light.engine import VSUNet diff --git a/applications/infection_classification/infection_classifier_testing.py b/applications/infection_classification/infection_classifier_testing.py index c4c5d761..fea8326d 100644 --- a/applications/infection_classification/infection_classifier_testing.py +++ b/applications/infection_classification/infection_classifier_testing.py @@ -1,10 +1,11 @@ # %% -from viscy.data.hcs import HCSDataModule import lightning.pytorch as pl from applications.infection_classification.classify_infection_2D import ( SemanticSegUNet2D, ) from pytorch_lightning.loggers import TensorBoardLogger + +from viscy.data.hcs import HCSDataModule from viscy.transforms import NormalizeSampled # %% test the model on the test set diff --git a/applications/infection_classification/predict_infection_classifier.py b/applications/infection_classification/predict_infection_classifier.py index 6e5397f7..458fc670 100644 --- a/applications/infection_classification/predict_infection_classifier.py +++ b/applications/infection_classification/predict_infection_classifier.py @@ -1,11 +1,12 @@ # %% -from viscy.light.predict_writer import HCSPredictionWriter -from viscy.data.hcs import HCSDataModule import lightning.pytorch as pl from applications.infection_classification.classify_infection_2D import ( SemanticSegUNet2D, ) + +from viscy.data.hcs import HCSDataModule +from viscy.light.predict_writer import HCSPredictionWriter from viscy.transforms import NormalizeSampled # %% # %% write the predictions to a zarr file diff --git a/applications/infection_classification/readme.md b/applications/infection_classification/readme.md index 2013d302..a5d317b7 100644 --- a/applications/infection_classification/readme.md +++ b/applications/infection_classification/readme.md @@ -10,4 +10,4 @@ Training: `infection_classification_*model.py` file implements a machine learnin Testing: `infection_classifier_testing.py` file tests the 2D infection classification model trained on a 2D dataset. -Prediction: `predict_classifier_testing.py` is an example script to perform prediction using 2D data and 2D model. It can be used to predict the infection type for new samples. \ No newline at end of file +Prediction: `predict_classifier_testing.py` is an example script to perform prediction using 2D data and 2D model. It can be used to predict the infection type for new samples. From 9d528ca8dcfafc773e1e6e06a0da8c0eaf582f81 Mon Sep 17 00:00:00 2001 From: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> Date: Wed, 10 Jul 2024 14:42:07 -0700 Subject: [PATCH 85/92] histogram with dask --- viscy/preprocessing/compute_pixel_ratio.py | 50 ---------------------- viscy/preprocessing/pixel_ratio.py | 23 ++++++++++ 2 files changed, 23 insertions(+), 50 deletions(-) delete mode 100644 viscy/preprocessing/compute_pixel_ratio.py create mode 100644 viscy/preprocessing/pixel_ratio.py diff --git a/viscy/preprocessing/compute_pixel_ratio.py b/viscy/preprocessing/compute_pixel_ratio.py deleted file mode 100644 index 6b30571a..00000000 --- a/viscy/preprocessing/compute_pixel_ratio.py +++ /dev/null @@ -1,50 +0,0 @@ -""" compute the pixel ratio of background (0), uninfected (1) and infected (2) pixels in the zarr dataset""" - -from iohub.ngff import open_ome_zarr - - -def calculate_pixel_ratio(dataset_path: str, target_channel: str): - """ - find ratio of background, uninfected and infected pixels - in the input dataset - Args: - dataset_path (str): Path to the dataset - target_channel (str): Name of the manual annotation channel - Returns: - pixel_ratio (list): List of ratios of background, uninfected and infected pixels - """ - zarr_input = open_ome_zarr( - dataset_path, - layout="hcs", - mode="r+", - ) - in_chan_names = zarr_input.channel_names - - num_pixels_bkg = 0 - num_pixels_uninf = 0 - num_pixels_inf = 0 - num_pixels = 0 - for well_id, well_data in zarr_input.wells(): - well_name, well_no = well_id.split("/") - - for pos_name, pos_data in well_data.positions(): - data = pos_data.data - T, C, Z, Y, X = data.shape - out_data = data.numpy() - for time in range(T): - Inf_mask = out_data[time, in_chan_names.index(target_channel), ...] - # Calculate the number of pixels valued 0, 1, and 2 in 'Inf_mask' - num_pixels_bkg = num_pixels_bkg + (Inf_mask == 0).sum() - num_pixels_uninf = num_pixels_uninf + (Inf_mask == 1).sum() - num_pixels_inf = num_pixels_inf + (Inf_mask == 2).sum() - num_pixels = num_pixels + Z * X * Y - - pixel_ratio_1 = [ - num_pixels / num_pixels_bkg, - num_pixels / num_pixels_uninf, - num_pixels / num_pixels_inf, - ] - pixel_ratio_sum = sum(pixel_ratio_1) - pixel_ratio = [ratio / pixel_ratio_sum for ratio in pixel_ratio_1] - - return pixel_ratio diff --git a/viscy/preprocessing/pixel_ratio.py b/viscy/preprocessing/pixel_ratio.py new file mode 100644 index 00000000..b40eb4e2 --- /dev/null +++ b/viscy/preprocessing/pixel_ratio.py @@ -0,0 +1,23 @@ +""" compute the pixel ratio of background (0), uninfected (1) and infected (2) pixels in the zarr dataset""" + +import dask.array as da +from iohub.ngff import open_ome_zarr +from numpy.typing import NDArray + + +def sematic_class_weights(dataset_path: str, target_channel: str) -> NDArray: + """Computes class balancing weights for semantic segmentation. + The weights can be used for cross-entropy loss. + + :param str dataset_path: HCS OME-Zarr dataset path + :param str target_channel: target channel name + :return NDArray: inverted ratio of background, uninfected and infected pixels + """ + dataset = open_ome_zarr(dataset_path) + arrays = [da.from_zarr(pos, "0") for _, pos in dataset.positions()] + imgs = da.stack(arrays, axis=0)[ + :, dataset.get_channel_index(target_channel) + ] + ratio, _ = da.histogram(imgs, bins=range(4), density=True) + weights = 1 / ratio + return weights.compute() From 7e477f4cab3ca23763e33d6a906ed31c7616a8be Mon Sep 17 00:00:00 2001 From: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> Date: Wed, 10 Jul 2024 15:02:21 -0700 Subject: [PATCH 86/92] fix index and test --- tests/preprocessing/test_pixel_ratio.py | 12 ++++++++++++ viscy/preprocessing/pixel_ratio.py | 4 ++-- 2 files changed, 14 insertions(+), 2 deletions(-) create mode 100644 tests/preprocessing/test_pixel_ratio.py diff --git a/tests/preprocessing/test_pixel_ratio.py b/tests/preprocessing/test_pixel_ratio.py new file mode 100644 index 00000000..7e8354d6 --- /dev/null +++ b/tests/preprocessing/test_pixel_ratio.py @@ -0,0 +1,12 @@ +from iohub.ngff import open_ome_zarr + +from viscy.preprocessing.pixel_ratio import sematic_class_weights + + +def test_sematic_class_weights(small_hcs_dataset): + weights = sematic_class_weights(small_hcs_dataset, "GFP") + assert weights.shape == (3,) + assert weights[0] == 1.0 + # infinity + assert weights[1] > 1.0 + assert weights[2] > 1.0 diff --git a/viscy/preprocessing/pixel_ratio.py b/viscy/preprocessing/pixel_ratio.py index b40eb4e2..8d3287c6 100644 --- a/viscy/preprocessing/pixel_ratio.py +++ b/viscy/preprocessing/pixel_ratio.py @@ -14,9 +14,9 @@ def sematic_class_weights(dataset_path: str, target_channel: str) -> NDArray: :return NDArray: inverted ratio of background, uninfected and infected pixels """ dataset = open_ome_zarr(dataset_path) - arrays = [da.from_zarr(pos, "0") for _, pos in dataset.positions()] + arrays = [da.from_zarr(pos["0"]) for _, pos in dataset.positions()] imgs = da.stack(arrays, axis=0)[ - :, dataset.get_channel_index(target_channel) + :, :, dataset.get_channel_index(target_channel) ] ratio, _ = da.histogram(imgs, bins=range(4), density=True) weights = 1 / ratio From 7a007f24ad237d5e3c62b1c9df91df5dd2d5c518 Mon Sep 17 00:00:00 2001 From: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> Date: Wed, 10 Jul 2024 15:04:37 -0700 Subject: [PATCH 87/92] fix import --- .../Infection_classification_25DModel.py | 4 ++-- .../Infection_classification_2Dmodel.py | 4 ++-- .../Infection_classification_covnextModel.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/applications/infection_classification/Infection_classification_25DModel.py b/applications/infection_classification/Infection_classification_25DModel.py index 015cd106..a4e712f5 100644 --- a/applications/infection_classification/Infection_classification_25DModel.py +++ b/applications/infection_classification/Infection_classification_25DModel.py @@ -9,7 +9,7 @@ from pytorch_lightning.loggers import TensorBoardLogger from viscy.data.hcs import HCSDataModule -from viscy.preprocessing import calculate_pixel_ratio +from viscy.preprocessing.pixel_ratio import sematic_class_weights from viscy.transforms import NormalizeSampled, RandWeightedCropd # %% Create a dataloader and visualize the batches. @@ -48,7 +48,7 @@ ], ) -pixel_ratio = calculate_pixel_ratio(dataset_path, target_channel="Inf_mask") +pixel_ratio = sematic_class_weights(dataset_path, target_channel="Inf_mask") # Prepare the data data_module.prepare_data() diff --git a/applications/infection_classification/Infection_classification_2Dmodel.py b/applications/infection_classification/Infection_classification_2Dmodel.py index 055c232d..333718aa 100644 --- a/applications/infection_classification/Infection_classification_2Dmodel.py +++ b/applications/infection_classification/Infection_classification_2Dmodel.py @@ -9,7 +9,7 @@ from pytorch_lightning.loggers import TensorBoardLogger from viscy.data.hcs import HCSDataModule -from viscy.preprocessing import calculate_pixel_ratio +from viscy.preprocessing.pixel_ratio import sematic_class_weights from viscy.transforms import ( NormalizeSampled, RandGaussianSmoothd, @@ -63,7 +63,7 @@ ), ], ) -pixel_ratio = calculate_pixel_ratio(dataset_path, target_channel="Inf_mask") +pixel_ratio = sematic_class_weights(dataset_path, target_channel="Inf_mask") # Prepare the data data_module.prepare_data() diff --git a/applications/infection_classification/Infection_classification_covnextModel.py b/applications/infection_classification/Infection_classification_covnextModel.py index 7fea6e76..bfe20362 100644 --- a/applications/infection_classification/Infection_classification_covnextModel.py +++ b/applications/infection_classification/Infection_classification_covnextModel.py @@ -11,7 +11,7 @@ from pytorch_lightning.loggers import TensorBoardLogger from viscy.data.hcs import HCSDataModule -from viscy.preprocessing import calculate_pixel_ratio +from viscy.preprocessing.pixel_ratio import sematic_class_weights from viscy.transforms import NormalizeSampled, RandWeightedCropd # %% Create a dataloader and visualize the batches. @@ -49,7 +49,7 @@ ) ], ) -pixel_ratio = calculate_pixel_ratio(dataset_path, target_channel="Inf_mask") +pixel_ratio = sematic_class_weights(dataset_path, target_channel="Inf_mask") # Prepare the data data_module.prepare_data() From 9b46035a974f75f0f67c1ddbd5f63804a0dc9547 Mon Sep 17 00:00:00 2001 From: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> Date: Wed, 10 Jul 2024 15:05:58 -0700 Subject: [PATCH 88/92] black --- viscy/preprocessing/pixel_ratio.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/viscy/preprocessing/pixel_ratio.py b/viscy/preprocessing/pixel_ratio.py index 8d3287c6..9d511985 100644 --- a/viscy/preprocessing/pixel_ratio.py +++ b/viscy/preprocessing/pixel_ratio.py @@ -15,9 +15,7 @@ def sematic_class_weights(dataset_path: str, target_channel: str) -> NDArray: """ dataset = open_ome_zarr(dataset_path) arrays = [da.from_zarr(pos["0"]) for _, pos in dataset.positions()] - imgs = da.stack(arrays, axis=0)[ - :, :, dataset.get_channel_index(target_channel) - ] + imgs = da.stack(arrays, axis=0)[:, :, dataset.get_channel_index(target_channel)] ratio, _ = da.histogram(imgs, bins=range(4), density=True) weights = 1 / ratio return weights.compute() From 173a5dbce68214d2ffbdbb11873a68f5445cf3e7 Mon Sep 17 00:00:00 2001 From: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> Date: Wed, 10 Jul 2024 15:11:59 -0700 Subject: [PATCH 89/92] fix float comp --- tests/preprocessing/test_pixel_ratio.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/preprocessing/test_pixel_ratio.py b/tests/preprocessing/test_pixel_ratio.py index 7e8354d6..4254d980 100644 --- a/tests/preprocessing/test_pixel_ratio.py +++ b/tests/preprocessing/test_pixel_ratio.py @@ -1,4 +1,5 @@ from iohub.ngff import open_ome_zarr +from numpy.testing import assert_allclose from viscy.preprocessing.pixel_ratio import sematic_class_weights @@ -6,7 +7,7 @@ def test_sematic_class_weights(small_hcs_dataset): weights = sematic_class_weights(small_hcs_dataset, "GFP") assert weights.shape == (3,) - assert weights[0] == 1.0 + assert_allclose(weights[0], 1.0) # infinity assert weights[1] > 1.0 assert weights[2] > 1.0 From 19cf4e6b695e83cfb349526db63d6627275d309d Mon Sep 17 00:00:00 2001 From: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> Date: Wed, 10 Jul 2024 15:15:44 -0700 Subject: [PATCH 90/92] clean up headers --- viscy/preprocessing/pixel_ratio.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/viscy/preprocessing/pixel_ratio.py b/viscy/preprocessing/pixel_ratio.py index 9d511985..8f7904be 100644 --- a/viscy/preprocessing/pixel_ratio.py +++ b/viscy/preprocessing/pixel_ratio.py @@ -1,5 +1,3 @@ -""" compute the pixel ratio of background (0), uninfected (1) and infected (2) pixels in the zarr dataset""" - import dask.array as da from iohub.ngff import open_ome_zarr from numpy.typing import NDArray From 4b36875f73fa87739d7b473a4bd7e7cb53261dd7 Mon Sep 17 00:00:00 2001 From: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> Date: Wed, 10 Jul 2024 15:16:04 -0700 Subject: [PATCH 91/92] clean up import --- tests/preprocessing/test_pixel_ratio.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/preprocessing/test_pixel_ratio.py b/tests/preprocessing/test_pixel_ratio.py index 4254d980..cc53f918 100644 --- a/tests/preprocessing/test_pixel_ratio.py +++ b/tests/preprocessing/test_pixel_ratio.py @@ -1,4 +1,3 @@ -from iohub.ngff import open_ome_zarr from numpy.testing import assert_allclose from viscy.preprocessing.pixel_ratio import sematic_class_weights From 37ab0aa014cecac8714541108e4f8bd5a60ec88d Mon Sep 17 00:00:00 2001 From: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> Date: Wed, 10 Jul 2024 15:18:48 -0700 Subject: [PATCH 92/92] add argument to change number of classes --- tests/preprocessing/test_pixel_ratio.py | 3 +++ viscy/preprocessing/pixel_ratio.py | 7 +++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/preprocessing/test_pixel_ratio.py b/tests/preprocessing/test_pixel_ratio.py index cc53f918..2dce7afe 100644 --- a/tests/preprocessing/test_pixel_ratio.py +++ b/tests/preprocessing/test_pixel_ratio.py @@ -10,3 +10,6 @@ def test_sematic_class_weights(small_hcs_dataset): # infinity assert weights[1] > 1.0 assert weights[2] > 1.0 + assert sematic_class_weights( + small_hcs_dataset, "GFP", num_classes=2 + ).shape == (2,) diff --git a/viscy/preprocessing/pixel_ratio.py b/viscy/preprocessing/pixel_ratio.py index 8f7904be..29c2ed41 100644 --- a/viscy/preprocessing/pixel_ratio.py +++ b/viscy/preprocessing/pixel_ratio.py @@ -3,17 +3,20 @@ from numpy.typing import NDArray -def sematic_class_weights(dataset_path: str, target_channel: str) -> NDArray: +def sematic_class_weights( + dataset_path: str, target_channel: str, num_classes: int = 3 +) -> NDArray: """Computes class balancing weights for semantic segmentation. The weights can be used for cross-entropy loss. :param str dataset_path: HCS OME-Zarr dataset path :param str target_channel: target channel name + :param int num_classes: number of classes :return NDArray: inverted ratio of background, uninfected and infected pixels """ dataset = open_ome_zarr(dataset_path) arrays = [da.from_zarr(pos["0"]) for _, pos in dataset.positions()] imgs = da.stack(arrays, axis=0)[:, :, dataset.get_channel_index(target_channel)] - ratio, _ = da.histogram(imgs, bins=range(4), density=True) + ratio, _ = da.histogram(imgs, bins=range(num_classes + 1), density=True) weights = 1 / ratio return weights.compute()