From 3fe47ef0534ae217fbf6592e093de0cd1271dcc2 Mon Sep 17 00:00:00 2001 From: vschaffn Date: Tue, 18 Feb 2025 17:45:49 +0100 Subject: [PATCH] feat: separate functions from dask --- .../distributed_computing/delayed_dask.py | 87 ++------------- .../distributed_computing/delayed_utils.py | 103 +++++++++++++++++- 2 files changed, 112 insertions(+), 78 deletions(-) diff --git a/geoutils/raster/distributed_computing/delayed_dask.py b/geoutils/raster/distributed_computing/delayed_dask.py index c8c2e4e7..d5b4791c 100644 --- a/geoutils/raster/distributed_computing/delayed_dask.py +++ b/geoutils/raster/distributed_computing/delayed_dask.py @@ -26,7 +26,6 @@ import numpy as np import rasterio as rio from dask.utils import cached_cumsum -from scipy.interpolate import interpn from geoutils._typing import NDArrayBool, NDArrayNum from geoutils.raster.distributed_computing.delayed_utils import ( @@ -37,6 +36,11 @@ _get_indices_block_per_subsample, _get_interp_indices_per_block, _get_subsample_size_from_user_input, + _interp_points_block, + _nb_valids, + _reproject_per_block, + _subsample_block, + _subsample_indices_block, ) # 1/ SUBSAMPLING @@ -52,9 +56,7 @@ @dask.delayed # type: ignore def _delayed_nb_valids(arr_chunk: NDArrayNum | NDArrayBool) -> NDArrayNum: """Count number of valid values per block.""" - if arr_chunk.dtype == "bool": - return np.array([np.count_nonzero(arr_chunk)]).reshape((1, 1)) - return np.array([np.count_nonzero(np.isfinite(arr_chunk))]).reshape((1, 1)) + return _nb_valids(arr_chunk) @dask.delayed # type: ignore @@ -62,10 +64,7 @@ def _delayed_subsample_block( arr_chunk: NDArrayNum | NDArrayBool, subsample_indices: NDArrayNum ) -> NDArrayNum | NDArrayBool: """Subsample the valid values at the corresponding 1D valid indices per block.""" - - if arr_chunk.dtype == "bool": - return arr_chunk[arr_chunk][subsample_indices] - return arr_chunk[np.isfinite(arr_chunk)][subsample_indices] + return _subsample_block(arr_chunk, subsample_indices) @dask.delayed # type: ignore @@ -73,20 +72,7 @@ def _delayed_subsample_indices_block( arr_chunk: NDArrayNum | NDArrayBool, subsample_indices: NDArrayNum, block_id: dict[str, Any] ) -> NDArrayNum: """Return 2D indices from the subsampled 1D valid indices per block.""" - - if arr_chunk.dtype == "bool": - ix, iy = np.unravel_index(np.argwhere(arr_chunk.flatten())[subsample_indices], shape=arr_chunk.shape) - else: - # Unravel indices of valid data to the shape of the block - ix, iy = np.unravel_index( - np.argwhere(np.isfinite(arr_chunk.flatten()))[subsample_indices], shape=arr_chunk.shape - ) - - # Convert to full-array indexes by adding the row and column starting indexes for this block - ix += block_id["xstart"] - iy += block_id["ystart"] - - return np.hstack((ix, iy)) + return _subsample_indices_block(arr_chunk, subsample_indices, block_id) def delayed_subsample( @@ -218,21 +204,7 @@ def _delayed_interp_points_block( """ Interpolate block in 2D out-of-memory for a regular or equal grid. """ - - # Extract information out of block_id dictionary - xs, ys, xres, yres = (block_id["xstart"], block_id["ystart"], block_id["xres"], block_id["yres"]) - - # Reconstruct the coordinates from xi/yi/xres/yres (as it has to be a regular grid) - x_coords = np.arange(xs, xs + xres * arr_chunk.shape[0], xres) - y_coords = np.arange(ys, ys + yres * arr_chunk.shape[1], yres) - - # TODO: Use scipy.map_coordinates for an equal grid as in Raster.interp_points? - - # Interpolate to points - interp_chunk = interpn(points=(x_coords, y_coords), values=arr_chunk, xi=(interp_coords[0, :], interp_coords[1, :])) - - # And return the interpolated array - return interp_chunk + return _interp_points_block(arr_chunk, block_id, interp_coords) def delayed_interp_points( @@ -333,46 +305,7 @@ def _delayed_reproject_per_block( """ Delayed reprojection per destination block (also rebuilds a square array combined from intersecting source blocks). """ - - # If no source chunk intersects, we return a chunk of destination nodata values - if len(src_arrs) == 0: - # We can use float32 to return NaN, will be cast to other floating type later if that's not source array dtype - dst_arr = np.zeros(combined_meta["dst_shape"], dtype=np.dtype("float32")) - dst_arr[:] = kwargs["dst_nodata"] - return dst_arr - - # First, we build an empty array with the combined shape, only with nodata values - comb_src_arr = np.ones((combined_meta["src_shape"]), dtype=src_arrs[0].dtype) - comb_src_arr[:] = kwargs["src_nodata"] - - # Then fill it with the source chunks values - for i, arr in enumerate(src_arrs): - bid = block_ids[i] - comb_src_arr[bid["rys"] : bid["rye"], bid["rxs"] : bid["rxe"]] = arr - - # Now, we can simply call Rasterio! - - # We build the combined transform from tuple - src_transform = rio.transform.Affine(*combined_meta["src_transform"]) - dst_transform = rio.transform.Affine(*combined_meta["dst_transform"]) - - # Reproject - dst_arr = np.zeros(combined_meta["dst_shape"], dtype=comb_src_arr.dtype) - - _ = rio.warp.reproject( - comb_src_arr, - dst_arr, - src_transform=src_transform, - src_crs=kwargs["src_crs"], - dst_transform=dst_transform, - dst_crs=kwargs["dst_crs"], - resampling=kwargs["resampling"], - src_nodata=kwargs["src_nodata"], - dst_nodata=kwargs["dst_nodata"], - num_threads=1, # Force the number of threads to 1 to avoid Dask/Rasterio conflicting on multi-threading - ) - - return dst_arr + return _reproject_per_block(*src_arrs, block_ids=block_ids, combined_meta=combined_meta, **kwargs) def delayed_reproject( diff --git a/geoutils/raster/distributed_computing/delayed_utils.py b/geoutils/raster/distributed_computing/delayed_utils.py index 48f93936..77fdde66 100644 --- a/geoutils/raster/distributed_computing/delayed_utils.py +++ b/geoutils/raster/distributed_computing/delayed_utils.py @@ -29,8 +29,9 @@ import numpy as np import pandas as pd import rasterio as rio +from scipy.interpolate import interpn -from geoutils._typing import NDArrayNum +from geoutils._typing import NDArrayBool, NDArrayNum from geoutils.projtools import _get_bounds_projected, _get_footprint_projected # 1/ SUBSAMPLING @@ -103,6 +104,38 @@ def _get_indices_block_per_subsample( return relative_index_per_block +def _nb_valids(arr_chunk: NDArrayNum | NDArrayBool) -> NDArrayNum: + """Count number of valid values per block.""" + if arr_chunk.dtype == "bool": + return np.array([np.count_nonzero(arr_chunk)]).reshape((1, 1)) + return np.array([np.count_nonzero(np.isfinite(arr_chunk))]).reshape((1, 1)) + + +def _subsample_block(arr_chunk: NDArrayNum | NDArrayBool, subsample_indices: NDArrayNum) -> NDArrayNum | NDArrayBool: + """Subsample the valid values at the corresponding 1D valid indices per block.""" + if arr_chunk.dtype == "bool": + return arr_chunk[arr_chunk][subsample_indices] + return arr_chunk[np.isfinite(arr_chunk)][subsample_indices] + + +def _subsample_indices_block( + arr_chunk: NDArrayNum | NDArrayBool, subsample_indices: NDArrayNum, block_id: dict[str, Any] +) -> NDArrayNum: + """Return 2D indices from the subsampled 1D valid indices per block.""" + if arr_chunk.dtype == "bool": + ix, iy = np.unravel_index(np.argwhere(arr_chunk.flatten())[subsample_indices], shape=arr_chunk.shape) + else: + # Unravel indices of valid data to the shape of the block + ix, iy = np.unravel_index( + np.argwhere(np.isfinite(arr_chunk.flatten()))[subsample_indices], shape=arr_chunk.shape + ) + + # Convert to full-array indexes by adding the row and column starting indexes for this block + ix += block_id["xstart"] + iy += block_id["ystart"] + return np.hstack((ix, iy)) + + # 2/ POINT INTERPOLATION ON REGULAR OR EQUAL GRID @@ -137,6 +170,26 @@ def _get_interp_indices_per_block( return ind_per_block +def _interp_points_block(arr_chunk: NDArrayNum, block_id: dict[str, Any], interp_coords: NDArrayNum) -> NDArrayNum: + """ + Interpolate block in 2D using multiprocessing for a regular or equal grid. + """ + # Extract information out of block_id dictionary + xs, ys, xres, yres = (block_id["xstart"], block_id["ystart"], block_id["xres"], block_id["yres"]) + + # Reconstruct the coordinates from xi/yi/xres/yres (as it has to be a regular grid) + x_coords = np.arange(xs, xs + xres * arr_chunk.shape[0], xres) + y_coords = np.arange(ys, ys + yres * arr_chunk.shape[1], yres) + + # TODO: Use scipy.map_coordinates for an equal grid as in Raster.interp_points? + + # Interpolate to points + interp_chunk = interpn(points=(x_coords, y_coords), values=arr_chunk, xi=(interp_coords[0, :], interp_coords[1, :])) + + # And return the interpolated array + return interp_chunk + + # 3/ REPROJECT # The following GeoGrid and GeoTiling classes assist in managing georeferenced grids and performing reprojection GeoGridType = TypeVar("GeoGridType", bound="GeoGrid") @@ -356,3 +409,51 @@ def _combined_blocks_shape_transform( combined_meta = {"src_shape": combined_shape, "src_transform": tuple(combined_transform)} return combined_meta, relative_block_indexes + + +def _reproject_per_block( + *src_arrs: tuple[NDArrayNum], block_ids: list[dict[str, int]], combined_meta: dict[str, Any], **kwargs: Any +) -> NDArrayNum: + """ + Reproject per destination block, rebuilt from intersecting source blocks, using multiprocessing. + """ + + # If no source chunk intersects, we return a chunk of destination nodata values + if len(src_arrs) == 0: + # We can use float32 to return NaN, will be cast to other floating type later if that's not source array dtype + dst_arr = np.zeros(combined_meta["dst_shape"], dtype=np.dtype("float32")) + dst_arr[:] = kwargs["dst_nodata"] + return dst_arr + + # First, we build an empty array with the combined shape, only with nodata values + comb_src_arr = np.ones((combined_meta["src_shape"]), dtype=src_arrs[0].dtype) + comb_src_arr[:] = kwargs["src_nodata"] + + # Then fill it with the source chunks values + for i, arr in enumerate(src_arrs): + bid = block_ids[i] + comb_src_arr[bid["rys"] : bid["rye"], bid["rxs"] : bid["rxe"]] = arr + + # Now, we can simply call Rasterio! + + # We build the combined transform from tuple + src_transform = rio.transform.Affine(*combined_meta["src_transform"]) + dst_transform = rio.transform.Affine(*combined_meta["dst_transform"]) + + # Reproject + dst_arr = np.zeros(combined_meta["dst_shape"], dtype=comb_src_arr.dtype) + + _ = rio.warp.reproject( + comb_src_arr, + dst_arr, + src_transform=src_transform, + src_crs=kwargs["src_crs"], + dst_transform=dst_transform, + dst_crs=kwargs["dst_crs"], + resampling=kwargs["resampling"], + src_nodata=kwargs["src_nodata"], + dst_nodata=kwargs["dst_nodata"], + num_threads=1, # Force the number of threads to 1 to avoid Dask/Rasterio conflicting on multi-threading + ) + + return dst_arr