Skip to content

Commit

Permalink
Add region split
Browse files Browse the repository at this point in the history
  • Loading branch information
relativityhd committed Dec 21, 2024
1 parent b1c0c1d commit 5130c9b
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 19 deletions.
29 changes: 26 additions & 3 deletions darts-acquisition/src/darts_acquisition/s2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,30 @@
logger = logging.getLogger(__name__.replace("darts_", "darts."))


def parse_s2_tile_id(fpath: str | Path) -> tuple[str, str, str]:
"""Parse the Sentinel 2 tile ID from a file path.
Args:
fpath (str | Path): The path to the directory containing the TIFF files.
Returns:
tuple[str, str, str]: A tuple containing the Planet crop ID, the Sentinel 2 tile ID and the combined tile ID.
Raises:
FileNotFoundError: If no matching TIFF file is found in the specified path.
"""
fpath = fpath if isinstance(fpath, Path) else Path(fpath)
try:
s2_image = next(fpath.glob("*_SR*.tif"))
except StopIteration:
raise FileNotFoundError(f"No matching TIFF files found in {fpath.resolve()} (.glob('*_SR*.tif'))")
planet_crop_id = fpath.stem
s2_tile_id = "_".join(s2_image.stem.split("_")[:3])
tile_id = f"{planet_crop_id}_{s2_tile_id}"
return planet_crop_id, s2_tile_id, tile_id


