Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the scale metadata to the output_stores #89

Merged
merged 9 commits into from
Jun 13, 2024
Merged
1 change: 1 addition & 0 deletions examples/configs/predict_example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ predict:
callbacks:
- class_path: viscy.light.predict_writer.HCSPredictionWriter
init_args:
metadata_store: null
output_store: null
write_input: false
write_interval: batch
Expand Down
27 changes: 26 additions & 1 deletion viscy/light/predict_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import torch
from iohub.ngff import ImageArray, _pad_shape, open_ome_zarr
from iohub.ngff_meta import TransformationMeta
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.callbacks import BasePredictionWriter
from numpy.typing import DTypeLike, NDArray
Expand Down Expand Up @@ -47,6 +48,7 @@ def _blend_in(old_stack: NDArray, new_stack: NDArray, z_slice: slice) -> None:
class HCSPredictionWriter(BasePredictionWriter):
"""Callback to store virtual staining predictions as HCS OME-Zarr.

:param str metadata_store: Path to the zarr store input
:param str output_store: Path to the zarr store to store output
:param bool write_input: Write the source and target channels too
(must be writing to a new store),
Expand All @@ -57,13 +59,33 @@ class HCSPredictionWriter(BasePredictionWriter):

def __init__(
self,
metadata_store: str,
output_store: str,
write_input: bool = False,
write_interval: Literal["batch", "epoch", "batch_and_epoch"] = "batch",
) -> None:
super().__init__(write_interval)
self.output_store = output_store
self.write_input = write_input
self.metadata_store = metadata_store
self._dataset_scale = None
self._get_scale_metadata(metadata_store)

def _get_scale_metadata(self, metadata_store: str) -> None:
# Update the scale metadata
try:
with open_ome_zarr(metadata_store, mode="r") as meta_plate:
for _, pos in meta_plate.positions():
self._dataset_scale = [
TransformationMeta(type="scale", scale=pos.scale)
]
_logger.debug(f"Dataset scale {pos.scale}.")
break
except IOError:
_logger.warning(
f"Could not read scale metadata from '{metadata_store}'. "
"Using default scale (1, 1, 1, 1, 1)."
)

def on_predict_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
dm: HCSDataModule = trainer.datamodule
Expand Down Expand Up @@ -128,7 +150,9 @@ def write_sample(
z_index += self.z_padding
z_slice = slice(z_index, z_index + sample_prediction.shape[-3])
image = self._create_image(
img_name, sample_prediction.shape, sample_prediction.dtype
img_name,
sample_prediction.shape,
sample_prediction.dtype,
)
_resize_image(image, t_index, z_slice)
if self.write_input:
Expand Down Expand Up @@ -160,4 +184,5 @@ def _create_image(self, img_name: str, shape: tuple[int], dtype: DTypeLike):
shape=shape,
dtype=dtype,
chunks=_pad_shape(tuple(shape[-2:]), 5),
transform=self._dataset_scale,
)
Loading