Skip to content

Commit

Permalink
Add proper nan handling
Browse files Browse the repository at this point in the history
  • Loading branch information
relativityhd committed Oct 27, 2024
1 parent b3e7ca5 commit 7021b28
Show file tree
Hide file tree
Showing 9 changed files with 85 additions and 15 deletions.
13 changes: 10 additions & 3 deletions darts-postprocessing/src/darts_postprocessing/prepare_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,20 @@ def prepare_export(tile: xr.Dataset) -> xr.Dataset:
# Binarize the segmentation
# Where the output from the ensemble / segmentation is nan turn it into 0, else threshold it
# Also, where there was no valid input data, turn it into 0
binarized = xr.where(~tile["probabilities"].isnull(), (tile["probabilities"] > 0.5), 0).astype("uint8") # noqa: PD003
binarized = (tile["probabilities"].fillna(0) > 0.5).astype("uint8")
tile["binarized_segmentation"] = xr.where(tile["valid_data_mask"], binarized, 0)
tile["binarized_segmentation"].attrs = {
"long_name": "Binarized Segmentation",
}

# Convert the probabilities to uint8
# Same but this time with 255 as no-data
intprobs = (tile["probabilities"] * 100).astype("uint8")
intprobs = xr.where(~tile["probabilities"].isnull(), intprobs, 255) # noqa: PD003
intprobs = (tile["probabilities"] * 100).fillna(255).astype("uint8")
tile["probabilities"] = xr.where(tile["valid_data_mask"], intprobs, 255)
tile["probabilities"].attrs = {
"long_name": "Probabilities",
"units": "%",
}
tile["probabilities"] = tile["probabilities"].rio.write_nodata(255)

return tile
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,16 @@ def load_arcticdem(fpath: Path, reference_dataset: xr.Dataset) -> xr.Dataset:
slope = load_vrt(slope_vrt, reference_dataset)
slope: xr.Dataset = (
slope.assign_attrs({"data_source": "arcticdem", "long_name": "Slope"})
.rio.write_nodata(float("nan"))
.astype("float32")
.to_dataset(name="slope")
)

