Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature request: Implement interp for interpolating between chunks of data (dask) #4078

Closed
pums974 opened this issue May 19, 2020 · 6 comments · Fixed by #4155
Closed

Feature request: Implement interp for interpolating between chunks of data (dask) #4078

pums974 opened this issue May 19, 2020 · 6 comments · Fixed by #4155

Comments

@pums974
Copy link
Contributor

pums974 commented May 19, 2020

In a project of mine I need to interpolate a dask-based xarray between chunk of data.

I made it work using monkey patching. I'm pretty sure that you can write it better but I made it as good as I could.

I hope that what I wrote can help you implement it properly.

from typing import Union, Tuple, Callable, Any, List

import dask.array as da
import numpy as np
import xarray as xr
import xarray.core.missing as m

def interp_func(var: Union[np.ndarray, da.Array],
                x: Tuple[xr.DataArray, ...],
                new_x: Tuple[xr.DataArray, ...],
                method: str,
                kwargs: Any) -> da.Array:
    """
    multi-dimensional interpolation for array-like. Interpolated axes should be
    located in the last position.

    Parameters
    ----------
    var: np.ndarray or dask.array.Array
        Array to be interpolated. The final dimension is interpolated.
    x: a list of 1d array.
        Original coordinates. Should not contain NaN.
    new_x: a list of 1d array
        New coordinates. Should not contain NaN.
    method: string
        {'linear', 'nearest', 'zero', 'slinear', 'quadratic', 'cubic'} for
        1-dimensional itnterpolation.
        {'linear', 'nearest'} for multidimensional interpolation
    **kwargs:
        Optional keyword arguments to be passed to scipy.interpolator

    Returns
    -------
    interpolated: array
        Interpolated array

    Note
    ----
    This requiers scipy installed.

    See Also
    --------
    scipy.interpolate.interp1d
    """

    try:
        # try the official interp_func first
        res = official_interp_func(var, x, new_x, method, kwargs)
        return res
    except NotImplementedError:
        # may fail if interpolating between chunks
        pass

    if len(x) == 1:
        func, _kwargs = m._get_interpolator(method, vectorizeable_only=True,
                                            **kwargs)
    else:
        func, _kwargs = m._get_interpolator_nd(method, **kwargs)

    # reshape new_x (TODO REMOVE ?)
    current_dims = [_x.name for _x in x]
    new_x = tuple([_x.set_dims(current_dims) for _x in new_x])

    # number of non interpolated dimensions
    nconst = var.ndim - len(x)

    # duplicate the ghost cells of the array
    bnd = {i: "none" for i in range(len(var.shape))}
    depth = {i: 0 if i < nconst else 1 for i in range(len(var.shape))}
    var_with_ghost = da.overlap.overlap(var, depth=depth, boundary=bnd)

    # chunks x and duplicate the ghost cells of x
    x = tuple(da.from_array(_x, chunks=chunks) for _x, chunks in zip(x, var.chunks[nconst:]))
    x_with_ghost = tuple(da.overlap.overlap(_x, depth={0: 1}, boundary={0: "none"})
                         for _x in x)

    # compute final chunks
    chunks_end = [np.cumsum(sizes) - 1 for _x in x
                                       for sizes in _x.chunks]
    chunks_end_with_ghost = [np.cumsum(sizes) - 1 for _x in x_with_ghost
                                                  for sizes in _x.chunks]
    total_chunks = []
    for dim, ce in enumerate(zip(chunks_end, chunks_end_with_ghost)):
        l_new_x_ends: List[np.ndarray] = []
        for iend, iend_with_ghost in zip(*ce):

            arr = np.moveaxis(new_x[dim].data, dim, -1)
            arr = arr[tuple([0] * (len(arr.shape) - 1))]

            n_no_ghost = (arr <= x[dim][iend]).sum()
            n_ghost = (arr <= x_with_ghost[dim][iend_with_ghost]).sum()

            equil = np.ceil(0.5 * (n_no_ghost + n_ghost)).astype(int)

            l_new_x_ends.append(equil)

        new_x_ends = np.array(l_new_x_ends)
        chunks = new_x_ends[0], *(new_x_ends[1:] - new_x_ends[:-1])
        total_chunks.append(tuple(chunks))
    final_chunks = var.chunks[:-len(x)] + tuple(total_chunks)

    # chunks new_x
    new_x = tuple(da.from_array(_x, chunks=total_chunks) for _x in new_x)

    # reshape x_with_ghost (TODO REMOVE ?)
    x_with_ghost = da.meshgrid(*x_with_ghost, indexing='ij')

    # compute on chunks (TODO use drop_axis and new_axis ?)
    res = da.map_blocks(_myinterpnd, var_with_ghost, func, _kwargs, len(x_with_ghost), *x_with_ghost, *new_x,
                        dtype=var.dtype, chunks=final_chunks)

    # reshape res and remove empty chunks (TODO REMOVE ?)
    res = res.squeeze()
    new_chunks = tuple([tuple([chunk for chunk in chunks if chunk > 0]) for chunks in res.chunks])
    res = res.rechunk(new_chunks)
    return res


