Skip to content

Commit

Permalink
Correctly shaped bounds for add_bounds method (#347)
Browse files Browse the repository at this point in the history
* Correctly shaped bounds for add_bounds method

* Update cf_xarray/accessor.py

saver transposing

Co-authored-by: Deepak Cherian <[email protected]>

* test_accessor: transpose expected added bounds

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* test_helpers: transpose expected added bounds

* fix helpers for bounds as last dim

* fix 2D bounds with bounds as last dim

Co-authored-by: Deepak Cherian <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Pascal Bourgault <[email protected]>
  • Loading branch information
4 people authored Nov 15, 2022
1 parent 205e673 commit 1298277
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 22 deletions.
4 changes: 2 additions & 2 deletions cf_xarray/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ def _guess_bounds_dim(da, dim=None, out_dim="bounds"):
daXY.isel(Xbnds=1, Ybnds=0),
],
out_dim,
)
).transpose(..., "bounds")
else:
dim = dim[0]
if dim not in da.dims:
Expand All @@ -507,7 +507,7 @@ def _guess_bounds_dim(da, dim=None, out_dim="bounds"):
first = (bounds.isel({dim: 0}) - diff.isel({dim: 0})).assign_coords(
{dim: da[dim][0]}
)
result = xr.concat([first, bounds], dim=dim)
result = xr.concat([first, bounds], dim=dim).transpose(..., "bounds")

return result

Expand Down
20 changes: 10 additions & 10 deletions cf_xarray/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@

def _create_mollw_ds():
# Dataset with random data on a grid that is some sort of Mollweide projection
XX, YY = np.mgrid[:11, :11] * 5 - 25
XX_bnds, YY_bnds = np.mgrid[:12, :12] * 5 - 27.5
YY, XX = np.mgrid[:11, :11] * 5 - 25
YY_bnds, XX_bnds = np.mgrid[:12, :12] * 5 - 27.5

R = 50
theta = np.arcsin(YY / (R * np.sqrt(2)))
Expand All @@ -179,7 +179,7 @@ def _create_mollw_ds():
lon_vertices[1:, 1:],
lon_vertices[1:, :-1],
),
axis=0,
axis=-1,
)
lat_bounds = np.stack(
(
Expand All @@ -188,31 +188,31 @@ def _create_mollw_ds():
lat_vertices[1:, 1:],
lat_vertices[1:, :-1],
),
axis=0,
axis=-1,
)

mollwds = xr.Dataset(
coords=dict(
lon=xr.DataArray(
lon,
dims=("x", "y"),
dims=("y", "x"),
attrs={"units": "degrees_east", "bounds": "lon_bounds"},
),
lat=xr.DataArray(
lat,
dims=("x", "y"),
dims=("y", "x"),
attrs={"units": "degrees_north", "bounds": "lat_bounds"},
),
),
data_vars=dict(
lon_bounds=xr.DataArray(
lon_bounds, dims=("bounds", "x", "y"), attrs={"units": "degrees_east"}
lon_bounds, dims=("y", "x", "bounds"), attrs={"units": "degrees_east"}
),
lat_bounds=xr.DataArray(
lat_bounds, dims=("bounds", "x", "y"), attrs={"units": "degrees_north"}
lat_bounds, dims=("y", "x", "bounds"), attrs={"units": "degrees_north"}
),
lon_vertices=xr.DataArray(lon_vertices, dims=("x_vertices", "y_vertices")),
lat_vertices=xr.DataArray(lat_vertices, dims=("x_vertices", "y_vertices")),
lon_vertices=xr.DataArray(lon_vertices, dims=("y_vertices", "x_vertices")),
lat_vertices=xr.DataArray(lat_vertices, dims=("y_vertices", "x_vertices")),
),
)

Expand Down
6 changes: 4 additions & 2 deletions cf_xarray/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def _bounds_helper(values, n_core_dims, nbounds, order):
top_left = values[..., -1:, :, 3]
vertex_vals = np.block([[bot_left, bot_right], [top_left, top_right]])
if order is None: # We verify if the ccw version works.
calc_bnds = np.moveaxis(vertices_to_bounds(vertex_vals).values, 0, -1)
calc_bnds = vertices_to_bounds(vertex_vals).values
order = (
"counterclockwise" if np.allclose(calc_bnds, values) else "clockwise"
)
Expand Down Expand Up @@ -155,4 +155,6 @@ def vertices_to_bounds(
raise ValueError(
f"vertices format not understood. Got {vertices.dims} with shape {vertices.shape}."
)
return xr.DataArray(bnd_vals, dims=out_dims[: vertices.ndim + 1])
return xr.DataArray(bnd_vals, dims=out_dims[: vertices.ndim + 1]).transpose(
..., out_dims[0]
)
6 changes: 4 additions & 2 deletions cf_xarray/tests/test_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,7 +788,9 @@ def test_add_bounds(dims):
name = f"{dim}_bounds"
assert name in added.coords
assert added[dim].attrs["bounds"] == name
assert_allclose(added[name].reset_coords(drop=True), expected[dim])
assert_allclose(
added[name].reset_coords(drop=True), expected[dim].transpose(..., "bounds")
)

_check_unchanged(original, ds)

Expand Down Expand Up @@ -824,7 +826,7 @@ def test_add_bounds_nd_variable():
)

actual = ds.cf.add_bounds("z", dim="x").z_bounds.reset_coords(drop=True)
xr.testing.assert_identical(expected, actual)
xr.testing.assert_identical(expected.transpose(..., "bounds"), actual)

with pytest.raises(NotImplementedError):
ds.drop_vars("x").cf.add_bounds("z", dim="x")
Expand Down
12 changes: 6 additions & 6 deletions cf_xarray/tests/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_bounds_to_vertices():
lat_c = cfxr.bounds_to_vertices(ds.lat_bounds, bounds_dim="bounds")
assert_array_equal(ds.lat.values + 1.25, lat_c.values[:-1])

# 2D case, CF- order
# 2D case
lat_ccw = cfxr.bounds_to_vertices(
mollwds.lat_bounds, bounds_dim="bounds", order="counterclockwise"
)
Expand All @@ -35,13 +35,13 @@ def test_bounds_to_vertices():
assert_equal(lon_no, lon_ccw)

# Transposing the array changes the bounds direction
ds = mollwds.transpose("bounds", "y", "x", "y_vertices", "x_vertices")
lon_c = cfxr.bounds_to_vertices(
ds = mollwds.transpose("x", "y", "x_vertices", "y_vertices", "bounds")
lon_cw = cfxr.bounds_to_vertices(
ds.lon_bounds, bounds_dim="bounds", order="clockwise"
)
lon_c2 = cfxr.bounds_to_vertices(ds.lon_bounds, bounds_dim="bounds", order=None)
assert_equal(ds.lon_vertices, lon_c)
assert_equal(ds.lon_vertices, lon_c2)
lon_no2 = cfxr.bounds_to_vertices(ds.lon_bounds, bounds_dim="bounds", order=None)
assert_equal(ds.lon_vertices, lon_cw)
assert_equal(ds.lon_vertices, lon_no2)

# Preserves dask-backed arrays
if DaskArray is not None:
Expand Down

0 comments on commit 1298277

Please sign in to comment.