From 9a2b7c83992791d3de1c66950158dceb54231d3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tobias=20H=C3=B6lzer?= Date: Sat, 21 Dec 2024 16:26:10 +0100 Subject: [PATCH] Make preprocess easier to read --- darts/src/darts/legacy_training/preprocess.py | 131 +++++++++++------- 1 file changed, 78 insertions(+), 53 deletions(-) diff --git a/darts/src/darts/legacy_training/preprocess.py b/darts/src/darts/legacy_training/preprocess.py index 7e4dd25..3128d60 100644 --- a/darts/src/darts/legacy_training/preprocess.py +++ b/darts/src/darts/legacy_training/preprocess.py @@ -12,6 +12,79 @@ logger = logging.getLogger(__name__) +def split_dataset_paths( + s2_paths: list[Path], train_data_dir: Path, test_val_split: float, test_regions: list[str] | None +): + """Split the dataset into a cross-val, a val-test and a test dataset. + + Returns a generator with: input-path, output-path and split/mode. + The test set is splitted first by the given regions and is meant to be used to evaluate the regional value shift. + Then the val-test set is splitted then by random at given size to evaluate the variance value shift. + + Args: + s2_paths (list[Path]): All paths found with tiffs. + train_data_dir (Path): Output path. + test_val_split (float): val-test ratio. + test_regions (list[str] | None): test regions. + + Returns: + [zip[tuple[Path, Path, str]]]: A generator with input-path, output-path and split/mode. + + """ + # Import here to avoid long loading times when running other commands + import geopandas as gpd + from darts_acquisition.s2 import parse_s2_tile_id + from sklearn.model_selection import train_test_split + + train_data_dir.mkdir(exist_ok=True, parents=True) + output_dir_cross_val = train_data_dir / "cross-val" + output_dir_val_test = train_data_dir / "val-test" + output_dir_test = train_data_dir / "test" + + # 1. Split regions + test_paths: list[Path] = [] + training_paths: list[Path] = [] + if test_regions: + for fpath in s2_paths: + _, s2_tile_id, _ = parse_s2_tile_id(fpath) + labels = gpd.read_file(fpath / f"{s2_tile_id}.shp") + # If any of the regions is in the test region, add to the test set + if labels["region"].isin(test_regions).any(): + test_paths.append(fpath) + else: + training_paths.append(fpath) + else: + training_paths = s2_paths + + # 2. Split by random sampling + cross_val_paths: list[Path] + val_test_paths: list[Path] + if len(training_paths) > 0: + cross_val_paths, val_test_paths = train_test_split(training_paths, test_size=test_val_split, random_state=42) + else: + cross_val_paths, val_test_paths = [], [] + logger.warning("No left over training samples found. Skipping train-val split.") + + logger.info( + f"Split the data into {len(cross_val_paths)} cross-val (train + val), " + f"{len(val_test_paths)} val-test (variance) and {len(test_paths)} test (region) samples." + ) + + fpathgen = chain(cross_val_paths, val_test_paths, test_paths) + outpathgen = chain( + repeat(output_dir_cross_val, len(cross_val_paths)), + repeat(output_dir_val_test, len(val_test_paths)), + repeat(output_dir_test, len(test_paths)), + ) + modegen = chain( + repeat("cross-val", len(cross_val_paths)), + repeat("val-test", len(val_test_paths)), + repeat("test", len(test_paths)), + ) + + return zip(fpathgen, outpathgen, modegen) + + def preprocess_s2_train_data( *, bands: list[str], @@ -32,7 +105,7 @@ def preprocess_s2_train_data( exclude_nan: bool = True, mask_erosion_size: int = 10, test_val_split: float = 0.05, - test_region: list[str] | str | None = None, + test_regions: list[str] | None = None, ): """Preprocess Sentinel 2 data for training. @@ -65,7 +138,7 @@ def preprocess_s2_train_data( mask_erosion_size (int, optional): The size of the disk to use for mask erosion and the edge-cropping. Defaults to 10. test_val_split (float, optional): The split ratio for the test and validation set. Defaults to 0.05. - test_region (list[str] | str, optional): The region to use for the test set. Defaults to None. + test_regions (list[str] | str, optional): The region to use for the test set. Defaults to None. """ # Import here to avoid long loading times when running other commands @@ -81,7 +154,6 @@ def preprocess_s2_train_data( from dask.distributed import Client, LocalCluster from lovely_tensors import monkey_patch from odc.stac import configure_rio - from sklearn.model_selection import train_test_split from darts.utils.cuda import debug_info, decide_device from darts.utils.earthengine import init_ee @@ -114,61 +186,14 @@ def preprocess_s2_train_data( norm_factors = {k: v for k, v in norm_factors.items() if k in bands} train_data_dir.mkdir(exist_ok=True, parents=True) - output_dir_cross_val = train_data_dir / "cross-val" - output_dir_val_test = train_data_dir / "val-test" - output_dir_test = train_data_dir / "test" # Find all Sentinel 2 scenes and split into train+val (cross-val), val-test (variance) and test (region) n_patches = 0 + joint_lables = [] s2_paths = sorted(sentinel2_dir.glob("*/")) logger.info(f"Found {len(s2_paths)} Sentinel 2 scenes in {sentinel2_dir}") - - # 1. Split regions - test_paths: list[Path] = [] - training_paths: list[Path] = [] - if test_region: - test_region = [test_region] if isinstance(test_region, str) else test_region - for fpath in s2_paths: - _, s2_tile_id, _ = parse_s2_tile_id(fpath) - labels = gpd.read_file(fpath / f"{s2_tile_id}.shp") - # If any of the regions is in the test region, add to the test set - if labels["region"].isin(test_region).any(): - test_paths.append(fpath) - else: - training_paths.append(fpath) - else: - training_paths = s2_paths - - # 2. Split by random sampling - cross_val_paths: list[Path] - val_test_paths: list[Path] - if len(training_paths) > 0: - cross_val_paths, val_test_paths = train_test_split( - training_paths, test_size=test_val_split, random_state=42 - ) - else: - cross_val_paths, val_test_paths = [], [] - logger.warning("No left over training samples found. Skipping train-val split.") - - logger.info( - f"Split the data into {len(cross_val_paths)} cross-val (train + val), " - f"{len(val_test_paths)} val-test (variance) and {len(test_paths)} test (region) samples." - ) - - fpathgen = chain(cross_val_paths, val_test_paths, test_paths) - outpathgen = chain( - repeat(output_dir_cross_val, len(cross_val_paths)), - repeat(output_dir_val_test, len(val_test_paths)), - repeat(output_dir_test, len(test_paths)), - ) - modegen = chain( - repeat("cross-val", len(cross_val_paths)), - repeat("val-test", len(val_test_paths)), - repeat("test", len(test_paths)), - ) - - joint_lables = [] - for i, (fpath, output_dir, mode) in enumerate(zip(fpathgen, outpathgen, modegen)): + path_gen = split_dataset_paths(s2_paths, train_data_dir, test_val_split, test_regions) + for i, (fpath, output_dir, mode) in enumerate(path_gen): try: _, s2_tile_id, tile_id = parse_s2_tile_id(fpath)