-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
fc58e08
commit ce63d87
Showing
7 changed files
with
263 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
82 changes: 82 additions & 0 deletions
82
darts-segmentation/src/darts_segmentation/prepare_training.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
"""Functions to prepare the training data for the segmentation model training.""" | ||
|
||
from collections.abc import Generator | ||
|
||
import geopandas as gpd | ||
import torch | ||
import xarray as xr | ||
from geocube.api.core import make_geocube | ||
|
||
from darts_segmentation.utils import create_patches | ||
|
||
|
||
def create_training_patches( | ||
tile: xr.Dataset, | ||
labels: gpd.GeoDataFrame, | ||
bands: list[str], | ||
patch_size: int, | ||
overlap: int, | ||
include_allzero: bool, | ||
include_nan_edges: bool, | ||
) -> Generator[tuple[torch.tensor, torch.tensor]]: | ||
"""Create training patches from a tile and labels. | ||
Args: | ||
tile (xr.Dataset): The input tile, containing preprocessed, harmonized data. | ||
labels (gpd.GeoDataFrame): The labels to be used for training. | ||
bands (list[str]): The bands to be used for training. Must be present in the tile. | ||
patch_size (int): The size of the patches. | ||
overlap (int): The size of the overlap. | ||
include_allzero (bool): Whether to include patches where the labels are all zero. | ||
include_nan_edges (bool): Whether to include patches where the input data has nan values at the edges. | ||
Yields: | ||
Generator[tuple[torch.tensor, torch.tensor]]: A tuple containing the input and the labels as pytorch tensors. | ||
The input has the format (C, H, W), the labels (H, W). | ||
""" | ||
# Rasterize the labels | ||
labels_rasterized = 1 - make_geocube(labels, measurements=["id"], like=tile).id.isnull() # noqa: PD003 | ||
|
||
# Filter out the nodata values | ||
labels_rasterized = xr.where(tile["valid_data_mask"], labels_rasterized, 0) | ||
|
||
# Replace invalid values with nan (used for nan check later on) | ||
tile = xr.where(tile["valid_data_mask"], tile, float("nan")) | ||
|
||
# Convert to dataaray and select the bands (bands are now in specified order) | ||
tile = tile.to_dataarray(dim="band").sel(band=bands) | ||
|
||
# Transpose to (C, H, W) | ||
tile = tile.transpose("band", "y", "x") | ||
labels_rasterized = labels_rasterized.transpose("y", "x") | ||
|
||
# Convert to tensor | ||
tensor_tile = torch.tensor(tile.values).float() | ||
tensor_labels = torch.tensor(labels_rasterized.values).float() | ||
|
||
assert tensor_tile.dim() == 3, f"Expects tensor_tile to has shape (C, H, W), got {tensor_tile.shape}" | ||
assert tensor_labels.dim() == 2, f"Expects tensor_labels to has shape (H, W), got {tensor_labels.shape}" | ||
|
||
# Create patches | ||
tensor_patches = create_patches(tensor_tile.unsqueeze(0), patch_size, overlap) | ||
tensor_patches = tensor_patches.reshape(-1, len(bands), patch_size, patch_size) | ||
tensor_labels = create_patches(tensor_labels.unsqueeze(0).unsqueeze(0), patch_size, overlap) | ||
tensor_labels = tensor_labels.reshape(-1, patch_size, patch_size) | ||
|
||
# Turn the patches into a list of tuples | ||
n_patches = tensor_patches.shape[0] | ||
for i in range(n_patches): | ||
x = tensor_patches[i] | ||
y = tensor_labels[i] | ||
|
||
if not include_allzero and y.sum() == 0: | ||
continue | ||
|
||
if not include_nan_edges and torch.isnan(x).any(): | ||
continue | ||
|
||
# Convert all nan values to 0 | ||
x[torch.isnan(x)] = 0 | ||
|
||
yield x, y |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
"""Training module for DARTS.""" | ||
|
||
import logging | ||
import multiprocessing as mp | ||
from math import ceil, sqrt | ||
from pathlib import Path | ||
from typing import Literal | ||
|
||
import toml | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def preprocess_s2_train_data( | ||
*, | ||
sentinel2_dir: Path, | ||
train_data_dir: Path, | ||
arcticdem_dir: Path, | ||
tcvis_dir: Path, | ||
bands: list[str], | ||
device: Literal["cuda", "cpu", "auto"] | int | None = None, | ||
ee_project: str | None = None, | ||
ee_use_highvolume: bool = True, | ||
tpi_outer_radius: int = 100, | ||
tpi_inner_radius: int = 0, | ||
patch_size: int = 1024, | ||
overlap: int = 16, | ||
include_allzero: bool = False, | ||
include_nan_edges: bool = True, | ||
): | ||
"""Preprocess Sentinel 2 data for training. | ||
Args: | ||
sentinel2_dir (Path): The directory containing the Sentinel 2 scenes. | ||
train_data_dir (Path): The "output" directory where the tensors are written to. | ||
arcticdem_dir (Path): The directory containing the ArcticDEM data (the datacube and the extent files). | ||
Will be created and downloaded if it does not exist. | ||
tcvis_dir (Path): The directory containing the TCVis data. | ||
bands (list[str]): The bands to be used for training. Must be present in the preprocessing. | ||
device (Literal["cuda", "cpu"] | int, optional): The device to run the model on. | ||
If "cuda" take the first device (0), if int take the specified device. | ||
If "auto" try to automatically select a free GPU (<50% memory usage). | ||
Defaults to "cuda" if available, else "cpu". | ||
ee_project (str, optional): The Earth Engine project ID or number to use. May be omitted if | ||
project is defined within persistent API credentials obtained via `earthengine authenticate`. | ||
ee_use_highvolume (bool, optional): Whether to use the high volume server (https://earthengine-highvolume.googleapis.com). | ||
tpi_outer_radius (int, optional): The outer radius of the annulus kernel for the tpi calculation | ||
in m. Defaults to 100m. | ||
tpi_inner_radius (int, optional): The inner radius of the annulus kernel for the tpi calculation | ||
in m. Defaults to 0. | ||
patch_size (int, optional): The patch size to use for inference. Defaults to 1024. | ||
overlap (int, optional): The overlap to use for inference. Defaults to 16. | ||
include_allzero (bool, optional): Whether to include patches where the labels are all zero. Defaults to False. | ||
include_nan_edges (bool, optional): Whether to include patches where the input data has nan values at the edges. | ||
Defaults to True. | ||
""" | ||
# Import here to avoid long loading times when running other commands | ||
import geopandas as gpd | ||
import torch | ||
from darts_acquisition.arcticdem import load_arcticdem_tile | ||
from darts_acquisition.s2 import load_s2_masks, load_s2_scene | ||
from darts_acquisition.tcvis import load_tcvis | ||
from darts_preprocessing import preprocess_legacy_fast | ||
from darts_segmentation.prepare_training import create_training_patches | ||
from dask.distributed import Client, LocalCluster | ||
from odc.stac import configure_rio | ||
|
||
from darts.utils.cuda import debug_info, decide_device | ||
from darts.utils.earthengine import init_ee | ||
|
||
debug_info() | ||
device = decide_device(device) | ||
init_ee(ee_project, ee_use_highvolume) | ||
|
||
cluster = LocalCluster(n_workers=mp.cpu_count() - 1) | ||
logger.info(f"Created Dask cluster: {cluster}") | ||
client = Client(cluster) | ||
logger.info(f"Using Dask client: {client}") | ||
configure_rio(cloud_defaults=True, aws={"aws_unsigned": True}, client=client) | ||
logger.info("Configured Rasterio with Dask") | ||
|
||
outpath_x = train_data_dir / "x" | ||
outpath_y = train_data_dir / "y" | ||
|
||
outpath_x.mkdir(exist_ok=True, parents=True) | ||
outpath_y.mkdir(exist_ok=True, parents=True) | ||
|
||
# Find all Sentinel 2 scenes | ||
n_patches = 0 | ||
for fpath in sentinel2_dir.glob("*/"): | ||
try: | ||
optical = load_s2_scene(fpath) | ||
arcticdem = load_arcticdem_tile( | ||
optical.odc.geobox, arcticdem_dir, resolution=10, buffer=ceil(tpi_outer_radius / 10 * sqrt(2)) | ||
) | ||
tcvis = load_tcvis(optical.odc.geobox, tcvis_dir) | ||
data_masks = load_s2_masks(fpath, optical.odc.geobox) | ||
|
||
tile = preprocess_legacy_fast( | ||
optical, | ||
arcticdem, | ||
tcvis, | ||
data_masks, | ||
tpi_outer_radius, | ||
tpi_inner_radius, | ||
device, | ||
) | ||
|
||
labels = gpd.read_file(fpath / f"{optical.attrs['tile_id']}.shp") | ||
tile_id = optical.attrs["tile_id"] | ||
|
||
# Save the patches | ||
gen = create_training_patches(tile, labels, bands, patch_size, overlap, include_allzero, include_nan_edges) | ||
for patch_id, (x, y) in enumerate(gen): | ||
torch.save(x, outpath_x / f"{tile_id}_pid{patch_id}.pt") | ||
torch.save(y, outpath_y / f"{tile_id}_pid{patch_id}.pt") | ||
n_patches += 1 | ||
logger.info(f"Processed {tile_id} with {patch_id} patches.") | ||
|
||
except KeyboardInterrupt: | ||
logger.info("Interrupted by user.") | ||
break | ||
|
||
except Exception as e: | ||
logger.warning(f"Could not process folder '{fpath.resolve()}'.\nSkipping...") | ||
logger.exception(e) | ||
|
||
# Save a config file as toml | ||
config = { | ||
"darts": { | ||
"sentinel2_dir": sentinel2_dir, | ||
"train_data_dir": train_data_dir, | ||
"arcticdem_dir": arcticdem_dir, | ||
"tcvis_dir": tcvis_dir, | ||
"bands": bands, | ||
"device": device, | ||
"ee_project": ee_project, | ||
"ee_use_highvolume": ee_use_highvolume, | ||
"tpi_outer_radius": tpi_outer_radius, | ||
"tpi_inner_radius": tpi_inner_radius, | ||
"patch_size": patch_size, | ||
"overlap": overlap, | ||
"include_allzero": include_allzero, | ||
"include_nan_edges": include_nan_edges, | ||
"n_patches": n_patches, | ||
} | ||
} | ||
with open(train_data_dir / "config.toml", "w") as f: | ||
toml.dump(config, f) | ||
|
||
logger.info(f"Saved {n_patches} patches to {train_data_dir}") | ||
|
||
client.close() | ||
cluster.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters