Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement ESMpy mask handling #23

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions xesmf/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def warn_lat_range(lat):
warnings.warn("Latitude is outside of [-90, 90]")


def esmf_grid(lon, lat, periodic=False):
def esmf_grid(lon, lat, periodic=False, mask=None):
'''
Create an ESMF.Grid object, for contrusting ESMF.Field and ESMF.Regrid

Expand All @@ -70,6 +70,10 @@ def esmf_grid(lon, lat, periodic=False):
Periodic in longitude? Default to False.
Only useful for source grid.

mask : 2D numpy array, optional
Grid mask. Follows SCRIP convention where 1 is unmasked and 0 is
masked.

Returns
-------
grid : ESMF.Grid object
Expand Down Expand Up @@ -111,6 +115,20 @@ def esmf_grid(lon, lat, periodic=False):
lon_pointer[...] = lon
lat_pointer[...] = lat

# Follows SCRIP convention where 1 is unmasked and 0 is masked.
# See https://github.com/NCPP/ocgis/blob/61d88c60e9070215f28c1317221c2e074f8fb145/src/ocgis/regrid/base.py#L391-L404
if mask is not None:
grid_mask = np.swapaxes(mask.astype(np.int32), 0, 1)
grid_mask = np.where(grid_mask == 0, 0, 1)
if not (grid_mask.shape == lon.shape):
raise ValueError(
"mask must have the same shape as the latitude/longitude"
"coordinates, got: mask.shape = %s, lon.shape = %s" %
(mask.shape, lon.shape))
grid.add_item(ESMF.GridItem.MASK, staggerloc=ESMF.StaggerLoc.CENTER,
from_file=False)
grid.mask[0][:] = grid_mask

return grid


Expand Down Expand Up @@ -240,7 +258,8 @@ def esmf_regrid_build(sourcegrid, destgrid, method,
# if the destination grid is larger than the source grid.
regrid = ESMF.Regrid(sourcefield, destfield, filename=filename,
regrid_method=esmf_regrid_method,
unmapped_action=ESMF.UnmappedAction.IGNORE)
unmapped_action=ESMF.UnmappedAction.IGNORE,
src_mask_values=[0], dst_mask_values=[0])

return regrid

Expand Down
11 changes: 10 additions & 1 deletion xesmf/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,14 @@ def ds_to_ESMFgrid(ds, need_bounds=False, periodic=None, append=None):
lat = np.asarray(ds['lat'])
lon, lat = as_2d_mesh(lon, lat)

if 'mask' in ds:
mask = np.asarray(ds['mask'])
print(mask.shape)
else:
mask = None

# tranpose the arrays so they become Fortran-ordered
grid = esmf_grid(lon.T, lat.T, periodic=periodic)
grid = esmf_grid(lon.T, lat.T, periodic=periodic, mask=mask)

if need_bounds:
lon_b = np.asarray(ds['lon_b'])
Expand Down Expand Up @@ -83,6 +89,9 @@ def __init__(self, ds_in, ds_out, method, periodic=False,
or 2D (Ny, Nx) for general curvilinear grids.
Shape of bounds should be (N+1,) or (Ny+1, Nx+1).

If either dataset includes a 2d mask variable, that will also be
used to inform the regridding.

method : str
Regridding method. Options are

Expand Down
17 changes: 17 additions & 0 deletions xesmf/tests/test_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,20 @@ def test_regrid_with_1d_grid():

# clean-up
regridder.clean_weight_file()


def test_build_regridder_with_masks():
ds_in['mask'] = xr.DataArray(
np.random.randint(2, size=ds_in['data'].shape),
dims=('y', 'x'))
print(ds_in)
# 'patch' is too slow to test
for method in ['bilinear', 'conservative', 'nearest_s2d', 'nearest_d2s']:
regridder = xe.Regridder(ds_in, ds_out, method)

# check screen output
assert repr(regridder) == str(regridder)
assert 'xESMF Regridder' in str(regridder)
assert method in str(regridder)

regridder.clean_weight_file()