Skip to content

Commit 9ae582f

Browse files
edyoshikunziw-liu
andcommitted
Add the scale metadata to the output_stores (#89)
* add the scale metadata from the input. * add change to config file * adding a try except * passing None for default behaviour and letting iohub handle the scale default. * making default metadta_store to none and adding letting iohub handle the exceptions * fix docstring * fix type hint * read input store directly * revert change to the example config --------- Co-authored-by: Ziwen Liu <[email protected]>
1 parent c56b9d0 commit 9ae582f

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

viscy/light/predict_writer.py

+16
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import logging
22
import os
3+
from pathlib import Path
34
from typing import Literal, Optional, Sequence
45

56
import numpy as np
67
import torch
78
from iohub.ngff import ImageArray, _pad_shape, open_ome_zarr
9+
from iohub.ngff_meta import TransformationMeta
810
from lightning.pytorch import LightningModule, Trainer
911
from lightning.pytorch.callbacks import BasePredictionWriter
1012
from numpy.typing import DTypeLike, NDArray
@@ -88,9 +90,22 @@ def __init__(
8890
super().__init__(write_interval)
8991
self.output_store = output_store
9092
self.write_input = write_input
93+
self._dataset_scale = None
94+
95+
def _get_scale_metadata(self, metadata_store: Path) -> None:
96+
# Update the scale metadata
97+
if metadata_store is not None:
98+
with open_ome_zarr(metadata_store, mode="r", layout="hcs") as meta_plate:
99+
for _, pos in meta_plate.positions():
100+
self._dataset_scale = [
101+
TransformationMeta(type="scale", scale=pos.scale)
102+
]
103+
_logger.debug(f"Dataset scale {pos.scale}.")
104+
break
91105

92106
def on_predict_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
93107
dm: HCSDataModule = trainer.datamodule
108+
self._get_scale_metadata(dm.data_path)
94109
self.z_padding = dm.z_window_size // 2 if dm.target_2d else 0
95110
_logger.debug(f"Setting Z padding to {self.z_padding}")
96111
source_channel = dm.source_channel
@@ -184,4 +199,5 @@ def _create_image(self, img_name: str, shape: tuple[int], dtype: DTypeLike):
184199
shape=shape,
185200
dtype=dtype,
186201
chunks=_pad_shape(tuple(shape[-2:]), 5),
202+
transform=self._dataset_scale,
187203
)

0 commit comments

Comments
 (0)