-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Prototype of bounded dataarray functionality #737
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,7 +26,7 @@ | |
get_dim_coords, | ||
get_dim_keys, | ||
) | ||
from xcdat.dataset import _get_data_var | ||
from xcdat.dataset import _get_data_var, boundedDataArray_to_dataset | ||
from xcdat.utils import _if_multidim_dask_array_then_load | ||
|
||
#: Type alias for a dictionary of axis keys mapped to their bounds. | ||
|
@@ -173,7 +173,7 @@ def average( | |
Using custom weights for averaging: | ||
|
||
>>> # The shape of the weights must align with the data var. | ||
>>> self.weights = xr.DataArray( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This documentation update is more generic than this PR and should probably be implemented separately... |
||
>>> weights = xr.DataArray( | ||
>>> data=np.ones((4, 4)), | ||
>>> coords={"lat": self.ds.lat, "lon": self.ds.lon}, | ||
>>> dims=["lat", "lon"], | ||
|
@@ -742,3 +742,187 @@ def _averager( | |
weighted_mean = data_var.cf.weighted(weights).mean(dim=dim) | ||
|
||
return weighted_mean | ||
|
||
|
||
# %% dataset accessors | ||
@xr.register_dataarray_accessor("spatial") | ||
class SpatialAccessorDa: | ||
def __init__(self, dataarray: xr.DataArray): | ||
self._dataarray: xr.DataArray = dataarray | ||
|
||
def average( | ||
self, | ||
axis: List[SpatialAxis] | Tuple[SpatialAxis, ...] = ("X", "Y"), | ||
weights: Union[Literal["generate"], xr.DataArray] = "generate", | ||
keep_weights: bool = False, | ||
lat_bounds: Optional[RegionAxisBounds] = None, | ||
lon_bounds: Optional[RegionAxisBounds] = None, | ||
) -> xr.DataArray: | ||
""" | ||
Calculates the spatial average for a rectilinear grid over an optionally | ||
specified regional domain. | ||
|
||
Operations include: | ||
|
||
- If a regional boundary is specified, check to ensure it is within the | ||
data variable's domain boundary. | ||
- If axis weights are not provided, get axis weights for standard axis | ||
domains specified in ``axis``. | ||
- Adjust weights to conform to the specified regional boundary. | ||
- Compute spatial weighted average. | ||
|
||
This method requires that the dataarray's coordinates have the 'axis' | ||
attribute set to the keys in ``axis``. For example, the latitude | ||
coordinates should have its 'axis' attribute set to 'Y' (which is also | ||
CF-compliant). This 'axis' attribute is used to retrieve the related | ||
coordinates via `cf_xarray`. Refer to this method's examples for more | ||
information. | ||
|
||
Parameters | ||
---------- | ||
axis : List[SpatialAxis] | ||
List of axis dimensions to average over, by default ("X", "Y"). | ||
Valid axis keys include "X" and "Y". | ||
weights : {"generate", xr.DataArray}, optional | ||
If "generate", then weights are generated. Otherwise, pass a | ||
DataArray containing the regional weights used for weighted | ||
averaging. ``weights`` must include the same spatial axis dimensions | ||
and have the same dimensional sizes as the data variable, by default | ||
"generate". | ||
keep_weights : bool, optional | ||
If calculating averages using weights, keep the weights in the | ||
final dataset output, by default False. | ||
lat_bounds : Optional[RegionAxisBounds], optional | ||
A tuple of floats/ints for the regional latitude lower and upper | ||
boundaries. This arg is used when calculating axis weights, but is | ||
ignored if ``weights`` are supplied. The lower bound cannot be | ||
larger than the upper bound, by default None. | ||
lon_bounds : Optional[RegionAxisBounds], optional | ||
A tuple of floats/ints for the regional longitude lower and upper | ||
boundaries. This arg is used when calculating axis weights, but is | ||
ignored if ``weights`` are supplied. The lower bound can be larger | ||
than the upper bound (e.g., across the prime meridian, dateline), by | ||
default None. | ||
|
||
Returns | ||
------- | ||
xr.DataArray | ||
Dataset with the spatially averaged variable. | ||
|
||
Examples | ||
-------- | ||
|
||
Check the 'axis' attribute is set on the required coordinates: | ||
|
||
>>> da.lat.attrs["axis"] | ||
>>> Y | ||
>>> | ||
>>> da.lon.attrs["axis"] | ||
>>> X | ||
|
||
Set the 'axis' attribute for the required coordinates if it isn't: | ||
|
||
>>> da.lat.attrs["axis"] = "Y" | ||
>>> da.lon.attrs["axis"] = "X" | ||
|
||
Call spatial averaging method: | ||
|
||
>>> da.spatial.average(...) | ||
|
||
Get global average time series: | ||
|
||
>>> ts_global = da.spatial.average(axis=["X", "Y"])["tas"] | ||
|
||
Get time series in Nino 3.4 domain: | ||
|
||
>>> ts_n34 = da.spatial.average(axis=["X", "Y"], | ||
>>> lat_bounds=(-5, 5), | ||
>>> lon_bounds=(-170, -120))["ts"] | ||
|
||
Get zonal mean time series: | ||
|
||
>>> ts_zonal = da.spatial.average(axis=["X"])["tas"] | ||
|
||
Using custom weights for averaging: | ||
|
||
>>> # The shape of the weights must align with the data var. | ||
>>> weights = xr.DataArray( | ||
>>> data=np.ones((4, 4)), | ||
>>> coords={"lat": self.ds.lat, "lon": self.ds.lon}, | ||
>>> dims=["lat", "lon"], | ||
>>> ) | ||
>>> | ||
>>> ts_global = ds.spatial.average("tas", axis=["X", "Y"], | ||
>>> weights=weights)["tas"] | ||
""" | ||
# convert dataarray to a dataset | ||
da = self._dataarray.copy() | ||
ds = boundedDataArray_to_dataset(da) | ||
# get data_var key | ||
data_var = da.name | ||
# pass on call to spatial averager | ||
ds_sa = ds.spatial.average( | ||
data_var, | ||
axis=axis, | ||
weights=weights, | ||
keep_weights=keep_weights, | ||
lat_bounds=lat_bounds, | ||
lon_bounds=lon_bounds, | ||
Comment on lines
+859
to
+870
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the template: take your "bounded dataarray", convert it back to a dataset, and then call the existing xcdat API (passing along all the underlying arguments). A shortfall, beyond replicating most (but not all) documentation, is that if we change the dataset API, we need to make sure to update it here, too. Maybe there is a fancy way to pass in |
||
) | ||
return ds_sa[data_var] | ||
|
||
def get_weights( | ||
self, | ||
axis: List[SpatialAxis] | Tuple[SpatialAxis, ...], | ||
lat_bounds: Optional[RegionAxisBounds] = None, | ||
lon_bounds: Optional[RegionAxisBounds] = None, | ||
data_var: Optional[str] = None, | ||
) -> xr.DataArray: | ||
""" | ||
Get area weights for specified axis keys and an optional target domain. | ||
|
||
This method first determines the weights for an individual axis based on | ||
the difference between the upper and lower bound. For latitude the | ||
weight is determined by the difference of sine(latitude). All axis | ||
weights are then combined to form a DataArray of weights that can be | ||
used to perform a weighted (spatial) average. | ||
|
||
If ``lat_bounds`` or ``lon_bounds`` are supplied, then grid cells | ||
outside this selected regional domain are given zero weight. Grid cells | ||
that are partially in this domain are given partial weight. | ||
|
||
Parameters | ||
---------- | ||
axis : List[SpatialAxis] | Tuple[SpatialAxis, ...] | ||
List of axis dimensions to average over. | ||
lat_bounds : Optional[RegionAxisBounds] | ||
Tuple of latitude boundaries for regional selection, by default | ||
None. | ||
lon_bounds : Optional[RegionAxisBounds] | ||
Tuple of longitude boundaries for regional selection, by default | ||
None. | ||
|
||
Returns | ||
------- | ||
xr.DataArray | ||
A DataArray containing the region weights to use during averaging. | ||
``weights`` are 1-D and correspond to the specified axes (``axis``) | ||
in the region. | ||
|
||
Notes | ||
----- | ||
This method was developed for rectilinear grids only. ``get_weights()`` | ||
recognizes and operate on latitude and longitude, but could be extended | ||
to work with other standard geophysical dimensions (e.g., time, depth, | ||
and pressure). | ||
""" | ||
# convert dataarray to a dataset | ||
da = self._dataarray.copy() | ||
ds = boundedDataArray_to_dataset(da) | ||
# get data_var key | ||
data_var = da.name | ||
# pass on call to get_weights | ||
weights = ds.spatial.get_weights( | ||
axis=axis, lat_bounds=lat_bounds, lon_bounds=lon_bounds, data_var=data_var | ||
) | ||
return weights |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm pasting the
lat_bnds
coordinate / dataarray that is embedded in the "bounded dataarray" – I'm wondering if this could take on a nicer form than this random dtype:('lower', '<f8'), ('upper', '<f8')
. I'm guessing if we had other bound types, we could make this more generic, e.g.,('lower_left', '<f8'), ('lower_right', '<f8'), ('upper_right', '<f8', ('upper_left', '<f8'))
.da.lat_bnds