From 8bc5b1a8baa6dab4935954c228a5b9da111caf55 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Sun, 8 Jan 2023 22:03:49 +0100 Subject: [PATCH 1/9] NwpIndex: Add NWP data indexer based on Caterva and ironArray --- .github/workflows/tests-conda.yml | 2 +- .github/workflows/tests-python.yml | 2 +- .gitignore | 1 + herbie/index/__init__.py | 6 + herbie/index/core.py | 202 +++++++++++++++++++++++++++++ herbie/index/loader.py | 59 +++++++++ herbie/index/monkey.py | 25 ++++ herbie/index/util.py | 31 +++++ setup.cfg | 5 + tests/test_index_era5.py | 170 ++++++++++++++++++++++++ 10 files changed, 501 insertions(+), 2 deletions(-) create mode 100644 herbie/index/__init__.py create mode 100644 herbie/index/core.py create mode 100644 herbie/index/loader.py create mode 100644 herbie/index/monkey.py create mode 100644 herbie/index/util.py create mode 100644 tests/test_index_era5.py diff --git a/.github/workflows/tests-conda.yml b/.github/workflows/tests-conda.yml index 743ff98d..9574a073 100644 --- a/.github/workflows/tests-conda.yml +++ b/.github/workflows/tests-conda.yml @@ -97,7 +97,7 @@ jobs: - name: INSTALL - Project run: | - pip install --editable=. + pip install --editable=.[indexing] - name: Run tests env: diff --git a/.github/workflows/tests-python.yml b/.github/workflows/tests-python.yml index 830d25ec..0718242f 100644 --- a/.github/workflows/tests-python.yml +++ b/.github/workflows/tests-python.yml @@ -67,7 +67,7 @@ jobs: - name: Install project run: | pip3 install --requirement=requirements-test.txt - pip3 install --editable=. + pip3 install --editable=.[indexing] - name: Run tests env: diff --git a/.gitignore b/.gitignore index ec33450d..306c7ebd 100644 --- a/.gitignore +++ b/.gitignore @@ -12,6 +12,7 @@ core.* *.idx *.grib2 +*.iarr !sample_data/hrrr/20201214/subset_20201214_hrrr.t00z.wrfsfcf12.grib2 .idea diff --git a/herbie/index/__init__.py b/herbie/index/__init__.py new file mode 100644 index 00000000..7c406231 --- /dev/null +++ b/herbie/index/__init__.py @@ -0,0 +1,6 @@ +# MIT License +# (c) 2023 Andreas Motl +# https://github.com/earthobservations +from herbie.index.monkey import monkeypatch_iarray + +monkeypatch_iarray() diff --git a/herbie/index/core.py b/herbie/index/core.py new file mode 100644 index 00000000..0ccbbf46 --- /dev/null +++ b/herbie/index/core.py @@ -0,0 +1,202 @@ +# MIT License +# (c) 2023 Andreas Motl +# https://github.com/earthobservations +import dataclasses +import logging +import os.path +import typing as t +from pathlib import Path + +import iarray_community as ia +import numpy as np +import xarray as xr +from ndindex import Slice +from scipy.constants import convert_temperature + +from herbie.index.util import dataset_info, round_clipped + +logger = logging.getLogger(__name__) + + +class NwpIndex: + """ + Manage a multidimensional index of NWP data, using Caterva and ironArray. + + - https://caterva.readthedocs.io/ + - https://ironarray.io/docs/html/ + """ + + # Where the ironArray files (`.iarr`) will be stored. + # FIXME: Segfaults when path contains spaces. => Report to ironArray fame. + # `/Users/amo/Library/Application Support/herbie/index-iarray/precipitation_amount_1hour_Accumulation.iarr` + # BASEDIR = platformdirs.user_data_path("herbie").joinpath("index-iarray") + + # Alternatively, just use the working directory for now. + BASEDIR = Path(os.path.curdir) + + # Default ironArray configuration. + IA_CONFIG = dict( + codec=ia.Codec.LZ4, + clevel=9, + # How to choose the best numbers? + # https://ironarray.io/docs/html/tutorials/03.Slicing_Datasets_and_Creating_Views.html#Optimization-Tips + chunks=(360, 360, 720), + blocks=(180, 180, 360), + # chunks=(360, 128, 1440), + # blocks=(8, 8, 720), + # TODO: Does it really work? + # nthreads=12, + ) + + def __init__(self, name, time_coordinate, resolution=None, data=None): + self.name = name + self.resolution = resolution + self.coordinate = Coordinate(time=time_coordinate) + if self.resolution: + self.coordinate.mkgrid(resolution=self.resolution) + self.data: ia.IArray = data + self.path = self.BASEDIR.joinpath(self.name).with_suffix(".iarr") + + def load(self): + self.data: ia.IArray = ia.open(str(self.path)) + logger.info(f"Loaded IArray from: {self.path}") + logger.debug(f"IArray info:\n{self.data.info}") + return self + + def save(self, dataset: xr.Dataset): + """ + Derived from ironArray's `fetch_data.py` example program [1,2], + and its documentation about "Configuring ironArray" [3]. + + [1] https://github.com/ironArray/iron-array-notebooks/blob/76fe0e9f93a75443e3aed73a9ffc36119d4aad6c/tutorials/fetch_data.py#L11-L18 + [2] https://github.com/ironArray/iron-array-notebooks/blob/76fe0e9f93a75443e3aed73a9ffc36119d4aad6c/tutorials/fetch_data.py#L37-L41 + [3] https://ironarray.io/docs/html/tutorials/02.Configuring_ironArray.html + """ + + # Use data from first data variable within dataset. + data_variable = list(dataset.data_vars.keys())[0] + logger.info(f"Discovered dataset variable: {data_variable}") + logger.info(f"Storing and indexing to: {self.path}") + logger.debug(f"Dataset info:\n{dataset_info(dataset)}") + + data = dataset[data_variable] + logger.info( + f"Data variable '{data_variable}' has shape={data.shape} and dtype={data.dtype}" + ) + with ia.config(**self.IA_CONFIG): + ia_data = ia.empty( + shape=data.shape, dtype=data.dtype, urlpath=str(self.path) + ) + logger.info("Populating IArray") + ia_data[:] = data.values + logger.info(f"IArray is ready") + logger.debug(f"IArray info:\n{ia_data.info}") + self.data = ia_data + + def round_location(self, value): + return round_clipped(value, self.resolution) + + def query(self, timestamp=None, lat=None, lon=None) -> "Result": + + # Query by point or range (bbox). + if lat is None: + idx_lat = np.where(self.coordinate.lat)[0] + lat_slice = Slice(start=idx_lat[0], stop=idx_lat[-1] + 1) + elif isinstance(lat, float): + idx_lat = np.where(self.coordinate.lat == self.round_location(lat))[0][0] + lat_slice = Slice(start=idx_lat, stop=idx_lat + 2) + elif isinstance(lat, t.Sequence): + idx_lat = np.where( + np.logical_and( + self.coordinate.lat >= self.round_location(lat[0]), + self.coordinate.lat <= self.round_location(lat[1]), + ) + )[0] + lat_slice = Slice(start=idx_lat[0], stop=idx_lat[-1] + 1) + else: + raise ValueError(f"Unable to process value for lat={lat}") + + if lon is None: + idx_lon = np.where(self.coordinate.lon)[0] + lon_slice = Slice(start=idx_lon[0], stop=idx_lon[-1] + 1) + elif isinstance(lon, float): + idx_lon = np.where(self.coordinate.lon == self.round_location(lon))[0][0] + lon_slice = Slice(start=idx_lon, stop=idx_lon + 2) + elif isinstance(lon, t.Sequence): + idx_lon = np.where( + np.logical_and( + self.coordinate.lon >= self.round_location(lon[0]), + self.coordinate.lon <= self.round_location(lon[1]), + ) + )[0] + lon_slice = Slice(start=idx_lon[0], stop=idx_lon[-1] + 1) + else: + raise ValueError(f"Unable to process value for lon={lon}") + + # Optionally query by timestamp, or not. + if timestamp: + idx_time = np.where(self.coordinate.time == np.datetime64(timestamp))[0][0] + time_slice = Slice(idx_time, idx_time + 2) + filtered = self.data[time_slice, lat_slice, lon_slice] + timestamp_coord = self.coordinate.time[time_slice.start : time_slice.stop] + else: + filtered = self.data[:, lat_slice, lon_slice] + timestamp_coord = self.coordinate.time[:] + + # Rebuild DataArray from result. + outdata = xr.DataArray( + filtered, + dims=("time", "lat", "lon"), + coords={ + "lat": self.coordinate.lat[lat_slice.start : lat_slice.stop], + "lon": self.coordinate.lon[lon_slice.start : lon_slice.stop], + "time": timestamp_coord, + }, + ) + + return Result(da=outdata) + + +@dataclasses.dataclass +class Coordinate: + """ + Manage data for all available coordinates. + + # TODO: How could this meta information be carried over from the source data? + """ + + time: t.Optional[np.ndarray] = None + lat: t.Optional[np.ndarray] = None + lon: t.Optional[np.ndarray] = None + + def mkgrid(self, resolution: float): + self.lat = np.arange(start=90.0, stop=-90.0, step=-resolution, dtype=np.float32) + self.lon = np.arange( + start=-180.0, stop=180.0, step=resolution, dtype=np.float32 + ) + + +@dataclasses.dataclass +class Result: + """ + Wrap query result, and provide convenience accessor methods and value converters. + """ + + da: xr.DataArray + + def select_first(self) -> xr.DataArray: + return self.da[0][0][0] + + def select_first_point(self): + return self.da.sel(lat=self.da["lat"][0], lon=self.da["lon"][0]) + + def select_first_timestamp(self): + return self.da.sel(time=self.da["time"][0]) + + def kelvin_to_celsius(self): + self.da.values = convert_temperature(self.da.values, "Kelvin", "Celsius") + return self + + def kelvin_to_fahrenheit(self): + self.da.values = convert_temperature(self.da.values, "Kelvin", "Fahrenheit") + return self diff --git a/herbie/index/loader.py b/herbie/index/loader.py new file mode 100644 index 00000000..533f6c5b --- /dev/null +++ b/herbie/index/loader.py @@ -0,0 +1,59 @@ +# MIT License +# (c) 2023 Andreas Motl +# https://github.com/earthobservations +import logging + +import fsspec +import numpy as np +import platformdirs +import s3fs +import xarray as xr + +logger = logging.getLogger(__name__) + + +CACHE_BASEDIR = platformdirs.user_cache_path("herbie").joinpath("index-download") + + +def open_era5_zarr(parameter, year, month, datestart, dateend) -> xr.Dataset: + """ + Load "ERA5 forecasts reanalysis" data from ECMWF. + The ERA5 HRES atmospheric data has a resolution of 31km, 0.28125 degrees [1]. + + The implementation is derived from ironArray's "Slicing Datasets and Creating + Views" documentation [2]. For processing data more efficiently, downloaded data + is cached locally, using fsspec's "filecache" filesystem [3]. + + [1] https://confluence.ecmwf.int/display/CKB/ERA5%3A+data+documentation#heading-Spatialgrid + [2] https://ironarray.io/docs/html/tutorials/03.Slicing_Datasets_and_Creating_Views.html + [3] https://filesystem-spec.readthedocs.io/en/latest/features.html#caching-files-locally + """ + location = f"era5-pds/zarr/{year}/{month:02d}/data/{parameter}.zarr/" + logger.info(f"Loading NWP data from {location}") + logger.info(f"Using local cache at {CACHE_BASEDIR}") + + # ERA5 is on AWS S3, it can be accessed anonymously. + fs = s3fs.S3FileSystem(anon=True) + + # Add local cache, using fsspec fame. + fs = fsspec.filesystem("filecache", cache_storage=str(CACHE_BASEDIR), fs=fs) + + # Access resource in Zarr format. + # Possible engines: ['scipy', 'cfgrib', 'gini', 'store', 'zarr'] + s3map = s3fs.S3Map(location, s3=fs) + ds = xr.open_dataset(s3map, engine="zarr") + + # The name of the `time` coordinate differs between datasets. + time_field_candidates = ["time0", "time1"] + for candidate in time_field_candidates: + if candidate in ds.coords: + time_field = candidate + + # Select subset of data based on time range. + indexers = {time_field: slice(np.datetime64(datestart), np.datetime64(dateend))} + ds = ds.sel(indexers=indexers) + + # Rearrange coordinates data from longitude 0 to 360 degrees (long3) to -180 to 180 degrees (long1). + ds = ds.assign(lon=ds["lon"] - 180) + + return ds diff --git a/herbie/index/monkey.py b/herbie/index/monkey.py new file mode 100644 index 00000000..5363fec8 --- /dev/null +++ b/herbie/index/monkey.py @@ -0,0 +1,25 @@ +# MIT License +# (c) 2023 Andreas Motl +# https://github.com/earthobservations +from iarray_community import IArray + +iarray_info_items_original = IArray.info_items + + +@property +def iarray_info_items(self): + """ + Just a minor patch for ironArray to extend info output. + + TODO: Submit patch to upstream repository. + https://github.com/ironArray/iarray-community + """ + items = iarray_info_items_original.fget(self) + items += [("codec", self.codec)] + items += [("clevel", self.clevel)] + items += [("size", self.size)] + return items + + +def monkeypatch_iarray(): + IArray.info_items = iarray_info_items diff --git a/herbie/index/util.py b/herbie/index/util.py new file mode 100644 index 00000000..562d5d06 --- /dev/null +++ b/herbie/index/util.py @@ -0,0 +1,31 @@ +# MIT License +# (c) 2023 Andreas Motl +# https://github.com/earthobservations +import io +import logging +import sys + + +def round_clipped(value, clipping): + """ + https://stackoverflow.com/a/7859208 + :param value: + :param clipping: + :return: + """ + return round(float(value) / clipping) * clipping + + +def setup_logging(level=logging.INFO): + log_format = "%(asctime)-15s [%(name)-20s] %(levelname)-7s: %(message)s" + logging.basicConfig(format=log_format, stream=sys.stderr, level=level) + + requests_log = logging.getLogger("botocore") + requests_log.setLevel(logging.INFO) + + +def dataset_info(ds) -> str: + buf = io.StringIO() + ds.info(buf) + buf.seek(0) + return buf.read() diff --git a/setup.cfg b/setup.cfg index c0930423..3de9f588 100644 --- a/setup.cfg +++ b/setup.cfg @@ -67,6 +67,11 @@ docs = sphinx-design sphinx-markdown-tables sphinxcontrib-mermaid +indexing = + iarray-community + s3fs + platformdirs + zarr #tests = # pytest diff --git a/tests/test_index_era5.py b/tests/test_index_era5.py new file mode 100644 index 00000000..30c8d8cb --- /dev/null +++ b/tests/test_index_era5.py @@ -0,0 +1,170 @@ +# MIT License +# (c) 2023 Andreas Motl +# https://github.com/earthobservations +import datetime +from unittest import mock + +import numpy as np +import pytest +import xarray as xr +from xarray.testing import assert_equal + +from herbie.index.core import NwpIndex +from herbie.index.loader import open_era5_zarr + +TEMP2M = "air_temperature_at_2_metres" + +TIMERANGE = np.arange( + start=np.datetime64("1987-10-01 08:00"), + stop=np.datetime64("1987-10-01 10:59"), + step=datetime.timedelta(hours=1), +) + + +@pytest.fixture +def era5_temp2m_index(): + nwp = NwpIndex(name=TEMP2M, time_coordinate=TIMERANGE, resolution=0.25) + if not nwp.path.exists(): + nwp.save(dataset=open_era5_zarr(TEMP2M, 1987, 10, TIMERANGE[0], TIMERANGE[-1])) + return nwp + + +def test_query_era5_monterey_fahrenheit_single_spot(era5_temp2m_index): + """ + Query indexed ERA5 NWP data for a specific point in space and time. + """ + + nwp = era5_temp2m_index.load() + + # Temperatures in Monterey, in Fahrenheit. + first = ( + nwp.query(timestamp="1987-10-01 08:00", lat=36.6083, lon=-121.8674) + .kelvin_to_fahrenheit() + .select_first() + ) + + # Verify values. + assert first.values == np.array(73.805008, dtype=np.float32) + + # Verify coordinate. + assert dict(first.coords) == dict( + time=xr.DataArray(data=np.datetime64("1987-10-01 08:00"), name="time"), + lat=xr.DataArray(data=np.float32(36.5), name="lat"), + lon=xr.DataArray(data=np.float32(-121.75), name="lon"), + ) + + +def test_query_era5_berlin_celsius_location_full_timerange(era5_temp2m_index): + """ + Query indexed ERA5 NWP data for a sequence of time steps. + """ + + nwp = era5_temp2m_index.load() + + # Temperatures in Berlin, in Celsius. + result = ( + nwp.query(lat=52.51074, lon=13.43506).kelvin_to_celsius().select_first_point() + ) + + # Verify values and coordinates. + reference = xr.DataArray( + data=np.array([6.600006, 6.600006, 6.600006], dtype=np.float32), + coords=dict( + time=xr.DataArray(data=TIMERANGE), + lat=xr.DataArray(data=np.float32(52.5)), + lon=xr.DataArray(data=np.float32(13.5)), + ), + ) + reference = reference.swap_dims(dim_0="time") + assert_equal(result, reference) + + +def test_query_era5_monterey_fahrenheit_bbox_area(era5_temp2m_index): + """ + Query indexed ERA5 NWP data for a given area. + + http://bboxfinder.com/ + """ + + nwp = era5_temp2m_index.load() + + # Temperatures in Monterey area, in Fahrenheit. + result = ( + nwp.query( + timestamp="1987-10-01 08:00", + lat=(36.450837, 36.700907), + lon=(-122.166252, -121.655045), + ) + .kelvin_to_fahrenheit() + .select_first_timestamp() + ) + + # Verify values and coordinates. + reference = xr.DataArray( + data=np.array( + [[73.58001, 71.89251, 70.88001], [74.93001, 75.717514, 73.80501]], + dtype=np.float32, + ), + dims=("lat", "lon"), + coords=dict( + time=xr.DataArray(data=np.datetime64("1987-10-01 08:00")), + lat=xr.DataArray( + data=np.array([36.75, 36.5], dtype=np.float32), dims=("lat",) + ), + lon=xr.DataArray( + data=np.array([-122.25, -122.0, -121.75], dtype=np.float32), + dims=("lon",), + ), + ), + ) + assert_equal(result, reference) + + +def test_query_era5_latitude_slice(era5_temp2m_index): + """ + Query indexed ERA5 NWP data for a given area along the same longitude coordinates. + """ + + nwp = era5_temp2m_index.load() + + # Temperatures for whole slice. + result = ( + nwp.query( + timestamp="1987-10-01 08:00", lat=None, lon=(-122.166252, -121.655045) + ) + .kelvin_to_celsius() + .select_first_timestamp() + ) + + # Verify coordinates. + reference = xr.DataArray( + data=mock.ANY, + dims=("lat", "lon"), + coords=dict( + time=xr.DataArray(data=np.datetime64("1987-10-01 08:00")), + lat=xr.DataArray( + data=np.arange(start=90.0, stop=-90.0, step=-0.25, dtype=np.float32), + dims=("lat",), + ), + lon=xr.DataArray( + data=np.array([-122.25, -122.0, -121.75], dtype=np.float32), + dims=("lon",), + ), + ), + ) + assert_equal(result, reference) + + # Verify values of first and last record, and its coordinate. + assert result[0].values.tolist() == [ + -21.587493896484375, + -21.587493896484375, + -21.587493896484375, + ] + assert result[0].coords["lat"] == 90 + + assert result[-1].values.tolist() == [ + -43.399993896484375, + -43.399993896484375, + -43.399993896484375, + ] + assert result[-1].coords["lat"] == -89.75 From 283c6d40cc5272487ae9864022455d7f4a3fc211 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Sun, 8 Jan 2023 23:53:29 +0100 Subject: [PATCH 2/9] NwpIndex: Naming things --- herbie/index/core.py | 22 +++++++++++++--------- tests/test_index_era5.py | 8 +++----- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/herbie/index/core.py b/herbie/index/core.py index 0ccbbf46..724e28f3 100644 --- a/herbie/index/core.py +++ b/herbie/index/core.py @@ -96,7 +96,7 @@ def save(self, dataset: xr.Dataset): def round_location(self, value): return round_clipped(value, self.resolution) - def query(self, timestamp=None, lat=None, lon=None) -> "Result": + def query(self, time=None, lat=None, lon=None) -> "Result": # Query by point or range (bbox). if lat is None: @@ -105,7 +105,7 @@ def query(self, timestamp=None, lat=None, lon=None) -> "Result": elif isinstance(lat, float): idx_lat = np.where(self.coordinate.lat == self.round_location(lat))[0][0] lat_slice = Slice(start=idx_lat, stop=idx_lat + 2) - elif isinstance(lat, t.Sequence): + elif isinstance(lat, (t.Sequence, np.ndarray)): idx_lat = np.where( np.logical_and( self.coordinate.lat >= self.round_location(lat[0]), @@ -114,7 +114,7 @@ def query(self, timestamp=None, lat=None, lon=None) -> "Result": )[0] lat_slice = Slice(start=idx_lat[0], stop=idx_lat[-1] + 1) else: - raise ValueError(f"Unable to process value for lat={lat}") + raise ValueError(f"Unable to process value for lat={lat}, type={type(lat)}") if lon is None: idx_lon = np.where(self.coordinate.lon)[0] @@ -122,7 +122,7 @@ def query(self, timestamp=None, lat=None, lon=None) -> "Result": elif isinstance(lon, float): idx_lon = np.where(self.coordinate.lon == self.round_location(lon))[0][0] lon_slice = Slice(start=idx_lon, stop=idx_lon + 2) - elif isinstance(lon, t.Sequence): + elif isinstance(lon, (t.Sequence, np.ndarray)): idx_lon = np.where( np.logical_and( self.coordinate.lon >= self.round_location(lon[0]), @@ -131,17 +131,21 @@ def query(self, timestamp=None, lat=None, lon=None) -> "Result": )[0] lon_slice = Slice(start=idx_lon[0], stop=idx_lon[-1] + 1) else: - raise ValueError(f"Unable to process value for lon={lon}") + raise ValueError(f"Unable to process value for lon={lon}, type={type(lon)}") # Optionally query by timestamp, or not. - if timestamp: - idx_time = np.where(self.coordinate.time == np.datetime64(timestamp))[0][0] + if time is None: + filtered = self.data[:, lat_slice, lon_slice] + timestamp_coord = self.coordinate.time[:] + elif isinstance(time, str): + idx_time = np.where(self.coordinate.time == np.datetime64(time))[0][0] time_slice = Slice(idx_time, idx_time + 2) filtered = self.data[time_slice, lat_slice, lon_slice] timestamp_coord = self.coordinate.time[time_slice.start : time_slice.stop] else: - filtered = self.data[:, lat_slice, lon_slice] - timestamp_coord = self.coordinate.time[:] + raise ValueError( + f"Unable to process value for time={time}, type={type(time)}" + ) # Rebuild DataArray from result. outdata = xr.DataArray( diff --git a/tests/test_index_era5.py b/tests/test_index_era5.py index 30c8d8cb..5d2c66eb 100644 --- a/tests/test_index_era5.py +++ b/tests/test_index_era5.py @@ -38,7 +38,7 @@ def test_query_era5_monterey_fahrenheit_single_spot(era5_temp2m_index): # Temperatures in Monterey, in Fahrenheit. first = ( - nwp.query(timestamp="1987-10-01 08:00", lat=36.6083, lon=-121.8674) + nwp.query(time="1987-10-01 08:00", lat=36.6083, lon=-121.8674) .kelvin_to_fahrenheit() .select_first() ) @@ -91,7 +91,7 @@ def test_query_era5_monterey_fahrenheit_bbox_area(era5_temp2m_index): # Temperatures in Monterey area, in Fahrenheit. result = ( nwp.query( - timestamp="1987-10-01 08:00", + time="1987-10-01 08:00", lat=(36.450837, 36.700907), lon=(-122.166252, -121.655045), ) @@ -129,9 +129,7 @@ def test_query_era5_latitude_slice(era5_temp2m_index): # Temperatures for whole slice. result = ( - nwp.query( - timestamp="1987-10-01 08:00", lat=None, lon=(-122.166252, -121.655045) - ) + nwp.query(time="1987-10-01 08:00", lat=None, lon=(-122.166252, -121.655045)) .kelvin_to_celsius() .select_first_timestamp() ) From f8022296ce9ff334cc80c8db11ce3d4d1f4a62cd Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Sun, 8 Jan 2023 23:53:45 +0100 Subject: [PATCH 3/9] NwpIndex: Accept time ranges for querying --- herbie/index/core.py | 10 +++++ tests/test_index_era5.py | 80 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+) diff --git a/herbie/index/core.py b/herbie/index/core.py index 724e28f3..45165613 100644 --- a/herbie/index/core.py +++ b/herbie/index/core.py @@ -142,6 +142,16 @@ def query(self, time=None, lat=None, lon=None) -> "Result": time_slice = Slice(idx_time, idx_time + 2) filtered = self.data[time_slice, lat_slice, lon_slice] timestamp_coord = self.coordinate.time[time_slice.start : time_slice.stop] + elif isinstance(time, (t.Sequence, np.ndarray)): + idx_time = np.where( + np.logical_and( + self.coordinate.time >= time[0], + self.coordinate.time <= time[1], + ) + )[0] + time_slice = Slice(start=idx_time[0], stop=idx_time[-1] + 1) + filtered = self.data[time_slice, lat_slice, lon_slice] + timestamp_coord = self.coordinate.time[time_slice.start : time_slice.stop] else: raise ValueError( f"Unable to process value for time={time}, type={type(time)}" diff --git a/tests/test_index_era5.py b/tests/test_index_era5.py index 5d2c66eb..d6f99650 100644 --- a/tests/test_index_era5.py +++ b/tests/test_index_era5.py @@ -166,3 +166,83 @@ def test_query_era5_latitude_slice(era5_temp2m_index): -43.399993896484375, ] assert result[-1].coords["lat"] == -89.75 + + +def test_query_era5_time_slice_tuple(era5_temp2m_index): + """ + Query indexed ERA5 NWP data within given time range. + This variant uses a `tuple` for defining time range boundaries. + + While the input dataset contains three records, filtering by + time range should only yield two records. + """ + + # Load data. + nwp = era5_temp2m_index.load() + + # Temperatures for whole slice. + result = ( + nwp.query( + time=(np.datetime64("1987-10-01 08:00"), np.datetime64("1987-10-01 09:05")), + lat=52.51074, + lon=13.43506, + ) + .kelvin_to_celsius() + .select_first_point() + ) + + # Verify values and coordinates. + timerange = np.arange( + start=np.datetime64("1987-10-01 08:00"), + stop=np.datetime64("1987-10-01 09:01"), + step=datetime.timedelta(hours=1), + ) + reference = xr.DataArray( + data=np.array([6.600006, 6.600006], dtype=np.float32), + coords=dict( + time=xr.DataArray(data=timerange), + lat=xr.DataArray(data=np.float32(52.5)), + lon=xr.DataArray(data=np.float32(13.5)), + ), + ) + reference = reference.swap_dims(dim_0="time") + assert_equal(result, reference) + + +def test_query_era5_time_slice_range(era5_temp2m_index): + """ + Query indexed ERA5 NWP data within given time range. + This variant uses a `np.array` for defining time range boundaries. + + While the input dataset contains three records, filtering by + time range should only yield two records. + """ + + # Load data. + nwp = era5_temp2m_index.load() + + # Define timerange used for querying. + timerange = np.arange( + start=np.datetime64("1987-10-01 08:00"), + stop=np.datetime64("1987-10-01 09:01"), + step=datetime.timedelta(hours=1), + ) + + # Temperatures for whole slice. + result = ( + nwp.query(time=timerange, lat=52.51074, lon=13.43506) + .kelvin_to_celsius() + .select_first_point() + ) + + # Verify values and coordinates. + reference = xr.DataArray( + data=np.array([6.600006, 6.600006], dtype=np.float32), + coords=dict( + time=xr.DataArray(data=timerange), + lat=xr.DataArray(data=np.float32(52.5)), + lon=xr.DataArray(data=np.float32(13.5)), + ), + ) + reference = reference.swap_dims(dim_0="time") + assert_equal(result, reference) From 9c944756aa3450293ff14dfc49d2eef17efde85d Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Mon, 9 Jan 2023 01:17:31 +0100 Subject: [PATCH 4/9] NwpIndex: Clean up `query` implementation --- herbie/index/core.py | 139 ++++++++++++++++++++++----------------- herbie/index/loader.py | 2 +- tests/test_index_era5.py | 2 +- 3 files changed, 79 insertions(+), 64 deletions(-) diff --git a/herbie/index/core.py b/herbie/index/core.py index 45165613..0d0b86e0 100644 --- a/herbie/index/core.py +++ b/herbie/index/core.py @@ -58,6 +58,9 @@ def __init__(self, name, time_coordinate, resolution=None, data=None): self.path = self.BASEDIR.joinpath(self.name).with_suffix(".iarr") def load(self): + """ + Load data from ironArray file. + """ self.data: ia.IArray = ia.open(str(self.path)) logger.info(f"Loaded IArray from: {self.path}") logger.debug(f"IArray info:\n{self.data.info}") @@ -65,6 +68,8 @@ def load(self): def save(self, dataset: xr.Dataset): """ + Save data to ironArray file, effectively indexing it on all dimensions. + Derived from ironArray's `fetch_data.py` example program [1,2], and its documentation about "Configuring ironArray" [3]. @@ -93,82 +98,92 @@ def save(self, dataset: xr.Dataset): logger.debug(f"IArray info:\n{ia_data.info}") self.data = ia_data - def round_location(self, value): - return round_clipped(value, self.resolution) - def query(self, time=None, lat=None, lon=None) -> "Result": + """ + Query ironArray by multiple dimensions. + """ - # Query by point or range (bbox). - if lat is None: - idx_lat = np.where(self.coordinate.lat)[0] - lat_slice = Slice(start=idx_lat[0], stop=idx_lat[-1] + 1) - elif isinstance(lat, float): - idx_lat = np.where(self.coordinate.lat == self.round_location(lat))[0][0] - lat_slice = Slice(start=idx_lat, stop=idx_lat + 2) - elif isinstance(lat, (t.Sequence, np.ndarray)): - idx_lat = np.where( - np.logical_and( - self.coordinate.lat >= self.round_location(lat[0]), - self.coordinate.lat <= self.round_location(lat[1]), - ) - )[0] - lat_slice = Slice(start=idx_lat[0], stop=idx_lat[-1] + 1) - else: - raise ValueError(f"Unable to process value for lat={lat}, type={type(lat)}") - - if lon is None: - idx_lon = np.where(self.coordinate.lon)[0] - lon_slice = Slice(start=idx_lon[0], stop=idx_lon[-1] + 1) - elif isinstance(lon, float): - idx_lon = np.where(self.coordinate.lon == self.round_location(lon))[0][0] - lon_slice = Slice(start=idx_lon, stop=idx_lon + 2) - elif isinstance(lon, (t.Sequence, np.ndarray)): - idx_lon = np.where( + # Compute slices for time or time range, and geolocation point or range (bbox). + time_slice = self.time_slice(coordinate="time", value=time) + lat_slice = self.geo_slice(coordinate="lat", value=lat) + lon_slice = self.geo_slice(coordinate="lon", value=lon) + + # Slice data. + data = self.data[time_slice, lat_slice, lon_slice] + + # Rebuild DataArray from result. + outdata = xr.DataArray( + data, + dims=("time", "lat", "lon"), + coords={ + "time": self.coordinate.time[time_slice.start : time_slice.stop], + "lat": self.coordinate.lat[lat_slice.start : lat_slice.stop], + "lon": self.coordinate.lon[lon_slice.start : lon_slice.stop], + }, + ) + + return Result(da=outdata) + + def geo_slice(self, coordinate: str, value: t.Union[float, t.Sequence, np.ndarray]): + """ + Compute slice for geolocation point or range (bbox). + """ + + coord = getattr(self.coordinate, coordinate) + + if value is None: + idx = np.where(coord)[0] + effective_slice = Slice(start=idx[0], stop=idx[-1] + 1) + elif isinstance(value, float): + idx = np.where(coord == self.round_location(value))[0][0] + effective_slice = Slice(start=idx, stop=idx + 2) + elif isinstance(value, (t.Sequence, np.ndarray)): + idx = np.where( np.logical_and( - self.coordinate.lon >= self.round_location(lon[0]), - self.coordinate.lon <= self.round_location(lon[1]), + coord >= self.round_location(value[0]), + coord <= self.round_location(value[1]), ) )[0] - lon_slice = Slice(start=idx_lon[0], stop=idx_lon[-1] + 1) + effective_slice = Slice(start=idx[0], stop=idx[-1] + 1) else: - raise ValueError(f"Unable to process value for lon={lon}, type={type(lon)}") - - # Optionally query by timestamp, or not. - if time is None: - filtered = self.data[:, lat_slice, lon_slice] - timestamp_coord = self.coordinate.time[:] - elif isinstance(time, str): - idx_time = np.where(self.coordinate.time == np.datetime64(time))[0][0] - time_slice = Slice(idx_time, idx_time + 2) - filtered = self.data[time_slice, lat_slice, lon_slice] - timestamp_coord = self.coordinate.time[time_slice.start : time_slice.stop] - elif isinstance(time, (t.Sequence, np.ndarray)): - idx_time = np.where( + raise ValueError( + f"Unable to process value for {coordinate}={value}, type={type(value)}" + ) + + return effective_slice + + def time_slice( + self, coordinate: str, value: t.Union[float, t.Sequence, np.ndarray] + ): + """ + Compute slice for time or time range. + """ + + coord = getattr(self.coordinate, coordinate) + + if value is None: + idx = np.where(self.coordinate.time)[0] + effective_slice = Slice(idx[0], idx[-1] + 1) + elif isinstance(value, str): + idx = np.where(coord == np.datetime64(value))[0][0] + effective_slice = Slice(idx, idx + 2) + elif isinstance(value, (t.Sequence, np.ndarray)): + idx = np.where( np.logical_and( - self.coordinate.time >= time[0], - self.coordinate.time <= time[1], + coord >= np.datetime64(value[0]), + coord <= np.datetime64(value[1]), ) )[0] - time_slice = Slice(start=idx_time[0], stop=idx_time[-1] + 1) - filtered = self.data[time_slice, lat_slice, lon_slice] - timestamp_coord = self.coordinate.time[time_slice.start : time_slice.stop] + effective_slice = Slice(start=idx[0], stop=idx[-1] + 1) else: raise ValueError( - f"Unable to process value for time={time}, type={type(time)}" + f"Unable to process value for {coordinate}={value}, type={type(value)}" ) - # Rebuild DataArray from result. - outdata = xr.DataArray( - filtered, - dims=("time", "lat", "lon"), - coords={ - "lat": self.coordinate.lat[lat_slice.start : lat_slice.stop], - "lon": self.coordinate.lon[lon_slice.start : lon_slice.stop], - "time": timestamp_coord, - }, - ) + return effective_slice - return Result(da=outdata) + def round_location(self, value): + return round_clipped(value, self.resolution) @dataclasses.dataclass diff --git a/herbie/index/loader.py b/herbie/index/loader.py index 533f6c5b..5f9b7e4e 100644 --- a/herbie/index/loader.py +++ b/herbie/index/loader.py @@ -17,7 +17,7 @@ def open_era5_zarr(parameter, year, month, datestart, dateend) -> xr.Dataset: """ - Load "ERA5 forecasts reanalysis" data from ECMWF. + Load "ERA5 forecasts reanalysis" data from ECMWF, using Zarr. The ERA5 HRES atmospheric data has a resolution of 31km, 0.28125 degrees [1]. The implementation is derived from ironArray's "Slicing Datasets and Creating diff --git a/tests/test_index_era5.py b/tests/test_index_era5.py index d6f99650..240333e7 100644 --- a/tests/test_index_era5.py +++ b/tests/test_index_era5.py @@ -56,7 +56,7 @@ def test_query_era5_monterey_fahrenheit_single_spot(era5_temp2m_index): def test_query_era5_berlin_celsius_location_full_timerange(era5_temp2m_index): """ - Query indexed ERA5 NWP data for a sequence of time steps. + Query indexed ERA5 NWP data for the whole time range. """ nwp = era5_temp2m_index.load() From 9e2cffd883918eaa50dac8e2d50a96583574032d Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Mon, 9 Jan 2023 02:15:25 +0100 Subject: [PATCH 5/9] NwpIndex: Add convenience accessor property `Result.data` It will auto-select the shape of the return value, based on the shape of the query parameters. --- herbie/index/core.py | 35 +++++++++++++++++++++++++++++++++-- herbie/index/util.py | 7 +++++++ tests/test_index_era5.py | 16 ++++++---------- 3 files changed, 46 insertions(+), 12 deletions(-) diff --git a/herbie/index/core.py b/herbie/index/core.py index 0d0b86e0..ab4096de 100644 --- a/herbie/index/core.py +++ b/herbie/index/core.py @@ -13,7 +13,7 @@ from ndindex import Slice from scipy.constants import convert_temperature -from herbie.index.util import dataset_info, round_clipped +from herbie.index.util import dataset_info, is_sequence, round_clipped logger = logging.getLogger(__name__) @@ -122,7 +122,7 @@ def query(self, time=None, lat=None, lon=None) -> "Result": }, ) - return Result(da=outdata) + return Result(qp=QueryParameter(time=time, lat=lat, lon=lon), da=outdata) def geo_slice(self, coordinate: str, value: t.Union[float, t.Sequence, np.ndarray]): """ @@ -186,6 +186,13 @@ def round_location(self, value): return round_clipped(value, self.resolution) +@dataclasses.dataclass +class QueryParameter: + time: t.Optional[str] = None + lat: t.Optional[float] = None + lon: t.Optional[float] = None + + @dataclasses.dataclass class Coordinate: """ @@ -211,6 +218,7 @@ class Result: Wrap query result, and provide convenience accessor methods and value converters. """ + qp: QueryParameter da: xr.DataArray def select_first(self) -> xr.DataArray: @@ -229,3 +237,26 @@ def kelvin_to_celsius(self): def kelvin_to_fahrenheit(self): self.da.values = convert_temperature(self.da.values, "Kelvin", "Fahrenheit") return self + + @property + def data(self) -> xr.DataArray: + """ + Auto-select shape of return value, based on the shape of the query parameters. + """ + all_defined = all( + v is not None for v in [self.qp.time, self.qp.lat, self.qp.lon] + ) + is_time_range = is_sequence(self.qp.time) + is_lat_range = is_sequence(self.qp.lat) + is_lon_range = is_sequence(self.qp.lon) + if all_defined and not any([is_time_range, is_lat_range, is_lon_range]): + return self.select_first() + elif not any([is_lat_range, is_lon_range]): + return self.select_first_point() + elif self.qp.time and not is_time_range: + return self.select_first_timestamp() + else: + raise ValueError( + f"Unable to auto-select shape of return value, " + f"query parameters have unknown shape: {self.qp}" + ) diff --git a/herbie/index/util.py b/herbie/index/util.py index 562d5d06..9ac064e6 100644 --- a/herbie/index/util.py +++ b/herbie/index/util.py @@ -4,6 +4,9 @@ import io import logging import sys +import typing as t + +import numpy as np def round_clipped(value, clipping): @@ -29,3 +32,7 @@ def dataset_info(ds) -> str: ds.info(buf) buf.seek(0) return buf.read() + + +def is_sequence(value): + return not isinstance(value, str) and isinstance(value, (t.Sequence, np.ndarray)) diff --git a/tests/test_index_era5.py b/tests/test_index_era5.py index 240333e7..9d0a9f2c 100644 --- a/tests/test_index_era5.py +++ b/tests/test_index_era5.py @@ -40,7 +40,7 @@ def test_query_era5_monterey_fahrenheit_single_spot(era5_temp2m_index): first = ( nwp.query(time="1987-10-01 08:00", lat=36.6083, lon=-121.8674) .kelvin_to_fahrenheit() - .select_first() + .data ) # Verify values. @@ -62,9 +62,7 @@ def test_query_era5_berlin_celsius_location_full_timerange(era5_temp2m_index): nwp = era5_temp2m_index.load() # Temperatures in Berlin, in Celsius. - result = ( - nwp.query(lat=52.51074, lon=13.43506).kelvin_to_celsius().select_first_point() - ) + result = nwp.query(lat=52.51074, lon=13.43506).kelvin_to_celsius().data # Verify values and coordinates. reference = xr.DataArray( @@ -96,7 +94,7 @@ def test_query_era5_monterey_fahrenheit_bbox_area(era5_temp2m_index): lon=(-122.166252, -121.655045), ) .kelvin_to_fahrenheit() - .select_first_timestamp() + .data ) # Verify values and coordinates. @@ -131,7 +129,7 @@ def test_query_era5_latitude_slice(era5_temp2m_index): result = ( nwp.query(time="1987-10-01 08:00", lat=None, lon=(-122.166252, -121.655045)) .kelvin_to_celsius() - .select_first_timestamp() + .data ) # Verify coordinates. @@ -188,7 +186,7 @@ def test_query_era5_time_slice_tuple(era5_temp2m_index): lon=13.43506, ) .kelvin_to_celsius() - .select_first_point() + .data ) # Verify values and coordinates. @@ -230,9 +228,7 @@ def test_query_era5_time_slice_range(era5_temp2m_index): # Temperatures for whole slice. result = ( - nwp.query(time=timerange, lat=52.51074, lon=13.43506) - .kelvin_to_celsius() - .select_first_point() + nwp.query(time=timerange, lat=52.51074, lon=13.43506).kelvin_to_celsius().data ) # Verify values and coordinates. From 6a689cb0ce28fbf9ddcb294ec58eeaa48dfbb8da Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Sat, 14 Jan 2023 00:56:08 +0100 Subject: [PATCH 6/9] NwpIndex: Persist original Dataset's schema as netCDF and JSON Because the engine breaks out the actual data from the Dataset into an ironArray, and its metadata will get lost, it is needed to collect the schema information about coordinates and dimensions and store it separately from the `.iarr` data. This will allow to reconstruct the original Dataset as close as possible when reading the data back. Currently, two files, `schema.nc`, and `schema.json` will be stored within the corresponding `.iarr` folder. --- herbie/index/core.py | 169 ++++++++++++++++++++++++--------------- herbie/index/loader.py | 11 +-- herbie/index/model.py | 98 +++++++++++++++++++++++ herbie/index/util.py | 12 ++- tests/test_index_era5.py | 89 +++++++++++---------- 5 files changed, 268 insertions(+), 111 deletions(-) create mode 100644 herbie/index/model.py diff --git a/herbie/index/core.py b/herbie/index/core.py index ab4096de..6d973022 100644 --- a/herbie/index/core.py +++ b/herbie/index/core.py @@ -13,7 +13,13 @@ from ndindex import Slice from scipy.constants import convert_temperature -from herbie.index.util import dataset_info, is_sequence, round_clipped +from herbie.index.model import DataSchema, QueryParameter +from herbie.index.util import ( + dataset_get_data_variable_names, + dataset_info, + is_sequence, + round_clipped, +) logger = logging.getLogger(__name__) @@ -24,6 +30,10 @@ class NwpIndex: - https://caterva.readthedocs.io/ - https://ironarray.io/docs/html/ + + TODO: Think about making this an xarray accessor, e.g. `ds.xindex`. + + - https://docs.xarray.dev/en/stable/internals/extending-xarray.html """ # Where the ironArray files (`.iarr`) will be stored. @@ -34,10 +44,10 @@ class NwpIndex: # Alternatively, just use the working directory for now. BASEDIR = Path(os.path.curdir) - # Default ironArray configuration. + # Configure ironArray. IA_CONFIG = dict( - codec=ia.Codec.LZ4, - clevel=9, + codec=ia.Codec.ZSTD, + clevel=1, # How to choose the best numbers? # https://ironarray.io/docs/html/tutorials/03.Slicing_Datasets_and_Creating_Views.html#Optimization-Tips chunks=(360, 360, 720), @@ -48,22 +58,45 @@ class NwpIndex: # nthreads=12, ) - def __init__(self, name, time_coordinate, resolution=None, data=None): - self.name = name - self.resolution = resolution - self.coordinate = Coordinate(time=time_coordinate) - if self.resolution: - self.coordinate.mkgrid(resolution=self.resolution) - self.data: ia.IArray = data + def __init__(self, name, resolution=None, schema=None, dataset=None, irondata=None): + self.name: str = name + self._resolution: float = resolution + self.dataset: xr.Dataset = dataset + self.irondata: ia.IArray = irondata + self.path = self.BASEDIR.joinpath(self.name).with_suffix(".iarr") + self.schema: DataSchema = schema or DataSchema(path=self.path) + + def exists(self): + return self.path.exists() + + @property + def resolution(self): + if self._resolution: + return self._resolution + elif self.schema.ds is not None: + return self.schema.get_resolution() + else: + raise ValueError("Resolution is required for querying the Dataset by geospatial coordinates") + + @resolution.setter + def resolution(self, value): + self._resolution = value def load(self): """ Load data from ironArray file. """ - self.data: ia.IArray = ia.open(str(self.path)) + + # Load data. + # TODO: Handle multiple variable names. + self.irondata: ia.IArray = ia.open(str(self.path)) logger.info(f"Loaded IArray from: {self.path}") - logger.debug(f"IArray info:\n{self.data.info}") + logger.debug(f"IArray info:\n{self.irondata.info}") + + # Load schema. + self.schema.load() + return self def save(self, dataset: xr.Dataset): @@ -79,7 +112,9 @@ def save(self, dataset: xr.Dataset): """ # Use data from first data variable within dataset. - data_variable = list(dataset.data_vars.keys())[0] + # TODO: Handle multiple variable names. + data_variables = dataset_get_data_variable_names(dataset) + data_variable = data_variables[0] logger.info(f"Discovered dataset variable: {data_variable}") logger.info(f"Storing and indexing to: {self.path}") logger.debug(f"Dataset info:\n{dataset_info(dataset)}") @@ -96,7 +131,10 @@ def save(self, dataset: xr.Dataset): ia_data[:] = data.values logger.info(f"IArray is ready") logger.debug(f"IArray info:\n{ia_data.info}") - self.data = ia_data + self.irondata = ia_data + + # Save schema. + self.schema.save(ds=dataset) def query(self, time=None, lat=None, lon=None) -> "Result": """ @@ -109,27 +147,45 @@ def query(self, time=None, lat=None, lon=None) -> "Result": lon_slice = self.geo_slice(coordinate="lon", value=lon) # Slice data. - data = self.data[time_slice, lat_slice, lon_slice] - - # Rebuild DataArray from result. - outdata = xr.DataArray( - data, - dims=("time", "lat", "lon"), - coords={ - "time": self.coordinate.time[time_slice.start : time_slice.stop], - "lat": self.coordinate.lat[lat_slice.start : lat_slice.stop], - "lon": self.coordinate.lon[lon_slice.start : lon_slice.stop], - }, - ) + data = self.irondata[time_slice, lat_slice, lon_slice] + + # Rebuild Dataset from result. + coords = { + "time": self.schema.ds.coords["time"][time_slice.start: time_slice.stop], + "lat": self.schema.ds.coords["lat"][lat_slice.start: lat_slice.stop], + "lon": self.schema.ds.coords["lon"][lon_slice.start: lon_slice.stop], + } + ds = self.to_dataset(data, coords=coords) + + return Result(qp=QueryParameter(time=time, lat=lat, lon=lon), ds=ds) + + def to_dataset(self, irondata, coords): + """ + Re-create Xarray Dataset from ironArray data and coordinates. + + The intention is to emit a Dataset which has the same character + as the Dataset originally loaded from GRIB/netCDF/HDF5/Zarr. + """ + + # Re-create empty Dataset with original shape. + schema = self.schema.ds.copy(deep=True) + dataset = xr.Dataset(data_vars=schema.data_vars, coords=coords, attrs=schema.attrs) - return Result(qp=QueryParameter(time=time, lat=lat, lon=lon), da=outdata) + # Populate data. + # TODO: Handle more than one variable. + # TODO: Is there a faster operation than using `list(irondata)`? + variable0_info = self.schema.metadata["variables"][0] + variable0_name = variable0_info["name"] + dataset[variable0_name] = xr.DataArray(list(irondata), **variable0_info) + + return dataset def geo_slice(self, coordinate: str, value: t.Union[float, t.Sequence, np.ndarray]): """ Compute slice for geolocation point or range (bbox). """ - coord = getattr(self.coordinate, coordinate) + coord = self.schema.ds.coords[coordinate] if value is None: idx = np.where(coord)[0] @@ -159,10 +215,10 @@ def time_slice( Compute slice for time or time range. """ - coord = getattr(self.coordinate, coordinate) + coord = self.schema.ds.coords[coordinate] if value is None: - idx = np.where(self.coordinate.time)[0] + idx = np.where(coord)[0] effective_slice = Slice(idx[0], idx[-1] + 1) elif isinstance(value, str): idx = np.where(coord == np.datetime64(value))[0][0] @@ -186,32 +242,6 @@ def round_location(self, value): return round_clipped(value, self.resolution) -@dataclasses.dataclass -class QueryParameter: - time: t.Optional[str] = None - lat: t.Optional[float] = None - lon: t.Optional[float] = None - - -@dataclasses.dataclass -class Coordinate: - """ - Manage data for all available coordinates. - - # TODO: How could this meta information be carried over from the source data? - """ - - time: t.Optional[np.ndarray] = None - lat: t.Optional[np.ndarray] = None - lon: t.Optional[np.ndarray] = None - - def mkgrid(self, resolution: float): - self.lat = np.arange(start=90.0, stop=-90.0, step=-resolution, dtype=np.float32) - self.lon = np.arange( - start=-180.0, stop=180.0, step=resolution, dtype=np.float32 - ) - - @dataclasses.dataclass class Result: """ @@ -219,23 +249,36 @@ class Result: """ qp: QueryParameter - da: xr.DataArray + ds: xr.Dataset + + @property + def pv(self): + """ + Return primary variable name. That is, the first one. + + # TODO: Handle multiple variable names. + """ + return list(self.ds.data_vars.keys())[0] def select_first(self) -> xr.DataArray: - return self.da[0][0][0] + return self.ds[self.pv][0][0][0] def select_first_point(self): - return self.da.sel(lat=self.da["lat"][0], lon=self.da["lon"][0]) + da = self.ds[self.pv] + return da.sel(lat=da["lat"][0], lon=da["lon"][0]) def select_first_timestamp(self): - return self.da.sel(time=self.da["time"][0]) + da = self.ds[self.pv] + return da.sel(time=da["time"][0]) def kelvin_to_celsius(self): - self.da.values = convert_temperature(self.da.values, "Kelvin", "Celsius") + da = self.ds[self.pv] + da.values = convert_temperature(da.values, "Kelvin", "Celsius") return self def kelvin_to_fahrenheit(self): - self.da.values = convert_temperature(self.da.values, "Kelvin", "Fahrenheit") + da = self.ds[self.pv] + da.values = convert_temperature(da.values, "Kelvin", "Fahrenheit") return self @property diff --git a/herbie/index/loader.py b/herbie/index/loader.py index 5f9b7e4e..cf6e4f34 100644 --- a/herbie/index/loader.py +++ b/herbie/index/loader.py @@ -15,7 +15,7 @@ CACHE_BASEDIR = platformdirs.user_cache_path("herbie").joinpath("index-download") -def open_era5_zarr(parameter, year, month, datestart, dateend) -> xr.Dataset: +def open_era5_zarr(parameter, year, month, datestart=None, dateend=None) -> xr.Dataset: """ Load "ERA5 forecasts reanalysis" data from ECMWF, using Zarr. The ERA5 HRES atmospheric data has a resolution of 31km, 0.28125 degrees [1]. @@ -43,15 +43,16 @@ def open_era5_zarr(parameter, year, month, datestart, dateend) -> xr.Dataset: s3map = s3fs.S3Map(location, s3=fs) ds = xr.open_dataset(s3map, engine="zarr") - # The name of the `time` coordinate differs between datasets. + # The name of the `time` coordinate may be different between datasets. time_field_candidates = ["time0", "time1"] for candidate in time_field_candidates: if candidate in ds.coords: - time_field = candidate + ds = ds.rename({candidate: "time"}) # Select subset of data based on time range. - indexers = {time_field: slice(np.datetime64(datestart), np.datetime64(dateend))} - ds = ds.sel(indexers=indexers) + if datestart and dateend: + indexers = {"time": slice(np.datetime64(datestart), np.datetime64(dateend))} + ds = ds.sel(indexers=indexers) # Rearrange coordinates data from longitude 0 to 360 degrees (long3) to -180 to 180 degrees (long1). ds = ds.assign(lon=ds["lon"] - 180) diff --git a/herbie/index/model.py b/herbie/index/model.py new file mode 100644 index 00000000..800b2dc7 --- /dev/null +++ b/herbie/index/model.py @@ -0,0 +1,98 @@ +# MIT License +# (c) 2023 Andreas Motl +# https://github.com/earthobservations +import dataclasses +import json +import typing as t +from pathlib import Path + +import xarray as xr + +from herbie.index.util import dataset_get_data_variable_names, dataset_without_data + + +@dataclasses.dataclass +class DataSchema: + """ + Manage saving and loading an Xarray Dataset schema in netCDF format. + + That means, on saving, all data variables are dropped, but metadata + information about them is stored alongside the data. This information + is reused when re-creating the Dataset in the same shape. + """ + + path: Path + ds: xr.Dataset = None + metadata: t.Dict = None + nc_file: Path = dataclasses.field(init=False) + json_file: Path = dataclasses.field(init=False) + + def __post_init__(self): + self.nc_file = self.path.joinpath("schema.nc") + self.json_file = self.path.joinpath("schema.json") + + def load(self): + """ + Load metadata information for Dataset from netCDF file. + """ + self.ds = xr.load_dataset(self.nc_file) + with open(self.json_file, "r") as fp: + self.metadata = json.load(fp) + + def save(self, ds: xr.Dataset): + """ + Strip data off Dataset, and save its metadata information into netCDF file. + """ + + self.ds = dataset_without_data(ds) + self.metadata = self.get_metadata(ds) + + self.ds.to_netcdf(self.nc_file) + with open(self.json_file, "w") as fp: + json.dump(self.metadata, fp, indent=2) + + @staticmethod + def get_metadata(ds: xr.Dataset): + """ + Get metadata from Dataset. + + This metadata is needed in order to save it for reconstructing the + complete Dataset later. + """ + result = [] + data_variables = dataset_get_data_variable_names(ds) + for variable in data_variables: + da: xr.DataArray = ds[variable] + item = { + "name": da.name, + "attrs": dict(da.attrs), + "dims": list(da.dims), + } + result.append(item) + return {"variables": result} + + def get_resolution(self): + """ + Derive resolution of grid from coordinates. + """ + lat_coord = self.ds.coords["lat"] + lon_coord = self.ds.coords["lon"] + lat_delta = lat_coord[1].values - lat_coord[0].values + lon_delta = lon_coord[1].values - lon_coord[0].values + if abs(lat_delta) == abs(lon_delta): + return abs(lat_delta) + else: + raise ValueError( + "Resolution computed from coordinates deviates between latitude and longitude" + ) + + +@dataclasses.dataclass +class QueryParameter: + """ + Manage query parameters. + """ + + time: t.Optional[str] = None + lat: t.Optional[float] = None + lon: t.Optional[float] = None diff --git a/herbie/index/util.py b/herbie/index/util.py index 9ac064e6..b5d97103 100644 --- a/herbie/index/util.py +++ b/herbie/index/util.py @@ -7,6 +7,7 @@ import typing as t import numpy as np +import xarray as xr def round_clipped(value, clipping): @@ -27,7 +28,7 @@ def setup_logging(level=logging.INFO): requests_log.setLevel(logging.INFO) -def dataset_info(ds) -> str: +def dataset_info(ds: xr.Dataset) -> str: buf = io.StringIO() ds.info(buf) buf.seek(0) @@ -36,3 +37,12 @@ def dataset_info(ds) -> str: def is_sequence(value): return not isinstance(value, str) and isinstance(value, (t.Sequence, np.ndarray)) + + +def dataset_get_data_variable_names(ds: xr.Dataset): + return list(ds.data_vars.keys()) + + +def dataset_without_data(ds: xr.Dataset): + data_variables = dataset_get_data_variable_names(ds) + return ds.drop_vars(names=data_variables) diff --git a/tests/test_index_era5.py b/tests/test_index_era5.py index 9d0a9f2c..cc6362c2 100644 --- a/tests/test_index_era5.py +++ b/tests/test_index_era5.py @@ -22,47 +22,49 @@ @pytest.fixture -def era5_temp2m_index(): - nwp = NwpIndex(name=TEMP2M, time_coordinate=TIMERANGE, resolution=0.25) - if not nwp.path.exists(): +def era5_temp2m(): + """ + Provide an instance of `NwpIndex` to the test cases. + """ + nwp = NwpIndex(name=TEMP2M) + if not nwp.exists(): nwp.save(dataset=open_era5_zarr(TEMP2M, 1987, 10, TIMERANGE[0], TIMERANGE[-1])) + nwp.load() return nwp -def test_query_era5_monterey_fahrenheit_single_spot(era5_temp2m_index): +def test_query_era5_monterey_fahrenheit_point_time(era5_temp2m): """ - Query indexed ERA5 NWP data for a specific point in space and time. + Query indexed ERA5 NWP data for a specific geopoint and time. """ - nwp = era5_temp2m_index.load() - # Temperatures in Monterey, in Fahrenheit. - first = ( - nwp.query(time="1987-10-01 08:00", lat=36.6083, lon=-121.8674) + item = ( + era5_temp2m.query(time="1987-10-01 08:00", lat=36.6083, lon=-121.8674) .kelvin_to_fahrenheit() .data ) # Verify values. - assert first.values == np.array(73.805008, dtype=np.float32) + assert item.values == np.array(73.805008, dtype=np.float32) # Verify coordinate. - assert dict(first.coords) == dict( + assert dict(item.coords) == dict( time=xr.DataArray(data=np.datetime64("1987-10-01 08:00"), name="time"), lat=xr.DataArray(data=np.float32(36.5), name="lat"), lon=xr.DataArray(data=np.float32(-121.75), name="lon"), ) -def test_query_era5_berlin_celsius_location_full_timerange(era5_temp2m_index): +def test_query_era5_berlin_celsius_point_timerange(era5_temp2m): """ - Query indexed ERA5 NWP data for the whole time range. + Query indexed ERA5 NWP data for the whole time range at a specific geopoint. """ - nwp = era5_temp2m_index.load() - # Temperatures in Berlin, in Celsius. - result = nwp.query(lat=52.51074, lon=13.43506).kelvin_to_celsius().data + result = era5_temp2m.query(lat=52.51074, lon=13.43506).kelvin_to_celsius().data + assert len(result.data) == 3 + assert result.shape == (3,) # Verify values and coordinates. reference = xr.DataArray( @@ -77,18 +79,16 @@ def test_query_era5_berlin_celsius_location_full_timerange(era5_temp2m_index): assert_equal(result, reference) -def test_query_era5_monterey_fahrenheit_bbox_area(era5_temp2m_index): +def test_query_era5_bbox_time(era5_temp2m): """ - Query indexed ERA5 NWP data for a given area. + Query indexed ERA5 NWP data for a given area, defined by a bounding box. http://bboxfinder.com/ """ - nwp = era5_temp2m_index.load() - # Temperatures in Monterey area, in Fahrenheit. result = ( - nwp.query( + era5_temp2m.query( time="1987-10-01 08:00", lat=(36.450837, 36.700907), lon=(-122.166252, -121.655045), @@ -96,6 +96,8 @@ def test_query_era5_monterey_fahrenheit_bbox_area(era5_temp2m_index): .kelvin_to_fahrenheit() .data ) + assert len(result.data) == 2 + assert result.shape == (2, 3) # Verify values and coordinates. reference = xr.DataArray( @@ -118,19 +120,22 @@ def test_query_era5_monterey_fahrenheit_bbox_area(era5_temp2m_index): assert_equal(result, reference) -def test_query_era5_latitude_slice(era5_temp2m_index): +def test_query_era5_geoslice_time(era5_temp2m): """ - Query indexed ERA5 NWP data for a given area along the same longitude coordinates. + Query indexed ERA5 NWP data for a given slice on the latitude coordinate, + along the same longitude coordinates. """ - nwp = era5_temp2m_index.load() - # Temperatures for whole slice. result = ( - nwp.query(time="1987-10-01 08:00", lat=None, lon=(-122.166252, -121.655045)) + era5_temp2m.query( + time="1987-10-01 08:00", lat=None, lon=(-122.166252, -121.655045) + ) .kelvin_to_celsius() .data ) + assert len(result.data) == 721 + assert result.shape == (721, 3) # Verify coordinates. reference = xr.DataArray( @@ -139,7 +144,7 @@ def test_query_era5_latitude_slice(era5_temp2m_index): coords=dict( time=xr.DataArray(data=np.datetime64("1987-10-01 08:00")), lat=xr.DataArray( - data=np.arange(start=90.0, stop=-90.0, step=-0.25, dtype=np.float32), + data=np.arange(start=90.0, stop=-90.01, step=-0.25, dtype=np.float32), dims=("lat",), ), lon=xr.DataArray( @@ -159,14 +164,14 @@ def test_query_era5_latitude_slice(era5_temp2m_index): assert result[0].coords["lat"] == 90 assert result[-1].values.tolist() == [ - -43.399993896484375, - -43.399993896484375, - -43.399993896484375, + -43.837493896484375, + -43.837493896484375, + -43.837493896484375, ] - assert result[-1].coords["lat"] == -89.75 + assert result[-1].coords["lat"] == -90.0 -def test_query_era5_time_slice_tuple(era5_temp2m_index): +def test_query_era5_point_timerange_tuple(era5_temp2m): """ Query indexed ERA5 NWP data within given time range. This variant uses a `tuple` for defining time range boundaries. @@ -175,12 +180,9 @@ def test_query_era5_time_slice_tuple(era5_temp2m_index): time range should only yield two records. """ - # Load data. - nwp = era5_temp2m_index.load() - # Temperatures for whole slice. result = ( - nwp.query( + era5_temp2m.query( time=(np.datetime64("1987-10-01 08:00"), np.datetime64("1987-10-01 09:05")), lat=52.51074, lon=13.43506, @@ -188,6 +190,8 @@ def test_query_era5_time_slice_tuple(era5_temp2m_index): .kelvin_to_celsius() .data ) + assert len(result.data) == 2 + assert result.shape == (2,) # Verify values and coordinates. timerange = np.arange( @@ -207,18 +211,15 @@ def test_query_era5_time_slice_tuple(era5_temp2m_index): assert_equal(result, reference) -def test_query_era5_time_slice_range(era5_temp2m_index): +def test_query_era5_point_timerange_numpy(era5_temp2m): """ Query indexed ERA5 NWP data within given time range. - This variant uses a `np.array` for defining time range boundaries. + This variant uses an `np.array` for defining time range boundaries. While the input dataset contains three records, filtering by time range should only yield two records. """ - # Load data. - nwp = era5_temp2m_index.load() - # Define timerange used for querying. timerange = np.arange( start=np.datetime64("1987-10-01 08:00"), @@ -228,8 +229,12 @@ def test_query_era5_time_slice_range(era5_temp2m_index): # Temperatures for whole slice. result = ( - nwp.query(time=timerange, lat=52.51074, lon=13.43506).kelvin_to_celsius().data + era5_temp2m.query(time=timerange, lat=52.51074, lon=13.43506) + .kelvin_to_celsius() + .data ) + assert len(result.data) == 2 + assert result.shape == (2,) # Verify values and coordinates. reference = xr.DataArray( From 72614d873a51857cded286fedeb6afdbd8f5baa8 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Sat, 14 Jan 2023 02:51:35 +0100 Subject: [PATCH 7/9] NwpIndex: Unlock querying time range by using pandas' DateTimeIndex --- herbie/index/core.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/herbie/index/core.py b/herbie/index/core.py index 6d973022..3b95df30 100644 --- a/herbie/index/core.py +++ b/herbie/index/core.py @@ -9,6 +9,7 @@ import iarray_community as ia import numpy as np +import pandas as pd import xarray as xr from ndindex import Slice from scipy.constants import convert_temperature @@ -223,7 +224,7 @@ def time_slice( elif isinstance(value, str): idx = np.where(coord == np.datetime64(value))[0][0] effective_slice = Slice(idx, idx + 2) - elif isinstance(value, (t.Sequence, np.ndarray)): + elif isinstance(value, (t.Sequence, np.ndarray, pd.DatetimeIndex)): idx = np.where( np.logical_and( coord >= np.datetime64(value[0]), From 8c05e4015e1668730c12be51fe4d70d96145204f Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Sat, 14 Jan 2023 02:53:41 +0100 Subject: [PATCH 8/9] NwpIndex: Unlock querying location defined by bounding box --- herbie/index/core.py | 16 ++++++++-- herbie/index/model.py | 15 +++++++++ tests/test_index_era5.py | 68 +++++++++++++++++++++++++++++++++++++++- 3 files changed, 96 insertions(+), 3 deletions(-) diff --git a/herbie/index/core.py b/herbie/index/core.py index 3b95df30..9947f68c 100644 --- a/herbie/index/core.py +++ b/herbie/index/core.py @@ -14,7 +14,7 @@ from ndindex import Slice from scipy.constants import convert_temperature -from herbie.index.model import DataSchema, QueryParameter +from herbie.index.model import BBox, DataSchema, QueryParameter from herbie.index.util import ( dataset_get_data_variable_names, dataset_info, @@ -137,11 +137,22 @@ def save(self, dataset: xr.Dataset): # Save schema. self.schema.save(ds=dataset) - def query(self, time=None, lat=None, lon=None) -> "Result": + def query(self, time=None, location: t.Union[BBox] = None, lat=None, lon=None) -> "Result": """ Query ironArray by multiple dimensions. """ + if location is not None: + + # Select location by bounding box. + # https://boundingbox.klokantech.com/ + if isinstance(location, BBox): + lat = [location.lat1, location.lat2] + lon = [location.lon1, location.lon2] + + else: + raise ValueError(f"Unable to process location={location}, type={type(location)}") + # Compute slices for time or time range, and geolocation point or range (bbox). time_slice = self.time_slice(coordinate="time", value=time) lat_slice = self.geo_slice(coordinate="lat", value=lat) @@ -195,6 +206,7 @@ def geo_slice(self, coordinate: str, value: t.Union[float, t.Sequence, np.ndarra idx = np.where(coord == self.round_location(value))[0][0] effective_slice = Slice(start=idx, stop=idx + 2) elif isinstance(value, (t.Sequence, np.ndarray)): + value = sorted(value) idx = np.where( np.logical_and( coord >= self.round_location(value[0]), diff --git a/herbie/index/model.py b/herbie/index/model.py index 800b2dc7..b2779ea9 100644 --- a/herbie/index/model.py +++ b/herbie/index/model.py @@ -96,3 +96,18 @@ class QueryParameter: time: t.Optional[str] = None lat: t.Optional[float] = None lon: t.Optional[float] = None + + +@dataclasses.dataclass +class BBox: + """ + Manage bounding box information. + + # min_x, min_y, max_x, max_y + # (lon1, lat1, lon2, lat2) = c.bbox + """ + + lon1: float + lat1: float + lon2: float + lat2: float diff --git a/tests/test_index_era5.py b/tests/test_index_era5.py index cc6362c2..a8ae57ae 100644 --- a/tests/test_index_era5.py +++ b/tests/test_index_era5.py @@ -5,12 +5,14 @@ from unittest import mock import numpy as np +import pandas as pd import pytest import xarray as xr from xarray.testing import assert_equal from herbie.index.core import NwpIndex from herbie.index.loader import open_era5_zarr +from herbie.index.model import BBox TEMP2M = "air_temperature_at_2_metres" @@ -213,7 +215,7 @@ def test_query_era5_point_timerange_tuple(era5_temp2m): def test_query_era5_point_timerange_numpy(era5_temp2m): """ - Query indexed ERA5 NWP data within given time range. + Query indexed ERA5 NWP data at a specific point within given time range. This variant uses an `np.array` for defining time range boundaries. While the input dataset contains three records, filtering by @@ -247,3 +249,67 @@ def test_query_era5_point_timerange_numpy(era5_temp2m): ) reference = reference.swap_dims(dim_0="time") assert_equal(result, reference) + + +def test_query_era5_bbox_timerange(era5_temp2m): + """ + Query indexed ERA5 NWP data within a given area and time range. + + This variant uses a pandas `DatetimeIndex` for defining the time range + boundaries, and a `BBox` instance for defining a geospatial bounding box. + """ + + data_var = "air_temperature_at_2_metres" + + # Temperatures in Berlin area, in Celsius. + ds = ( + era5_temp2m.query( + time=pd.date_range( + start="1987-10-01 08:00", end="1987-10-01 09:00", freq="H" + ), + location=BBox(lon1=13.000, lat1=52.700, lon2=13.600, lat2=52.300), + ) + .kelvin_to_celsius() + .ds + ) + assert len(ds) == 1 + assert ds[data_var].shape == (2, 3, 3) + assert ds[data_var].dims == ("time", "lat", "lon") + + # Verify values and coordinates. + reference = xr.DataArray( + dims=("time", "lat", "lon"), + data=np.array( + [ + [ + [6.412506, 6.412506, 6.475006], + [6.537506, 6.537506, 6.600006], + [6.662506, 6.662506, 6.725006], + ], + [ + [6.350006, 6.412506, 6.412506], + [6.537506, 6.537506, 6.600006], + [6.662506, 6.662506, 6.662506], + ], + ], + dtype=np.float32, + ), + coords=dict( + time=xr.DataArray( + data=np.arange( + start=np.datetime64("1987-10-01 08:00:00"), + stop=np.datetime64("1987-10-01 09:00:01"), + step=datetime.timedelta(hours=1), + ), + name="time", + dims=("time",), + ), + lat=xr.DataArray( + data=np.array([52.75, 52.5, 52.25], dtype=np.float32), dims=("lat",) + ), + lon=xr.DataArray( + data=np.array([13.0, 13.25, 13.5], dtype=np.float32), dims=("lon",) + ), + ), + ) + assert_equal(ds[data_var], reference) From 7cdcac22cb42d2b20acde15f02f2ec9165dbb508 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Sat, 14 Jan 2023 05:28:44 +0100 Subject: [PATCH 9/9] NwpIndex: Unlock querying location defined by circular area --- herbie/index/core.py | 26 ++++++++++++-- herbie/index/model.py | 23 ++++++++++++ herbie/index/util.py | 3 ++ tests/test_index_era5.py | 75 ++++++++++++++++++++++++++++++++++++++-- 4 files changed, 123 insertions(+), 4 deletions(-) diff --git a/herbie/index/core.py b/herbie/index/core.py index 9947f68c..bef42bfe 100644 --- a/herbie/index/core.py +++ b/herbie/index/core.py @@ -10,16 +10,19 @@ import iarray_community as ia import numpy as np import pandas as pd +import shapely import xarray as xr from ndindex import Slice from scipy.constants import convert_temperature +from shapely.geometry import CAP_STYLE, Point, Polygon -from herbie.index.model import BBox, DataSchema, QueryParameter +from herbie.index.model import BBox, Circle, DataSchema, QueryParameter from herbie.index.util import ( dataset_get_data_variable_names, dataset_info, is_sequence, round_clipped, + unit, ) logger = logging.getLogger(__name__) @@ -137,13 +140,32 @@ def save(self, dataset: xr.Dataset): # Save schema. self.schema.save(ds=dataset) - def query(self, time=None, location: t.Union[BBox] = None, lat=None, lon=None) -> "Result": + def query(self, time=None, location: t.Union[BBox, Circle] = None, lat=None, lon=None) -> "Result": """ Query ironArray by multiple dimensions. """ if location is not None: + # Select location by circle (point and distance). + if isinstance(location, Circle): + circle: Circle = location + # At 38 degrees North latitude (which passes through Stockton California + # and Charlottesville Virginia), one degree of longitude equals 54.6 miles. + # => 0.25 degrees equal 13.65 miles. + # + # -- https://www.usgs.gov/faqs/how-much-distance-does-degree-minute-and-second-cover-your-maps + # + # FIXME: Verify this, and apply the correct conversion for other places on earth. + factor = 54.6 * self.resolution + distance = (circle.distance / (factor * unit.miles)).magnitude + + # Compute minimum bounding rectangle from circle. + point = Point([circle.point.longitude, circle.point.latitude]) \ + .buffer(distance, cap_style=CAP_STYLE.square) + bbox: Polygon = point.minimum_rotated_rectangle + location = BBox(*bbox.bounds) + # Select location by bounding box. # https://boundingbox.klokantech.com/ if isinstance(location, BBox): diff --git a/herbie/index/model.py b/herbie/index/model.py index b2779ea9..b8b7b50d 100644 --- a/herbie/index/model.py +++ b/herbie/index/model.py @@ -7,6 +7,7 @@ from pathlib import Path import xarray as xr +from pint import Quantity from herbie.index.util import dataset_get_data_variable_names, dataset_without_data @@ -98,6 +99,28 @@ class QueryParameter: lon: t.Optional[float] = None +@dataclasses.dataclass +class Point: + """ + Manage geopoint information. + """ + + longitude: float + latitude: float + + +@dataclasses.dataclass +class Circle: + """ + Manage geolocation circle information. + + Radius in kilometers. + """ + + point: Point + distance: Quantity + + @dataclasses.dataclass class BBox: """ diff --git a/herbie/index/util.py b/herbie/index/util.py index b5d97103..1d5f39ac 100644 --- a/herbie/index/util.py +++ b/herbie/index/util.py @@ -7,8 +7,11 @@ import typing as t import numpy as np +import pint import xarray as xr +unit = pint.UnitRegistry() + def round_clipped(value, clipping): """ diff --git a/tests/test_index_era5.py b/tests/test_index_era5.py index a8ae57ae..c4d427de 100644 --- a/tests/test_index_era5.py +++ b/tests/test_index_era5.py @@ -12,7 +12,8 @@ from herbie.index.core import NwpIndex from herbie.index.loader import open_era5_zarr -from herbie.index.model import BBox +from herbie.index.model import BBox, Circle, Point +from herbie.index.util import unit TEMP2M = "air_temperature_at_2_metres" @@ -253,7 +254,7 @@ def test_query_era5_point_timerange_numpy(era5_temp2m): def test_query_era5_bbox_timerange(era5_temp2m): """ - Query indexed ERA5 NWP data within a given area and time range. + Query indexed ERA5 NWP data within a given bounding box area and time range. This variant uses a pandas `DatetimeIndex` for defining the time range boundaries, and a `BBox` instance for defining a geospatial bounding box. @@ -313,3 +314,73 @@ def test_query_era5_bbox_timerange(era5_temp2m): ), ) assert_equal(ds[data_var], reference) + + +def test_query_era5_circle_timerange(era5_temp2m): + """ + Query indexed ERA5 NWP data within a given circular area and time range. + + This variant uses a pandas `DatetimeIndex` for defining the time range + boundaries, and a `Circle` instance for defining a geospatial bounding box. + """ + + data_var = "air_temperature_at_2_metres" + + # Temperatures in Monterey area, in Fahrenheit. + ds = ( + era5_temp2m.query( + time=pd.date_range( + start="1987-10-01 08:00", end="1987-10-01 09:00", freq="H" + ), + location=Circle( + Point(longitude=-121.8674, latitude=36.6083), distance=3.5 * unit.miles + ), + ) + .kelvin_to_fahrenheit() + .ds + ) + assert len(ds) == 1 + assert ds[data_var].shape == (2, 3, 3) + assert ds[data_var].dims == ("time", "lat", "lon") + + # Verify values and coordinates. + reference = xr.DataArray( + dims=("time", "lat", "lon"), + data=np.array( + [ + [ + [71.89251, 70.88001, 69.64251], + [75.717514, 73.80501, 72.11751], + [78.530014, 77.29251, 76.280014], + ], + [ + [72.23001, 71.217514, 69.98001], + [76.61751, 74.48001, 72.68001], + [79.76751, 78.41751, 77.18001], + ], + ], + dtype=np.float32, + ), + coords=dict( + time=xr.DataArray( + data=np.arange( + start=np.datetime64("1987-10-01 08:00:00"), + stop=np.datetime64("1987-10-01 09:00:00.001"), + step=datetime.timedelta(hours=1), + ), + name="time", + dims=("time",), + ), + lat=xr.DataArray( + data=np.array([36.75, 36.5, 36.25], dtype=np.float32), + name="lat", + dims=("lat",), + ), + lon=xr.DataArray( + data=np.array([-122.0, -121.75, -121.5], dtype=np.float32), + name="lon", + dims=("lon",), + ), + ), + ) + assert_equal(ds[data_var], reference)