diff --git a/satpy/readers/safe_sar_c.py b/satpy/readers/safe_sar_c.py index 3f634d267c..0f61b78e7f 100644 --- a/satpy/readers/safe_sar_c.py +++ b/satpy/readers/safe_sar_c.py @@ -27,7 +27,7 @@ import numpy as np from osgeo import gdal -from dask.array import Array +import dask.array as da from xarray import DataArray import xarray.ufuncs as xu from dask.base import tokenize @@ -179,46 +179,29 @@ def interpolate_xarray(xpoints, ypoints, values, shape, kind='cubic', for j, hcs in enumerate(hchunks) } - res = Array(dskx, name, shape=list(shape), - chunks=(blocksize, blocksize), - dtype=values.dtype) + res = da.Array(dskx, name, shape=list(shape), + chunks=(blocksize, blocksize), + dtype=values.dtype) return DataArray(res, dims=('y', 'x')) -def interpolate_xarray_linear(xpoints, ypoints, values, shape, - blocksize=CHUNK_SIZE): +def interpolate_xarray_linear(xpoints, ypoints, values, shape): """Interpolate linearly, generating a dask array.""" from scipy.interpolate.interpnd import (LinearNDInterpolator, _ndim_coords_from_arrays) - - vblocksize, hblocksize = blocksize, blocksize - - vchunks = range(0, shape[0], vblocksize) - hchunks = range(0, shape[1], hblocksize) - points = _ndim_coords_from_arrays(np.vstack((np.asarray(ypoints), np.asarray(xpoints))).T) - token = tokenize(blocksize, points, values, shape) - name = 'interpolate2-' + token - interpolator = LinearNDInterpolator(points, values) - def intp(slice_rows, slice_cols, interpolator): - grid_x, grid_y = np.mgrid[slice_rows, slice_cols] - return interpolator((grid_x, grid_y)) - - dskx = {(name, i, j): (intp, - slice(vcs, min(vcs + vblocksize, shape[0])), - slice(hcs, min(hcs + hblocksize, shape[1])), - interpolator) - for i, vcs in enumerate(vchunks) - for j, hcs in enumerate(hchunks) - } + def intp(grid_x, grid_y, interpolator): + return interpolator((grid_y, grid_x)) - res = Array(dskx, name, shape=list(shape), - chunks=(vblocksize, hblocksize), - dtype=values.dtype) + grid_x, grid_y = da.meshgrid(da.arange(shape[1], chunks=CHUNK_SIZE), + da.arange(shape[0], chunks=CHUNK_SIZE)) + # workaround for non-thread-safe first call of the interpolator: + interpolator((0, 0)) + res = da.map_blocks(intp, grid_x, grid_y, interpolator=interpolator) return DataArray(res, dims=('y', 'x')) @@ -316,9 +299,9 @@ def read_band(self, blocksize=CHUNK_SIZE): for j, hcs in enumerate(hchunks) } - res = Array(dskx, name, shape=list(shape), - chunks=(blocksize, blocksize), - dtype=np.uint16) + res = da.Array(dskx, name, shape=list(shape), + chunks=(blocksize, blocksize), + dtype=np.uint16) return DataArray(res, dims=('y', 'x')) def get_lonlats(self):