From 22bfc71eb8578d89a3b8cb701bf83c304b39d591 Mon Sep 17 00:00:00 2001
From: Eduardo Hirata-Miyasaki <edhiratam@gmail.com>
Date: Wed, 12 Jun 2024 17:51:00 -0700
Subject: [PATCH 1/9] add the scale metadata from the input.

---
 viscy/light/predict_writer.py | 18 +++++++++++++++++-
 1 file changed, 17 insertions(+), 1 deletion(-)

diff --git a/viscy/light/predict_writer.py b/viscy/light/predict_writer.py
index 7a58009c..e4b3ae17 100644
--- a/viscy/light/predict_writer.py
+++ b/viscy/light/predict_writer.py
@@ -5,6 +5,7 @@
 import numpy as np
 import torch
 from iohub.ngff import ImageArray, _pad_shape, open_ome_zarr
+from iohub.ngff_meta import TransformationMeta
 from lightning.pytorch import LightningModule, Trainer
 from lightning.pytorch.callbacks import BasePredictionWriter
 from numpy.typing import DTypeLike, NDArray
@@ -47,6 +48,7 @@ def _blend_in(old_stack: NDArray, new_stack: NDArray, z_slice: slice) -> None:
 class HCSPredictionWriter(BasePredictionWriter):
     """Callback to store virtual staining predictions as HCS OME-Zarr.
 
+    :param str metadata_store: Path to the zarr store input
     :param str output_store: Path to the zarr store to store output
     :param bool write_input: Write the source and target channels too
         (must be writing to a new store),
@@ -57,6 +59,7 @@ class HCSPredictionWriter(BasePredictionWriter):
 
     def __init__(
         self,
+        metadata_store: str,
         output_store: str,
         write_input: bool = False,
         write_interval: Literal["batch", "epoch", "batch_and_epoch"] = "batch",
@@ -64,6 +67,16 @@ def __init__(
         super().__init__(write_interval)
         self.output_store = output_store
         self.write_input = write_input
+        self.metadata_store = metadata_store
+        self._dataset_scale = (1, 1, 1, 1, 1)
+        self._get_scale_metadata(metadata_store)
+
+    def _get_scale_metadata(self, metadata_store: str) -> None:
+        # Update the scale metadata
+        with open_ome_zarr(metadata_store, mode="r") as meta_plate:
+            for _, pos in meta_plate.positions():
+                self._dataset_scale = pos.scale
+                break
 
     def on_predict_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
         dm: HCSDataModule = trainer.datamodule
@@ -128,7 +141,9 @@ def write_sample(
         z_index += self.z_padding
         z_slice = slice(z_index, z_index + sample_prediction.shape[-3])
         image = self._create_image(
-            img_name, sample_prediction.shape, sample_prediction.dtype
+            img_name,
+            sample_prediction.shape,
+            sample_prediction.dtype,
         )
         _resize_image(image, t_index, z_slice)
         if self.write_input:
@@ -160,4 +175,5 @@ def _create_image(self, img_name: str, shape: tuple[int], dtype: DTypeLike):
             shape=shape,
             dtype=dtype,
             chunks=_pad_shape(tuple(shape[-2:]), 5),
+            transform=[TransformationMeta(type="scale", scale=self._dataset_scale)],
         )

From dfcb98a29739384980b01b8bc53737a94af3f8f1 Mon Sep 17 00:00:00 2001
From: Eduardo Hirata-Miyasaki <edhiratam@gmail.com>
Date: Wed, 12 Jun 2024 17:57:00 -0700
Subject: [PATCH 2/9] add change to config file

---
 examples/configs/predict_example.yml | 1 +
 1 file changed, 1 insertion(+)

diff --git a/examples/configs/predict_example.yml b/examples/configs/predict_example.yml
index b2556139..1bb296c3 100644
--- a/examples/configs/predict_example.yml
+++ b/examples/configs/predict_example.yml
@@ -10,6 +10,7 @@ predict:
     callbacks:
       - class_path: viscy.light.predict_writer.HCSPredictionWriter
         init_args:
+          metadata_store: null
           output_store: null
           write_input: false
           write_interval: batch

From 29b24b7ef9736e2ad5c5c39b6a2c5c9e8c893331 Mon Sep 17 00:00:00 2001
From: Eduardo Hirata-Miyasaki <edhiratam@gmail.com>
Date: Wed, 12 Jun 2024 18:01:44 -0700
Subject: [PATCH 3/9] adding a try except

---
 viscy/light/predict_writer.py | 14 ++++++++++----
 1 file changed, 10 insertions(+), 4 deletions(-)

diff --git a/viscy/light/predict_writer.py b/viscy/light/predict_writer.py
index e4b3ae17..40ad56f8 100644
--- a/viscy/light/predict_writer.py
+++ b/viscy/light/predict_writer.py
@@ -73,10 +73,16 @@ def __init__(
 
     def _get_scale_metadata(self, metadata_store: str) -> None:
         # Update the scale metadata
-        with open_ome_zarr(metadata_store, mode="r") as meta_plate:
-            for _, pos in meta_plate.positions():
-                self._dataset_scale = pos.scale
-                break
+        try:
+            with open_ome_zarr(metadata_store, mode="r") as meta_plate:
+                for _, pos in meta_plate.positions():
+                    self._dataset_scale = pos.scale
+                    break
+        except IOError:
+            _logger.warning(
+                f"Could not read scale metadata from '{metadata_store}'. "
+                "Using default scale (1, 1, 1, 1, 1)."
+            )
 
     def on_predict_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
         dm: HCSDataModule = trainer.datamodule

From ca1d89723603ba6470fb2a69b03b35a4925def3d Mon Sep 17 00:00:00 2001
From: Eduardo Hirata-Miyasaki <edhiratam@gmail.com>
Date: Wed, 12 Jun 2024 18:08:01 -0700
Subject: [PATCH 4/9] passing None for default behaviour and letting iohub
 handle the scale default.

---
 viscy/light/predict_writer.py | 9 ++++++---
 1 file changed, 6 insertions(+), 3 deletions(-)

diff --git a/viscy/light/predict_writer.py b/viscy/light/predict_writer.py
index 40ad56f8..96318f88 100644
--- a/viscy/light/predict_writer.py
+++ b/viscy/light/predict_writer.py
@@ -68,7 +68,7 @@ def __init__(
         self.output_store = output_store
         self.write_input = write_input
         self.metadata_store = metadata_store
-        self._dataset_scale = (1, 1, 1, 1, 1)
+        self._dataset_scale = None
         self._get_scale_metadata(metadata_store)
 
     def _get_scale_metadata(self, metadata_store: str) -> None:
@@ -76,7 +76,10 @@ def _get_scale_metadata(self, metadata_store: str) -> None:
         try:
             with open_ome_zarr(metadata_store, mode="r") as meta_plate:
                 for _, pos in meta_plate.positions():
-                    self._dataset_scale = pos.scale
+                    self._dataset_scale = [
+                        TransformationMeta(type="scale", scale=pos.scale)
+                    ]
+                    _logger.debug(f"Dataset scale {pos.scale}.")
                     break
         except IOError:
             _logger.warning(
@@ -181,5 +184,5 @@ def _create_image(self, img_name: str, shape: tuple[int], dtype: DTypeLike):
             shape=shape,
             dtype=dtype,
             chunks=_pad_shape(tuple(shape[-2:]), 5),
-            transform=[TransformationMeta(type="scale", scale=self._dataset_scale)],
+            transform=self._dataset_scale,
         )

From 7d3e7b82782ce4b8c02887c041b44af4a20979e6 Mon Sep 17 00:00:00 2001
From: Eduardo Hirata-Miyasaki <edhiratam@gmail.com>
Date: Thu, 13 Jun 2024 12:09:04 -0700
Subject: [PATCH 5/9] making default metadta_store to none and adding letting
 iohub handle the exceptions

---
 viscy/light/predict_writer.py | 13 +++----------
 1 file changed, 3 insertions(+), 10 deletions(-)

diff --git a/viscy/light/predict_writer.py b/viscy/light/predict_writer.py
index 96318f88..6fc93cf5 100644
--- a/viscy/light/predict_writer.py
+++ b/viscy/light/predict_writer.py
@@ -59,8 +59,8 @@ class HCSPredictionWriter(BasePredictionWriter):
 
     def __init__(
         self,
-        metadata_store: str,
         output_store: str,
+        metadata_store: str | None = None,
         write_input: bool = False,
         write_interval: Literal["batch", "epoch", "batch_and_epoch"] = "batch",
     ) -> None:
@@ -73,7 +73,7 @@ def __init__(
 
     def _get_scale_metadata(self, metadata_store: str) -> None:
         # Update the scale metadata
-        try:
+        if metadata_store is not None:
             with open_ome_zarr(metadata_store, mode="r") as meta_plate:
                 for _, pos in meta_plate.positions():
                     self._dataset_scale = [
@@ -81,11 +81,6 @@ def _get_scale_metadata(self, metadata_store: str) -> None:
                     ]
                     _logger.debug(f"Dataset scale {pos.scale}.")
                     break
-        except IOError:
-            _logger.warning(
-                f"Could not read scale metadata from '{metadata_store}'. "
-                "Using default scale (1, 1, 1, 1, 1)."
-            )
 
     def on_predict_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
         dm: HCSDataModule = trainer.datamodule
@@ -150,9 +145,7 @@ def write_sample(
         z_index += self.z_padding
         z_slice = slice(z_index, z_index + sample_prediction.shape[-3])
         image = self._create_image(
-            img_name,
-            sample_prediction.shape,
-            sample_prediction.dtype,
+            img_name, sample_prediction.shape, sample_prediction.dtype
         )
         _resize_image(image, t_index, z_slice)
         if self.write_input:

From 8de699c6255ae2b3baee1b0783c49871508869ca Mon Sep 17 00:00:00 2001
From: Ziwen Liu <ziwen.liu@czbiohub.org>
Date: Thu, 13 Jun 2024 13:19:18 -0700
Subject: [PATCH 6/9] fix docstring

---
 viscy/light/predict_writer.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/viscy/light/predict_writer.py b/viscy/light/predict_writer.py
index 6fc93cf5..76c868dc 100644
--- a/viscy/light/predict_writer.py
+++ b/viscy/light/predict_writer.py
@@ -48,8 +48,8 @@ def _blend_in(old_stack: NDArray, new_stack: NDArray, z_slice: slice) -> None:
 class HCSPredictionWriter(BasePredictionWriter):
     """Callback to store virtual staining predictions as HCS OME-Zarr.
 
