Skip to content

Commit

Permalink
Load ArcticDEM via VRT
Browse files Browse the repository at this point in the history
  • Loading branch information
relativityhd committed Oct 23, 2024
1 parent e77b7ea commit 4a098fd
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 23 deletions.
67 changes: 67 additions & 0 deletions darts-acquisition/src/darts_acquisition/arcticdem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""ArcticDEM related data loading."""

import logging
import os
import time
from pathlib import Path

logger = logging.getLogger(__name__)


def create_arcticdem_vrt(dem_data_dir: Path, vrt_target_dir: Path):
"""Create a VRT file from ArcticDEM data.
Args:
dem_data_dir (Path): The directory containing the ArcticDEM data (.tif).
vrt_target_dir (Path): The output directory.
Raises:
OSError: If the target directory is not writable.
"""
start_time = time.time()
logger.debug(f"Creating ArcticDEM VRT file at {vrt_target_dir} based on {dem_data_dir}.")

try:
from osgeo import gdal

logger.debug(f"Found gdal bindings {gdal.__version__}.")
except ModuleNotFoundError as e:
logger.exception(
"The python GDAL bindings where not found. Please install those which are appropriate for your platform."
)
raise e

# decide on the exception behavior of GDAL to supress a warning if we dont
# don't know if this is necessary in all GDAL versions
try:
gdal.UseExceptions()
logger.debug("Enabled gdal exceptions")
except AttributeError():
pass

# subdirs = {"elevation": "tiles_rel_el", "slope": "tiles_slope"}
subdirs = {"elevation": "relative_elevation", "slope": "slope"}

# check first if BOTH files are writable
non_writable_files = []
for name in subdirs.keys():
output_file_path = vrt_target_dir / f"{name}.vrt"
if not os.access(output_file_path, os.W_OK) and output_file_path.exists():
non_writable_files.append(output_file_path)
if len(non_writable_files) > 0:
raise OSError(f"cannot write to {', '.join([f.name for f in non_writable_files])}")

for name, subdir in subdirs.items():
output_file_path = vrt_target_dir / f"{name}.vrt"
# check the file first if we can write to it

ds_path = dem_data_dir / subdir
filelist = [str(f.absolute().resolve()) for f in ds_path.glob("*.tif")]
logger.debug(f"Found {len(filelist)} files for {name} at {ds_path}.")
logger.debug(f"Writing VRT to '{output_file_path}'")
src_nodata = "nan" if name == "slope" else 0
opt = gdal.BuildVRTOptions(srcNodata=src_nodata, VRTNodata=0)
gdal.BuildVRT(str(output_file_path.absolute()), filelist, options=opt)

logger.debug(f"Creation of VRT took {time.time() - start_time:.2f}s")
Original file line number Diff line number Diff line change
Expand Up @@ -4,36 +4,75 @@
import time
from pathlib import Path

import rasterio
import rasterio.mask
import rioxarray # noqa: F401
import xarray as xr

logger = logging.getLogger(__name__)


def load_arcticdem(elevation_path: Path, slope_path: Path, reference_dataset: xr.Dataset) -> xr.Dataset:
def load_vrt(vrt_path: Path, reference_dataset: xr.Dataset) -> xr.DataArray:
"""Load a VRT file and reproject it to match the reference dataset.
Args:
vrt_path (Path): Path to the vrt file.
reference_dataset (xr.Dataset): The reference dataset.
Raises:
FileNotFoundError: If the VRT file is not found.
Returns:
xr.DataArray: The VRT data reprojected to match the reference dataarray.
"""
if not vrt_path.exists():
raise FileNotFoundError(f"Could not find the VRT file at {vrt_path}")

start_time = time.time()

with rasterio.open(vrt_path) as src:
with rasterio.vrt.WarpedVRT(
src, crs=reference_dataset.rio.crs, resampling=rasterio.enums.Resampling.cubic
) as vrt:
bounds = reference_dataset.rio.bounds()
windows = vrt.window(*bounds)
shape = (1, len(reference_dataset.y), len(reference_dataset.x))
data = vrt.read(window=windows, out_shape=shape)[0] # This is the most time consuming part of the function
da = xr.DataArray(data, dims=["y", "x"], coords={"y": reference_dataset.y, "x": reference_dataset.x})
da.rio.write_crs(reference_dataset.rio.crs, inplace=True)
da.rio.write_transform(reference_dataset.rio.transform(), inplace=True)

logger.debug(f"Loaded VRT data from {vrt_path} in {time.time() - start_time} seconds.")
return da


def load_arcticdem(fpath: Path, reference_dataset: xr.Dataset) -> xr.Dataset:
"""Load ArcticDEM data and reproject it to match the reference dataset.
Args:
elevation_path (Path): The path to the ArcticDEM elevation data.
slope_path (Path): The path to the ArcticDEM slope data.
fpath (Path): The path to the ArcticDEM data.
reference_dataset (xr.Dataset): The reference dataset to reproject, resampled and cropped the ArcticDEM data to.
Returns:
xr.Dataset: The ArcticDEM data reprojected, resampled and cropped to match the reference dataset.
"""
start_time = time.time()
logger.debug(f"Loading ArcticDEM data from {elevation_path} and {slope_path}")
relative_elevation = xr.open_dataarray(elevation_path).isel(band=0).drop_vars("band")
relative_elevation: xr.DataArray = relative_elevation.rio.reproject_match(reference_dataset)
logger.debug(f"Loading ArcticDEM data from {fpath}")

slope_vrt = fpath / "slope.vrt"
elevation_vrt = fpath / "elevation.vrt"

slope = load_vrt(slope_vrt, reference_dataset)
slope: xr.Dataset = slope.assign_attrs({"data_source": "arcticdem", "long_name": "Slope"}).to_dataset(name="slope")

relative_elevation = load_vrt(elevation_vrt, reference_dataset)
relative_elevation: xr.Dataset = relative_elevation.assign_attrs(
{"data_source": "arcticdem", "long_name": "Relative Elevation"}
).to_dataset(name="relative_elevation")

slope = xr.open_dataarray(slope_path).isel(band=0).drop_vars("band")
slope: xr.DataArray = slope.rio.reproject_match(reference_dataset)
slope: xr.Dataset = slope.assign_attrs({"data_source": "arcticdem", "long_name": "Slope"}).to_dataset(name="slope")

articdem_ds = xr.merge([relative_elevation, slope])
logger.debug(f"Loaded ArcticDEM data in {time.time() - start_time} seconds.")
return articdem_ds
13 changes: 5 additions & 8 deletions darts-preprocessing/src/darts_preprocessing/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@
from darts_preprocessing.engineering.indices import calculate_ndvi


def load_and_preprocess_planet_scene(planet_scene_path: Path, elevation_path: Path, slope_path: Path) -> xr.Dataset:
def load_and_preprocess_planet_scene(planet_scene_path: Path, arcticdem_dir: Path) -> xr.Dataset:
"""Load and preprocess a Planet Scene (PSOrthoTile or PSScene) into an xr.Dataset.
Args:
planet_scene_path (Path): path to the Planet Scene
elevation_path (Path): path to the elevation data
slope_path (Path): path to the slope data
arcticdem_dir (Path): path to the ArcticDEM directory
Returns:
xr.Dataset: preprocessed Planet Scene
Expand Down Expand Up @@ -48,10 +47,8 @@ def load_and_preprocess_planet_scene(planet_scene_path: Path, elevation_path: Pa
from darts_preprocessing.preprocess import load_and_preprocess_planet_scene
fpath = Path("data/input/planet/planet/PSOrthoTile/4372514/5790392_4372514_2022-07-16_2459")
scene_id = fpath.parent.name
elevation_path = input_data_dir / "ArcticDEM" / "relative_elevation" / f"{scene_id}_relative_elevation_100.tif"
slope_path = input_data_dir / "ArcticDEM" / "slope" / f"{scene_id}_slope.tif"
tile = load_and_preprocess_planet_scene(fpath, elevation_path, slope_path)
arcticdem_dir = input_data_dir / "ArcticDEM" / "relative_elevation" / f"{scene_id}_relative_elevation_100.tif"
tile = load_and_preprocess_planet_scene(fpath, arcticdem_dir)
```
""" # noqa: E501
Expand All @@ -61,7 +58,7 @@ def load_and_preprocess_planet_scene(planet_scene_path: Path, elevation_path: Pa
# calculate xr.dataset ndvi
ds_ndvi = calculate_ndvi(ds_planet)

ds_articdem = load_arcticdem(elevation_path, slope_path, ds_planet)
ds_articdem = load_arcticdem(arcticdem_dir, ds_planet)

# # get xr.dataset for tcvis
# ds_tcvis = load_auxiliary(planet_scene_path, tcvis_path)
Expand Down
9 changes: 7 additions & 2 deletions darts/src/darts/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Annotated

import cyclopts
from darts_acquisition.arcticdem import create_arcticdem_vrt
from rich.console import Console

from darts import __version__
Expand All @@ -22,8 +23,11 @@
config=config_parser, # config=cyclopts.config.Toml("config.toml", root_keys=["darts"], search_parents=True)
)

pipeline_group = cyclopts.Group.create_ordered("Pipeline Commands")
data_group = cyclopts.Group.create_ordered("Data Commands")

@app.command

# @app.command
def hello(name: str, n: int = 1):
"""Say hello to someone.
Expand All @@ -42,7 +46,8 @@ def hello(name: str, n: int = 1):
logger.info(f"Hello {name}")


app.command()(run_native_orthotile_pipeline)
app.command(group=pipeline_group)(run_native_orthotile_pipeline)
app.command(group=data_group)(create_arcticdem_vrt)


# Intercept the logging behavior to add a file handler
Expand Down
6 changes: 3 additions & 3 deletions darts/src/darts/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@ def run_native_orthotile_pipeline(
from darts_preprocessing import load_and_preprocess_planet_scene
from darts_segmentation import SMPSegmenter

arcticdem_dir = input_data_dir / "ArcticDEM"

# Find all PlanetScope scenes
for fpath in (input_data_dir / "planet" / "PSOrthoTile").glob("*/*/"):
scene_id = fpath.parent.name
elevation_path = input_data_dir / "ArcticDEM" / "relative_elevation" / f"{scene_id}_relative_elevation_100.tif"
slope_path = input_data_dir / "ArcticDEM" / "slope" / f"{scene_id}_slope.tif"
outpath = output_data_dir / scene_id

tile = load_and_preprocess_planet_scene(fpath, elevation_path, slope_path)
tile = load_and_preprocess_planet_scene(fpath, arcticdem_dir)

model = SMPSegmenter(model_dir / "RTS_v6_notcvis.pt")
tile = model.segment_tile(
Expand Down

0 comments on commit 4a098fd

Please sign in to comment.