Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prototype of bounded dataarray functionality #737

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 105 additions & 1 deletion xcdat/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from xcdat import bounds as bounds_accessor # noqa: F401
from xcdat._logger import _setup_custom_logger
from xcdat.axis import CFAxisKey, _get_all_coord_keys, swap_lon_axis
from xcdat.axis import CFAxisKey, _get_all_coord_keys, get_dim_keys, swap_lon_axis
from xcdat.axis import center_times as center_times_func

logger = _setup_custom_logger(__name__)
Expand Down Expand Up @@ -746,3 +746,107 @@ def _get_data_var(dataset: xr.Dataset, key: str) -> xr.DataArray:
raise KeyError(f"The data variable '{key}' does not exist in the Dataset.")

return dv.copy()


def get_bounded_dataarray(ds: xr.Dataset, key: str) -> xr.DataArray:
"""
Convert a dataset to a dataarray with the bounds embedded as coordinates
(i.e., a bounded DataArray).

Parameters
----------
dataset : xr.Dataset
The Dataset.
key : str
The data variable key.

Returns
-------
xr.DataArray
The bounded DataArray.

Raises
------
KeyError
If the data variable does not exist in the Dataset.
"""
ds = ds.copy()
# get dataarray
da = ds.get(key)
# check if dataarray exists
if da is None:
raise KeyError(f"The data variable '{key}' does not exist in the Dataset.")
# loop over coordinates to get coordinates and coordinate bounds
coords = {}
for c_key in ds[key].cf.axes.keys():
try:
# get dimension key (e.g., "time", "lat", "lon")
dim_key = get_dim_keys(ds, c_key)
# add axis to coordinate dict
coords[dim_key] = ds[dim_key]
# get coordinate dtype
dim_value_dtype = ds[dim_key].dtype
# create a bounds dtype (based on coordinate dtype)
bounds_dtype = np.dtype(
[("lower", dim_value_dtype), ("upper", dim_value_dtype)]
)
# get the bounds for axis
bnds = ds.bounds.get_bounds(axis=c_key)
# convert to expected form for bounds_dtype
newbnds = [tuple(row) for row in bnds.values]
# create new bounds object
newbnds = np.array(newbnds, dtype=bounds_dtype)
# add bounds to coordinate dict
coords[bnds.name] = (dim_key, newbnds)
except: # noqa: E722
continue
# return dataarray with bounds
da = xr.DataArray(ds[key], coords=coords, dims=ds[key].dims, attrs=ds[key].attrs)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm pasting the lat_bnds coordinate / dataarray that is embedded in the "bounded dataarray" – I'm wondering if this could take on a nicer form than this random dtype: ('lower', '<f8'), ('upper', '<f8'). I'm guessing if we had other bound types, we could make this more generic, e.g., ('lower_left', '<f8'), ('lower_right', '<f8'), ('upper_right', '<f8', ('upper_left', '<f8')).

da.lat_bnds

<xarray.DataArray 'lat_bnds' (lat: 64)> Size: 1kB
array([(-90. , -86.57774751), (-86.57774751, -83.75702878),
(-83.75702878, -80.95502019), (-80.95502019, -78.15834785),
...
( 78.15834785, 80.95502019), ( 80.95502019, 83.75702878),
( 83.75702878, 86.57774751), ( 86.57774751, 90. )],
dtype=[('lower', '<f8'), ('upper', '<f8')])
Coordinates:

  • lat (lat) float64 512B -87.86 -85.1 -82.31 -79.53 ... 82.31 85.1 87.86
    lat_bnds (lat) [('lower', '<f8'), ('upper', '<f8')] 1kB (-90.0, -86.5777...

return da


def boundedDataArray_to_dataset(bda):
"""
Convert a bounded dataarray to a dataset.

Parameters
----------
bda : xr.DataArray
The bounded dataarray.

Returns
-------
xr.Dataset
The dataset.

Notes
-----
Note that the .name attribute must be set in the dataarray.
"""
# convert to dataset object
ds = bda.to_dataset()
# loop over coordinates and convert data array bound coordinates
# to bounds dataarrays
for c_key in ds.cf.axes.keys():
try:
# get dimension key (e.g., "time", "lat", "lon")
dim_key = get_dim_keys(ds, c_key)
# get bounds key
bnds_key = ds[dim_key].bounds
# get bounds in xr.dataset form
bnds = [[b[0], b[1]] for b in ds[bnds_key].to_numpy()]
# remove bounds from dataarray
del ds[bnds_key]
# create bounds dataarray
bda = xr.DataArray(
data=bnds, dims=[dim_key, "bnds"], coords={dim_key: ds[dim_key]}
)
# update bounds in output dataset
ds[bnds_key] = bda
except: # noqa: E722
continue
return ds


# add get_bounded_array call to xr dataset objects
xr.Dataset.__call__ = get_bounded_dataarray
188 changes: 186 additions & 2 deletions xcdat/spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
get_dim_coords,
get_dim_keys,
)
from xcdat.dataset import _get_data_var
from xcdat.dataset import _get_data_var, boundedDataArray_to_dataset
from xcdat.utils import _if_multidim_dask_array_then_load