def load_s2_scene(fpath: str | Path) -> xr.Dataset:
"""Load a Sentinel 2 satellite GeoTIFF file and return it as an xarray datset.
Expand Down Expand Up @@ -50,11 +74,10 @@ def load_s2_scene(fpath: str | Path) -> xr.Dataset:
{"data_source": "s2", "long_name": f"Sentinel 2 {var.capitalize()}", "units": "Reflectance"}
)

planet_crop_id = fpath.stem
s2_tile_id = "_".join(s2_image.stem.split("_")[:3])
planet_crop_id, s2_tile_id, tile_id = parse_s2_tile_id(fpath)
ds_s2.attrs["planet_crop_id"] = planet_crop_id
ds_s2.attrs["s2_tile_id"] = s2_tile_id
ds_s2.attrs["tile_id"] = f"{planet_crop_id}_{s2_tile_id}"
ds_s2.attrs["tile_id"] = tile_id
logger.debug(f"Loaded Sentinel 2 scene in {time.time() - start_time} seconds.")
return ds_s2

Expand Down
82 changes: 66 additions & 16 deletions darts/src/darts/legacy_training/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def preprocess_s2_train_data(
exclude_nan: bool = True,
mask_erosion_size: int = 10,
test_val_split: float = 0.05,
test_region: list[str] | str | None = None,
):
"""Preprocess Sentinel 2 data for training.
Expand Down Expand Up @@ -64,14 +65,16 @@ def preprocess_s2_train_data(
mask_erosion_size (int, optional): The size of the disk to use for mask erosion and the edge-cropping.
Defaults to 10.
test_val_split (float, optional): The split ratio for the test and validation set. Defaults to 0.05.
test_region (list[str] | str, optional): The region to use for the test set. Defaults to None.
"""
# Import here to avoid long loading times when running other commands
import geopandas as gpd
import pandas as pd
import torch
import xarray as xr
from darts_acquisition.arcticdem import load_arcticdem_tile
from darts_acquisition.s2 import load_s2_masks, load_s2_scene
from darts_acquisition.s2 import load_s2_masks, load_s2_scene, parse_s2_tile_id
from darts_acquisition.tcvis import load_tcvis
from darts_preprocessing import preprocess_legacy_fast
from darts_segmentation.training.prepare_training import create_training_patches
Expand Down Expand Up @@ -111,25 +114,63 @@ def preprocess_s2_train_data(
norm_factors = {k: v for k, v in norm_factors.items() if k in bands}

train_data_dir.mkdir(exist_ok=True, parents=True)
output_dir_train = train_data_dir / "train"
output_dir_val = train_data_dir / "val"
output_dir_cross_val = train_data_dir / "cross-val"
output_dir_val_test = train_data_dir / "val-test"
output_dir_test = train_data_dir / "test"

# Find all Sentinel 2 scenes
# Find all Sentinel 2 scenes and split into train+val (cross-val), val-test (variance) and test (region)
n_patches = 0
s2_paths = sorted(sentinel2_dir.glob("*/"))
logger.info(f"Found {len(s2_paths)} Sentinel 2 scenes in {sentinel2_dir}")
train_paths: list[Path]
val_paths: list[Path]
train_paths, val_paths = train_test_split(s2_paths, test_size=test_val_split, random_state=42)
logger.info(f"Split the data into {len(train_paths)} training and {len(val_paths)} validation samples.")

fpathgen = chain(train_paths, val_paths)
modegen = chain(repeat(output_dir_train, len(train_paths)), repeat(output_dir_val, len(val_paths)))
for i, (fpath, output_dir) in enumerate(zip(fpathgen, modegen)):

# 1. Split regions
test_paths: list[Path] = []
training_paths: list[Path] = []
if test_region:
test_region = [test_region] if isinstance(test_region, str) else test_region
for fpath in s2_paths:
_, s2_tile_id, _ = parse_s2_tile_id(fpath)
labels = gpd.read_file(fpath / f"{s2_tile_id}.shp")
# If any of the regions is in the test region, add to the test set
if labels["region"].isin(test_region).any():
test_paths.append(fpath)
else:
training_paths.append(fpath)
else:
training_paths = s2_paths

# 2. Split by random sampling
cross_val_paths: list[Path]
val_test_paths: list[Path]
if len(training_paths) > 0:
cross_val_paths, val_test_paths = train_test_split(
training_paths, test_size=test_val_split, random_state=42
)
else:
cross_val_paths, val_test_paths = [], []
logger.warning("No left over training samples found. Skipping train-val split.")

logger.info(
f"Split the data into {len(cross_val_paths)} cross-val (train + val), "
f"{len(val_test_paths)} val-test (variance) and {len(test_paths)} test (region) samples."
)

fpathgen = chain(cross_val_paths, val_test_paths, test_paths)
outpathgen = chain(
repeat(output_dir_cross_val, len(cross_val_paths)),
repeat(output_dir_val_test, len(val_test_paths)),
repeat(output_dir_test, len(test_paths)),
)
modegen = chain(
repeat("cross-val", len(cross_val_paths)),
repeat("val-test", len(val_test_paths)),
repeat("test", len(test_paths)),
)

joint_lables = []
for i, (fpath, output_dir, mode) in enumerate(zip(fpathgen, outpathgen, modegen)):
try:
optical = load_s2_scene(fpath)
logger.info(f"Found optical tile with size {optical.sizes}")
tile_id = optical.attrs["tile_id"]
_, s2_tile_id, tile_id = parse_s2_tile_id(fpath)

# Check for a cached preprocessed file
if preprocess_cache and (preprocess_cache / f"{tile_id}.nc").exists():
Expand All @@ -139,6 +180,8 @@ def preprocess_s2_train_data(
"spatial_ref"
)
else:
optical = load_s2_scene(fpath)
logger.info(f"Found optical tile with size {optical.sizes}")
arctidem_res = 10
arcticdem_buffer = ceil(tpi_outer_radius / arctidem_res * sqrt(2))
arcticdem = load_arcticdem_tile(
Expand All @@ -163,7 +206,7 @@ def preprocess_s2_train_data(
logger.info(f"Caching preprocessed data to {cache_file.resolve()}")
tile.to_netcdf(cache_file, engine="h5netcdf")

labels = gpd.read_file(fpath / f"{optical.attrs['s2_tile_id']}.shp")
labels = gpd.read_file(fpath / f"{s2_tile_id}.shp")

# Save the patches
gen = create_training_patches(
Expand All @@ -188,6 +231,9 @@ def preprocess_s2_train_data(
torch.save(x, outdir_x / f"{tile_id}_pid{patch_id}.pt")
torch.save(y, outdir_y / f"{tile_id}_pid{patch_id}.pt")
n_patches += 1
if n_patches > 0 and len(labels) > 0:
labels["mode"] = mode
joint_lables.append(labels)

logger.info(
f"Processed sample {i + 1} of {len(s2_paths)} '{fpath.resolve()}'"
Expand All @@ -201,6 +247,10 @@ def preprocess_s2_train_data(
logger.warning(f"Could not process folder sample {i} '{fpath.resolve()}'.\nSkipping...")
logger.exception(e)

# Save the used labels
joint_lables = pd.concat(joint_lables)
joint_lables.to_file(train_data_dir / "labels.geojson", driver="GeoJSON")

# Save a config file as toml
config = {
"darts": {
Expand Down

0 comments on commit 5130c9b

Please sign in to comment.