Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update regridder.grid property to ensure cf compliant coordinates #736

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions tests/test_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -1261,6 +1261,21 @@ def test_grid(self):
):
ds_multi.regridder.grid # noqa: B018

def test_grid_adds_cf_attributes_to_non_cf_compliant_coords(self):
ds = fixtures.generate_dataset(
decode_times=True, cf_compliant=False, has_bounds=True
)
# Remove "axis" and "standard_name" attributes.
ds["lat"].attrs = {"units": "degrees_east"}
ds["lon"].attrs = {"units": "degrees_north"}

grid = ds.regridder.grid

assert grid["lat"].attrs["axis"] == "Y"
assert grid["lat"].attrs["standard_name"] == "latitude"
assert grid["lon"].attrs["axis"] == "X"
assert grid["lon"].attrs["standard_name"] == "longitude"

def test_grid_raises_error_when_dataset_has_multiple_dims_for_an_axis(self):
ds_bounds = fixtures.generate_dataset(
decode_times=True, cf_compliant=True, has_bounds=True
Expand Down
36 changes: 35 additions & 1 deletion xcdat/regridder/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import xarray as xr

from xcdat.axis import CFAxisKey, get_dim_coords
from xcdat.axis import CF_ATTR_MAP, CFAxisKey, get_dim_coords
from xcdat.regridder import regrid2, xesmf, xgcm
from xcdat.regridder.grid import _validate_grid_has_single_axis_dim

Expand Down Expand Up @@ -109,13 +109,47 @@ def _get_axis_data(

_validate_grid_has_single_axis_dim(name, coord_var)

coord_var = self._ensure_cf_compliance(coord_var, name) # type: ignore

try:
bounds_var = self._ds.bounds.get_bounds(name, coord_var.name)
except KeyError:
bounds_var = None

return coord_var, bounds_var

def _ensure_cf_compliance(
self, coord_var: xr.DataArray, name: CFAxisKey
) -> xr.DataArray:
"""Ensure that the coordinate variable is CF-compliant.

This function adds the "axis" and "standard_name" attributes to the
coordinates if they are not already present. Coordinates must be
CF-compliant in order for xESMF to interpret them using CF-xarray.

Parameters
----------
coords : xr.DataArray
Coordinates to make CF compliant.
name : CFAxisKey
Name of the axis.

Returns
-------
xr.DataArray
CF compliant coordinates.
"""
coord_var_new = coord_var.copy()
cf_attrs = CF_ATTR_MAP[name]

if "axis" not in coord_var_new.attrs:
coord_var_new.attrs["axis"] = cf_attrs["axis"]

if "standard_name" not in coord_var_new.attrs:
coord_var_new.attrs["standard_name"] = cf_attrs["coordinate"]

return coord_var_new

def horizontal(
self,
data_var: str,
Expand Down
Loading