Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix offset #8

Merged
merged 11 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/examples/advanced_example_pytorch_inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ pad = dict(Y=16, X=16)
patch_sampler = zds.PatchSampler(patch_size=patch_size, pad=pad, allow_incomplete_patches=True)
```

Create a dataset from the list of filenames. All those files should be stored within their respective group "0".
Create a dataset from the list of filenames. All those files should be stored within their respective group "4", which in this case it correspond to a downsampled version of the full resolution image by a factor of 16.

Also, specify that the axes order in the image is Time-Channel-Depth-Height-Width (TCZYX), so the data can be handled correctly

Expand Down
58 changes: 38 additions & 20 deletions zarrdataset/_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,26 +142,32 @@ def _compute_corners(self, coordinates: np.ndarray, scale: np.ndarray

return corners

def _compute_reference_indices(self, reference_coordinates: np.ndarray
def _compute_reference_indices(self, reference_coordinates: np.ndarray,
reference_axes_sizes: np.ndarray
) -> Tuple[List[np.ndarray],
List[Tuple[int]]]:
reference_per_axis = list(map(
lambda coords: np.append(np.full((1, ), fill_value=-float("inf")),
np.unique(coords)),
reference_coordinates.T
lambda coords, axis_size: np.concatenate((
np.full((1, ), fill_value=-float("inf")),
np.unique(coords),
np.full((1, ), fill_value=np.max(coords) + axis_size))),
reference_coordinates.T,
reference_axes_sizes
))

reference_idx = map(
lambda coord_axis, ref_axis:
np.argmax(ref_axis[None, ...]
* (coord_axis[..., None] >= ref_axis[None, ...]),
axis=-1),
np.max(np.arange(ref_axis.size)
* (coord_axis.reshape(-1, 1) >= ref_axis[None, ...]),
axis=1),
reference_coordinates.T,
reference_per_axis
)
reference_idx = np.stack(tuple(reference_idx), axis=-1)
reference_idx = reference_idx.reshape(reference_coordinates.T.shape)

reference_idx = [
tuple(tls_coord - 1)
tuple(tls_coord)
for tls_coord in reference_idx.reshape(-1, len(reference_per_axis))
]

Expand All @@ -172,13 +178,14 @@ def _compute_overlap(self, corners_coordinates: np.ndarray,
np.ndarray]:
tls_idx = map(
lambda coord_axis, ref_axis:
np.argmax(ref_axis[None, None, ...]
* (coord_axis[..., None] >= ref_axis[None, None, ...]),
axis=-1),
np.max(np.arange(ref_axis.size)
* (coord_axis.reshape(-1, 1) >= ref_axis[None, ...]),
axis=1),
np.moveaxis(corners_coordinates, -1, 0),
reference_per_axis
)
tls_idx = np.stack(tuple(tls_idx), axis=-1)
tls_idx = tls_idx.reshape(corners_coordinates.shape)

tls_coordinates = map(
lambda tls_coord, ref_axis: ref_axis[tls_coord],
Expand All @@ -192,11 +199,12 @@ def _compute_overlap(self, corners_coordinates: np.ndarray,
dist2cut = np.fabs(corners_coordinates - corners_cut[None])
coverage = np.prod(dist2cut, axis=-1)

return coverage, tls_idx - 1
return coverage, tls_idx

def _compute_grid(self, chunk_tlbr: dict, mask: ImageBase,
patch_size: dict,
image_size: dict,
min_area: float,
allow_incomplete_patches: bool = False):
mask_scale = np.array([mask.scale.get(ax, 1)
for ax in self.spatial_axes],
Expand All @@ -223,7 +231,7 @@ def _compute_grid(self, chunk_tlbr: dict, mask: ImageBase,
]

if min(image_blocks) == 0:
return []
return []

image_scale = np.array([patch_size.get(ax, 1)
for ax in self.spatial_axes],
Expand Down Expand Up @@ -254,6 +262,7 @@ def _compute_grid(self, chunk_tlbr: dict, mask: ImageBase,
for ax in self.spatial_axes],
dtype=np.float32
)

chunk_br_coordinates = np.array(
[chunk_tlbr[ax].stop
if chunk_tlbr[ax].stop is not None
Expand All @@ -271,32 +280,39 @@ def _compute_grid(self, chunk_tlbr: dict, mask: ImageBase,
)
mask_coordinates = mask_coordinates[in_chunk]

# Translate the mask coordinates to the origin for comparison with
# image coordinates.
mask_coordinates -= chunk_tl_coordinates

if all(map(operator.ge, image_scale, mask_scale)):
mask_corners = self._compute_corners(mask_coordinates, mask_scale)

(reference_per_axis,
reference_idx) =\
self._compute_reference_indices(image_coordinates)
self._compute_reference_indices(image_coordinates, image_scale)

(coverage,
corners_idx) = self._compute_overlap(mask_corners,
reference_per_axis)

covered_indices = [
reference_idx.index(tuple(idx))
if tuple(idx) in reference_idx else len(reference_idx)
for idx in corners_idx.reshape(-1, len(self.spatial_axes))
]

patches_coverage = np.bincount(covered_indices,
weights=coverage.flatten(),
minlength=np.prod(image_blocks))
minlength=len(reference_idx) + 1)
patches_coverage = patches_coverage[:-1]

else:
image_corners = self._compute_corners(image_coordinates,
image_scale)

(reference_per_axis,
reference_idx) = self._compute_reference_indices(mask_coordinates)
reference_idx) = self._compute_reference_indices(mask_coordinates,
mask_scale)

(coverage,
corners_idx) = self._compute_overlap(image_corners,
Expand All @@ -309,10 +325,6 @@ def _compute_grid(self, chunk_tlbr: dict, mask: ImageBase,

patches_coverage = np.sum(covered_indices * coverage, axis=0)

min_area = self._min_area
if min_area < 1:
min_area *= np.prod(list(patch_size.values()))

minimum_covered_tls = image_coordinates[patches_coverage > min_area]
minimum_covered_tls = minimum_covered_tls.astype(np.int64)

Expand Down Expand Up @@ -396,6 +408,7 @@ def compute_chunks(self,
mask,
self._max_chunk_size,
image_size,
min_area=1,
allow_incomplete_patches=True
)

Expand Down Expand Up @@ -428,11 +441,16 @@ def compute_patches(self, image_collection: ImageCollection,
for ax in self.spatial_axes
}

min_area = self._min_area
if min_area < 1:
min_area *= np.prod(list(patch_size.values()))

valid_mask_toplefts = self._compute_valid_toplefts(
chunk_tlbr,
mask,
stride,
image_size=image_size,
min_area=min_area,
allow_incomplete_patches=self._allow_incomplete_patches
)

Expand Down
Loading