relative_elevation = load_vrt(elevation_vrt, reference_dataset)
relative_elevation: xr.Dataset = (
relative_elevation.assign_attrs({"data_source": "arcticdem", "long_name": "Relative Elevation", "units": "m"})
.fillna(0)
.rio.write_nodata(0)
.astype("int16")
.to_dataset(name="relative_elevation")
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def load_planet_scene(fpath: str | Path) -> xr.Dataset:
datasets = [
planet_da.sel(band=index)
.assign_attrs({"data_source": "planet", "long_name": f"PLANET {name.capitalize()}", "units": "Reflectance"})
.fillna(0)
.rio.write_nodata(0)
.astype("uint16")
.to_dataset(name=name)
.drop_vars("band")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def load_s2_scene(fpath: str | Path) -> xr.Dataset:
datasets = [
s2_da.sel(band=index)
.assign_attrs({"data_source": "s2", "long_name": f"Sentinel 2 {name.capitalize()}", "units": "Reflectance"})
.fillna(0)
.rio.write_nodata(0)
.astype("uint16")
.to_dataset(name=name)
.drop_vars("band")
Expand Down
14 changes: 12 additions & 2 deletions darts-preprocessing/src/darts_preprocessing/data_sources/tcvis.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pathlib import Path

import ee
import numpy as np
import pyproj
import rasterio
import xarray as xr
Expand Down Expand Up @@ -82,14 +83,23 @@ def load_tcvis(reference_dataset: xr.Dataset, cache_dir: Path | None = None) ->
ds.rio.write_crs(ds.attrs["crs"], inplace=True)
ds.rio.set_spatial_dims(x_dim="x", y_dim="y", inplace=True)
search_time = time.time()
logger.debug(f"Found a dataset with shape {ds.sizes} in {search_time - start_time} seconds")
logger.debug(f"Found a dataset with shape {ds.sizes} in {search_time - start_time} seconds.")

# Save original min-max values for each band for clipping later
clip_values = {band: (ds[band].min().values.item(), ds[band].max().values.item()) for band in ds.data_vars} # noqa: PD011

# Interpolate missing values (there are very few, so we actually can interpolate them)
for band in ds.data_vars:
ds[band] = ds[band].rio.write_nodata(np.nan).rio.interpolate_na()

logger.debug(f"Reproject dataset to match reference dataset {reference_dataset.sizes}")
ds = ds.rio.reproject_match(reference_dataset, resampling=rasterio.enums.Resampling.cubic)
logger.debug(f"Reshaped dataset in {time.time() - search_time} seconds")

# Convert to uint8
for band in ds.data_vars:
ds[band] = ds[band].astype("uint8")
band_min, band_max = clip_values[band]
ds[band] = ds[band].clip(band_min, band_max, keep_attrs=True).astype("uint8").rio.write_nodata(None)

# Save to cache
if cache_dir is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,12 @@ def calculate_ndvi(planet_scene_dataset: xr.Dataset, nir_band: str = "nir", red_
r = planet_scene_dataset[red_band].astype("float32")
ndvi = (nir - r) / (nir + r)

# scale to Uint16
ndvi = ((ndvi + 1) * 1e4).astype("uint16")
# Scale to 0 - 20000 (for later conversion to uint16)
ndvi = (ndvi.clip(-1, 1) + 1) * 1e4
# Make nan to 0
ndvi = ndvi.fillna(0).rio.write_nodata(0)
# Convert to uint16
ndvi = ndvi.astype("uint16")

ndvi = ndvi.assign_attrs({"data_source": "planet", "long_name": "NDVI"}).to_dataset(name="ndvi")
logger.debug(f"NDVI calculated in {time.time() - start} seconds.")
Expand Down
38 changes: 35 additions & 3 deletions darts-segmentation/src/darts_segmentation/segment.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
"""Functionality for segmenting tiles."""

import logging
from pathlib import Path
from typing import Any, TypedDict

import segmentation_models_pytorch as smp
import torch
import torch.nn as nn
import xarray as xr
from lovely_numpy import lovely

from darts_segmentation.utils import predict_in_patches

logger = logging.getLogger(__name__)


class SMPSegmenterConfig(TypedDict):
"""Configuration for the segmentor."""
Expand Down Expand Up @@ -80,8 +84,13 @@ def tile2tensor(self, tile: xr.Dataset) -> torch.Tensor:
for feature_name in self.config["input_combination"]:
norm = self.config["norm_factors"][feature_name]
band_data = tile[feature_name]
band_info_before = lovely(band_data.values)
# Normalize the band data
band_data = band_data * norm
band_info_after = lovely(band_data.values)
logger.debug(
f"Normalised '{feature_name}' with {norm=}.\nBefore: {band_info_before}.\nAfter: {band_info_after}"
)
bands.append(torch.from_numpy(band_data.values))

return torch.stack(bands, dim=0)
Expand All @@ -100,8 +109,13 @@ def tile2tensor_batched(self, tiles: list[xr.Dataset]) -> torch.Tensor:
norm = self.config["norm_factors"][feature_name]
for tile in tiles:
band_data = tile[feature_name]
band_info_before = lovely(band_data.values)
# Normalize the band data
band_data = band_data * norm
band_info_after = lovely(band_data.values)
logger.debug(
f"Normalised '{feature_name}' with {norm=}.\nBefore: {band_info_before}.\nAfter: {band_info_after}"
)
bands.append(torch.from_numpy(band_data.values))
# TODO: Test this
return torch.stack(bands, dim=0).reshape(len(tiles), len(self.config["input_combination"]), *bands[0].shape)
Expand Down Expand Up @@ -139,6 +153,7 @@ def segment_tile(
tile["probabilities"].attrs = {
"long_name": "Probabilities",
}
tile["probabilities"] = tile["probabilities"].fillna(float("nan")).rio.write_nodata(float("nan"))
return tile

def segment_tile_batched(
Expand Down Expand Up @@ -182,13 +197,26 @@ def segment_tile_batched(
tile["probabilities"].attrs = {
"long_name": "Probabilities",
}
tile["probabilities"] = tile["probabilities"].fillna(float("nan")).rio.write_nodata(float("nan"))
return tiles

def __call__(self, input: xr.Dataset | list[xr.Dataset]) -> xr.Dataset | list[xr.Dataset]:
def __call__(
self,
input: xr.Dataset | list[xr.Dataset],
patch_size: int = 1024,
overlap: int = 16,
batch_size: int = 8,
reflection: int = 0,
) -> xr.Dataset | list[xr.Dataset]:
"""Run inference on a single tile or a list of tiles.
Args:
input: A single tile or a list of tiles.
patch_size (int): The size of the patches. Defaults to 1024.
overlap (int): The size of the overlap. Defaults to 16.
batch_size (int): The batch size for the prediction, NOT the batch_size of input tiles.
Tensor will be sliced into patches and these again will be infered in batches. Defaults to 8.
reflection (int): Reflection-Padding which will be applied to the edges of the tensor. Defaults to 0.
Returns:
A single tile or a list of tiles augmented by a predicted `probabilities` layer, depending on the input.
Expand All @@ -199,8 +227,12 @@ def __call__(self, input: xr.Dataset | list[xr.Dataset]) -> xr.Dataset | list[xr
"""
if isinstance(input, xr.Dataset):
return self.segment_tile(input)
return self.segment_tile(
input, patch_size=patch_size, overlap=overlap, batch_size=batch_size, reflection=reflection
)
elif isinstance(input, list):
return self.segment_tile_batched(input)
return self.segment_tile_batched(
input, patch_size=patch_size, overlap=overlap, batch_size=batch_size, reflection=reflection
)
else:
raise ValueError(f"Expected xr.Dataset or list of xr.Dataset, got {type(input)}")
3 changes: 2 additions & 1 deletion docs/dev/arch.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ The following diagram visualizes the steps of the major `packages` of the pipeli

Each Tile should be represented as a single `xr.Dataset` with each feature / band as `DataVariable`.
Each DataVariable should have their `data_source` documented in the `attrs`, aswell as `long_name` and `units` if any for plotting.
A `_FillValue` should also be set for no-data with `.rio.write_nodata("no-data-value")`

### Preprocessing Output

Expand Down Expand Up @@ -135,7 +136,7 @@ Coordinates: `x`, `y` and `spatial_ref` (from rioxarray)
| --------------------------- | ------ | ----- | ------- | ---------------- | -------------------- |
| [Output from Preprocessing] | | | | | |
| `probabilities_percent` | (x, y) | uint8 | 255 | long_name, units | Values between 0-100 |
| `binarized_segmentation` | (x, y) | uint8 | - | long_name, units | |
| `binarized_segmentation` | (x, y) | uint8 | - | long_name | |

### PyTorch Model checkpoints

Expand Down
17 changes: 13 additions & 4 deletions notebooks/test-e2e.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"from pathlib import Path\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import xarray as xr\n",
"from darts_postprocessing.prepare_export import prepare_export\n",
"from darts_preprocessing.preprocess import load_and_preprocess_planet_scene\n",
"from darts_segmentation.segment import SMPSegmenter\n",
Expand Down Expand Up @@ -50,7 +51,15 @@
"metadata": {},
"outputs": [],
"source": [
"tile = load_and_preprocess_planet_scene(fpath, arcticdem_dir, cache_dir)\n",
"cache_file = DATA_ROOT / \"intermediate\" / f\"planet_{fpath.stem}.nc\"\n",
"force = False\n",
"if cache_file.exists() and not force:\n",
" tile = xr.open_dataset(cache_file, engine=\"h5netcdf\", mask_and_scale=False)\n",
"else:\n",
" tile = load_and_preprocess_planet_scene(fpath, arcticdem_dir, cache_dir)\n",
" cache_file.parent.mkdir(exist_ok=True, parents=True)\n",
" tile.to_netcdf(cache_file, engine=\"h5netcdf\")\n",
"\n",
"tile"
]
},
Expand Down Expand Up @@ -102,8 +111,8 @@
"metadata": {},
"outputs": [],
"source": [
"final = prepare_export(tile)\n",
"final"
"tile = prepare_export(tile)\n",
"tile"
]
},
{
Expand All @@ -112,7 +121,7 @@
"metadata": {},
"outputs": [],
"source": [
"final_low_res = final.coarsen(x=16, y=16, boundary=\"trim\").mean()\n",
"final_low_res = tile.coarsen(x=16, y=16, boundary=\"trim\").mean()\n",
"fig, axs = plt.subplots(2, 7, figsize=(36, 10))\n",
"axs = axs.flatten()\n",
"for i, v in enumerate(final_low_res.data_vars):\n",
Expand Down

0 comments on commit 7021b28

Please sign in to comment.