From 22bfc71eb8578d89a3b8cb701bf83c304b39d591 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki <edhiratam@gmail.com> Date: Wed, 12 Jun 2024 17:51:00 -0700 Subject: [PATCH 1/9] add the scale metadata from the input. --- viscy/light/predict_writer.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/viscy/light/predict_writer.py b/viscy/light/predict_writer.py index 7a58009c..e4b3ae17 100644 --- a/viscy/light/predict_writer.py +++ b/viscy/light/predict_writer.py @@ -5,6 +5,7 @@ import numpy as np import torch from iohub.ngff import ImageArray, _pad_shape, open_ome_zarr +from iohub.ngff_meta import TransformationMeta from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.callbacks import BasePredictionWriter from numpy.typing import DTypeLike, NDArray @@ -47,6 +48,7 @@ def _blend_in(old_stack: NDArray, new_stack: NDArray, z_slice: slice) -> None: class HCSPredictionWriter(BasePredictionWriter): """Callback to store virtual staining predictions as HCS OME-Zarr. + :param str metadata_store: Path to the zarr store input :param str output_store: Path to the zarr store to store output :param bool write_input: Write the source and target channels too (must be writing to a new store), @@ -57,6 +59,7 @@ class HCSPredictionWriter(BasePredictionWriter): def __init__( self, + metadata_store: str, output_store: str, write_input: bool = False, write_interval: Literal["batch", "epoch", "batch_and_epoch"] = "batch", @@ -64,6 +67,16 @@ def __init__( super().__init__(write_interval) self.output_store = output_store self.write_input = write_input + self.metadata_store = metadata_store + self._dataset_scale = (1, 1, 1, 1, 1) + self._get_scale_metadata(metadata_store) + + def _get_scale_metadata(self, metadata_store: str) -> None: + # Update the scale metadata + with open_ome_zarr(metadata_store, mode="r") as meta_plate: + for _, pos in meta_plate.positions(): + self._dataset_scale = pos.scale + break def on_predict_start(self, trainer: Trainer, pl_module: LightningModule) -> None: dm: HCSDataModule = trainer.datamodule @@ -128,7 +141,9 @@ def write_sample( z_index += self.z_padding z_slice = slice(z_index, z_index + sample_prediction.shape[-3]) image = self._create_image( - img_name, sample_prediction.shape, sample_prediction.dtype + img_name, + sample_prediction.shape, + sample_prediction.dtype, ) _resize_image(image, t_index, z_slice) if self.write_input: @@ -160,4 +175,5 @@ def _create_image(self, img_name: str, shape: tuple[int], dtype: DTypeLike): shape=shape, dtype=dtype, chunks=_pad_shape(tuple(shape[-2:]), 5), + transform=[TransformationMeta(type="scale", scale=self._dataset_scale)], ) From dfcb98a29739384980b01b8bc53737a94af3f8f1 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki <edhiratam@gmail.com> Date: Wed, 12 Jun 2024 17:57:00 -0700 Subject: [PATCH 2/9] add change to config file --- examples/configs/predict_example.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/configs/predict_example.yml b/examples/configs/predict_example.yml index b2556139..1bb296c3 100644 --- a/examples/configs/predict_example.yml +++ b/examples/configs/predict_example.yml @@ -10,6 +10,7 @@ predict: callbacks: - class_path: viscy.light.predict_writer.HCSPredictionWriter init_args: + metadata_store: null output_store: null write_input: false write_interval: batch From 29b24b7ef9736e2ad5c5c39b6a2c5c9e8c893331 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki <edhiratam@gmail.com> Date: Wed, 12 Jun 2024 18:01:44 -0700 Subject: [PATCH 3/9] adding a try except --- viscy/light/predict_writer.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/viscy/light/predict_writer.py b/viscy/light/predict_writer.py index e4b3ae17..40ad56f8 100644 --- a/viscy/light/predict_writer.py +++ b/viscy/light/predict_writer.py @@ -73,10 +73,16 @@ def __init__( def _get_scale_metadata(self, metadata_store: str) -> None: # Update the scale metadata - with open_ome_zarr(metadata_store, mode="r") as meta_plate: - for _, pos in meta_plate.positions(): - self._dataset_scale = pos.scale - break + try: + with open_ome_zarr(metadata_store, mode="r") as meta_plate: + for _, pos in meta_plate.positions(): + self._dataset_scale = pos.scale + break + except IOError: + _logger.warning( + f"Could not read scale metadata from '{metadata_store}'. " + "Using default scale (1, 1, 1, 1, 1)." + ) def on_predict_start(self, trainer: Trainer, pl_module: LightningModule) -> None: dm: HCSDataModule = trainer.datamodule From ca1d89723603ba6470fb2a69b03b35a4925def3d Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki <edhiratam@gmail.com> Date: Wed, 12 Jun 2024 18:08:01 -0700 Subject: [PATCH 4/9] passing None for default behaviour and letting iohub handle the scale default. --- viscy/light/predict_writer.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/viscy/light/predict_writer.py b/viscy/light/predict_writer.py index 40ad56f8..96318f88 100644 --- a/viscy/light/predict_writer.py +++ b/viscy/light/predict_writer.py @@ -68,7 +68,7 @@ def __init__( self.output_store = output_store self.write_input = write_input self.metadata_store = metadata_store - self._dataset_scale = (1, 1, 1, 1, 1) + self._dataset_scale = None self._get_scale_metadata(metadata_store) def _get_scale_metadata(self, metadata_store: str) -> None: @@ -76,7 +76,10 @@ def _get_scale_metadata(self, metadata_store: str) -> None: try: with open_ome_zarr(metadata_store, mode="r") as meta_plate: for _, pos in meta_plate.positions(): - self._dataset_scale = pos.scale + self._dataset_scale = [ + TransformationMeta(type="scale", scale=pos.scale) + ] + _logger.debug(f"Dataset scale {pos.scale}.") break except IOError: _logger.warning( @@ -181,5 +184,5 @@ def _create_image(self, img_name: str, shape: tuple[int], dtype: DTypeLike): shape=shape, dtype=dtype, chunks=_pad_shape(tuple(shape[-2:]), 5), - transform=[TransformationMeta(type="scale", scale=self._dataset_scale)], + transform=self._dataset_scale, ) From 7d3e7b82782ce4b8c02887c041b44af4a20979e6 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki <edhiratam@gmail.com> Date: Thu, 13 Jun 2024 12:09:04 -0700 Subject: [PATCH 5/9] making default metadta_store to none and adding letting iohub handle the exceptions --- viscy/light/predict_writer.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/viscy/light/predict_writer.py b/viscy/light/predict_writer.py index 96318f88..6fc93cf5 100644 --- a/viscy/light/predict_writer.py +++ b/viscy/light/predict_writer.py @@ -59,8 +59,8 @@ class HCSPredictionWriter(BasePredictionWriter): def __init__( self, - metadata_store: str, output_store: str, + metadata_store: str | None = None, write_input: bool = False, write_interval: Literal["batch", "epoch", "batch_and_epoch"] = "batch", ) -> None: @@ -73,7 +73,7 @@ def __init__( def _get_scale_metadata(self, metadata_store: str) -> None: # Update the scale metadata - try: + if metadata_store is not None: with open_ome_zarr(metadata_store, mode="r") as meta_plate: for _, pos in meta_plate.positions(): self._dataset_scale = [ @@ -81,11 +81,6 @@ def _get_scale_metadata(self, metadata_store: str) -> None: ] _logger.debug(f"Dataset scale {pos.scale}.") break - except IOError: - _logger.warning( - f"Could not read scale metadata from '{metadata_store}'. " - "Using default scale (1, 1, 1, 1, 1)." - ) def on_predict_start(self, trainer: Trainer, pl_module: LightningModule) -> None: dm: HCSDataModule = trainer.datamodule @@ -150,9 +145,7 @@ def write_sample( z_index += self.z_padding z_slice = slice(z_index, z_index + sample_prediction.shape[-3]) image = self._create_image( - img_name, - sample_prediction.shape, - sample_prediction.dtype, + img_name, sample_prediction.shape, sample_prediction.dtype ) _resize_image(image, t_index, z_slice) if self.write_input: From 8de699c6255ae2b3baee1b0783c49871508869ca Mon Sep 17 00:00:00 2001 From: Ziwen Liu <ziwen.liu@czbiohub.org> Date: Thu, 13 Jun 2024 13:19:18 -0700 Subject: [PATCH 6/9] fix docstring --- viscy/light/predict_writer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/viscy/light/predict_writer.py b/viscy/light/predict_writer.py index 6fc93cf5..76c868dc 100644 --- a/viscy/light/predict_writer.py +++ b/viscy/light/predict_writer.py @@ -48,8 +48,8 @@ def _blend_in(old_stack: NDArray, new_stack: NDArray, z_slice: slice) -> None: class HCSPredictionWriter(BasePredictionWriter): """Callback to store virtual staining predictions as HCS OME-Zarr. - :param str metadata_store: Path to the zarr store input :param str output_store: Path to the zarr store to store output + :param str metadata_store: Path to the OME-Zarr dataset to copy scale metadata from :param bool write_input: Write the source and target channels too (must be writing to a new store), defaults to False From 7bba528732fb11cc1a796bfcafe6315c86c64bc0 Mon Sep 17 00:00:00 2001 From: Ziwen Liu <ziwen.liu@czbiohub.org> Date: Thu, 13 Jun 2024 13:20:10 -0700 Subject: [PATCH 7/9] fix type hint --- viscy/light/predict_writer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/viscy/light/predict_writer.py b/viscy/light/predict_writer.py index 76c868dc..95c2a241 100644 --- a/viscy/light/predict_writer.py +++ b/viscy/light/predict_writer.py @@ -31,7 +31,7 @@ def _resize_image(image: ImageArray, t_index: int, z_slice: slice) -> None: ) -def _blend_in(old_stack: NDArray, new_stack: NDArray, z_slice: slice) -> None: +def _blend_in(old_stack: NDArray, new_stack: NDArray, z_slice: slice) -> NDArray: if z_slice.start == 0: return new_stack depth = z_slice.stop - z_slice.start From ea0edfbbef9667f2ec8c30f7116b9391046dbc22 Mon Sep 17 00:00:00 2001 From: Ziwen Liu <ziwen.liu@czbiohub.org> Date: Thu, 13 Jun 2024 13:39:07 -0700 Subject: [PATCH 8/9] read input store directly --- viscy/light/predict_writer.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/viscy/light/predict_writer.py b/viscy/light/predict_writer.py index 95c2a241..87d4b74d 100644 --- a/viscy/light/predict_writer.py +++ b/viscy/light/predict_writer.py @@ -1,5 +1,6 @@ import logging import os +from pathlib import Path from typing import Literal, Optional, Sequence import numpy as np @@ -49,7 +50,6 @@ class HCSPredictionWriter(BasePredictionWriter): """Callback to store virtual staining predictions as HCS OME-Zarr. :param str output_store: Path to the zarr store to store output - :param str metadata_store: Path to the OME-Zarr dataset to copy scale metadata from :param bool write_input: Write the source and target channels too (must be writing to a new store), defaults to False @@ -60,21 +60,18 @@ class HCSPredictionWriter(BasePredictionWriter): def __init__( self, output_store: str, - metadata_store: str | None = None, write_input: bool = False, write_interval: Literal["batch", "epoch", "batch_and_epoch"] = "batch", ) -> None: super().__init__(write_interval) self.output_store = output_store self.write_input = write_input - self.metadata_store = metadata_store self._dataset_scale = None - self._get_scale_metadata(metadata_store) - def _get_scale_metadata(self, metadata_store: str) -> None: + def _get_scale_metadata(self, metadata_store: Path) -> None: # Update the scale metadata if metadata_store is not None: - with open_ome_zarr(metadata_store, mode="r") as meta_plate: + with open_ome_zarr(metadata_store, mode="r", layout="hcs") as meta_plate: for _, pos in meta_plate.positions(): self._dataset_scale = [ TransformationMeta(type="scale", scale=pos.scale) @@ -84,6 +81,7 @@ def _get_scale_metadata(self, metadata_store: str) -> None: def on_predict_start(self, trainer: Trainer, pl_module: LightningModule) -> None: dm: HCSDataModule = trainer.datamodule + self._get_scale_metadata(dm.data_path) self.z_padding = dm.z_window_size // 2 if dm.target_2d else 0 _logger.debug(f"Setting Z padding to {self.z_padding}") source_channel = dm.source_channel From dd6dabdf6958bde2ce6e773ba93979bb3ab62e14 Mon Sep 17 00:00:00 2001 From: Ziwen Liu <ziwen.liu@czbiohub.org> Date: Thu, 13 Jun 2024 13:39:51 -0700 Subject: [PATCH 9/9] revert change to the example config --- examples/configs/predict_example.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/configs/predict_example.yml b/examples/configs/predict_example.yml index 1bb296c3..b2556139 100644 --- a/examples/configs/predict_example.yml +++ b/examples/configs/predict_example.yml @@ -10,7 +10,6 @@ predict: callbacks: - class_path: viscy.light.predict_writer.HCSPredictionWriter init_args: - metadata_store: null output_store: null write_input: false write_interval: batch