diff --git a/darts/src/darts/training.py b/darts/src/darts/training.py index 05bff52..3ccc597 100644 --- a/darts/src/darts/training.py +++ b/darts/src/darts/training.py @@ -29,11 +29,12 @@ def preprocess_s2_train_data( *, + bands: list[str], sentinel2_dir: Path, train_data_dir: Path, arcticdem_dir: Path, tcvis_dir: Path, - bands: list[str], + preprocess_cache: Path | None = None, device: Literal["cuda", "cpu", "auto"] | int | None = None, ee_project: str | None = None, ee_use_highvolume: bool = True, @@ -48,12 +49,13 @@ def preprocess_s2_train_data( """Preprocess Sentinel 2 data for training. Args: + bands (list[str]): The bands to be used for training. Must be present in the preprocessing. sentinel2_dir (Path): The directory containing the Sentinel 2 scenes. train_data_dir (Path): The "output" directory where the tensors are written to. arcticdem_dir (Path): The directory containing the ArcticDEM data (the datacube and the extent files). Will be created and downloaded if it does not exist. tcvis_dir (Path): The directory containing the TCVis data. - bands (list[str]): The bands to be used for training. Must be present in the preprocessing. + preprocess_cache (Path, optional): The directory to store the preprocessed data. Defaults to None. device (Literal["cuda", "cpu"] | int, optional): The device to run the model on. If "cuda" take the first device (0), if int take the specified device. If "auto" try to automatically select a free GPU (<50% memory usage). @@ -76,6 +78,7 @@ def preprocess_s2_train_data( # Import here to avoid long loading times when running other commands import geopandas as gpd import torch + import xarray as xr from darts_acquisition.arcticdem import load_arcticdem_tile from darts_acquisition.s2 import load_s2_masks, load_s2_scene from darts_acquisition.tcvis import load_tcvis @@ -129,26 +132,39 @@ def preprocess_s2_train_data( for i, fpath in enumerate(s2_paths): try: optical = load_s2_scene(fpath) - arcticdem = load_arcticdem_tile( - optical.odc.geobox, arcticdem_dir, resolution=10, buffer=ceil(tpi_outer_radius / 10 * sqrt(2)) - ) - tcvis = load_tcvis(optical.odc.geobox, tcvis_dir) - data_masks = load_s2_masks(fpath, optical.odc.geobox) - - tile = preprocess_legacy_fast( - optical, - arcticdem, - tcvis, - data_masks, - tpi_outer_radius, - tpi_inner_radius, - device, - ) + tile_id = optical.attrs["tile_id"] + + # Check for a cached preprocessed file + if preprocess_cache and (preprocess_cache / f"{tile_id}.nc").exists(): + cache_file = preprocess_cache / f"{tile_id}.nc" + logger.info(f"Loading preprocessed data from {cache_file.resolve()}") + tile = xr.open_dataset(preprocess_cache / f"{tile_id}.nc", engine="h5netcdf").set_coords("spatial_ref") + else: + arcticdem = load_arcticdem_tile( + optical.odc.geobox, arcticdem_dir, resolution=10, buffer=ceil(tpi_outer_radius / 10 * sqrt(2)) + ) + tcvis = load_tcvis(optical.odc.geobox, tcvis_dir) + data_masks = load_s2_masks(fpath, optical.odc.geobox) + + tile: xr.Dataset = preprocess_legacy_fast( + optical, + arcticdem, + tcvis, + data_masks, + tpi_outer_radius, + tpi_inner_radius, + device, + ) + # Only cache if we have a cache directory + if preprocess_cache: + preprocess_cache.mkdir(exist_ok=True, parents=True) + cache_file = preprocess_cache / f"{tile_id}.nc" + logger.info(f"Caching preprocessed data to {cache_file.resolve()}") + tile.to_netcdf(cache_file, engine="h5netcdf") labels = gpd.read_file(fpath / f"{optical.attrs['s2_tile_id']}.shp") # Save the patches - tile_id = optical.attrs["tile_id"] gen = create_training_patches( tile, labels,