-    :param str metadata_store: Path to the zarr store input
     :param str output_store: Path to the zarr store to store output
+    :param str metadata_store: Path to the OME-Zarr dataset to copy scale metadata from
     :param bool write_input: Write the source and target channels too
         (must be writing to a new store),
         defaults to False

From 7bba528732fb11cc1a796bfcafe6315c86c64bc0 Mon Sep 17 00:00:00 2001
From: Ziwen Liu <ziwen.liu@czbiohub.org>
Date: Thu, 13 Jun 2024 13:20:10 -0700
Subject: [PATCH 7/9] fix type hint

---
 viscy/light/predict_writer.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/viscy/light/predict_writer.py b/viscy/light/predict_writer.py
index 76c868dc..95c2a241 100644
--- a/viscy/light/predict_writer.py
+++ b/viscy/light/predict_writer.py
@@ -31,7 +31,7 @@ def _resize_image(image: ImageArray, t_index: int, z_slice: slice) -> None:
         )
 
 
-def _blend_in(old_stack: NDArray, new_stack: NDArray, z_slice: slice) -> None:
+def _blend_in(old_stack: NDArray, new_stack: NDArray, z_slice: slice) -> NDArray:
     if z_slice.start == 0:
         return new_stack
     depth = z_slice.stop - z_slice.start

