Skip to content

Commit

Permalink
make stats and histogram optional stac-utils#466
Browse files Browse the repository at this point in the history
  • Loading branch information
thomas-maschler committed Sep 29, 2023
1 parent 53b99ad commit 083da25
Show file tree
Hide file tree
Showing 2 changed files with 206 additions and 11 deletions.
35 changes: 24 additions & 11 deletions src/stactools/core/add_raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
BINS = 256


def add_raster_to_item(item: Item) -> Item:
def add_raster_to_item(
item: Item, statistics: bool = True, histogram: bool = True
) -> Item:
"""Adds the raster extension to an item.
Args:
Expand All @@ -34,27 +36,38 @@ def add_raster_to_item(item: Item) -> Item:
if asset.roles and "data" in asset.roles:
raster = RasterExtension.ext(asset)
href = make_absolute_href(asset.href, item.get_self_href())
bands = _read_bands(href)
bands = _read_bands(href, statistics, histogram)
if bands:
raster.apply(bands)
return item


def _read_bands(href: str) -> List[RasterBand]:
def _read_bands(href: str, statistics: bool, histogram: bool) -> List[RasterBand]:
bands = []
with rasterio.open(href) as dataset:
for i, index in enumerate(dataset.indexes):
data = dataset.read(index, masked=True)
band = RasterBand.create()
band.nodata = dataset.nodatavals[i]
band.spatial_resolution = dataset.transform[0]
band.data_type = DataType(dataset.dtypes[i])
minimum = float(numpy.min(data))
maximum = float(numpy.max(data))
band.statistics = Statistics.create(minimum=minimum, maximum=maximum)
hist_data, _ = numpy.histogram(data, range=(minimum, maximum), bins=BINS)
band.histogram = Histogram.create(
BINS, minimum, maximum, hist_data.tolist()
)

if statistics or histogram:
data = dataset.read(index, masked=True)
minimum = float(numpy.nanmin(data))
maximum = float(numpy.nanmax(data))
if statistics:
band.statistics = Statistics.create(minimum=minimum, maximum=maximum)
if histogram:
# the entire array is masked, or NAN values are not set to nodata.
# won't be able to compute histogram and will return empty array.
if numpy.isnan(minimum):
band.histogram = Histogram.create(0, minimum, maximum, [])
else:
hist_data, _ = numpy.histogram(
data, range=(minimum, maximum), bins=BINS
)
band.histogram = Histogram.create(
BINS, minimum, maximum, hist_data.tolist()
)
bands.append(band)
return bands
182 changes: 182 additions & 0 deletions tests/core/test_add_raster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
import tempfile
from typing import Callable, List, Optional

import numpy as np
import pystac
import pytest
import rasterio
from rasterio.crs import CRS
from rasterio.transform import Affine
from stactools.core import create
from stactools.core.add_raster import add_raster_to_item


def random_data(count: int) -> np.ndarray:
return np.random.rand(count, 10, 10) * 10


def nan_data(count: int) -> np.ndarray:
data = np.empty((count, 10, 10))
data[:] = np.nan
return data


def data_with_nan(count: int) -> np.ndarray:
data = np.random.rand(count, 10, 10) * 10
data[0][1][1] = np.nan
return data


def zero_data(count: int) -> np.ndarray:
return np.zeros((count, 10, 10))


def test_add_raster(tmp_asset_path) -> None:
item = create.item(tmp_asset_path)
add_raster_to_item(item)

asset: pystac.Asset = item.assets["data"]

_assert_asset(
asset,
expected_count=1,
expected_nodata=None,
expected_spatial_resolution=60.0,
expected_dtype=np.dtype("uint8"),
expected_min=[74.0],
expected_max=[255.0],
)


@pytest.mark.parametrize(
"count,nodata,dtype,datafunc,hist_count",
[
(1, 0, np.dtype("int8"), random_data, 256),
(1, None, np.dtype("float64"), random_data, 256),
(1, np.nan, np.dtype("float64"), random_data, 256),
(2, 0, np.dtype("int8"), random_data, 256),
(2, None, np.dtype("float64"), random_data, 256),
(2, np.nan, np.dtype("float64"), random_data, 256),
(1, 0, np.dtype("uint8"), zero_data, 0),
(1, None, np.dtype("uint8"), zero_data, 256),
(1, None, np.dtype("float64"), nan_data, 0),
(1, np.nan, np.dtype("float64"), nan_data, 0),
(1, None, np.dtype("float64"), data_with_nan, 256),
(1, np.nan, np.dtype("float64"), data_with_nan, 256),
],
)
def test_add_raster_with_nodata(
count: int, nodata: float, dtype: np.dtype, datafunc: Callable, hist_count: int
) -> None:
print("COUNT ", count)
print("NODATA ", nodata)
print("DTYPE ", dtype)
print("DATAFUNC ", datafunc)

with tempfile.NamedTemporaryFile(suffix=".tif") as tmpfile:
print(tmpfile.name)
with rasterio.open(
tmpfile.name,
mode="w",
driver="GTiff",
count=count,
nodata=nodata,
dtype=dtype,
transform=Affine(0.1, 0.0, 1.0, 0.0, -0.1, 1.0),
width=10,
height=10,
crs=CRS.from_epsg(4326),
) as dst:
data = datafunc(count)
data.astype(dtype)
dst.write(data)

with rasterio.open(tmpfile.name) as src:
data = src.read(masked=True)
minimum = []
maximum = []
for i, _ in enumerate(src.indexes):
minimum.append(float(np.nanmin(data[i])))
maximum.append(float(np.nanmax(data[i])))

item = create.item(tmpfile.name)

add_raster_to_item(item)

asset: pystac.Asset = item.assets["data"]
_assert_asset(
asset,
expected_count=count,
expected_nodata=nodata,
expected_spatial_resolution=0.1,
expected_dtype=dtype,
expected_min=minimum,
expected_max=maximum,
expected_hist_count=hist_count,
)


def test_add_raster_without_stats(tmp_asset_path) -> None:
item = create.item(tmp_asset_path)
add_raster_to_item(item, statistics=False)

asset: pystac.Asset = item.assets["data"]
bands = asset.extra_fields.get("raster:bands")

assert bands[0].get("statistics") is None
assert bands[0].get("histogram")


def test_add_raster_without_histogram(tmp_asset_path) -> None:
item = create.item(tmp_asset_path)
add_raster_to_item(item, histogram=False)

asset: pystac.Asset = item.assets["data"]
bands = asset.extra_fields.get("raster:bands")

assert bands[0].get("statistics")
assert bands[0].get("histogram") is None


def _assert_asset(
asset: pystac.Asset,
expected_count: int,
expected_nodata: Optional[float],
expected_dtype: np.dtype,
expected_spatial_resolution: float,
expected_min: List[float],
expected_max: List[float],
expected_hist_count=256,
) -> None:
bands = asset.extra_fields.get("raster:bands")
assert bands
assert len(bands) == expected_count

for i, band in enumerate(bands):
nodata = band.get("nodata")
dtype = band["data_type"].value
spatial_resolution = band["spatial_resolution"]
statistics = band["statistics"]
histogram = band["histogram"]
assert nodata == expected_nodata or (
np.isnan(nodata) and np.isnan(expected_nodata)
)
assert dtype == expected_dtype.name
assert spatial_resolution == expected_spatial_resolution
assert statistics == {
"minimum": expected_min[i],
"maximum": expected_max[i],
} or (
np.isnan(statistics["maximum"])
and np.isnan(expected_max[i])
and np.isnan(statistics["minimum"])
and np.isnan(expected_min[i])
)
assert histogram["count"] == expected_hist_count
assert histogram["max"] == band["statistics"]["maximum"] or (
np.isnan(histogram["max"]) and np.isnan(statistics["maximum"])
)
assert histogram["min"] == band["statistics"]["minimum"] or (
np.isnan(histogram["min"]) and np.isnan(statistics["minimum"])
)
assert len(histogram["buckets"]) == histogram["count"]

0 comments on commit 083da25

Please sign in to comment.