Skip to content

Commit

Permalink
Make ensemble work
Browse files Browse the repository at this point in the history
  • Loading branch information
relativityhd committed Oct 27, 2024
1 parent a0ffac6 commit ea7934e
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 29 deletions.
127 changes: 115 additions & 12 deletions darts-ensemble/src/darts_ensemble/ensemble_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,38 +13,141 @@ def __init__(
self,
rts_v6_tcvis_model_path: str | Path,
rts_v6_notcvis_model_path: str | Path,
binarize_threshold: float = 0.5,
):
"""Initialize the ensemble.
Args:
rts_v6_tcvis_model_path (str | Path): Path to the model trained with TCVIS data.
rts_v6_notcvis_model_path (str | Path): Path to the model trained without TCVIS data.
binarize_threshold (float, optional): Threshold to binarize the ensemble output. Defaults to 0.5.
"""
self.rts_v6_tcvis_model = SMPSegmenter(rts_v6_tcvis_model_path)
self.rts_v6_notcvis_model = SMPSegmenter(rts_v6_notcvis_model_path)
self.threshold = binarize_threshold

def __call__(self, tile: xr.Dataset, keep_inputs: bool = False) -> xr.Dataset:
"""Run the ensemble on the given tile.
def segment_tile(
self,
tile: xr.Dataset,
patch_size: int = 1024,
overlap: int = 16,
batch_size: int = 8,
reflection: int = 0,
keep_inputs: bool = False,
) -> xr.Dataset:
"""Run inference on a tile.
Args:
tile (xr.Dataset): Input tile from preprocessing.
tile: The input tile, containing preprocessed, harmonized data.
patch_size (int): The size of the patches. Defaults to 1024.
overlap (int): The size of the overlap. Defaults to 16.
batch_size (int): The batch size for the prediction, NOT the batch_size of input tiles.
Tensor will be sliced into patches and these again will be infered in batches. Defaults to 8.
reflection (int): Reflection-Padding which will be applied to the edges of the tensor. Defaults to 0.
keep_inputs (bool, optional): Whether to keep the input probabilities in the output. Defaults to False.
Returns:
xr.Dataset: Output tile with the ensemble applied.
Input tile augmented by a predicted `probabilities` layer with type float32 and range [0, 1].
"""
tcvis_tile = self.rts_v6_tcvis_model.segment_tile(tile)
notcvis_tile = self.rts_v6_notcvis_model.segment_tile(tile)
tcvis_probabilities = self.rts_v6_tcvis_model.segment_tile(
tile, patch_size=patch_size, overlap=overlap, batch_size=batch_size, reflection=reflection
)["probabilities"].copy()
notcvis_propabilities = self.rts_v6_notcvis_model.segment_tile(
tile, patch_size=patch_size, overlap=overlap, batch_size=batch_size, reflection=reflection
)["probabilities"].copy()

tile["probabilities"] = (tcvis_tile["probabilities"] + notcvis_tile["probabilities"]) / 2
tile["probabilities"] = (tcvis_probabilities + notcvis_propabilities) / 2

if keep_inputs:
tile["probabilities-tcvis"] = tcvis_tile["probabilities"]
tile["probabilities-notcvis"] = notcvis_tile["probabilities"]
tile["probabilities-tcvis"] = tcvis_probabilities
tile["probabilities-notcvis"] = notcvis_propabilities

return tile

def segment_tile_batched(
self,
tiles: list[xr.Dataset],
patch_size: int = 1024,
overlap: int = 16,
batch_size: int = 8,
reflection: int = 0,
keep_inputs: bool = False,
) -> list[xr.Dataset]:
"""Run inference on a list of tiles.
Args:
tiles: The input tiles, containing preprocessed, harmonized data.
patch_size (int): The size of the patches. Defaults to 1024.
overlap (int): The size of the overlap. Defaults to 16.
batch_size (int): The batch size for the prediction, NOT the batch_size of input tiles.
Tensor will be sliced into patches and these again will be infered in batches. Defaults to 8.
reflection (int): Reflection-Padding which will be applied to the edges of the tensor. Defaults to 0.
keep_inputs (bool, optional): Whether to keep the input probabilities in the output. Defaults to False.
Returns:
A list of input tiles augmented by a predicted `probabilities` layer with type float32 and range [0, 1].
"""
for tile in tiles: # Note that tile is still a reference -> tiles will be changed!
tcvis_probabilities = self.rts_v6_tcvis_model.segment_tile(
tile, patch_size=patch_size, overlap=overlap, batch_size=batch_size, reflection=reflection
)["probabilities"].copy()
notcvis_propabilities = self.rts_v6_notcvis_model.segment_tile(
tile, patch_size=patch_size, overlap=overlap, batch_size=batch_size, reflection=reflection
)["probabilities"].copy()

tile["probabilities"] = (tcvis_probabilities + notcvis_propabilities) / 2

if keep_inputs:
tile["probabilities-tcvis"] = tcvis_probabilities
tile["probabilities-notcvis"] = notcvis_propabilities

return tiles

def __call__(
self,
input: xr.Dataset | list[xr.Dataset],
patch_size: int = 1024,
overlap: int = 16,
batch_size: int = 8,
reflection: int = 0,
keep_inputs: bool = False,
) -> xr.Dataset:
"""Run the ensemble on the given tile.
Args:
input (xr.Dataset | list[xr.Dataset]): A single tile or a list of tiles.
tile (xr.Dataset): Input tile from preprocessing.
patch_size (int): The size of the patches. Defaults to 1024.
overlap (int): The size of the overlap. Defaults to 16.
batch_size (int): The batch size for the prediction, NOT the batch_size of input tiles.
Tensor will be sliced into patches and these again will be infered in batches. Defaults to 8.
reflection (int): Reflection-Padding which will be applied to the edges of the tensor. Defaults to 0.
keep_inputs (bool, optional): Whether to keep the input probabilities in the output. Defaults to False.
Returns:
xr.Dataset: Output tile with the ensemble applied.
Raises:
ValueError: in case the input is not an xr.Dataset or a list of xr.Dataset
"""
if isinstance(input, xr.Dataset):
return self.segment_tile(
input,
patch_size=patch_size,
overlap=overlap,
batch_size=batch_size,
reflection=reflection,
keep_inputs=keep_inputs,
)
elif isinstance(input, list):
return self.segment_tile_batched(
input,
patch_size=patch_size,
overlap=overlap,
batch_size=batch_size,
reflection=reflection,
keep_inputs=keep_inputs,
)
else:
raise ValueError("Input must be an xr.Dataset or a list of xr.Dataset.")
13 changes: 1 addition & 12 deletions darts-segmentation/src/darts_segmentation/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import torch
import torch.nn as nn
import xarray as xr
from lovely_numpy import lovely

from darts_segmentation.utils import predict_in_patches

Expand Down Expand Up @@ -84,13 +83,8 @@ def tile2tensor(self, tile: xr.Dataset) -> torch.Tensor:
for feature_name in self.config["input_combination"]:
norm = self.config["norm_factors"][feature_name]
band_data = tile[feature_name]
band_info_before = lovely(band_data.values)
# Normalize the band data
band_data = band_data * norm
band_info_after = lovely(band_data.values)
logger.debug(
f"Normalised '{feature_name}' with {norm=}.\nBefore: {band_info_before}.\nAfter: {band_info_after}"
)
bands.append(torch.from_numpy(band_data.values))

return torch.stack(bands, dim=0)
Expand All @@ -109,13 +103,8 @@ def tile2tensor_batched(self, tiles: list[xr.Dataset]) -> torch.Tensor:
norm = self.config["norm_factors"][feature_name]
for tile in tiles:
band_data = tile[feature_name]
band_info_before = lovely(band_data.values)
# Normalize the band data
band_data = band_data * norm
band_info_after = lovely(band_data.values)
logger.debug(
f"Normalised '{feature_name}' with {norm=}.\nBefore: {band_info_before}.\nAfter: {band_info_after}"
)
bands.append(torch.from_numpy(band_data.values))
# TODO: Test this
return torch.stack(bands, dim=0).reshape(len(tiles), len(self.config["input_combination"]), *bands[0].shape)
Expand Down Expand Up @@ -211,7 +200,7 @@ def __call__(
"""Run inference on a single tile or a list of tiles.
Args:
input: A single tile or a list of tiles.
input (xr.Dataset | list[xr.Dataset]): A single tile or a list of tiles.
patch_size (int): The size of the patches. Defaults to 1024.
overlap (int): The size of the overlap. Defaults to 16.
batch_size (int): The batch size for the prediction, NOT the batch_size of input tiles.
Expand Down
13 changes: 8 additions & 5 deletions notebooks/test-e2e.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
"\n",
"import matplotlib.pyplot as plt\n",
"import xarray as xr\n",
"from darts_ensemble.ensemble_v1 import EnsembleV1\n",
"from darts_postprocessing.prepare_export import prepare_export\n",
"from darts_preprocessing.preprocess import load_and_preprocess_planet_scene\n",
"from darts_segmentation.segment import SMPSegmenter\n",
"from rich import traceback\n",
"from rich.logging import RichHandler\n",
"\n",
Expand Down Expand Up @@ -83,8 +83,11 @@
"metadata": {},
"outputs": [],
"source": [
"model = SMPSegmenter(\"../models/RTS_v6_tcvis.pt\")\n",
"tile = model.segment_tile(tile, batch_size=4)\n",
"ensemble = EnsembleV1(\n",
" \"../models/RTS_v6_tcvis.pt\",\n",
" \"../models/RTS_v6_notcvis.pt\",\n",
")\n",
"tile = ensemble(tile, batch_size=4, keep_inputs=True)\n",
"tile"
]
},
Expand All @@ -95,7 +98,7 @@
"outputs": [],
"source": [
"final_low_res = tile.coarsen(x=16, y=16, boundary=\"trim\").mean()\n",
"fig, axs = plt.subplots(2, 7, figsize=(36, 10))\n",
"fig, axs = plt.subplots(2, 8, figsize=(36, 10))\n",
"axs = axs.flatten()\n",
"for i, v in enumerate(final_low_res.data_vars):\n",
" if v == \"probabilities\":\n",
Expand All @@ -122,7 +125,7 @@
"outputs": [],
"source": [
"final_low_res = tile.coarsen(x=16, y=16, boundary=\"trim\").mean()\n",
"fig, axs = plt.subplots(2, 7, figsize=(36, 10))\n",
"fig, axs = plt.subplots(2, 8, figsize=(36, 10))\n",
"axs = axs.flatten()\n",
"for i, v in enumerate(final_low_res.data_vars):\n",
" if v == \"probabilities\":\n",
Expand Down

0 comments on commit ea7934e

Please sign in to comment.