From ea0edfbbef9667f2ec8c30f7116b9391046dbc22 Mon Sep 17 00:00:00 2001
From: Ziwen Liu <ziwen.liu@czbiohub.org>
Date: Thu, 13 Jun 2024 13:39:07 -0700
Subject: [PATCH 8/9] read input store directly

---
 viscy/light/predict_writer.py | 10 ++++------
 1 file changed, 4 insertions(+), 6 deletions(-)

diff --git a/viscy/light/predict_writer.py b/viscy/light/predict_writer.py
index 95c2a241..87d4b74d 100644
--- a/viscy/light/predict_writer.py
+++ b/viscy/light/predict_writer.py
@@ -1,5 +1,6 @@
 import logging
 import os
+from pathlib import Path
 from typing import Literal, Optional, Sequence
 
 import numpy as np
@@ -49,7 +50,6 @@ class HCSPredictionWriter(BasePredictionWriter):
     """Callback to store virtual staining predictions as HCS OME-Zarr.
 
     :param str output_store: Path to the zarr store to store output
-    :param str metadata_store: Path to the OME-Zarr dataset to copy scale metadata from
     :param bool write_input: Write the source and target channels too
         (must be writing to a new store),
         defaults to False
@@ -60,21 +60,18 @@ class HCSPredictionWriter(BasePredictionWriter):
     def __init__(
         self,
         output_store: str,
-        metadata_store: str | None = None,
         write_input: bool = False,
         write_interval: Literal["batch", "epoch", "batch_and_epoch"] = "batch",
     ) -> None:
         super().__init__(write_interval)
         self.output_store = output_store
         self.write_input = write_input
-        self.metadata_store = metadata_store
         self._dataset_scale = None
-        self._get_scale_metadata(metadata_store)
 
-    def _get_scale_metadata(self, metadata_store: str) -> None:
+    def _get_scale_metadata(self, metadata_store: Path) -> None:
         # Update the scale metadata
         if metadata_store is not None:
-            with open_ome_zarr(metadata_store, mode="r") as meta_plate:
+            with open_ome_zarr(metadata_store, mode="r", layout="hcs") as meta_plate:
                 for _, pos in meta_plate.positions():
                     self._dataset_scale = [
                         TransformationMeta(type="scale", scale=pos.scale)
@@ -84,6 +81,7 @@ def _get_scale_metadata(self, metadata_store: str) -> None:
 
     def on_predict_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
         dm: HCSDataModule = trainer.datamodule
+        self._get_scale_metadata(dm.data_path)
         self.z_padding = dm.z_window_size // 2 if dm.target_2d else 0
         _logger.debug(f"Setting Z padding to {self.z_padding}")
         source_channel = dm.source_channel

From dd6dabdf6958bde2ce6e773ba93979bb3ab62e14 Mon Sep 17 00:00:00 2001
From: Ziwen Liu <ziwen.liu@czbiohub.org>
Date: Thu, 13 Jun 2024 13:39:51 -0700
Subject: [PATCH 9/9] revert change to the example config

---
 examples/configs/predict_example.yml | 1 -
 1 file changed, 1 deletion(-)

diff --git a/examples/configs/predict_example.yml b/examples/configs/predict_example.yml
index 1bb296c3..b2556139 100644
--- a/examples/configs/predict_example.yml
+++ b/examples/configs/predict_example.yml
@@ -10,7 +10,6 @@ predict:
     callbacks:
       - class_path: viscy.light.predict_writer.HCSPredictionWriter
         init_args:
-          metadata_store: null
           output_store: null
           write_input: false
           write_interval: batch