diff --git a/xesmf/backend.py b/xesmf/backend.py index 2fa8d96..df152e8 100644 --- a/xesmf/backend.py +++ b/xesmf/backend.py @@ -52,7 +52,7 @@ def warn_lat_range(lat): warnings.warn("Latitude is outside of [-90, 90]") -def esmf_grid(lon, lat, periodic=False): +def esmf_grid(lon, lat, periodic=False, mask=None): ''' Create an ESMF.Grid object, for contrusting ESMF.Field and ESMF.Regrid @@ -70,6 +70,10 @@ def esmf_grid(lon, lat, periodic=False): Periodic in longitude? Default to False. Only useful for source grid. + mask : 2D numpy array, optional + Grid mask. Follows SCRIP convention where 1 is unmasked and 0 is + masked. + Returns ------- grid : ESMF.Grid object @@ -111,6 +115,20 @@ def esmf_grid(lon, lat, periodic=False): lon_pointer[...] = lon lat_pointer[...] = lat + # Follows SCRIP convention where 1 is unmasked and 0 is masked. + # See https://github.com/NCPP/ocgis/blob/61d88c60e9070215f28c1317221c2e074f8fb145/src/ocgis/regrid/base.py#L391-L404 + if mask is not None: + grid_mask = np.swapaxes(mask.astype(np.int32), 0, 1) + grid_mask = np.where(grid_mask == 0, 0, 1) + if not (grid_mask.shape == lon.shape): + raise ValueError( + "mask must have the same shape as the latitude/longitude" + "coordinates, got: mask.shape = %s, lon.shape = %s" % + (mask.shape, lon.shape)) + grid.add_item(ESMF.GridItem.MASK, staggerloc=ESMF.StaggerLoc.CENTER, + from_file=False) + grid.mask[0][:] = grid_mask + return grid @@ -240,7 +258,8 @@ def esmf_regrid_build(sourcegrid, destgrid, method, # if the destination grid is larger than the source grid. regrid = ESMF.Regrid(sourcefield, destfield, filename=filename, regrid_method=esmf_regrid_method, - unmapped_action=ESMF.UnmappedAction.IGNORE) + unmapped_action=ESMF.UnmappedAction.IGNORE, + src_mask_values=[0], dst_mask_values=[0]) return regrid diff --git a/xesmf/frontend.py b/xesmf/frontend.py index e2ced7d..2b4565e 100644 --- a/xesmf/frontend.py +++ b/xesmf/frontend.py @@ -54,8 +54,14 @@ def ds_to_ESMFgrid(ds, need_bounds=False, periodic=None, append=None): lat = np.asarray(ds['lat']) lon, lat = as_2d_mesh(lon, lat) + if 'mask' in ds: + mask = np.asarray(ds['mask']) + print(mask.shape) + else: + mask = None + # tranpose the arrays so they become Fortran-ordered - grid = esmf_grid(lon.T, lat.T, periodic=periodic) + grid = esmf_grid(lon.T, lat.T, periodic=periodic, mask=mask) if need_bounds: lon_b = np.asarray(ds['lon_b']) @@ -83,6 +89,9 @@ def __init__(self, ds_in, ds_out, method, periodic=False, or 2D (Ny, Nx) for general curvilinear grids. Shape of bounds should be (N+1,) or (Ny+1, Nx+1). + If either dataset includes a 2d mask variable, that will also be + used to inform the regridding. + method : str Regridding method. Options are diff --git a/xesmf/tests/test_frontend.py b/xesmf/tests/test_frontend.py index 5b9d151..1609ebc 100644 --- a/xesmf/tests/test_frontend.py +++ b/xesmf/tests/test_frontend.py @@ -170,3 +170,20 @@ def test_regrid_with_1d_grid(): # clean-up regridder.clean_weight_file() + + +def test_build_regridder_with_masks(): + ds_in['mask'] = xr.DataArray( + np.random.randint(2, size=ds_in['data'].shape), + dims=('y', 'x')) + print(ds_in) + # 'patch' is too slow to test + for method in ['bilinear', 'conservative', 'nearest_s2d', 'nearest_d2s']: + regridder = xe.Regridder(ds_in, ds_out, method) + + # check screen output + assert repr(regridder) == str(regridder) + assert 'xESMF Regridder' in str(regridder) + assert method in str(regridder) + + regridder.clean_weight_file()