Skip to content

Commit

Permalink
feat: separate functions from dask
Browse files Browse the repository at this point in the history
  • Loading branch information
vschaffn committed Feb 25, 2025
1 parent da3e0f3 commit 3fe47ef
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 78 deletions.
87 changes: 10 additions & 77 deletions geoutils/raster/distributed_computing/delayed_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import numpy as np
import rasterio as rio
from dask.utils import cached_cumsum
from scipy.interpolate import interpn

from geoutils._typing import NDArrayBool, NDArrayNum
from geoutils.raster.distributed_computing.delayed_utils import (
Expand All @@ -37,6 +36,11 @@
_get_indices_block_per_subsample,
_get_interp_indices_per_block,
_get_subsample_size_from_user_input,
_interp_points_block,
_nb_valids,
_reproject_per_block,
_subsample_block,
_subsample_indices_block,
)

# 1/ SUBSAMPLING
Expand All @@ -52,41 +56,23 @@
@dask.delayed # type: ignore
def _delayed_nb_valids(arr_chunk: NDArrayNum | NDArrayBool) -> NDArrayNum:
"""Count number of valid values per block."""
if arr_chunk.dtype == "bool":
return np.array([np.count_nonzero(arr_chunk)]).reshape((1, 1))
return np.array([np.count_nonzero(np.isfinite(arr_chunk))]).reshape((1, 1))
return _nb_valids(arr_chunk)


@dask.delayed # type: ignore
def _delayed_subsample_block(
arr_chunk: NDArrayNum | NDArrayBool, subsample_indices: NDArrayNum
) -> NDArrayNum | NDArrayBool:
"""Subsample the valid values at the corresponding 1D valid indices per block."""

if arr_chunk.dtype == "bool":
return arr_chunk[arr_chunk][subsample_indices]
return arr_chunk[np.isfinite(arr_chunk)][subsample_indices]
return _subsample_block(arr_chunk, subsample_indices)


@dask.delayed # type: ignore
def _delayed_subsample_indices_block(
arr_chunk: NDArrayNum | NDArrayBool, subsample_indices: NDArrayNum, block_id: dict[str, Any]
) -> NDArrayNum:
"""Return 2D indices from the subsampled 1D valid indices per block."""

if arr_chunk.dtype == "bool":
ix, iy = np.unravel_index(np.argwhere(arr_chunk.flatten())[subsample_indices], shape=arr_chunk.shape)
else:
# Unravel indices of valid data to the shape of the block
ix, iy = np.unravel_index(
np.argwhere(np.isfinite(arr_chunk.flatten()))[subsample_indices], shape=arr_chunk.shape
)

# Convert to full-array indexes by adding the row and column starting indexes for this block
ix += block_id["xstart"]
iy += block_id["ystart"]

return np.hstack((ix, iy))
return _subsample_indices_block(arr_chunk, subsample_indices, block_id)


