Skip to content

Commit

Permalink
Merge pull request #5 from fractal-analytics-platform/task_tools
Browse files Browse the repository at this point in the history
Add more task tools
  • Loading branch information
lorenzocerrone authored Feb 9, 2025
2 parents e53b3a6 + 85c21b1 commit 4ca1bd5
Show file tree
Hide file tree
Showing 11 changed files with 402 additions and 32 deletions.
8 changes: 4 additions & 4 deletions src/fractal_converters_tools/stitching.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def sort_tiles_by_distance(
raise ValueError("Tiles are not coplanar")

min_point = _min_point(tiles)
return sorted(tiles, key=lambda x: (x.top_l - min_point).length())
return sorted(tiles, key=lambda x: (x.top_l - min_point).lengthXY())


def remove_tiles_offset(tiles: list[Tile]) -> list[Tile]:
Expand Down Expand Up @@ -115,7 +115,7 @@ def _remove_tile_XY_overalap(
moved_bbox = query_tile.move_by(vec)
iou = moved_bbox.iouXY(ref_tile)
if iou < eps:
lengths.append(vec.length())
lengths.append(vec.lengthXY())
vectors.append(vec)

min_idx = np.argmin(lengths)
Expand Down Expand Up @@ -166,7 +166,7 @@ def resolve_grid_tiles_overlap(tiles: list[Tile], grid_setup: GridSetup) -> list

# Find if a bounding box is close to the (x_in, y_in) position
point = Point(x_in, y_in, z=z, c=c, t=t)
distances = [(point - bbox.top_l).length() for bbox in tiles]
distances = [(point - bbox.top_l).lengthXY() for bbox in tiles]
min_dist = np.min(distances)
closest_bbox = tiles[np.argmin(distances)]

Expand Down Expand Up @@ -246,7 +246,7 @@ def remove_pixel_gaps(tiles: list[Tile], max_gap: int = 1) -> list[Tile]:
y_in = j * offset_y

point = Point(x_in, y_in, z=z, c=c, t=t)
distances = [(point - bbox.top_l).length() for bbox in tiles]
distances = [(point - bbox.top_l).lengthXY() for bbox in tiles]
min_dist = np.min(distances)
closest_bbox = tiles[np.argmin(distances)]

Expand Down
48 changes: 48 additions & 0 deletions src/fractal_converters_tools/task_common_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""Shared models for the fractal_converters_tools tasks."""

from typing import Literal

from pydantic import BaseModel, Field


class AdvancedComputeOptions(BaseModel):
"""Advanced options for the conversion.
Attributes:
num_levels (int): The number of resolution levels in the pyramid.
tiling_mode (Literal["auto", "grid", "free", "none"]): Specify the tiling mode.
"auto" will automatically determine the tiling mode.
"grid" if the input data is a grid, it will be tiled using snap-to-grid.
"free" will remove any overlap between tiles using a snap-to-corner
approach.
"none" will write the positions as is, using the microscope metadata.
swap_xy (bool): Swap x and y axes coordinates in the metadata. This is sometimes
necessary to ensure correct image tiling and registration.
invert_x (bool): Invert x axis coordinates in the metadata. This is
sometimes necessary to ensure correct image tiling and registration.
invert_y (bool): Invert y axis coordinates in the metadata. This is
sometimes necessary to ensure correct image tiling and registration.
max_xy_chunk (int): XY chunk size is set as the minimum of this value and the
microscope tile size.
z_chunk (int): Z chunk size.
c_chunk (int): C chunk size.
t_chunk (int): T chunk size.
"""

num_levels: int = Field(default=5, ge=1)
tiling_mode: Literal["auto", "grid", "free", "none"] = "auto"
swap_xy: bool = False
invert_x: bool = False
invert_y: bool = False
max_xy_chunk: int = Field(default=4096, ge=1)
z_chunk: int = Field(default=10, ge=1)
c_chunk: int = Field(default=1, ge=1)
t_chunk: int = Field(default=1, ge=1)


class ConvertParallelInitArgs(BaseModel):
"""Arguments for the compute task."""

tiled_image_pickled_path: str
overwrite: bool
advanced_compute_options: AdvancedComputeOptions
94 changes: 94 additions & 0 deletions src/fractal_converters_tools/task_compute_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""A generic task to convert a LIF plate to OME-Zarr."""

import logging
import pickle
from functools import partial
from pathlib import Path

from fractal_converters_tools.omezarr_image_writers import write_tiled_image
from fractal_converters_tools.stitching import standard_stitching_pipe
from fractal_converters_tools.task_common_models import ConvertParallelInitArgs
from fractal_converters_tools.tiled_image import PlatePathBuilder

logger = logging.getLogger(__name__)


def _clean_up_pickled_file(pickle_path: Path):
"""Clean up the pickled file and the directory if it is empty.
Args:
pickle_path (Path): Path to the pickled file.
"""
try:
pickle_path.unlink()
if not list(pickle_path.parent.iterdir()):
pickle_path.parent.rmdir()
except Exception as e:
# This path is not tested
# But if multiple processes are trying to clean up the same file
# it might raise an exception.
logger.error(f"An error occurred while cleaning up the pickled file: {e}")


def generic_compute_task(
*,
# Fractal parameters
zarr_url: str,
init_args: ConvertParallelInitArgs,
):
"""Initialize the task to convert a LIF plate to OME-Zarr.
Args:
zarr_url (str): URL to the OME-Zarr file.
init_args (ConvertScanrInitArgs): Arguments for the initialization task.
"""
pickle_path = Path(init_args.tiled_image_pickled_path)
if not pickle_path.exists():
logger.error(f"Pickled file {pickle_path} does not exist.")
raise FileNotFoundError(f"Pickled file {pickle_path} does not exist.")

with open(pickle_path, "rb") as f:
tiled_image = pickle.load(f)

try:
stitching_pipe = partial(
standard_stitching_pipe,
mode=init_args.advanced_compute_options.tiling_mode,
swap_xy=init_args.advanced_compute_options.swap_xy,
invert_x=init_args.advanced_compute_options.invert_x,
invert_y=init_args.advanced_compute_options.invert_y,
)

new_zarr_url, is_3d, is_time_series = write_tiled_image(
zarr_dir=zarr_url,
tiled_image=tiled_image,
stiching_pipe=stitching_pipe,
num_levels=init_args.advanced_compute_options.num_levels,
max_xy_chunk=init_args.advanced_compute_options.max_xy_chunk,
z_chunk=init_args.advanced_compute_options.z_chunk,
c_chunk=init_args.advanced_compute_options.c_chunk,
t_chunk=init_args.advanced_compute_options.t_chunk,
overwrite=init_args.overwrite,
)
except Exception as e:
logger.error(f"An error occurred while processing {tiled_image}.")
_clean_up_pickled_file(pickle_path)
raise e

p_types = {"is_3D": is_3d}

if isinstance(tiled_image.path_builder, PlatePathBuilder):
attributes = {
"well": f"{tiled_image.path_builder.row}{tiled_image.path_builder.column}",
"plate": tiled_image.path_builder.plate_path,
}
else:
attributes = {}

_clean_up_pickled_file(pickle_path)

return {
"image_list_updates": [
{"zarr_url": new_zarr_url, "types": p_types, "attributes": attributes}
]
}
50 changes: 50 additions & 0 deletions src/fractal_converters_tools/task_init_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""Tools to initialize a conversion tasks."""

import pickle
from uuid import uuid4

from fractal_converters_tools.task_common_models import (
AdvancedComputeOptions,
ConvertParallelInitArgs,
)
from fractal_converters_tools.tiled_image import TiledImage


def build_parallelization_list(
zarr_dir: str,
tiled_images: list[TiledImage],
overwrite: bool,
advanced_compute_options: AdvancedComputeOptions,
tmp_dir_name: str | None = None,
) -> list[dict]:
"""Build a list of dictionaries to parallelize the conversion.
Args:
zarr_dir (str): The path to the zarr directory.
tiled_images (list[TiledImage]): A list of tiled images objects to convert.
overwrite (bool): Overwrite the existing zarr directory.
advanced_compute_options (AdvancedComputeOptions): The advanced compute options.
tmp_dir_name (str, optional): The name of the temporary directory to store the
pickled tiled images.
"""
parallelization_list = []

tmp_dir_name = tmp_dir_name if tmp_dir_name else "_tmp_coverter_dir"
pickle_dir = zarr_dir / tmp_dir_name
pickle_dir.mkdir(parents=True, exist_ok=True)

for tile in tiled_images:
tile_pickle_path = pickle_dir / f"{uuid4()}.pkl"
with open(tile_pickle_path, "wb") as f:
pickle.dump(tile, f)
parallelization_list.append(
{
"zarr_url": str(zarr_dir),
"init_args": ConvertParallelInitArgs(
tiled_image_pickled_path=str(tile_pickle_path),
overwrite=overwrite,
advanced_compute_options=advanced_compute_options,
).model_dump(),
}
)
return parallelization_list
18 changes: 9 additions & 9 deletions src/fractal_converters_tools/tile.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,20 +79,20 @@ def __mul__(self, scalar: float) -> "Vector":
int(self.t * scalar),
)

def normalize(self) -> "Vector":
def normalizeXY(self) -> "Vector":
"""Normalize the vector."""
length = self.length()
length = self.lengthXY()
return Vector(
self.x / length,
self.y / length,
self.z / length,
int(self.c / length),
int(self.t / length),
self.z,
self.c,
self.t,
)

def length(self) -> float:
def lengthXY(self) -> float:
"""Compute the length of the vector."""
return (self.x**2 + self.y**2 + self.z**2 + self.c**2 + self.t**2) ** 0.5
return (self.x**2 + self.y**2) ** 0.5

def to_pixel_space(self, pixel_size: PixelSize) -> "Vector":
"""Convert the vector to pixel space."""
Expand Down Expand Up @@ -261,10 +261,10 @@ def __eq__(self, value: "Tile") -> bool:
else:
value = value.to_pixel_space()

if (self.top_l - value.top_l).length() > 1e-9:
if (self.top_l - value.top_l).lengthXY() > 1e-9:
return False

if (self.diag - value.diag).length() > 1e-9:
if (self.diag - value.diag).lengthXY() > 1e-9:
return False

return True
Expand Down
11 changes: 4 additions & 7 deletions src/fractal_converters_tools/tiled_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ def __init__(
path_builder: PathBuilder,
channel_names: list[str] | None = None,
wavelength_ids: list[int] | None = None,
num_levels: int = 5,
):
"""Initialize the acquisition."""
self._name = name
Expand All @@ -105,7 +104,10 @@ def __init__(

self._channel_names = channel_names
self._wavelength_ids = wavelength_ids
self._num_levels = num_levels

def __repr__(self) -> str:
"""Return the string representation of the object."""
return f"TiledImage(name={self._name}, path={self.path})"

@property
def tiles(self) -> list[Tile]:
Expand Down Expand Up @@ -142,8 +144,3 @@ def pixel_size(self) -> PixelSize | None:
if len(self.tiles) == 0:
return None
return self.tiles[0].pixel_size

@property
def num_levels(self) -> int:
"""Return the number of levels."""
return self._num_levels
80 changes: 80 additions & 0 deletions tests/test_compute_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from pathlib import Path

import pytest
from ngio.utils import NgioFileExistsError
from utils import generate_tiled_image

from fractal_converters_tools.task_common_models import (
AdvancedComputeOptions,
ConvertParallelInitArgs,
)
from fractal_converters_tools.task_compute_tools import (
_clean_up_pickled_file,
generic_compute_task,
)
from fractal_converters_tools.task_init_tools import build_parallelization_list


def test_compute(tmp_path):
images_path = tmp_path / "test_write_images"

tiled_images = [
generate_tiled_image(
plate_name="plate_1",
row="A",
column=0,
acquisition_id=0,
tiled_image_name="image_1",
)
]

adv_comp_model = AdvancedComputeOptions()

par_args = build_parallelization_list(
zarr_dir=images_path,
tiled_images=tiled_images,
overwrite=False,
advanced_compute_options=adv_comp_model,
)[0]

zarr_url = par_args["zarr_url"]
init_args = ConvertParallelInitArgs(**par_args["init_args"])
image_list_updates = generic_compute_task(zarr_url=zarr_url, init_args=init_args)

assert "image_list_updates" in image_list_updates
assert len(image_list_updates["image_list_updates"]) == 1

new_zarr_url = image_list_updates["image_list_updates"][0]["zarr_url"]
p_types = image_list_updates["image_list_updates"][0]["types"]
attributes = image_list_updates["image_list_updates"][0]["attributes"]

assert Path(new_zarr_url).exists()
assert p_types == {"is_3D": False}
assert attributes == {"well": "A0", "plate": "plate_1.zarr"}

# Test if overwrite is working
par_args = build_parallelization_list(
zarr_dir=images_path,
tiled_images=tiled_images,
overwrite=False,
advanced_compute_options=adv_comp_model,
)[0]
zarr_url = par_args["zarr_url"]
init_args = ConvertParallelInitArgs(**par_args["init_args"])
with pytest.raises(NgioFileExistsError):
generic_compute_task(zarr_url=zarr_url, init_args=init_args)

# This should not raise an error since the the pickle is removed
# after the first run failed
with pytest.raises(FileNotFoundError):
generic_compute_task(zarr_url=zarr_url, init_args=init_args)


def test_pickle_cleanup(tmp_path):
pickle_path = tmp_path / "pickle_dir" / "test.pkl"
pickle_path.parent.mkdir(parents=True, exist_ok=True)
pickle_path.touch()
assert pickle_path.exists()
_clean_up_pickled_file(pickle_path)
assert not pickle_path.exists()
assert not pickle_path.parent.exists()
Loading

0 comments on commit 4ca1bd5

Please sign in to comment.