Skip to content

Commit

Permalink
Infer 1D bounds for nD variables.
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Jun 3, 2022
1 parent f24c1c3 commit a63851d
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions cf_xarray/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,20 +455,28 @@ def wrapper(obj: DataArray | Dataset, key: str):
}


def _guess_bounds_dim(da):
def _guess_bounds_dim(da, dim=None):
"""
Guess bounds values given a 1D coordinate variable.
Assumes equal spacing on either side of the coordinate label.
"""
assert da.ndim == 1
if dim is None:
if da.ndim != 1:
raise ValueError(
f"If dim is None, variable {da.name} must be 1D. Received {da.ndim} dimensions instead."
)
(dim,) = da.dims
if dim not in da.dims:
(dim,) = da.cf.axes[dim]

dim = da.dims[0]
diff = da.diff(dim)
lower = da - diff / 2
upper = da + diff / 2
bounds = xr.concat([lower, upper], dim="bounds")

first = (bounds.isel({dim: 0}) - diff[0]).assign_coords({dim: da[dim][0]})
first = (bounds.isel({dim: 0}) - diff.isel({dim: 0})).assign_coords(
{dim: da[dim][0]}
)
result = xr.concat([first, bounds], dim=dim)

return result
Expand Down Expand Up @@ -2097,7 +2105,7 @@ def get_bounds_dim_name(self, key: str) -> str:
assert self._obj.sizes[bounds_dim] in [2, 4]
return bounds_dim

def add_bounds(self, keys: str | Iterable[str]):
def add_bounds(self, keys: str | Iterable[str], *, dim=None):
"""
Returns a new object with bounds variables. The bounds values are guessed assuming
equal spacing on either side of a coordinate label.
Expand All @@ -2106,6 +2114,9 @@ def add_bounds(self, keys: str | Iterable[str]):
----------
keys : str or Iterable[str]
Either a single variable name or a list of variable names.
dim : str, optional
Core dimension along whch to estimate bounds. If None, ``keys``
must refer to 1D variables only.
Returns
-------
Expand Down Expand Up @@ -2151,7 +2162,9 @@ def add_bounds(self, keys: str | Iterable[str]):
bname = f"{var}_bounds"
if bname in obj.variables:
raise ValueError(f"Bounds variable name {bname!r} will conflict!")
obj.coords[bname] = _guess_bounds_dim(obj[var].reset_coords(drop=True))
obj.coords[bname] = _guess_bounds_dim(
obj[var].reset_coords(drop=True), dim=dim
)
obj[var].attrs["bounds"] = bname

return self._maybe_to_dataarray(obj)
Expand Down

0 comments on commit a63851d

Please sign in to comment.