Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for the get_transform analog to set_transform #264

Merged
merged 13 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 99 additions & 12 deletions iohub/ngff/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,19 +969,9 @@ def scale(self) -> list[float]:
Helper function for scale transform metadata of
highest resolution scale.
"""
scale = [1] * self.data.ndim
transforms = (
self.metadata.multiscales[0].datasets[0].coordinate_transformations
return self.get_effective_scale(
self.metadata.multiscales[0].datasets[0].path
)
for trans in transforms:
if trans.type == "scale":
if len(trans.scale) != len(scale):
raise RuntimeError(
f"Length of scale transformation {len(trans.scale)} "
f"does not match data dimension {len(scale)}."
)
scale = [s1 * s2 for s1, s2 in zip(scale, trans.scale)]
return scale

@property
def axis_names(self) -> list[str]:
Expand Down Expand Up @@ -1010,6 +1000,103 @@ def get_axis_index(self, axis_name: str) -> int:
"""
return self.axis_names.index(axis_name.lower())

def _get_all_transforms(
self, image: str | Literal["*"]
) -> list[TransformationMeta]:
"""Get all transforms metadata
for one image array or the whole FOV.

Parameters
----------
image : str | Literal["*"]
Name of one image array (e.g. "0") to query,
or "*" for the whole FOV

Returns
-------
list[TransformationMeta]
All transforms applicable to this image or FOV.
"""
ziw-liu marked this conversation as resolved.
Show resolved Hide resolved
transforms: list[TransformationMeta] = (
[
t
for t in self.metadata.multiscales[
0
].coordinate_transformations
]
if self.metadata.multiscales[0].coordinate_transformations
is not None
else []
)
if image != "*" and image in self:
for i, dataset_meta in enumerate(
self.metadata.multiscales[0].datasets
):
if dataset_meta.path == image:
transforms.extend(
self.metadata.multiscales[0]
.datasets[i]
.coordinate_transformations
)
elif image != "*":
raise ValueError(f"Key {image} not recognized.")
return transforms

def get_effective_scale(
self,
image: str | Literal["*"],
) -> list[float]:
"""Get the effective coordinate scale metadata
for one image array or the whole FOV.

Parameters
----------
image : str | Literal["*"]
Name of one image array (e.g. "0") to query,
or "*" for the whole FOV

Returns
-------
list[float]
A list of floats representing the total scale
for the image or FOV for each axis.
"""
transforms = self._get_all_transforms(image)

full_scale = np.ones(len(self.axes), dtype=float)
for transform in transforms:
if transform.type == "scale":
full_scale *= np.array(transform.scale)

return [float(x) for x in full_scale]

def get_effective_translation(
self,
image: str | Literal["*"],
) -> TransformationMeta:
"""Get the effective coordinate translation metadata
for one image array or the whole FOV.

Parameters
----------
image : str | Literal["*"]
Name of one image array (e.g. "0") to query,
or "*" for the whole FOV

Returns
-------
list[float]
A list of floats representing the total translation
for the image or FOV for each axis.
"""
transforms = self._get_all_transforms(image)
full_translation = np.zeros(len(self.axes), dtype=float)
for transform in transforms:
if transform.type == "translation":
full_translation += np.array(transform.translation)

return [float(x) for x in full_translation]

def set_transform(
self,
image: str | Literal["*"],
Expand Down
94 changes: 94 additions & 0 deletions tests/ngff/test_ngff.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,100 @@ def test_set_transform_image(ch_shape_dtype, arr_name):
]


input_transformations = [
([TransformationMeta(type="identity")], []),
([TransformationMeta(type="scale", scale=(1.0, 2.0, 3.0, 4.0, 5.0))], []),
(
[
TransformationMeta(
type="translation", translation=(1.0, 2.0, 3.0, 4.0, 5.0)
)
],
[],
),
(
[
TransformationMeta(type="scale", scale=(2.0, 2.0, 2.0, 2.0, 2.0)),
TransformationMeta(
type="translation", translation=(1.0, 1.0, 1.0, 1.0, 1.0)
),
],
[
TransformationMeta(type="scale", scale=(2.0, 2.0, 2.0, 2.0, 2.0)),
TransformationMeta(
type="translation", translation=(1.0, 1.0, 1.0, 1.0, 1.0)
),
],
),
]
target_scales = [
[1.0, 1.0, 1.0, 1.0, 1.0],
[1.0, 2.0, 3.0, 4.0, 5.0],
[1.0, 1.0, 1.0, 1.0, 1.0],
[4.0, 4.0, 4.0, 4.0, 4.0],
]
target_translations = [
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
[1.0, 2.0, 3.0, 4.0, 5.0],
[2.0, 2.0, 2.0, 2.0, 2.0],
]


@pytest.mark.parametrize(
"transforms",
[
(saved, target)
for saved, target in zip(input_transformations, target_scales)
],
)
@given(
ch_shape_dtype=_channels_and_random_5d_shape_and_dtype(),
arr_name=short_alpha_numeric,
)
def test_get_effective_scale_image(transforms, ch_shape_dtype, arr_name):
"""Test `iohub.ngff.Position.get_effective_scale()`"""
(fov_transform, img_transform), expected_scale = transforms
channel_names, shape, dtype = ch_shape_dtype
with TemporaryDirectory() as temp_dir:
store_path = os.path.join(temp_dir, "ome.zarr")
with open_ome_zarr(
store_path, layout="fov", mode="w-", channel_names=channel_names
) as dataset:
dataset.create_zeros(name=arr_name, shape=shape, dtype=dtype)
dataset.set_transform(image="*", transform=fov_transform)
dataset.set_transform(image=arr_name, transform=img_transform)
scale = dataset.get_effective_scale(image=arr_name)
assert scale == expected_scale


@pytest.mark.parametrize(
"transforms",
[
(saved, target)
for saved, target in zip(input_transformations, target_translations)
],
)
@given(
ch_shape_dtype=_channels_and_random_5d_shape_and_dtype(),
arr_name=short_alpha_numeric,
)
def test_get_effective_translation_image(transforms, ch_shape_dtype, arr_name):
"""Test `iohub.ngff.Position.get_effective_translation()`"""
(fov_transform, img_transform), expected_translation = transforms
channel_names, shape, dtype = ch_shape_dtype
with TemporaryDirectory() as temp_dir:
store_path = os.path.join(temp_dir, "ome.zarr")
with open_ome_zarr(
store_path, layout="fov", mode="w-", channel_names=channel_names
) as dataset:
dataset.create_zeros(name=arr_name, shape=shape, dtype=dtype)
dataset.set_transform(image="*", transform=fov_transform)
dataset.set_transform(image=arr_name, transform=img_transform)
translation = dataset.get_effective_translation(image=arr_name)
assert translation == expected_translation


@given(
ch_shape_dtype=_channels_and_random_5d_shape_and_dtype(),
arr_name=short_alpha_numeric,
Expand Down
Loading