Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compatibility with xarray>=2022.9.0 #276

Merged
merged 3 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions xbout/geometries.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
_set_attrs_on_all_vars,
_set_as_coord,
_1d_coord_from_spacing,
_maybe_rename_dimension,
)

REGISTERED_GEOMETRIES = {}
Expand Down Expand Up @@ -386,12 +387,12 @@ def add_toroidal_geometry_coords(ds, *, coordinates=None, grid=None):
],
)

if "t" in ds.dims:
if coordinates["t"] != "t":
# Rename 't' if user requested it
ds = ds.rename(t=coordinates["t"])
ds = _maybe_rename_dimension(ds, "t", coordinates["t"])

# Change names of dimensions to Orthogonal Toroidal ones
ds = ds.rename(y=coordinates["y"])
ds = _maybe_rename_dimension(ds, "y", coordinates["y"])

# TODO automatically make this coordinate 1D in simplified cases?
ds = ds.rename(psixy=coordinates["x"])
Expand All @@ -407,7 +408,7 @@ def add_toroidal_geometry_coords(ds, *, coordinates=None, grid=None):

# If full data (not just grid file) then toroidal dim will be present
if "z" in ds.dims:
ds = ds.rename(z=coordinates["z"])
ds = _maybe_rename_dimension(ds, "z", coordinates["z"])

# Record which dimension 'z' was renamed to.
ds.metadata["bout_zdim"] = coordinates["z"]
Expand Down Expand Up @@ -482,7 +483,7 @@ def add_s_alpha_geometry_coords(ds, *, coordinates=None, grid=None):
ds["r"] = ds["hthe"].isel({ycoord: 0}).squeeze(drop=True)
ds["r"].attrs["units"] = "m"
ds = ds.set_coords("r")
ds = ds.rename(x="r")
ds = ds.swap_dims(x="r")
ds.metadata["bout_xdim"] = "r"

if hthe_from_grid:
Expand Down
16 changes: 12 additions & 4 deletions xbout/region.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ def __init__(
ref_yind = ylower_ind
dx = ds["dx"].isel({self.ycoord: ref_yind})
dx_cumsum = dx.cumsum()
self.xinner = dx_cumsum[xinner_ind] - dx[xinner_ind]
self.xouter = dx_cumsum[xouter_ind - 1] + dx[xouter_ind - 1]
self.xinner = (dx_cumsum[xinner_ind] - dx[xinner_ind]).values
self.xouter = (dx_cumsum[xouter_ind - 1] + dx[xouter_ind - 1]).values

# dy is constant in the x-direction, so convert to a 1d array
# Define ref_xind so that we avoid using values from the corner cells, which
Expand All @@ -136,8 +136,8 @@ def __init__(
ref_xind = xinner_ind
dy = ds["dy"].isel(**{self.xcoord: ref_xind})
dy_cumsum = dy.cumsum()
self.ylower = dy_cumsum[ylower_ind] - dy[ylower_ind]
self.yupper = dy_cumsum[yupper_ind - 1]
self.ylower = (dy_cumsum[ylower_ind] - dy[ylower_ind]).values
self.yupper = (dy_cumsum[yupper_ind - 1]).values

def __repr__(self):
result = "<xbout.region.Region>\n"
Expand Down Expand Up @@ -1355,7 +1355,9 @@ def _concat_inner_guards(da, da_global, mxg):
# https://github.com/pydata/xarray/issues/4393
# da_inner = da_inner.assign_coords(**{xcoord: new_xcoord, ycoord: new_ycoord})
da_inner[xcoord].data[...] = new_xcoord.data
da_inner = da_inner.reset_index(xcoord).set_xindex(xcoord)
da_inner[ycoord].data[...] = new_ycoord.data
da_inner = da_inner.reset_index(ycoord).set_xindex(ycoord)

save_regions = da.bout._regions
da = xr.concat((da_inner, da), xcoord, join="exact")
Expand Down Expand Up @@ -1466,7 +1468,9 @@ def _concat_outer_guards(da, da_global, mxg):
# https://github.com/pydata/xarray/issues/4393
# da_outer = da_outer.assign_coords(**{xcoord: new_xcoord, ycoord: new_ycoord})
da_outer[xcoord].data[...] = new_xcoord.data
da_outer = da_outer.reset_index(xcoord).set_xindex(xcoord)
da_outer[ycoord].data[...] = new_ycoord.data
da_outer = da_outer.reset_index(ycoord).set_xindex(ycoord)

save_regions = da.bout._regions
da = xr.concat((da, da_outer), xcoord, join="exact")
Expand Down Expand Up @@ -1566,7 +1570,9 @@ def _concat_lower_guards(da, da_global, mxg, myg):
# https://github.com/pydata/xarray/issues/4393
# da_lower = da_lower.assign_coords(**{xcoord: new_xcoord, ycoord: new_ycoord})
da_lower[xcoord].data[...] = new_xcoord.data
da_lower = da_lower.reset_index(xcoord).set_xindex(xcoord)
da_lower[ycoord].data[...] = new_ycoord.data
da_lower = da_lower.reset_index(ycoord).set_xindex(ycoord)

if "poloidal_distance" in da.coords and myg > 0:
# Special handling for core regions to deal with branch cut
Expand Down Expand Up @@ -1682,7 +1688,9 @@ def _concat_upper_guards(da, da_global, mxg, myg):
# https://github.com/pydata/xarray/issues/4393
# da_upper = da_upper.assign_coords(**{xcoord: new_xcoord, ycoord: new_ycoord})
da_upper[xcoord].data[...] = new_xcoord.data
da_upper = da_upper.reset_index(xcoord).set_xindex(xcoord)
da_upper[ycoord].data[...] = new_ycoord.data
da_upper = da_upper.reset_index(ycoord).set_xindex(ycoord)

if "poloidal_distance" in da.coords and myg > 0:
# Special handling for core regions to deal with branch cut
Expand Down
11 changes: 11 additions & 0 deletions xbout/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,3 +880,14 @@ def _set_as_coord(ds, name):
except ValueError:
pass
return ds


def _maybe_rename_dimension(ds, old_name, new_name):
if old_name in ds.dims and new_name != old_name:
# Rename dimension
ds = ds.swap_dims({old_name: new_name})
if old_name in ds:
# Rename coordinate if it exists
ds = ds.rename({old_name: new_name})

return ds