Skip to content

Commit

Permalink
Add a training preprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
relativityhd committed Nov 29, 2024
1 parent fc58e08 commit ce63d87
Show file tree
Hide file tree
Showing 7 changed files with 263 additions and 2 deletions.
17 changes: 17 additions & 0 deletions config.toml.example
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,20 @@ patch-size = 1024
overlap = 256
batch-size = 2
reflection = 32

[darts.training]
bands = [
'blue',
'green',
'red',
'nir',
'ndvi',
'tc_brightness',
'tc_greenness',
'tc_wetness',
'relative_elevation',
'slope',
]
train-data-dir = "data/training"
include-allzero = false
include-nan-edges = true
4 changes: 3 additions & 1 deletion darts-acquisition/src/darts_acquisition/s2.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ def load_s2_scene(fpath: str | Path) -> xr.Dataset:
]

ds_s2 = xr.merge(datasets)
ds_s2.attrs["tile_id"] = fpath.stem
planet_crop_id = fpath.stem
s2_tile_id = "_".join(s2_image.stem.split("_")[:3])
ds_s2.attrs["tile_id"] = f"{planet_crop_id}_{s2_tile_id}"
logger.debug(f"Loaded Sentinel 2 scene in {time.time() - start_time} seconds.")
return ds_s2

