-
Notifications
You must be signed in to change notification settings - Fork 35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Regridding xarray dataset with chunked dask-backed arrays #222
Comments
Sadly, not for now. |
If anyone has some ideas to solve this, that would be a great contribution. |
See https://discourse.pangeo.io/t/conservative-region-aggregation-with-xarray-geopandas-and-sparse/2715 for possible solution. |
We're hoping to have an intern work on this next summer. If anyone has tips to share, please leave them here. |
Here's how to do it Read (or convert) weights as pydata/sparsedef read_xesmf_weights_file(filename):
import numpy as np
import sparse
import xarray as xr
weights = xr.open_dataset(filename)
# input variable shape
in_shape = weights.src_grid_dims.load().data
# output variable shape
out_shape = weights.dst_grid_dims.load().data.tolist()[::-1]
print(f"Regridding from {in_shape} to {out_shape}")
rows = weights['row'] - 1 # row indices (1-based)
cols = weights['col'] - 1 # col indices (1-based)
# construct a sparse array,
# reshape to 3D : lat, lon, ncol
# This reshaping should allow optional chunking along
# lat, lon later
sparse_array_data = sparse.COO(
coords=np.stack([rows.data, cols.data]),
data=weights.S.data,
shape=(weights.sizes["n_b"], weights.sizes["n_a"]),
fill_value=0,
).reshape((*out_shape, -1))
# Create a DataArray with sparse weights and the output coordinates
xsparse_wgts = xr.DataArray(
sparse_array_data,
dims=("lat", "lon", "ncol"),
# Add useful coordinate information, this will get propagated to the output
coords={
"lat": ("lat", weights.yc_b.data.reshape(out_shape)[:, 0]),
"lon": ("lon", weights.xc_b.data.reshape(out_shape)[0, :]),
},
# propagate useful information like regridding algorithm
attrs=weights.attrs,
)
return xsparse_wgts
xsparse_wgts = read_xesmf_weights_file(map_path + map_file) apply weights using opt_einsumhttps://dgasmith.github.io/opt_einsum/ def apply_weights(dataset, weights):
def _apply(da):
# 🐵 🔧
xr.core.duck_array_ops.einsum = opt_einsum.contract
ans = xr.dot(
da,
weights,
# This dimension will be "contracted"
# or summmed over after multiplying by the weights
dims="ncol",
)
# 🐵 🔧 : restore back to original
xr.core.duck_array_ops.einsum = np.einsum
return ans
vars_with_ncol = [
name for name, array in dataset.variables.items()
if "ncol" in array.dims and name not in weights.coords
]
regridded = dataset[vars_with_ncol].map(_apply)
# merge in other variables, but skip those that are already set
# like lat, lon
return xr.merge([dataset.drop_vars(regridded.variables), regridded])
apply_weights(psfile, xsparse_wgts.chunk()) Gainzzzz
See pydata/xarray#7764 for the upstream issue to avoid the monkey-patch |
@charlesgauthier-udm Here's the "parallelize the application" issue. |
Is there a way I can reliably regrid an xarray.Dataset object to a lower/higher resolution if it has variables with dask-backed chunked arrays. Every single time I try to use the output of the call to
xesmf.Regridder
to regrid the input data I get aexception. To get it to work, I have to force the datasets to have only a single chunk with
.chunk(-1)
. This can cause tasks to fail when the dask graph is computed since a single chunk for large datasets can consume a lot of memory. Any workaround for this without using a single chunk?The text was updated successfully, but these errors were encountered: