Skip to content

Commit

Permalink
postprocessing/prepare export: binarize submodel output if existing
Browse files Browse the repository at this point in the history
  • Loading branch information
iona5 committed Nov 4, 2024
1 parent 11367c8 commit 2274883
Showing 1 changed file with 27 additions and 18 deletions.
45 changes: 27 additions & 18 deletions darts-postprocessing/src/darts_postprocessing/prepare_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,23 +13,32 @@ def prepare_export(tile: xr.Dataset) -> xr.Dataset:
xr.Dataset: Output tile.
"""
# Binarize the segmentation
# Where the output from the ensemble / segmentation is nan turn it into 0, else threshold it
# Also, where there was no valid input data, turn it into 0
binarized = (tile["probabilities"].fillna(0) > 0.5).astype("uint8")
tile["binarized_segmentation"] = xr.where(tile["valid_data_mask"], binarized, 0)
tile["binarized_segmentation"].attrs = {
"long_name": "Binarized Segmentation",
}

# Convert the probabilities to uint8
# Same but this time with 255 as no-data
intprobs = (tile["probabilities"] * 100).fillna(255).astype("uint8")
tile["probabilities"] = xr.where(tile["valid_data_mask"], intprobs, 255)
tile["probabilities"].attrs = {
"long_name": "Probabilities",
"units": "%",
}
tile["probabilities"] = tile["probabilities"].rio.write_nodata(255)

def _prep_layer(tile, layername, binarized_layer_name):
# Binarize the segmentation
# Where the output from the ensemble / segmentation is nan turn it into 0, else threshold it
# Also, where there was no valid input data, turn it into 0
binarized = (tile[layername].fillna(0) > 0.5).astype("uint8")
tile[binarized_layer_name] = xr.where(tile["valid_data_mask"], binarized, 0)
tile[binarized_layer_name].attrs = {
"long_name": "Binarized Segmentation",
}

# Convert the probabilities to uint8
# Same but this time with 255 as no-data
intprobs = (tile[layername] * 100).fillna(255).astype("uint8")
tile[layername] = xr.where(tile["valid_data_mask"], intprobs, 255)
tile[layername].attrs = {
"long_name": "Probabilities",
"units": "%",
}
tile[layername] = tile[layername].rio.write_nodata(255)
return tile

tile = _prep_layer(tile, "probabilities", "binarized_segmentation")
if "probabilities-tcvis" in tile:
tile = _prep_layer(tile, "probabilities-tcvis", "binarized_segmentation-tcvis")
if "probabilities-notcvis" in tile:
tile = _prep_layer(tile, "probabilities-notcvis", "binarized_segmentation-notcvis")

return tile

0 comments on commit 2274883

Please sign in to comment.