def _myinterpnd(var: da.Array,
                func: Callable[..., Any],
                kwargs: Any,
                nx: int,
                *arrs: da.Array) -> da.Array:
    _old_x, _new_x = arrs[:nx], arrs[nx:]

    # reshape x (TODO REMOVE ?)
    old_x = tuple([np.moveaxis(tmp, dim, -1)[tuple([0] * (len(tmp.shape) - 1))]
                   for dim, tmp in enumerate(_old_x)])

    new_x = tuple([xr.DataArray(_x) for _x in _new_x])

    return m._interpnd(var, old_x, new_x, func, kwargs)


official_interp_func = m.interp_func
m.interp_func = interp_func
@pums974
Copy link
Contributor Author

pums974 commented Jun 12, 2020

Any feedback ?
Is my issue usefull ?
Should I write a merge request ?

@jkmacc-LANL
Copy link

This looks interesting, but it’s hard to see what it does from the code block. Could you please include a minimal example that includes inputs and outputs, similar to how the xarray docs do it?

@dcherian
Copy link
Contributor

Thanks @pums974 we would gladly take a pull request. As @jkmacc-LANL suggests, some tests and comments would be nice in the PR.

It's hard to see exactly what this code is doing, but the general idea of using overlap and map_blocks is sound. Is there a reason you didn't use map_overlap?

@pums974
Copy link
Contributor Author

pums974 commented Jun 14, 2020

Thanks, I'll try to make you tests and example quickly.

As for map_overlap, I tried, but something went wrong (I don't remember what though) . Maybe I didn't try enough.

EDIT: I cannot use map_overlap because it doesn't pass *args through to map_blocks

@pums974
Copy link
Contributor Author

pums974 commented Jun 15, 2020

When using the current official interp function (xarray v0.15.1), the code:

    datax = xr.DataArray(data=da.from_array(np.arange(0, 4), chunks=2),
                         coords={"x": np.linspace(0, 1, 4)},
                         dims="x")
    datay = xr.DataArray(data=da.from_array(np.arange(0, 4), chunks=2),
                         coords={"y": np.linspace(0, 1, 4)},
                         dims="y")
    data = datax * datay

    # both of these interp call fails
    res = datax.interp(x=np.linspace(0, 1))
    print(res.load())

    res = data.interp(x=np.linspace(0, 1), y=0.5)
    print(res.load())

fails with NotImplementedError: Chunking along the dimension to be interpolated (0) is not yet supported.

but succeed with the monkey patched version

EDIT : added the second interp in order to show more general use

@pums974
Copy link
Contributor Author

pums974 commented Jun 15, 2020

I also want to alert that my version does not work with "advanced interpolation" (as shown in the xarray documentation)
Also, my version cannot be used to make interpolate_na work with chunked data

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants