|
1 | 1 | import logging
|
2 | 2 | import os
|
| 3 | +from pathlib import Path |
3 | 4 | from typing import Literal, Optional, Sequence
|
4 | 5 |
|
5 | 6 | import numpy as np
|
6 | 7 | import torch
|
7 | 8 | from iohub.ngff import ImageArray, _pad_shape, open_ome_zarr
|
| 9 | +from iohub.ngff_meta import TransformationMeta |
8 | 10 | from lightning.pytorch import LightningModule, Trainer
|
9 | 11 | from lightning.pytorch.callbacks import BasePredictionWriter
|
10 | 12 | from numpy.typing import DTypeLike, NDArray
|
@@ -88,9 +90,22 @@ def __init__(
|
88 | 90 | super().__init__(write_interval)
|
89 | 91 | self.output_store = output_store
|
90 | 92 | self.write_input = write_input
|
| 93 | + self._dataset_scale = None |
| 94 | + |
| 95 | + def _get_scale_metadata(self, metadata_store: Path) -> None: |
| 96 | + # Update the scale metadata |
| 97 | + if metadata_store is not None: |
| 98 | + with open_ome_zarr(metadata_store, mode="r", layout="hcs") as meta_plate: |
| 99 | + for _, pos in meta_plate.positions(): |
| 100 | + self._dataset_scale = [ |
| 101 | + TransformationMeta(type="scale", scale=pos.scale) |
| 102 | + ] |
| 103 | + _logger.debug(f"Dataset scale {pos.scale}.") |
| 104 | + break |
91 | 105 |
|
92 | 106 | def on_predict_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
|
93 | 107 | dm: HCSDataModule = trainer.datamodule
|
| 108 | + self._get_scale_metadata(dm.data_path) |
94 | 109 | self.z_padding = dm.z_window_size // 2 if dm.target_2d else 0
|
95 | 110 | _logger.debug(f"Setting Z padding to {self.z_padding}")
|
96 | 111 | source_channel = dm.source_channel
|
@@ -184,4 +199,5 @@ def _create_image(self, img_name: str, shape: tuple[int], dtype: DTypeLike):
|
184 | 199 | shape=shape,
|
185 | 200 | dtype=dtype,
|
186 | 201 | chunks=_pad_shape(tuple(shape[-2:]), 5),
|
| 202 | + transform=self._dataset_scale, |
187 | 203 | )
|
0 commit comments