diff --git a/darts-acquisition/src/darts_acquisition/s2.py b/darts-acquisition/src/darts_acquisition/s2.py index 3b6825f..c8abb81 100644 --- a/darts-acquisition/src/darts_acquisition/s2.py +++ b/darts-acquisition/src/darts_acquisition/s2.py @@ -12,6 +12,30 @@ logger = logging.getLogger(__name__.replace("darts_", "darts.")) +def parse_s2_tile_id(fpath: str | Path) -> tuple[str, str, str]: + """Parse the Sentinel 2 tile ID from a file path. + + Args: + fpath (str | Path): The path to the directory containing the TIFF files. + + Returns: + tuple[str, str, str]: A tuple containing the Planet crop ID, the Sentinel 2 tile ID and the combined tile ID. + + Raises: + FileNotFoundError: If no matching TIFF file is found in the specified path. + + """ + fpath = fpath if isinstance(fpath, Path) else Path(fpath) + try: + s2_image = next(fpath.glob("*_SR*.tif")) + except StopIteration: + raise FileNotFoundError(f"No matching TIFF files found in {fpath.resolve()} (.glob('*_SR*.tif'))") + planet_crop_id = fpath.stem + s2_tile_id = "_".join(s2_image.stem.split("_")[:3]) + tile_id = f"{planet_crop_id}_{s2_tile_id}" + return planet_crop_id, s2_tile_id, tile_id + + def load_s2_scene(fpath: str | Path) -> xr.Dataset: """Load a Sentinel 2 satellite GeoTIFF file and return it as an xarray datset. @@ -50,11 +74,10 @@ def load_s2_scene(fpath: str | Path) -> xr.Dataset: {"data_source": "s2", "long_name": f"Sentinel 2 {var.capitalize()}", "units": "Reflectance"} ) - planet_crop_id = fpath.stem - s2_tile_id = "_".join(s2_image.stem.split("_")[:3]) + planet_crop_id, s2_tile_id, tile_id = parse_s2_tile_id(fpath) ds_s2.attrs["planet_crop_id"] = planet_crop_id ds_s2.attrs["s2_tile_id"] = s2_tile_id - ds_s2.attrs["tile_id"] = f"{planet_crop_id}_{s2_tile_id}" + ds_s2.attrs["tile_id"] = tile_id logger.debug(f"Loaded Sentinel 2 scene in {time.time() - start_time} seconds.") return ds_s2 diff --git a/darts/src/darts/legacy_training/preprocess.py b/darts/src/darts/legacy_training/preprocess.py index 2fb5893..7e4dd25 100644 --- a/darts/src/darts/legacy_training/preprocess.py +++ b/darts/src/darts/legacy_training/preprocess.py @@ -32,6 +32,7 @@ def preprocess_s2_train_data( exclude_nan: bool = True, mask_erosion_size: int = 10, test_val_split: float = 0.05, + test_region: list[str] | str | None = None, ): """Preprocess Sentinel 2 data for training. @@ -64,14 +65,16 @@ def preprocess_s2_train_data( mask_erosion_size (int, optional): The size of the disk to use for mask erosion and the edge-cropping. Defaults to 10. test_val_split (float, optional): The split ratio for the test and validation set. Defaults to 0.05. + test_region (list[str] | str, optional): The region to use for the test set. Defaults to None. """ # Import here to avoid long loading times when running other commands import geopandas as gpd + import pandas as pd import torch import xarray as xr from darts_acquisition.arcticdem import load_arcticdem_tile - from darts_acquisition.s2 import load_s2_masks, load_s2_scene + from darts_acquisition.s2 import load_s2_masks, load_s2_scene, parse_s2_tile_id from darts_acquisition.tcvis import load_tcvis from darts_preprocessing import preprocess_legacy_fast from darts_segmentation.training.prepare_training import create_training_patches @@ -111,25 +114,63 @@ def preprocess_s2_train_data( norm_factors = {k: v for k, v in norm_factors.items() if k in bands} train_data_dir.mkdir(exist_ok=True, parents=True) - output_dir_train = train_data_dir / "train" - output_dir_val = train_data_dir / "val" + output_dir_cross_val = train_data_dir / "cross-val" + output_dir_val_test = train_data_dir / "val-test" + output_dir_test = train_data_dir / "test" - # Find all Sentinel 2 scenes + # Find all Sentinel 2 scenes and split into train+val (cross-val), val-test (variance) and test (region) n_patches = 0 s2_paths = sorted(sentinel2_dir.glob("*/")) logger.info(f"Found {len(s2_paths)} Sentinel 2 scenes in {sentinel2_dir}") - train_paths: list[Path] - val_paths: list[Path] - train_paths, val_paths = train_test_split(s2_paths, test_size=test_val_split, random_state=42) - logger.info(f"Split the data into {len(train_paths)} training and {len(val_paths)} validation samples.") - - fpathgen = chain(train_paths, val_paths) - modegen = chain(repeat(output_dir_train, len(train_paths)), repeat(output_dir_val, len(val_paths))) - for i, (fpath, output_dir) in enumerate(zip(fpathgen, modegen)): + + # 1. Split regions + test_paths: list[Path] = [] + training_paths: list[Path] = [] + if test_region: + test_region = [test_region] if isinstance(test_region, str) else test_region + for fpath in s2_paths: + _, s2_tile_id, _ = parse_s2_tile_id(fpath) + labels = gpd.read_file(fpath / f"{s2_tile_id}.shp") + # If any of the regions is in the test region, add to the test set + if labels["region"].isin(test_region).any(): + test_paths.append(fpath) + else: + training_paths.append(fpath) + else: + training_paths = s2_paths + + # 2. Split by random sampling + cross_val_paths: list[Path] + val_test_paths: list[Path] + if len(training_paths) > 0: + cross_val_paths, val_test_paths = train_test_split( + training_paths, test_size=test_val_split, random_state=42 + ) + else: + cross_val_paths, val_test_paths = [], [] + logger.warning("No left over training samples found. Skipping train-val split.") + + logger.info( + f"Split the data into {len(cross_val_paths)} cross-val (train + val), " + f"{len(val_test_paths)} val-test (variance) and {len(test_paths)} test (region) samples." + ) + + fpathgen = chain(cross_val_paths, val_test_paths, test_paths) + outpathgen = chain( + repeat(output_dir_cross_val, len(cross_val_paths)), + repeat(output_dir_val_test, len(val_test_paths)), + repeat(output_dir_test, len(test_paths)), + ) + modegen = chain( + repeat("cross-val", len(cross_val_paths)), + repeat("val-test", len(val_test_paths)), + repeat("test", len(test_paths)), + ) + + joint_lables = [] + for i, (fpath, output_dir, mode) in enumerate(zip(fpathgen, outpathgen, modegen)): try: - optical = load_s2_scene(fpath) - logger.info(f"Found optical tile with size {optical.sizes}") - tile_id = optical.attrs["tile_id"] + _, s2_tile_id, tile_id = parse_s2_tile_id(fpath) # Check for a cached preprocessed file if preprocess_cache and (preprocess_cache / f"{tile_id}.nc").exists(): @@ -139,6 +180,8 @@ def preprocess_s2_train_data( "spatial_ref" ) else: + optical = load_s2_scene(fpath) + logger.info(f"Found optical tile with size {optical.sizes}") arctidem_res = 10 arcticdem_buffer = ceil(tpi_outer_radius / arctidem_res * sqrt(2)) arcticdem = load_arcticdem_tile( @@ -163,7 +206,7 @@ def preprocess_s2_train_data( logger.info(f"Caching preprocessed data to {cache_file.resolve()}") tile.to_netcdf(cache_file, engine="h5netcdf") - labels = gpd.read_file(fpath / f"{optical.attrs['s2_tile_id']}.shp") + labels = gpd.read_file(fpath / f"{s2_tile_id}.shp") # Save the patches gen = create_training_patches( @@ -188,6 +231,9 @@ def preprocess_s2_train_data( torch.save(x, outdir_x / f"{tile_id}_pid{patch_id}.pt") torch.save(y, outdir_y / f"{tile_id}_pid{patch_id}.pt") n_patches += 1 + if n_patches > 0 and len(labels) > 0: + labels["mode"] = mode + joint_lables.append(labels) logger.info( f"Processed sample {i + 1} of {len(s2_paths)} '{fpath.resolve()}'" @@ -201,6 +247,10 @@ def preprocess_s2_train_data( logger.warning(f"Could not process folder sample {i} '{fpath.resolve()}'.\nSkipping...") logger.exception(e) + # Save the used labels + joint_lables = pd.concat(joint_lables) + joint_lables.to_file(train_data_dir / "labels.geojson", driver="GeoJSON") + # Save a config file as toml config = { "darts": {