Expand Down
82 changes: 82 additions & 0 deletions darts-segmentation/src/darts_segmentation/prepare_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""Functions to prepare the training data for the segmentation model training."""

from collections.abc import Generator

import geopandas as gpd
import torch
import xarray as xr
from geocube.api.core import make_geocube

from darts_segmentation.utils import create_patches


def create_training_patches(
tile: xr.Dataset,
labels: gpd.GeoDataFrame,
bands: list[str],
patch_size: int,
overlap: int,
include_allzero: bool,
include_nan_edges: bool,
) -> Generator[tuple[torch.tensor, torch.tensor]]:
"""Create training patches from a tile and labels.
Args:
tile (xr.Dataset): The input tile, containing preprocessed, harmonized data.
labels (gpd.GeoDataFrame): The labels to be used for training.
bands (list[str]): The bands to be used for training. Must be present in the tile.
patch_size (int): The size of the patches.
overlap (int): The size of the overlap.
include_allzero (bool): Whether to include patches where the labels are all zero.
include_nan_edges (bool): Whether to include patches where the input data has nan values at the edges.
Yields:
Generator[tuple[torch.tensor, torch.tensor]]: A tuple containing the input and the labels as pytorch tensors.
The input has the format (C, H, W), the labels (H, W).
"""
# Rasterize the labels
labels_rasterized = 1 - make_geocube(labels, measurements=["id"], like=tile).id.isnull() # noqa: PD003

# Filter out the nodata values
labels_rasterized = xr.where(tile["valid_data_mask"], labels_rasterized, 0)

# Replace invalid values with nan (used for nan check later on)
tile = xr.where(tile["valid_data_mask"], tile, float("nan"))

# Convert to dataaray and select the bands (bands are now in specified order)
tile = tile.to_dataarray(dim="band").sel(band=bands)

# Transpose to (C, H, W)
tile = tile.transpose("band", "y", "x")
labels_rasterized = labels_rasterized.transpose("y", "x")

# Convert to tensor
tensor_tile = torch.tensor(tile.values).float()
tensor_labels = torch.tensor(labels_rasterized.values).float()

assert tensor_tile.dim() == 3, f"Expects tensor_tile to has shape (C, H, W), got {tensor_tile.shape}"
assert tensor_labels.dim() == 2, f"Expects tensor_labels to has shape (H, W), got {tensor_labels.shape}"

# Create patches
tensor_patches = create_patches(tensor_tile.unsqueeze(0), patch_size, overlap)
tensor_patches = tensor_patches.reshape(-1, len(bands), patch_size, patch_size)
tensor_labels = create_patches(tensor_labels.unsqueeze(0).unsqueeze(0), patch_size, overlap)
tensor_labels = tensor_labels.reshape(-1, patch_size, patch_size)

# Turn the patches into a list of tuples
n_patches = tensor_patches.shape[0]
for i in range(n_patches):
x = tensor_patches[i]
y = tensor_labels[i]

if not include_allzero and y.sum() == 0:
continue

if not include_nan_edges and torch.isnan(x).any():
continue

# Convert all nan values to 0
x[torch.isnan(x)] = 0

yield x, y
4 changes: 4 additions & 0 deletions darts/src/darts/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
run_native_sentinel2_pipeline,
run_native_sentinel2_pipeline_fast,
)
from darts.training import preprocess_s2_train_data
from darts.utils.config import ConfigParser
from darts.utils.logging import add_logging_handlers, setup_logging

Expand All @@ -33,6 +34,7 @@

pipeline_group = cyclopts.Group.create_ordered("Pipeline Commands")
data_group = cyclopts.Group.create_ordered("Data Commands")
train_group = cyclopts.Group.create_ordered("Training Commands")


@app.command
Expand Down Expand Up @@ -67,6 +69,8 @@ def env_info():
app.command(group=pipeline_group)(run_native_sentinel2_pipeline)
app.command(group=pipeline_group)(run_native_sentinel2_pipeline_fast)

app.command(group=train_group)(preprocess_s2_train_data)


# Custom wrapper for the create_arcticdem_vrt function, which dodges the loading of all the heavy modules
@app.command(group=data_group)
Expand Down
1 change: 0 additions & 1 deletion darts/src/darts/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,6 @@ def run_native_sentinel2_pipeline_fast(
Args:
sentinel2_dir (Path): The directory containing the Sentinel 2 scenes.
scenes_dir (Path): The directory containing the PlanetScope scenes.
output_data_dir (Path): The "output" directory.
arcticdem_dir (Path): The directory containing the ArcticDEM data (the datacube and the extent files).
Will be created and downloaded if it does not exist.
Expand Down
155 changes: 155 additions & 0 deletions darts/src/darts/training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
"""Training module for DARTS."""

import logging
import multiprocessing as mp
from math import ceil, sqrt
from pathlib import Path
from typing import Literal

import toml

logger = logging.getLogger(__name__)


def preprocess_s2_train_data(
*,
sentinel2_dir: Path,
train_data_dir: Path,
arcticdem_dir: Path,
tcvis_dir: Path,
bands: list[str],
device: Literal["cuda", "cpu", "auto"] | int | None = None,
ee_project: str | None = None,
ee_use_highvolume: bool = True,
tpi_outer_radius: int = 100,
tpi_inner_radius: int = 0,
patch_size: int = 1024,
overlap: int = 16,
include_allzero: bool = False,
include_nan_edges: bool = True,
):
"""Preprocess Sentinel 2 data for training.
Args:
sentinel2_dir (Path): The directory containing the Sentinel 2 scenes.
train_data_dir (Path): The "output" directory where the tensors are written to.
arcticdem_dir (Path): The directory containing the ArcticDEM data (the datacube and the extent files).
Will be created and downloaded if it does not exist.
tcvis_dir (Path): The directory containing the TCVis data.
bands (list[str]): The bands to be used for training. Must be present in the preprocessing.
device (Literal["cuda", "cpu"] | int, optional): The device to run the model on.
If "cuda" take the first device (0), if int take the specified device.
If "auto" try to automatically select a free GPU (<50% memory usage).
Defaults to "cuda" if available, else "cpu".
ee_project (str, optional): The Earth Engine project ID or number to use. May be omitted if
project is defined within persistent API credentials obtained via `earthengine authenticate`.
ee_use_highvolume (bool, optional): Whether to use the high volume server (https://earthengine-highvolume.googleapis.com).
tpi_outer_radius (int, optional): The outer radius of the annulus kernel for the tpi calculation
in m. Defaults to 100m.
tpi_inner_radius (int, optional): The inner radius of the annulus kernel for the tpi calculation
in m. Defaults to 0.
patch_size (int, optional): The patch size to use for inference. Defaults to 1024.
overlap (int, optional): The overlap to use for inference. Defaults to 16.
include_allzero (bool, optional): Whether to include patches where the labels are all zero. Defaults to False.
include_nan_edges (bool, optional): Whether to include patches where the input data has nan values at the edges.
Defaults to True.
"""
# Import here to avoid long loading times when running other commands
import geopandas as gpd
import torch
from darts_acquisition.arcticdem import load_arcticdem_tile
from darts_acquisition.s2 import load_s2_masks, load_s2_scene
from darts_acquisition.tcvis import load_tcvis
from darts_preprocessing import preprocess_legacy_fast
from darts_segmentation.prepare_training import create_training_patches
from dask.distributed import Client, LocalCluster
from odc.stac import configure_rio

from darts.utils.cuda import debug_info, decide_device
from darts.utils.earthengine import init_ee

debug_info()
device = decide_device(device)
init_ee(ee_project, ee_use_highvolume)

cluster = LocalCluster(n_workers=mp.cpu_count() - 1)
logger.info(f"Created Dask cluster: {cluster}")
client = Client(cluster)
logger.info(f"Using Dask client: {client}")
configure_rio(cloud_defaults=True, aws={"aws_unsigned": True}, client=client)
logger.info("Configured Rasterio with Dask")

outpath_x = train_data_dir / "x"
outpath_y = train_data_dir / "y"

outpath_x.mkdir(exist_ok=True, parents=True)
outpath_y.mkdir(exist_ok=True, parents=True)

# Find all Sentinel 2 scenes
n_patches = 0
for fpath in sentinel2_dir.glob("*/"):
try:
optical = load_s2_scene(fpath)
arcticdem = load_arcticdem_tile(
optical.odc.geobox, arcticdem_dir, resolution=10, buffer=ceil(tpi_outer_radius / 10 * sqrt(2))
)
tcvis = load_tcvis(optical.odc.geobox, tcvis_dir)
data_masks = load_s2_masks(fpath, optical.odc.geobox)

tile = preprocess_legacy_fast(
optical,
arcticdem,
tcvis,
data_masks,
tpi_outer_radius,
tpi_inner_radius,
device,
)

labels = gpd.read_file(fpath / f"{optical.attrs['tile_id']}.shp")
tile_id = optical.attrs["tile_id"]

# Save the patches
gen = create_training_patches(tile, labels, bands, patch_size, overlap, include_allzero, include_nan_edges)
for patch_id, (x, y) in enumerate(gen):
torch.save(x, outpath_x / f"{tile_id}_pid{patch_id}.pt")
torch.save(y, outpath_y / f"{tile_id}_pid{patch_id}.pt")
n_patches += 1
logger.info(f"Processed {tile_id} with {patch_id} patches.")

except KeyboardInterrupt:
logger.info("Interrupted by user.")
break

except Exception as e:
logger.warning(f"Could not process folder '{fpath.resolve()}'.\nSkipping...")
logger.exception(e)

# Save a config file as toml
config = {
"darts": {
"sentinel2_dir": sentinel2_dir,
"train_data_dir": train_data_dir,
"arcticdem_dir": arcticdem_dir,
"tcvis_dir": tcvis_dir,
"bands": bands,
"device": device,
"ee_project": ee_project,
"ee_use_highvolume": ee_use_highvolume,
"tpi_outer_radius": tpi_outer_radius,
"tpi_inner_radius": tpi_inner_radius,
"patch_size": patch_size,
"overlap": overlap,
"include_allzero": include_allzero,
"include_nan_edges": include_nan_edges,
"n_patches": n_patches,
}
}
with open(train_data_dir / "config.toml", "w") as f:
toml.dump(config, f)

logger.info(f"Saved {n_patches} patches to {train_data_dir}")

client.close()
cluster.close()
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ dependencies = [
"xpystac>=0.1.3",
"odc-geo>=0.4.8",
"odc-stac[botocore]>=0.3.10",
"toml>=0.10.2",
"zarr[jupyter]>=2.18.3",
# Training and Inference
"segmentation-models-pytorch>=0.3.4",
Expand All @@ -32,6 +33,7 @@ dependencies = [
"scipy>=1.14.1",
"xarray-spatial>=0.4.0",
"dask>=2024.11.0",
"geocube>=0.7.0",
# Visualization
"cartopy>=0.24.1",
"hvplot>=0.11.1",
Expand Down

0 comments on commit ce63d87

Please sign in to comment.