Skip to content

Commit

Permalink
Add cupy to preprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
relativityhd committed Nov 20, 2024
1 parent 642d8dd commit c61613b
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 34 deletions.
16 changes: 11 additions & 5 deletions darts-acquisition/src/darts_acquisition/arcticdem/datacube.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ def load_arcticdem_tile(
resolution: RESOLUTIONS,
chunk_size: int = 6000,
buffer: int = 0,
persist: bool = True,
) -> xr.Dataset:
"""Get the corresponding ArcticDEM tile for the given geobox.
Expand All @@ -308,6 +309,8 @@ def load_arcticdem_tile(
chunk_size (int, optional): The chunk size for the datacube. Only relevant for the initial creation.
Has no effect otherwise. Defaults to 6000.
buffer (int, optional): The buffer around the geobox in pixels. Defaults to 0.
persist (bool, optional): If the data should be persisted in memory.
If not, this will return a Dask backed Dataset. Defaults to True.
Returns:
xr.Dataset: The ArcticDEM tile, with a buffer applied.
Expand Down Expand Up @@ -364,10 +367,13 @@ def load_arcticdem_tile(
arcticdem_aoi["datamask"] = arcticdem_aoi.datamask.astype("uint8")

# The following code would load the data from disk
tick_sload = time.perf_counter()
arcticdem_aoi = arcticdem_aoi.compute()
tick_eload = time.perf_counter()
logger.debug(f"ArcticDEM AOI loaded from disk in {tick_eload - tick_sload:.2f} seconds")
if persist:
tick_sload = time.perf_counter()
arcticdem_aoi = arcticdem_aoi.compute()
tick_eload = time.perf_counter()
logger.debug(f"ArcticDEM AOI loaded from disk in {tick_eload - tick_sload:.2f} seconds")

logger.info(f"ArcticDEM tile loaded in {time.perf_counter() - tick_fstart:.2f} seconds")
logger.info(
f"ArcticDEM tile {'loaded' if persist else 'lazy-opened'} in {time.perf_counter() - tick_fstart:.2f} seconds"
)
return arcticdem_aoi
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def calculate_topographic_position_index(
f"{inner_radius}-{outer_radius} ({inner_radius_m}-{outer_radius_m}) cells."
)

tpi = arcticdem_ds.dem - convolution.convolution_2d(arcticdem_ds.dem.values, kernel) / kernel.sum()
tpi = arcticdem_ds.dem - convolution.convolution_2d(arcticdem_ds.dem, kernel) / kernel.sum()
tpi.attrs = {
"long_name": "Topographic Position Index",
"units": "m",
Expand Down
26 changes: 24 additions & 2 deletions darts-preprocessing/src/darts_preprocessing/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,24 @@
import logging
import time

import rasterio
import odc.geo.xr # noqa: F401
import xarray as xr
from xrspatial.utils import has_cuda_and_cupy

from darts_preprocessing.engineering.arcticdem import calculate_slope, calculate_topographic_position_index
from darts_preprocessing.engineering.indices import calculate_ndvi

logger = logging.getLogger(__name__.replace("darts_", "darts."))


if has_cuda_and_cupy():
import cupy_xarray # noqa: F401

logger.info("GPU-accelerated xrspatial functions are available.")
else:
logger.info("GPU-accelerated xrspatial functions are not available.")


def preprocess_legacy(
ds_optical: xr.Dataset,
ds_arcticdem: xr.Dataset,
Expand Down Expand Up @@ -49,6 +58,7 @@ def preprocess_legacy_fast(
ds_data_masks: xr.Dataset,
tpi_outer_radius: int = 30,
tpi_inner_radius: int = 25,
use_gpu: bool = True,
) -> xr.Dataset:
"""Preprocess optical data with legacy (DARTS v1) preprocessing steps, but with new data concepts.
Expand All @@ -70,6 +80,7 @@ def preprocess_legacy_fast(
in number of cells. Defaults to 30.
tpi_inner_radius (int, optional): The inner radius of the annulus kernel for the tpi calculation
in number of cells. Defaults to 25.
use_gpu (bool, optional): Whether to use GPU-accelerated functions. Defaults to True.
Returns:
xr.Dataset: The preprocessed dataset.
Expand All @@ -86,9 +97,20 @@ def preprocess_legacy_fast(

# Calculate TPI and slope from ArcticDEM
# We need to calculate them before reprojecting, hence we cant merge the data yet
# Move to GPU if available
if use_gpu and has_cuda_and_cupy():
logger.debug("Moving arcticdem to GPU.")
# Check if dem is dask, if not persist it, since tpi and slope can't be calculated from cupy-dask arrays
if ds_arcticdem.chunks is not None:
ds_arcticdem = ds_arcticdem.persist()
ds_arcticdem = ds_arcticdem.cupy.as_cupy()

ds_arcticdem = calculate_topographic_position_index(ds_arcticdem, tpi_outer_radius, tpi_inner_radius)
ds_arcticdem = calculate_slope(ds_arcticdem)
ds_arcticdem = ds_arcticdem.rio.reproject_match(ds_optical, resampling=rasterio.enums.Resampling.cubic)
# Move back to CPU
if use_gpu and has_cuda_and_cupy():
ds_arcticdem = ds_arcticdem.cupy.as_numpy()
ds_arcticdem = ds_arcticdem.odc.reproject(ds_optical.odc.geobox, resampling="cubic")

ds_merged["dem"] = ds_arcticdem.dem
ds_merged["relative_elevation"] = ds_arcticdem.tpi
Expand Down
4 changes: 2 additions & 2 deletions darts/src/darts/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from darts.utils.config import ConfigParser
from darts.utils.logging import add_logging_handlers, setup_logging

root = Path(__name__).resolve()
root_file = Path(__file__).resolve()
logger = logging.getLogger(__name__)
console = Console()

Expand Down Expand Up @@ -76,7 +76,7 @@ def launcher( # noqa: D103
):
command, bound = app.parse_args(tokens)
add_logging_handlers(command.__name__, console, log_dir)
logger.debug(f"Running on Python version {sys.version} from {root}")
logger.debug(f"Running on Python version {sys.version} from {__name__} ({root_file})")
return command(*bound.args, **bound.kwargs)


Expand Down
53 changes: 35 additions & 18 deletions notebooks/test-arcticdem-datacube.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"import logging\n",
"from pathlib import Path\n",
"\n",
"import cupy_xarray # noqa: F401\n",
"import hvplot.xarray # noqa: F401\n",
"import rioxarray # noqa: F401\n",
"from darts_acquisition.planet import load_planet_scene\n",
Expand All @@ -26,7 +27,7 @@
" datefmt=\"[%X]\",\n",
" handlers=[RichHandler(rich_tracebacks=True)],\n",
")\n",
"traceback.install(show_locals=False)\n",
"traceback.install(show_locals=True)\n",
"client = Client()\n",
"configure_rio(cloud_defaults=True, aws={\"aws_unsigned\": True}, client=client)\n",
"client"
Expand All @@ -52,7 +53,7 @@
"outputs": [],
"source": [
"# load planet scene\n",
"ds_planet = load_planet_scene(fpath).isel(x=slice(0, 2000), y=slice(6000, 8000))\n",
"ds_planet = load_planet_scene(fpath) # .isel(x=slice(0, 2000), y=slice(6000, 8000))\n",
"ds_planet"
]
},
Expand All @@ -71,7 +72,7 @@
"source": [
"from darts_acquisition.arcticdem import load_arcticdem_tile\n",
"\n",
"ds = load_arcticdem_tile(ds_planet, arcticdem_dir, resolution=2, buffer=0)\n",
"ds = load_arcticdem_tile(ds_planet.odc.geobox, arcticdem_dir, resolution=2, buffer=0, persist=True)\n",
"ds"
]
},
Expand All @@ -92,16 +93,35 @@
"outputs": [],
"source": [
"crs = ds_planet.rio.crs.to_string()\n",
"dem_plot = ds.dem.rio.reproject_match(ds_planet).hvplot.image(aggregator=\"max\", rasterize=True, cmap=\"terrain\")\n",
"red_plot = ds_planet.red.hvplot.image(x=\"x\", y=\"y\", aggregator=\"mean\", rasterize=True, cmap=\"reds\")\n",
"dem_plot = ds.dem.rio.reproject_match(ds_planet).hvplot.image(\n",
" aggregator=\"max\", rasterize=True, cmap=\"terrain\", data_aspect=1, crs=crs, projection=crs\n",
")\n",
"red_plot = ds_planet.red.hvplot.image(\n",
" x=\"x\", y=\"y\", aggregator=\"mean\", rasterize=True, cmap=\"reds\", data_aspect=1, crs=crs, projection=crs\n",
")\n",
"dem_plot + red_plot"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Calculate Relative Elevation and Slope"
"## Calculate Relative Elevation and Slope\n",
"\n",
"### Numpy\n",
"\n",
"Topographic Position Index calculated in 652.77 seconds.\n",
"Slope calculated in 3.67 seconds.\n",
"\n",
"### Dask (4 worker)\n",
"\n",
"Topographic Position Index calculated in 135.34 seconds.\n",
"Slope calculated in 4.33 seconds.\n",
"\n",
"### Cupy\n",
"\n",
"Topographic Position Index calculated in 12.69 seconds. \n",
"Slope calculated in 0.16 seconds.\n"
]
},
{
Expand All @@ -111,21 +131,18 @@
"outputs": [],
"source": [
"from darts_preprocessing.engineering.arcticdem import calculate_slope, calculate_topographic_position_index\n",
"from xrspatial.utils import has_cuda_and_cupy\n",
"\n",
"use_cupy = True\n",
"if use_cupy and has_cuda_and_cupy():\n",
" ds = ds.cupy.as_cupy()\n",
"ds = calculate_topographic_position_index(ds)\n",
"ds = calculate_slope(ds)\n",
"if ds.cupy.is_cupy:\n",
" ds = ds.cupy.as_numpy()\n",
"ds"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ds = ds.persist()"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -134,13 +151,13 @@
"source": [
"crs = ds.rio.crs.to_string()\n",
"dem_plot = ds.dem.hvplot.image(\n",
" x=\"x\", y=\"y\", aggregator=\"max\", rasterize=True, cmap=\"terrain\", crs=crs, projection=crs, title=\"DEM\"\n",
" x=\"x\", y=\"y\", aggregator=\"max\", rasterize=True, cmap=\"terrain\", data_aspect=1, crs=crs, projection=crs, title=\"DEM\"\n",
")\n",
"tpi_plot = ds.tpi.hvplot.image(\n",
" x=\"x\", y=\"y\", aggregator=\"max\", rasterize=True, cmap=\"terrain\", crs=crs, projection=crs, title=\"TPI\"\n",
" x=\"x\", y=\"y\", aggregator=\"max\", rasterize=True, cmap=\"terrain\", data_aspect=1, crs=crs, projection=crs, title=\"TPI\"\n",
")\n",
"slope_plot = ds.slope.hvplot.image(\n",
" x=\"x\", y=\"y\", aggregator=\"max\", rasterize=True, cmap=\"terrain\", crs=crs, projection=crs, title=\"Slope\"\n",
" x=\"x\", y=\"y\", aggregator=\"max\", rasterize=True, data_aspect=1, crs=crs, projection=crs, title=\"Slope\"\n",
")\n",
"dem_plot + tpi_plot + slope_plot"
]
Expand Down
8 changes: 4 additions & 4 deletions notebooks/test-e2e.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
" )\n",
" for z in tile.data_vars\n",
" ]\n",
" return hv.Layout(var_plots).cols(ncols)\n"
" return hv.Layout(var_plots).cols(ncols)"
]
},
{
Expand All @@ -82,13 +82,13 @@
"outputs": [],
"source": [
"cache_file = DATA_ROOT / \"intermediate\" / f\"planet_{fpath.stem}.nc\"\n",
"force = True\n",
"force = False\n",
"slc = {\"x\": slice(0, 1000), \"y\": slice(7000, 8000)}\n",
"if cache_file.exists() and not force:\n",
" tile = xr.open_dataset(cache_file, engine=\"h5netcdf\", mask_and_scale=False).isel(slc)\n",
" tile = xr.open_dataset(cache_file, engine=\"h5netcdf\", mask_and_scale=False).set_coords(\"spatial_ref\")\n",
"else:\n",
" optical = load_planet_scene(fpath).isel(slc)\n",
" arcticdem = load_arcticdem_tile(optical, arcticdem_dir, resolution=2)\n",
" arcticdem = load_arcticdem_tile(optical.odc.geobox, arcticdem_dir, resolution=2)\n",
" tcvis = load_tcvis(optical, None if force else cache_dir)\n",
" data_masks = load_planet_masks(fpath).isel(slc)\n",
" tile = preprocess_legacy_fast(optical, arcticdem, tcvis, data_masks)\n",
Expand Down
9 changes: 7 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ dependencies = [
"folium>=0.18.0",
"bokeh>=3.5.2",
"jupyter-bokeh>=4.0.5",
"setuptools>=75.5.0",
]
readme = "README.md"
requires-python = ">= 3.11"
Expand All @@ -56,14 +57,18 @@ cpu = ["torch==2.2.0+cpu", "torchvision==0.17.0+cpu"]
cuda11 = [
"torch==2.2.0+cu118",
"torchvision==0.17.0+cu118",
"cupy-cuda11x>=13.3.0",
"cucim-cu11>=24.8.0",
"cupy>=13.3.0",
"cupy-xarray>=0.1.4",
"cuda-python>=12.6.2.post1",
]
cuda12 = [
"torch==2.2.0+cu121",
"torchvision==0.17.0+cu121",
"cupy-cuda12x>=13.3.0",
"cucim-cu12==24.8.*",
"cupy>=13.3.0",
"cupy-xarray>=0.1.4",
"cuda-python>=12.6.2.post1",
]
gdal393 = ["gdal==3.9.3"]
gdal39 = ["gdal==3.9.2"]
Expand Down

0 comments on commit c61613b

Please sign in to comment.