#: Type alias for a dictionary of axis keys mapped to their bounds.
Expand Down Expand Up @@ -173,7 +173,7 @@ def average(
Using custom weights for averaging:

>>> # The shape of the weights must align with the data var.
>>> self.weights = xr.DataArray(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This documentation update is more generic than this PR and should probably be implemented separately...

>>> weights = xr.DataArray(
>>> data=np.ones((4, 4)),
>>> coords={"lat": self.ds.lat, "lon": self.ds.lon},
>>> dims=["lat", "lon"],
Expand Down Expand Up @@ -742,3 +742,187 @@ def _averager(
weighted_mean = data_var.cf.weighted(weights).mean(dim=dim)

return weighted_mean


# %% dataset accessors
@xr.register_dataarray_accessor("spatial")
class SpatialAccessorDa:
def __init__(self, dataarray: xr.DataArray):
self._dataarray: xr.DataArray = dataarray

def average(
self,
axis: List[SpatialAxis] | Tuple[SpatialAxis, ...] = ("X", "Y"),
weights: Union[Literal["generate"], xr.DataArray] = "generate",
keep_weights: bool = False,
lat_bounds: Optional[RegionAxisBounds] = None,
lon_bounds: Optional[RegionAxisBounds] = None,
) -> xr.DataArray:
"""
Calculates the spatial average for a rectilinear grid over an optionally
specified regional domain.

Operations include:

- If a regional boundary is specified, check to ensure it is within the
data variable's domain boundary.
- If axis weights are not provided, get axis weights for standard axis
domains specified in ``axis``.
- Adjust weights to conform to the specified regional boundary.
- Compute spatial weighted average.

This method requires that the dataarray's coordinates have the 'axis'
attribute set to the keys in ``axis``. For example, the latitude
coordinates should have its 'axis' attribute set to 'Y' (which is also
CF-compliant). This 'axis' attribute is used to retrieve the related
coordinates via `cf_xarray`. Refer to this method's examples for more
information.

Parameters
----------
axis : List[SpatialAxis]
List of axis dimensions to average over, by default ("X", "Y").
Valid axis keys include "X" and "Y".
weights : {"generate", xr.DataArray}, optional
If "generate", then weights are generated. Otherwise, pass a
DataArray containing the regional weights used for weighted
averaging. ``weights`` must include the same spatial axis dimensions
and have the same dimensional sizes as the data variable, by default
"generate".
keep_weights : bool, optional
If calculating averages using weights, keep the weights in the
final dataset output, by default False.
lat_bounds : Optional[RegionAxisBounds], optional
A tuple of floats/ints for the regional latitude lower and upper
boundaries. This arg is used when calculating axis weights, but is
ignored if ``weights`` are supplied. The lower bound cannot be
larger than the upper bound, by default None.
lon_bounds : Optional[RegionAxisBounds], optional
A tuple of floats/ints for the regional longitude lower and upper
boundaries. This arg is used when calculating axis weights, but is
ignored if ``weights`` are supplied. The lower bound can be larger
than the upper bound (e.g., across the prime meridian, dateline), by
default None.

Returns
-------
xr.DataArray
Dataset with the spatially averaged variable.

Examples
--------

Check the 'axis' attribute is set on the required coordinates:

>>> da.lat.attrs["axis"]
>>> Y
>>>
>>> da.lon.attrs["axis"]
>>> X

Set the 'axis' attribute for the required coordinates if it isn't:

>>> da.lat.attrs["axis"] = "Y"
>>> da.lon.attrs["axis"] = "X"

Call spatial averaging method:

>>> da.spatial.average(...)

Get global average time series:

>>> ts_global = da.spatial.average(axis=["X", "Y"])["tas"]

Get time series in Nino 3.4 domain:

>>> ts_n34 = da.spatial.average(axis=["X", "Y"],
>>> lat_bounds=(-5, 5),
>>> lon_bounds=(-170, -120))["ts"]

Get zonal mean time series:

>>> ts_zonal = da.spatial.average(axis=["X"])["tas"]

Using custom weights for averaging:

>>> # The shape of the weights must align with the data var.
>>> weights = xr.DataArray(
>>> data=np.ones((4, 4)),
>>> coords={"lat": self.ds.lat, "lon": self.ds.lon},
>>> dims=["lat", "lon"],
>>> )
>>>
>>> ts_global = ds.spatial.average("tas", axis=["X", "Y"],
>>> weights=weights)["tas"]
"""
# convert dataarray to a dataset
da = self._dataarray.copy()
ds = boundedDataArray_to_dataset(da)
# get data_var key
data_var = da.name
# pass on call to spatial averager
ds_sa = ds.spatial.average(
data_var,
axis=axis,
weights=weights,
keep_weights=keep_weights,
lat_bounds=lat_bounds,
lon_bounds=lon_bounds,
Comment on lines +859 to +870
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the template: take your "bounded dataarray", convert it back to a dataset, and then call the existing xcdat API (passing along all the underlying arguments).

A shortfall, beyond replicating most (but not all) documentation, is that if we change the dataset API, we need to make sure to update it here, too. Maybe there is a fancy way to pass in **kwargs that can do this automatically (probably not)?

)
return ds_sa[data_var]

def get_weights(
self,
axis: List[SpatialAxis] | Tuple[SpatialAxis, ...],
lat_bounds: Optional[RegionAxisBounds] = None,
lon_bounds: Optional[RegionAxisBounds] = None,
data_var: Optional[str] = None,
) -> xr.DataArray:
"""
Get area weights for specified axis keys and an optional target domain.

This method first determines the weights for an individual axis based on
the difference between the upper and lower bound. For latitude the
weight is determined by the difference of sine(latitude). All axis
weights are then combined to form a DataArray of weights that can be
used to perform a weighted (spatial) average.

If ``lat_bounds`` or ``lon_bounds`` are supplied, then grid cells
outside this selected regional domain are given zero weight. Grid cells
that are partially in this domain are given partial weight.

Parameters
----------
axis : List[SpatialAxis] | Tuple[SpatialAxis, ...]
List of axis dimensions to average over.
lat_bounds : Optional[RegionAxisBounds]
Tuple of latitude boundaries for regional selection, by default
None.
lon_bounds : Optional[RegionAxisBounds]
Tuple of longitude boundaries for regional selection, by default
None.

Returns
-------
xr.DataArray
A DataArray containing the region weights to use during averaging.
``weights`` are 1-D and correspond to the specified axes (``axis``)
in the region.

Notes
-----
This method was developed for rectilinear grids only. ``get_weights()``
recognizes and operate on latitude and longitude, but could be extended
to work with other standard geophysical dimensions (e.g., time, depth,
and pressure).
"""
# convert dataarray to a dataset
da = self._dataarray.copy()
ds = boundedDataArray_to_dataset(da)
# get data_var key
data_var = da.name
# pass on call to get_weights
weights = ds.spatial.get_weights(
axis=axis, lat_bounds=lat_bounds, lon_bounds=lon_bounds, data_var=data_var
)
return weights
Loading