diff --git a/darts-postprocessing/src/darts_postprocessing/prepare_export.py b/darts-postprocessing/src/darts_postprocessing/prepare_export.py index 6614543..fc4fd74 100644 --- a/darts-postprocessing/src/darts_postprocessing/prepare_export.py +++ b/darts-postprocessing/src/darts_postprocessing/prepare_export.py @@ -16,13 +16,20 @@ def prepare_export(tile: xr.Dataset) -> xr.Dataset: # Binarize the segmentation # Where the output from the ensemble / segmentation is nan turn it into 0, else threshold it # Also, where there was no valid input data, turn it into 0 - binarized = xr.where(~tile["probabilities"].isnull(), (tile["probabilities"] > 0.5), 0).astype("uint8") # noqa: PD003 + binarized = (tile["probabilities"].fillna(0) > 0.5).astype("uint8") tile["binarized_segmentation"] = xr.where(tile["valid_data_mask"], binarized, 0) + tile["binarized_segmentation"].attrs = { + "long_name": "Binarized Segmentation", + } # Convert the probabilities to uint8 # Same but this time with 255 as no-data - intprobs = (tile["probabilities"] * 100).astype("uint8") - intprobs = xr.where(~tile["probabilities"].isnull(), intprobs, 255) # noqa: PD003 + intprobs = (tile["probabilities"] * 100).fillna(255).astype("uint8") tile["probabilities"] = xr.where(tile["valid_data_mask"], intprobs, 255) + tile["probabilities"].attrs = { + "long_name": "Probabilities", + "units": "%", + } + tile["probabilities"] = tile["probabilities"].rio.write_nodata(255) return tile diff --git a/darts-preprocessing/src/darts_preprocessing/data_sources/arcticdem.py b/darts-preprocessing/src/darts_preprocessing/data_sources/arcticdem.py index 7b0d8cc..51a171c 100644 --- a/darts-preprocessing/src/darts_preprocessing/data_sources/arcticdem.py +++ b/darts-preprocessing/src/darts_preprocessing/data_sources/arcticdem.py @@ -68,6 +68,7 @@ def load_arcticdem(fpath: Path, reference_dataset: xr.Dataset) -> xr.Dataset: slope = load_vrt(slope_vrt, reference_dataset) slope: xr.Dataset = ( slope.assign_attrs({"data_source": "arcticdem", "long_name": "Slope"}) + .rio.write_nodata(float("nan")) .astype("float32") .to_dataset(name="slope") ) @@ -75,6 +76,8 @@ def load_arcticdem(fpath: Path, reference_dataset: xr.Dataset) -> xr.Dataset: relative_elevation = load_vrt(elevation_vrt, reference_dataset) relative_elevation: xr.Dataset = ( relative_elevation.assign_attrs({"data_source": "arcticdem", "long_name": "Relative Elevation", "units": "m"}) + .fillna(0) + .rio.write_nodata(0) .astype("int16") .to_dataset(name="relative_elevation") ) diff --git a/darts-preprocessing/src/darts_preprocessing/data_sources/planet.py b/darts-preprocessing/src/darts_preprocessing/data_sources/planet.py index b32cb50..cf436e2 100644 --- a/darts-preprocessing/src/darts_preprocessing/data_sources/planet.py +++ b/darts-preprocessing/src/darts_preprocessing/data_sources/planet.py @@ -43,6 +43,8 @@ def load_planet_scene(fpath: str | Path) -> xr.Dataset: datasets = [ planet_da.sel(band=index) .assign_attrs({"data_source": "planet", "long_name": f"PLANET {name.capitalize()}", "units": "Reflectance"}) + .fillna(0) + .rio.write_nodata(0) .astype("uint16") .to_dataset(name=name) .drop_vars("band") diff --git a/darts-preprocessing/src/darts_preprocessing/data_sources/s2.py b/darts-preprocessing/src/darts_preprocessing/data_sources/s2.py index 8ac2bd6..0407b05 100644 --- a/darts-preprocessing/src/darts_preprocessing/data_sources/s2.py +++ b/darts-preprocessing/src/darts_preprocessing/data_sources/s2.py @@ -43,6 +43,8 @@ def load_s2_scene(fpath: str | Path) -> xr.Dataset: datasets = [ s2_da.sel(band=index) .assign_attrs({"data_source": "s2", "long_name": f"Sentinel 2 {name.capitalize()}", "units": "Reflectance"}) + .fillna(0) + .rio.write_nodata(0) .astype("uint16") .to_dataset(name=name) .drop_vars("band") diff --git a/darts-preprocessing/src/darts_preprocessing/data_sources/tcvis.py b/darts-preprocessing/src/darts_preprocessing/data_sources/tcvis.py index f90e846..fc36227 100644 --- a/darts-preprocessing/src/darts_preprocessing/data_sources/tcvis.py +++ b/darts-preprocessing/src/darts_preprocessing/data_sources/tcvis.py @@ -6,6 +6,7 @@ from pathlib import Path import ee +import numpy as np import pyproj import rasterio import xarray as xr @@ -82,14 +83,23 @@ def load_tcvis(reference_dataset: xr.Dataset, cache_dir: Path | None = None) -> ds.rio.write_crs(ds.attrs["crs"], inplace=True) ds.rio.set_spatial_dims(x_dim="x", y_dim="y", inplace=True) search_time = time.time() - logger.debug(f"Found a dataset with shape {ds.sizes} in {search_time - start_time} seconds") + logger.debug(f"Found a dataset with shape {ds.sizes} in {search_time - start_time} seconds.") + + # Save original min-max values for each band for clipping later + clip_values = {band: (ds[band].min().values.item(), ds[band].max().values.item()) for band in ds.data_vars} # noqa: PD011 + + # Interpolate missing values (there are very few, so we actually can interpolate them) + for band in ds.data_vars: + ds[band] = ds[band].rio.write_nodata(np.nan).rio.interpolate_na() logger.debug(f"Reproject dataset to match reference dataset {reference_dataset.sizes}") ds = ds.rio.reproject_match(reference_dataset, resampling=rasterio.enums.Resampling.cubic) logger.debug(f"Reshaped dataset in {time.time() - search_time} seconds") + # Convert to uint8 for band in ds.data_vars: - ds[band] = ds[band].astype("uint8") + band_min, band_max = clip_values[band] + ds[band] = ds[band].clip(band_min, band_max, keep_attrs=True).astype("uint8").rio.write_nodata(None) # Save to cache if cache_dir is not None: diff --git a/darts-preprocessing/src/darts_preprocessing/engineering/indices.py b/darts-preprocessing/src/darts_preprocessing/engineering/indices.py index 0e137b1..dae1001 100644 --- a/darts-preprocessing/src/darts_preprocessing/engineering/indices.py +++ b/darts-preprocessing/src/darts_preprocessing/engineering/indices.py @@ -43,8 +43,12 @@ def calculate_ndvi(planet_scene_dataset: xr.Dataset, nir_band: str = "nir", red_ r = planet_scene_dataset[red_band].astype("float32") ndvi = (nir - r) / (nir + r) - # scale to Uint16 - ndvi = ((ndvi + 1) * 1e4).astype("uint16") + # Scale to 0 - 20000 (for later conversion to uint16) + ndvi = (ndvi.clip(-1, 1) + 1) * 1e4 + # Make nan to 0 + ndvi = ndvi.fillna(0).rio.write_nodata(0) + # Convert to uint16 + ndvi = ndvi.astype("uint16") ndvi = ndvi.assign_attrs({"data_source": "planet", "long_name": "NDVI"}).to_dataset(name="ndvi") logger.debug(f"NDVI calculated in {time.time() - start} seconds.") diff --git a/darts-segmentation/src/darts_segmentation/segment.py b/darts-segmentation/src/darts_segmentation/segment.py index a7c62f2..40b9ef3 100644 --- a/darts-segmentation/src/darts_segmentation/segment.py +++ b/darts-segmentation/src/darts_segmentation/segment.py @@ -1,5 +1,6 @@ """Functionality for segmenting tiles.""" +import logging from pathlib import Path from typing import Any, TypedDict @@ -7,9 +8,12 @@ import torch import torch.nn as nn import xarray as xr +from lovely_numpy import lovely from darts_segmentation.utils import predict_in_patches +logger = logging.getLogger(__name__) + class SMPSegmenterConfig(TypedDict): """Configuration for the segmentor.""" @@ -80,8 +84,13 @@ def tile2tensor(self, tile: xr.Dataset) -> torch.Tensor: for feature_name in self.config["input_combination"]: norm = self.config["norm_factors"][feature_name] band_data = tile[feature_name] + band_info_before = lovely(band_data.values) # Normalize the band data band_data = band_data * norm + band_info_after = lovely(band_data.values) + logger.debug( + f"Normalised '{feature_name}' with {norm=}.\nBefore: {band_info_before}.\nAfter: {band_info_after}" + ) bands.append(torch.from_numpy(band_data.values)) return torch.stack(bands, dim=0) @@ -100,8 +109,13 @@ def tile2tensor_batched(self, tiles: list[xr.Dataset]) -> torch.Tensor: norm = self.config["norm_factors"][feature_name] for tile in tiles: band_data = tile[feature_name] + band_info_before = lovely(band_data.values) # Normalize the band data band_data = band_data * norm + band_info_after = lovely(band_data.values) + logger.debug( + f"Normalised '{feature_name}' with {norm=}.\nBefore: {band_info_before}.\nAfter: {band_info_after}" + ) bands.append(torch.from_numpy(band_data.values)) # TODO: Test this return torch.stack(bands, dim=0).reshape(len(tiles), len(self.config["input_combination"]), *bands[0].shape) @@ -139,6 +153,7 @@ def segment_tile( tile["probabilities"].attrs = { "long_name": "Probabilities", } + tile["probabilities"] = tile["probabilities"].fillna(float("nan")).rio.write_nodata(float("nan")) return tile def segment_tile_batched( @@ -182,13 +197,26 @@ def segment_tile_batched( tile["probabilities"].attrs = { "long_name": "Probabilities", } + tile["probabilities"] = tile["probabilities"].fillna(float("nan")).rio.write_nodata(float("nan")) return tiles - def __call__(self, input: xr.Dataset | list[xr.Dataset]) -> xr.Dataset | list[xr.Dataset]: + def __call__( + self, + input: xr.Dataset | list[xr.Dataset], + patch_size: int = 1024, + overlap: int = 16, + batch_size: int = 8, + reflection: int = 0, + ) -> xr.Dataset | list[xr.Dataset]: """Run inference on a single tile or a list of tiles. Args: input: A single tile or a list of tiles. + patch_size (int): The size of the patches. Defaults to 1024. + overlap (int): The size of the overlap. Defaults to 16. + batch_size (int): The batch size for the prediction, NOT the batch_size of input tiles. + Tensor will be sliced into patches and these again will be infered in batches. Defaults to 8. + reflection (int): Reflection-Padding which will be applied to the edges of the tensor. Defaults to 0. Returns: A single tile or a list of tiles augmented by a predicted `probabilities` layer, depending on the input. @@ -199,8 +227,12 @@ def __call__(self, input: xr.Dataset | list[xr.Dataset]) -> xr.Dataset | list[xr """ if isinstance(input, xr.Dataset): - return self.segment_tile(input) + return self.segment_tile( + input, patch_size=patch_size, overlap=overlap, batch_size=batch_size, reflection=reflection + ) elif isinstance(input, list): - return self.segment_tile_batched(input) + return self.segment_tile_batched( + input, patch_size=patch_size, overlap=overlap, batch_size=batch_size, reflection=reflection + ) else: raise ValueError(f"Expected xr.Dataset or list of xr.Dataset, got {type(input)}") diff --git a/docs/dev/arch.md b/docs/dev/arch.md index cce50ca..93c42b9 100644 --- a/docs/dev/arch.md +++ b/docs/dev/arch.md @@ -95,6 +95,7 @@ The following diagram visualizes the steps of the major `packages` of the pipeli Each Tile should be represented as a single `xr.Dataset` with each feature / band as `DataVariable`. Each DataVariable should have their `data_source` documented in the `attrs`, aswell as `long_name` and `units` if any for plotting. +A `_FillValue` should also be set for no-data with `.rio.write_nodata("no-data-value")` ### Preprocessing Output @@ -135,7 +136,7 @@ Coordinates: `x`, `y` and `spatial_ref` (from rioxarray) | --------------------------- | ------ | ----- | ------- | ---------------- | -------------------- | | [Output from Preprocessing] | | | | | | | `probabilities_percent` | (x, y) | uint8 | 255 | long_name, units | Values between 0-100 | -| `binarized_segmentation` | (x, y) | uint8 | - | long_name, units | | +| `binarized_segmentation` | (x, y) | uint8 | - | long_name | | ### PyTorch Model checkpoints diff --git a/notebooks/test-e2e.ipynb b/notebooks/test-e2e.ipynb index 39c766c..a7040ad 100644 --- a/notebooks/test-e2e.ipynb +++ b/notebooks/test-e2e.ipynb @@ -10,6 +10,7 @@ "from pathlib import Path\n", "\n", "import matplotlib.pyplot as plt\n", + "import xarray as xr\n", "from darts_postprocessing.prepare_export import prepare_export\n", "from darts_preprocessing.preprocess import load_and_preprocess_planet_scene\n", "from darts_segmentation.segment import SMPSegmenter\n", @@ -50,7 +51,15 @@ "metadata": {}, "outputs": [], "source": [ - "tile = load_and_preprocess_planet_scene(fpath, arcticdem_dir, cache_dir)\n", + "cache_file = DATA_ROOT / \"intermediate\" / f\"planet_{fpath.stem}.nc\"\n", + "force = False\n", + "if cache_file.exists() and not force:\n", + " tile = xr.open_dataset(cache_file, engine=\"h5netcdf\", mask_and_scale=False)\n", + "else:\n", + " tile = load_and_preprocess_planet_scene(fpath, arcticdem_dir, cache_dir)\n", + " cache_file.parent.mkdir(exist_ok=True, parents=True)\n", + " tile.to_netcdf(cache_file, engine=\"h5netcdf\")\n", + "\n", "tile" ] }, @@ -102,8 +111,8 @@ "metadata": {}, "outputs": [], "source": [ - "final = prepare_export(tile)\n", - "final" + "tile = prepare_export(tile)\n", + "tile" ] }, { @@ -112,7 +121,7 @@ "metadata": {}, "outputs": [], "source": [ - "final_low_res = final.coarsen(x=16, y=16, boundary=\"trim\").mean()\n", + "final_low_res = tile.coarsen(x=16, y=16, boundary=\"trim\").mean()\n", "fig, axs = plt.subplots(2, 7, figsize=(36, 10))\n", "axs = axs.flatten()\n", "for i, v in enumerate(final_low_res.data_vars):\n",