diff --git a/docs/source/examples/advanced_example_pytorch_inference.md b/docs/source/examples/advanced_example_pytorch_inference.md index d3a6cd9..29217cc 100644 --- a/docs/source/examples/advanced_example_pytorch_inference.md +++ b/docs/source/examples/advanced_example_pytorch_inference.md @@ -58,7 +58,7 @@ pad = dict(Y=16, X=16) patch_sampler = zds.PatchSampler(patch_size=patch_size, pad=pad, allow_incomplete_patches=True) ``` -Create a dataset from the list of filenames. All those files should be stored within their respective group "0". +Create a dataset from the list of filenames. All those files should be stored within their respective group "4", which in this case it correspond to a downsampled version of the full resolution image by a factor of 16. Also, specify that the axes order in the image is Time-Channel-Depth-Height-Width (TCZYX), so the data can be handled correctly diff --git a/zarrdataset/_samplers.py b/zarrdataset/_samplers.py index 52df198..67bba1d 100644 --- a/zarrdataset/_samplers.py +++ b/zarrdataset/_samplers.py @@ -142,26 +142,32 @@ def _compute_corners(self, coordinates: np.ndarray, scale: np.ndarray return corners - def _compute_reference_indices(self, reference_coordinates: np.ndarray + def _compute_reference_indices(self, reference_coordinates: np.ndarray, + reference_axes_sizes: np.ndarray ) -> Tuple[List[np.ndarray], List[Tuple[int]]]: reference_per_axis = list(map( - lambda coords: np.append(np.full((1, ), fill_value=-float("inf")), - np.unique(coords)), - reference_coordinates.T + lambda coords, axis_size: np.concatenate(( + np.full((1, ), fill_value=-float("inf")), + np.unique(coords), + np.full((1, ), fill_value=np.max(coords) + axis_size))), + reference_coordinates.T, + reference_axes_sizes )) reference_idx = map( lambda coord_axis, ref_axis: - np.argmax(ref_axis[None, ...] - * (coord_axis[..., None] >= ref_axis[None, ...]), - axis=-1), + np.max(np.arange(ref_axis.size) + * (coord_axis.reshape(-1, 1) >= ref_axis[None, ...]), + axis=1), reference_coordinates.T, reference_per_axis ) reference_idx = np.stack(tuple(reference_idx), axis=-1) + reference_idx = reference_idx.reshape(reference_coordinates.T.shape) + reference_idx = [ - tuple(tls_coord - 1) + tuple(tls_coord) for tls_coord in reference_idx.reshape(-1, len(reference_per_axis)) ] @@ -172,13 +178,14 @@ def _compute_overlap(self, corners_coordinates: np.ndarray, np.ndarray]: tls_idx = map( lambda coord_axis, ref_axis: - np.argmax(ref_axis[None, None, ...] - * (coord_axis[..., None] >= ref_axis[None, None, ...]), - axis=-1), + np.max(np.arange(ref_axis.size) + * (coord_axis.reshape(-1, 1) >= ref_axis[None, ...]), + axis=1), np.moveaxis(corners_coordinates, -1, 0), reference_per_axis ) tls_idx = np.stack(tuple(tls_idx), axis=-1) + tls_idx = tls_idx.reshape(corners_coordinates.shape) tls_coordinates = map( lambda tls_coord, ref_axis: ref_axis[tls_coord], @@ -192,11 +199,12 @@ def _compute_overlap(self, corners_coordinates: np.ndarray, dist2cut = np.fabs(corners_coordinates - corners_cut[None]) coverage = np.prod(dist2cut, axis=-1) - return coverage, tls_idx - 1 + return coverage, tls_idx def _compute_grid(self, chunk_tlbr: dict, mask: ImageBase, patch_size: dict, image_size: dict, + min_area: float, allow_incomplete_patches: bool = False): mask_scale = np.array([mask.scale.get(ax, 1) for ax in self.spatial_axes], @@ -223,7 +231,7 @@ def _compute_grid(self, chunk_tlbr: dict, mask: ImageBase, ] if min(image_blocks) == 0: - return [] + return [] image_scale = np.array([patch_size.get(ax, 1) for ax in self.spatial_axes], @@ -254,6 +262,7 @@ def _compute_grid(self, chunk_tlbr: dict, mask: ImageBase, for ax in self.spatial_axes], dtype=np.float32 ) + chunk_br_coordinates = np.array( [chunk_tlbr[ax].stop if chunk_tlbr[ax].stop is not None @@ -271,12 +280,16 @@ def _compute_grid(self, chunk_tlbr: dict, mask: ImageBase, ) mask_coordinates = mask_coordinates[in_chunk] + # Translate the mask coordinates to the origin for comparison with + # image coordinates. + mask_coordinates -= chunk_tl_coordinates + if all(map(operator.ge, image_scale, mask_scale)): mask_corners = self._compute_corners(mask_coordinates, mask_scale) (reference_per_axis, reference_idx) =\ - self._compute_reference_indices(image_coordinates) + self._compute_reference_indices(image_coordinates, image_scale) (coverage, corners_idx) = self._compute_overlap(mask_corners, @@ -284,19 +297,22 @@ def _compute_grid(self, chunk_tlbr: dict, mask: ImageBase, covered_indices = [ reference_idx.index(tuple(idx)) + if tuple(idx) in reference_idx else len(reference_idx) for idx in corners_idx.reshape(-1, len(self.spatial_axes)) ] patches_coverage = np.bincount(covered_indices, weights=coverage.flatten(), - minlength=np.prod(image_blocks)) + minlength=len(reference_idx) + 1) + patches_coverage = patches_coverage[:-1] else: image_corners = self._compute_corners(image_coordinates, image_scale) (reference_per_axis, - reference_idx) = self._compute_reference_indices(mask_coordinates) + reference_idx) = self._compute_reference_indices(mask_coordinates, + mask_scale) (coverage, corners_idx) = self._compute_overlap(image_corners, @@ -309,10 +325,6 @@ def _compute_grid(self, chunk_tlbr: dict, mask: ImageBase, patches_coverage = np.sum(covered_indices * coverage, axis=0) - min_area = self._min_area - if min_area < 1: - min_area *= np.prod(list(patch_size.values())) - minimum_covered_tls = image_coordinates[patches_coverage > min_area] minimum_covered_tls = minimum_covered_tls.astype(np.int64) @@ -396,6 +408,7 @@ def compute_chunks(self, mask, self._max_chunk_size, image_size, + min_area=1, allow_incomplete_patches=True ) @@ -428,11 +441,16 @@ def compute_patches(self, image_collection: ImageCollection, for ax in self.spatial_axes } + min_area = self._min_area + if min_area < 1: + min_area *= np.prod(list(patch_size.values())) + valid_mask_toplefts = self._compute_valid_toplefts( chunk_tlbr, mask, stride, image_size=image_size, + min_area=min_area, allow_incomplete_patches=self._allow_incomplete_patches )