Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allowing boolean array in delayed_subsample #9

Merged
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions geoutils/raster/delayed.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Module for dask-delayed functions for out-of-memory raster operations.
"""

from __future__ import annotations

import warnings
Expand All @@ -15,7 +16,7 @@
from dask.utils import cached_cumsum
from scipy.interpolate import interpn

from geoutils._typing import NDArrayNum
from geoutils._typing import NDArrayBool, NDArrayNum
from geoutils.projtools import _get_bounds_projected, _get_footprint_projected

# 1/ SUBSAMPLING
Expand Down Expand Up @@ -96,18 +97,22 @@ def _get_indices_block_per_subsample(


@dask.delayed # type: ignore
def _delayed_nb_valids(arr_chunk: NDArrayNum) -> NDArrayNum:
def _delayed_nb_valids(arr_chunk: NDArrayNum | NDArrayBool) -> NDArrayNum:
"""Count number of valid values per block."""
if arr_chunk.dtype == "bool":
return np.array([np.count_nonzero(arr_chunk)]).reshape((1, 1))
return np.array([np.count_nonzero(np.isfinite(arr_chunk))]).reshape((1, 1))


@dask.delayed # type: ignore
def _delayed_subsample_block(arr_chunk: NDArrayNum, subsample_indices: NDArrayNum) -> NDArrayNum:
def _delayed_subsample_block(
arr_chunk: NDArrayNum | NDArrayBool, subsample_indices: NDArrayNum
) -> NDArrayNum | NDArrayBool:
"""Subsample the valid values at the corresponding 1D valid indices per block."""

s_chunk = arr_chunk[np.isfinite(arr_chunk)][subsample_indices]

return s_chunk
if arr_chunk.dtype == "bool":
return arr_chunk[arr_chunk][subsample_indices] # type: ignore
return arr_chunk[np.isfinite(arr_chunk)][subsample_indices]


@dask.delayed # type: ignore
Expand All @@ -116,8 +121,13 @@ def _delayed_subsample_indices_block(
) -> NDArrayNum:
"""Return 2D indices from the subsampled 1D valid indices per block."""

# Unravel indices of valid data to the shape of the block
ix, iy = np.unravel_index(np.argwhere(np.isfinite(arr_chunk.flatten()))[subsample_indices], shape=arr_chunk.shape)
if arr_chunk.dtype == "bool":
ix, iy = np.unravel_index(np.argwhere(arr_chunk.flatten())[subsample_indices], shape=arr_chunk.shape)
else:
# Unravel indices of valid data to the shape of the block
ix, iy = np.unravel_index(
np.argwhere(np.isfinite(arr_chunk.flatten()))[subsample_indices], shape=arr_chunk.shape
)

# Convert to full-array indexes by adding the row and column starting indexes for this block
ix += block_id["xstart"]
Expand Down Expand Up @@ -722,9 +732,11 @@ def delayed_reproject(
# transform of each tuples of source blocks
src_block_ids = np.array(src_geotiling.get_block_locations())
meta_params = [
_combined_blocks_shape_transform(sub_block_ids=src_block_ids[sbid], src_geogrid=src_geogrid)
if len(sbid) > 0
else ({}, [])
(
_combined_blocks_shape_transform(sub_block_ids=src_block_ids[sbid], src_geogrid=src_geogrid)
if len(sbid) > 0
else ({}, [])
)
for sbid in dest2source
]
# We also add the output transform/shape for this destination chunk in the combined meta
Expand Down