def delayed_subsample(
Expand Down Expand Up @@ -218,21 +204,7 @@ def _delayed_interp_points_block(
"""
Interpolate block in 2D out-of-memory for a regular or equal grid.
"""

# Extract information out of block_id dictionary
xs, ys, xres, yres = (block_id["xstart"], block_id["ystart"], block_id["xres"], block_id["yres"])

# Reconstruct the coordinates from xi/yi/xres/yres (as it has to be a regular grid)
x_coords = np.arange(xs, xs + xres * arr_chunk.shape[0], xres)
y_coords = np.arange(ys, ys + yres * arr_chunk.shape[1], yres)

# TODO: Use scipy.map_coordinates for an equal grid as in Raster.interp_points?

# Interpolate to points
interp_chunk = interpn(points=(x_coords, y_coords), values=arr_chunk, xi=(interp_coords[0, :], interp_coords[1, :]))

# And return the interpolated array
return interp_chunk
return _interp_points_block(arr_chunk, block_id, interp_coords)


def delayed_interp_points(
Expand Down Expand Up @@ -333,46 +305,7 @@ def _delayed_reproject_per_block(
"""
Delayed reprojection per destination block (also rebuilds a square array combined from intersecting source blocks).
"""

# If no source chunk intersects, we return a chunk of destination nodata values
if len(src_arrs) == 0:
# We can use float32 to return NaN, will be cast to other floating type later if that's not source array dtype
dst_arr = np.zeros(combined_meta["dst_shape"], dtype=np.dtype("float32"))
dst_arr[:] = kwargs["dst_nodata"]
return dst_arr

# First, we build an empty array with the combined shape, only with nodata values
comb_src_arr = np.ones((combined_meta["src_shape"]), dtype=src_arrs[0].dtype)
comb_src_arr[:] = kwargs["src_nodata"]

# Then fill it with the source chunks values
for i, arr in enumerate(src_arrs):
bid = block_ids[i]
comb_src_arr[bid["rys"] : bid["rye"], bid["rxs"] : bid["rxe"]] = arr

# Now, we can simply call Rasterio!

# We build the combined transform from tuple
src_transform = rio.transform.Affine(*combined_meta["src_transform"])
dst_transform = rio.transform.Affine(*combined_meta["dst_transform"])

# Reproject
dst_arr = np.zeros(combined_meta["dst_shape"], dtype=comb_src_arr.dtype)

_ = rio.warp.reproject(
comb_src_arr,
dst_arr,
src_transform=src_transform,
src_crs=kwargs["src_crs"],
dst_transform=dst_transform,
dst_crs=kwargs["dst_crs"],
resampling=kwargs["resampling"],
src_nodata=kwargs["src_nodata"],
dst_nodata=kwargs["dst_nodata"],
num_threads=1, # Force the number of threads to 1 to avoid Dask/Rasterio conflicting on multi-threading
)

return dst_arr
return _reproject_per_block(*src_arrs, block_ids=block_ids, combined_meta=combined_meta, **kwargs)


def delayed_reproject(
Expand Down
103 changes: 102 additions & 1 deletion geoutils/raster/distributed_computing/delayed_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@
import numpy as np
import pandas as pd
import rasterio as rio
from scipy.interpolate import interpn

from geoutils._typing import NDArrayNum
from geoutils._typing import NDArrayBool, NDArrayNum
from geoutils.projtools import _get_bounds_projected, _get_footprint_projected

# 1/ SUBSAMPLING
Expand Down Expand Up @@ -103,6 +104,38 @@ def _get_indices_block_per_subsample(
return relative_index_per_block


def _nb_valids(arr_chunk: NDArrayNum | NDArrayBool) -> NDArrayNum:
"""Count number of valid values per block."""
if arr_chunk.dtype == "bool":
return np.array([np.count_nonzero(arr_chunk)]).reshape((1, 1))
return np.array([np.count_nonzero(np.isfinite(arr_chunk))]).reshape((1, 1))


def _subsample_block(arr_chunk: NDArrayNum | NDArrayBool, subsample_indices: NDArrayNum) -> NDArrayNum | NDArrayBool:
"""Subsample the valid values at the corresponding 1D valid indices per block."""
if arr_chunk.dtype == "bool":
return arr_chunk[arr_chunk][subsample_indices]
return arr_chunk[np.isfinite(arr_chunk)][subsample_indices]


def _subsample_indices_block(
arr_chunk: NDArrayNum | NDArrayBool, subsample_indices: NDArrayNum, block_id: dict[str, Any]
) -> NDArrayNum:
"""Return 2D indices from the subsampled 1D valid indices per block."""
if arr_chunk.dtype == "bool":
ix, iy = np.unravel_index(np.argwhere(arr_chunk.flatten())[subsample_indices], shape=arr_chunk.shape)
else:
# Unravel indices of valid data to the shape of the block
ix, iy = np.unravel_index(
np.argwhere(np.isfinite(arr_chunk.flatten()))[subsample_indices], shape=arr_chunk.shape
)

# Convert to full-array indexes by adding the row and column starting indexes for this block
ix += block_id["xstart"]
iy += block_id["ystart"]
return np.hstack((ix, iy))


# 2/ POINT INTERPOLATION ON REGULAR OR EQUAL GRID


Expand Down Expand Up @@ -137,6 +170,26 @@ def _get_interp_indices_per_block(
return ind_per_block


def _interp_points_block(arr_chunk: NDArrayNum, block_id: dict[str, Any], interp_coords: NDArrayNum) -> NDArrayNum:
"""
Interpolate block in 2D using multiprocessing for a regular or equal grid.
"""
# Extract information out of block_id dictionary
xs, ys, xres, yres = (block_id["xstart"], block_id["ystart"], block_id["xres"], block_id["yres"])

# Reconstruct the coordinates from xi/yi/xres/yres (as it has to be a regular grid)
x_coords = np.arange(xs, xs + xres * arr_chunk.shape[0], xres)
y_coords = np.arange(ys, ys + yres * arr_chunk.shape[1], yres)

# TODO: Use scipy.map_coordinates for an equal grid as in Raster.interp_points?

# Interpolate to points
interp_chunk = interpn(points=(x_coords, y_coords), values=arr_chunk, xi=(interp_coords[0, :], interp_coords[1, :]))

# And return the interpolated array
return interp_chunk


# 3/ REPROJECT
# The following GeoGrid and GeoTiling classes assist in managing georeferenced grids and performing reprojection
GeoGridType = TypeVar("GeoGridType", bound="GeoGrid")
Expand Down Expand Up @@ -356,3 +409,51 @@ def _combined_blocks_shape_transform(
combined_meta = {"src_shape": combined_shape, "src_transform": tuple(combined_transform)}

return combined_meta, relative_block_indexes


def _reproject_per_block(
*src_arrs: tuple[NDArrayNum], block_ids: list[dict[str, int]], combined_meta: dict[str, Any], **kwargs: Any
) -> NDArrayNum:
"""
Reproject per destination block, rebuilt from intersecting source blocks, using multiprocessing.
"""

# If no source chunk intersects, we return a chunk of destination nodata values
if len(src_arrs) == 0:
# We can use float32 to return NaN, will be cast to other floating type later if that's not source array dtype
dst_arr = np.zeros(combined_meta["dst_shape"], dtype=np.dtype("float32"))
dst_arr[:] = kwargs["dst_nodata"]
return dst_arr

# First, we build an empty array with the combined shape, only with nodata values
comb_src_arr = np.ones((combined_meta["src_shape"]), dtype=src_arrs[0].dtype)
comb_src_arr[:] = kwargs["src_nodata"]

# Then fill it with the source chunks values
for i, arr in enumerate(src_arrs):
bid = block_ids[i]
comb_src_arr[bid["rys"] : bid["rye"], bid["rxs"] : bid["rxe"]] = arr

# Now, we can simply call Rasterio!

# We build the combined transform from tuple
src_transform = rio.transform.Affine(*combined_meta["src_transform"])
dst_transform = rio.transform.Affine(*combined_meta["dst_transform"])

# Reproject
dst_arr = np.zeros(combined_meta["dst_shape"], dtype=comb_src_arr.dtype)

_ = rio.warp.reproject(
comb_src_arr,
dst_arr,
src_transform=src_transform,
src_crs=kwargs["src_crs"],
dst_transform=dst_transform,
dst_crs=kwargs["dst_crs"],
resampling=kwargs["resampling"],
src_nodata=kwargs["src_nodata"],
dst_nodata=kwargs["dst_nodata"],
num_threads=1, # Force the number of threads to 1 to avoid Dask/Rasterio conflicting on multi-threading
)

return dst_arr

0 comments on commit 3fe47ef

Please sign in to comment.