Skip to content

Commit

Permalink
Update attrs and utilize datamask
Browse files Browse the repository at this point in the history
  • Loading branch information
relativityhd committed Oct 27, 2024
1 parent 245a0ab commit b3e7ca5
Show file tree
Hide file tree
Showing 9 changed files with 111 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,15 @@ def prepare_export(tile: xr.Dataset) -> xr.Dataset:
"""
# Binarize the segmentation
tile["binarized_segmentation"] = (tile["probabilities"] > 0.5).astype("uint8")
# Where the output from the ensemble / segmentation is nan turn it into 0, else threshold it
# Also, where there was no valid input data, turn it into 0
binarized = xr.where(~tile["probabilities"].isnull(), (tile["probabilities"] > 0.5), 0).astype("uint8") # noqa: PD003
tile["binarized_segmentation"] = xr.where(tile["valid_data_mask"], binarized, 0)

# Convert the probabilities to uint8
# Same but this time with 255 as no-data
intprobs = (tile["probabilities"] * 100).astype("uint8")
tile["probabilities"] = xr.where(~tile["probabilities"].isnull(), intprobs, 255) # noqa: PD003
intprobs = xr.where(~tile["probabilities"].isnull(), intprobs, 255) # noqa: PD003
tile["probabilities"] = xr.where(tile["valid_data_mask"], intprobs, 255)

return tile
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,18 @@ def load_arcticdem(fpath: Path, reference_dataset: xr.Dataset) -> xr.Dataset:
elevation_vrt = fpath / "elevation.vrt"

slope = load_vrt(slope_vrt, reference_dataset)
slope: xr.Dataset = slope.assign_attrs({"data_source": "arcticdem", "long_name": "Slope"}).to_dataset(name="slope")
slope: xr.Dataset = (
slope.assign_attrs({"data_source": "arcticdem", "long_name": "Slope"})
.astype("float32")
.to_dataset(name="slope")
)

relative_elevation = load_vrt(elevation_vrt, reference_dataset)
relative_elevation: xr.Dataset = relative_elevation.assign_attrs(
{"data_source": "arcticdem", "long_name": "Relative Elevation"}
).to_dataset(name="relative_elevation")
relative_elevation: xr.Dataset = (
relative_elevation.assign_attrs({"data_source": "arcticdem", "long_name": "Relative Elevation", "units": "m"})
.astype("int16")
.to_dataset(name="relative_elevation")
)

articdem_ds = xr.merge([relative_elevation, slope])
logger.debug(f"Loaded ArcticDEM data in {time.time() - start_time} seconds.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def load_planet_scene(fpath: str | Path) -> xr.Dataset:
# Create a list to hold datasets
datasets = [
planet_da.sel(band=index)
.assign_attrs({"data_source": "planet", "long_name": f"PLANET {name.capitalize()}"})
.assign_attrs({"data_source": "planet", "long_name": f"PLANET {name.capitalize()}", "units": "Reflectance"})
.astype("uint16")
.to_dataset(name=name)
.drop_vars("band")
for index, name in bands.items()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def load_s2_scene(fpath: str | Path) -> xr.Dataset:
# Create a list to hold datasets
datasets = [
s2_da.sel(band=index)
.assign_attrs({"data_source": "s2", "long_name": f"Sentinel 2 {name.capitalize()}"})
.assign_attrs({"data_source": "s2", "long_name": f"Sentinel 2 {name.capitalize()}", "units": "Reflectance"})
.astype("uint16")
.to_dataset(name=name)
.drop_vars("band")
for index, name in bands.items()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from pathlib import Path

import ee
import numpy as np
import pyproj
import rasterio
import xarray as xr
Expand Down Expand Up @@ -76,8 +75,8 @@ def load_tcvis(reference_dataset: xr.Dataset, cache_dir: Path | None = None) ->
)
for band in ds.data_vars:
ds[band].attrs = {
"data_source": "landsat-trends",
"long_name": f"TC {band.split('_')[1].capitalize()}",
"data_source": "ee:ingmarnitze/TCTrend_SR_2000-2019_TCVIS",
"long_name": f"Tasseled Cap {band.split('_')[1].capitalize()}",
}

ds.rio.write_crs(ds.attrs["crs"], inplace=True)
Expand All @@ -90,7 +89,7 @@ def load_tcvis(reference_dataset: xr.Dataset, cache_dir: Path | None = None) ->
logger.debug(f"Reshaped dataset in {time.time() - search_time} seconds")

for band in ds.data_vars:
ds[band] = ds[band].astype(np.uint8)
ds[band] = ds[band].astype("uint8")

# Save to cache
if cache_dir is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@
logger = logging.getLogger(__name__)


def calculate_ndvi(
planet_scene_dataset: xr.Dataset, nir_band: str = "nir", red_band: str = "red", scale_to_uint16: bool = True
) -> xr.Dataset:
def calculate_ndvi(planet_scene_dataset: xr.Dataset, nir_band: str = "nir", red_band: str = "red") -> xr.Dataset:
"""Calculate NDVI from an xarray Dataset containing spectral bands.
Example:
Expand All @@ -25,9 +23,6 @@ def calculate_ndvi(
correspond to the variable name for the NIR band in the 'band' dimension. Defaults to "nir".
red_band (str, optional): The name of the Red band in the Dataset (default is "red"). This name should
correspond to the variable name for the Red band in the 'band' dimension. Defaults to "red".
scale_to_uint16 (bool, optional): If True, scales the NDVI values to a range of 0 to 65535 (Uint16).
This is useful for storing NDVI values in a more compact format while preserving detail.
Defaults to False.
Returns:
xr.Dataset: A new Dataset containing the calculated NDVI values. The resulting Dataset will have
Expand All @@ -48,9 +43,8 @@ def calculate_ndvi(
r = planet_scene_dataset[red_band].astype("float32")
ndvi = (nir - r) / (nir + r)

# scale to Uint16 if required
if scale_to_uint16:
ndvi = ((ndvi + 1) * 1e4).astype("uint16")
# scale to Uint16
ndvi = ((ndvi + 1) * 1e4).astype("uint16")

ndvi = ndvi.assign_attrs({"data_source": "planet", "long_name": "NDVI"}).to_dataset(name="ndvi")
logger.debug(f"NDVI calculated in {time.time() - start} seconds.")
Expand Down
8 changes: 6 additions & 2 deletions darts-segmentation/src/darts_segmentation/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,9 @@ def segment_tile(
# Highly sophisticated DL-based predictor
# TODO: is there a better way to pass metadata?
tile["probabilities"] = tile["red"].copy(data=probabilities.cpu().numpy())
tile["probabilities"].attrs = {}
tile["probabilities"].attrs = {
"long_name": "Probabilities",
}
return tile

def segment_tile_batched(
Expand Down Expand Up @@ -177,7 +179,9 @@ def segment_tile_batched(
for tile, probs in zip(tiles, probabilities):
# TODO: is there a better way to pass metadata?
tile["probabilities"] = tile["red"].copy(data=probs.cpu().numpy())
tile["probabilities"].attrs = {}
tile["probabilities"].attrs = {
"long_name": "Probabilities",
}
return tiles

def __call__(self, input: xr.Dataset | list[xr.Dataset]) -> xr.Dataset | list[xr.Dataset]:
Expand Down
50 changes: 25 additions & 25 deletions docs/dev/arch.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,48 +94,48 @@ The following diagram visualizes the steps of the major `packages` of the pipeli
![DARTS nextgen pipeline steps](../assets/darts_nextgen_pipeline-steps.png)

Each Tile should be represented as a single `xr.Dataset` with each feature / band as `DataVariable`.
Each DataVariable should have their `source` documented in the `attrs`.
Each DataVariable should have their `data_source` documented in the `attrs`, aswell as `long_name` and `units` if any for plotting.

### Preprocessing Output

Coordinates: `x`, `y` and `spatial_ref` (from rioxarray)

| DataVariable | shape | dtype | attrs |
| -------------------- | ------ | ------- | -------- |
| `blue` | (x, y) | uint16 | - source |
| `green` | (x, y) | uint16 | - source |
| `red` | (x, y) | uint16 | - source |
| `nir` | (x, y) | uint16 | - source |
| `ndvi` | (x, y) | float32 | - source |
| `relative_elevation` | (x, y) | float32 | - source |
| `slope` | (x, y) | float32 | - source |
| `tc_brightness` | (x, y) | uint8 | - source |
| `tc_greenness` | (x, y) | uint8 | - source |
| `tc_wetness` | (x, y) | uint8 | - source |
| `valid_data_mask` | (x, y) | bool | - source |
| `quality_data_mask` | (x, y) | bool | - source |
| DataVariable | shape | dtype | no-data | attrs | note |
| -------------------- | ------ | ------- | ------- | ----------------------------- | ---------------------------------- |
| `blue` | (x, y) | uint16 | 0 | data_source, long_name, units | |
| `green` | (x, y) | uint16 | 0 | data_source, long_name, units | |
| `red` | (x, y) | uint16 | 0 | data_source, long_name, units | |
| `nir` | (x, y) | uint16 | 0 | data_source, long_name, units | |
| `ndvi` | (x, y) | uint16 | 0 | data_source, long_name | Values between 0-20.000 (+1, *1e4) |
| `relative_elevation` | (x, y) | int16 | 0 | data_source, long_name, units | |
| `slope` | (x, y) | float32 | nan | data_source, long_name | |
| `tc_brightness` | (x, y) | uint8 | - | data_source, long_name | |
| `tc_greenness` | (x, y) | uint8 | - | data_source, long_name | |
| `tc_wetness` | (x, y) | uint8 | - | data_source, long_name | |
| `valid_data_mask` | (x, y) | bool | - | data_source, long_name | |
| `quality_data_mask` | (x, y) | bool | - | data_source, long_name | |

### Segmentation / Ensemble Output

Coordinates: `x`, `y` and `spatial_ref` (from rioxarray)

| DataVariable | shape | dtype | attrs |
| --------------------------- | ------ | ------- | ----- |
| [Output from Preprocessing] | | | |
| `probabilities` | (x, y) | float32 | |
| `probabilities-model-X*` | (x, y) | bool | |
| DataVariable | shape | dtype | no-data | attrs |
| --------------------------- | ------ | ------- | ------- | --------- |
| [Output from Preprocessing] | | | | |
| `probabilities` | (x, y) | float32 | nan | long_name |
| `probabilities-model-X*` | (x, y) | float32 | nan | long_name |

\*: optional intermedia probabilities in an ensemble

### Postprocessing Output

Coordinates: `x`, `y` and `spatial_ref` (from rioxarray)

| DataVariable | shape | dtype | attrs | note |
| --------------------------- | ------ | ----- | ----- | -------------------------------- |
| [Output from Preprocessing] | | | | |
| `probabilities_percent` | (x, y) | uint8 | | Values between 0-100, nodata:255 |
| `binarized_segmentation` | (x, y) | uint8 | | |
| DataVariable | shape | dtype | no-data | attrs | note |
| --------------------------- | ------ | ----- | ------- | ---------------- | -------------------- |
| [Output from Preprocessing] | | | | | |
| `probabilities_percent` | (x, y) | uint8 | 255 | long_name, units | Values between 0-100 |
| `binarized_segmentation` | (x, y) | uint8 | - | long_name, units | |

### PyTorch Model checkpoints

Expand Down
78 changes: 53 additions & 25 deletions notebooks/test-e2e.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,24 @@
"from pathlib import Path\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import xarray as xr\n",
"from darts_postprocessing.prepare_export import prepare_export\n",
"from darts_preprocessing.preprocess_tobi import load_and_preprocess_planet_scene\n",
"from darts_preprocessing.preprocess import load_and_preprocess_planet_scene\n",
"from darts_segmentation.segment import SMPSegmenter\n",
"from lovely_tensors import monkey_patch\n",
"from rich import traceback\n",
"from rich.logging import RichHandler\n",
"\n",
"xr.set_options(display_expand_data=False)\n",
"from darts.utils.earthengine import init_ee\n",
"from darts.utils.logging import setup_logging\n",
"\n",
"# Set up logging\n",
"logging.basicConfig(level=logging.INFO, handlers=[RichHandler()])\n",
"logging.getLogger(\"darts_preprocessing\").setLevel(logging.DEBUG)\n",
"logging.getLogger(\"darts_segmentation\").setLevel(logging.DEBUG)\n",
"\n",
"monkey_patch()\n",
"traceback.install(show_locals=True)"
"setup_logging()\n",
"logging.basicConfig(\n",
" level=logging.INFO,\n",
" format=\"%(message)s\",\n",
" datefmt=\"[%X]\",\n",
" handlers=[RichHandler(rich_tracebacks=True)],\n",
")\n",
"traceback.install(show_locals=False)\n",
"init_ee(\"ee-tobias-hoelzer\")"
]
},
{
Expand All @@ -35,15 +36,12 @@
"metadata": {},
"outputs": [],
"source": [
"DATA_ROOT = Path(\"../data/input\")\n",
"\n",
"# fpath = DATA_ROOT / \"planet/PSOrthoTile/4372514/5790392_4372514_2022-07-16_2459\"\n",
"fpath = DATA_ROOT / \"planet/PSOrthoTile/4974017/5854937_4974017_2022-08-14_2475\"\n",
"scene_id = fpath.parent.name\n",
"DATA_ROOT = Path(\"../data\")\n",
"\n",
"# TODO: change to vrt\n",
"elevation_path = DATA_ROOT / \"ArcticDEM\" / \"relative_elevation\" / f\"{scene_id}_relative_elevation_100.tif\"\n",
"slope_path = DATA_ROOT / \"ArcticDEM\" / \"slope\" / f\"{scene_id}_slope.tif\"\n"
"# fpath = DATA_ROOT / \"input/planet/PSOrthoTile/4372514/5790392_4372514_2022-07-16_2459\"\n",
"fpath = DATA_ROOT / \"input/planet/PSOrthoTile/4974017/5854937_4974017_2022-08-14_2475\"\n",
"arcticdem_dir = DATA_ROOT / \"input/ArcticDEM\"\n",
"cache_dir = DATA_ROOT / \"download\""
]
},
{
Expand All @@ -52,7 +50,7 @@
"metadata": {},
"outputs": [],
"source": [
"tile = load_and_preprocess_planet_scene(fpath, elevation_path, slope_path)\n",
"tile = load_and_preprocess_planet_scene(fpath, arcticdem_dir, cache_dir)\n",
"tile"
]
},
Expand All @@ -63,7 +61,7 @@
"outputs": [],
"source": [
"tile_low_res = tile.coarsen(x=16, y=16, boundary=\"trim\").mean()\n",
"fig, axs = plt.subplots(2, 5, figsize=(30, 10))\n",
"fig, axs = plt.subplots(2, 6, figsize=(30, 10))\n",
"axs = axs.flatten()\n",
"for i, v in enumerate(tile_low_res.data_vars):\n",
" tile_low_res[v].plot(ax=axs[i], cmap=\"gray\")\n",
Expand All @@ -76,9 +74,36 @@
"metadata": {},
"outputs": [],
"source": [
"model = SMPSegmenter(\"../models/RTS_v6_notcvis.pt\")\n",
"model = SMPSegmenter(\"../models/RTS_v6_tcvis.pt\")\n",
"tile = model.segment_tile(tile, batch_size=4)\n",
"final = prepare_export(tile)"
"tile"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"final_low_res = tile.coarsen(x=16, y=16, boundary=\"trim\").mean()\n",
"fig, axs = plt.subplots(2, 7, figsize=(36, 10))\n",
"axs = axs.flatten()\n",
"for i, v in enumerate(final_low_res.data_vars):\n",
" if v == \"probabilities\":\n",
" final_low_res[v].plot(ax=axs[i], cmap=\"gray\", vmin=0, vmax=1)\n",
" else:\n",
" final_low_res[v].plot(ax=axs[i], cmap=\"gray\")\n",
" axs[i].set_title(v)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"final = prepare_export(tile)\n",
"final"
]
},
{
Expand All @@ -88,10 +113,13 @@
"outputs": [],
"source": [
"final_low_res = final.coarsen(x=16, y=16, boundary=\"trim\").mean()\n",
"fig, axs = plt.subplots(2, 6, figsize=(36, 10))\n",
"fig, axs = plt.subplots(2, 7, figsize=(36, 10))\n",
"axs = axs.flatten()\n",
"for i, v in enumerate(final_low_res.data_vars):\n",
" final_low_res[v].plot(ax=axs[i], cmap=\"gray\")\n",
" if v == \"probabilities\":\n",
" final_low_res[v].plot(ax=axs[i], cmap=\"gray\", vmin=0, vmax=100)\n",
" else:\n",
" final_low_res[v].plot(ax=axs[i], cmap=\"gray\")\n",
" axs[i].set_title(v)"
]
}
Expand Down

0 comments on commit b3e7ca5

Please sign in to comment.