Skip to content

Commit

Permalink
Merge pull request #302 from pytroll/feature-multicore-sar-geography
Browse files Browse the repository at this point in the history
Workaround the LinearNDInterpolator thread-safety issue for Sentinel 1 SAR geolocation
  • Loading branch information
mraspaud authored May 22, 2018
2 parents 40ea3e2 + b712e53 commit b6e13e9
Showing 1 changed file with 15 additions and 32 deletions.
47 changes: 15 additions & 32 deletions satpy/readers/safe_sar_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

import numpy as np
from osgeo import gdal
from dask.array import Array
import dask.array as da
from xarray import DataArray
import xarray.ufuncs as xu
from dask.base import tokenize
Expand Down Expand Up @@ -179,46 +179,29 @@ def interpolate_xarray(xpoints, ypoints, values, shape, kind='cubic',
for j, hcs in enumerate(hchunks)
}

res = Array(dskx, name, shape=list(shape),
chunks=(blocksize, blocksize),
dtype=values.dtype)
res = da.Array(dskx, name, shape=list(shape),
chunks=(blocksize, blocksize),
dtype=values.dtype)
return DataArray(res, dims=('y', 'x'))


def interpolate_xarray_linear(xpoints, ypoints, values, shape,
blocksize=CHUNK_SIZE):
def interpolate_xarray_linear(xpoints, ypoints, values, shape):
"""Interpolate linearly, generating a dask array."""
from scipy.interpolate.interpnd import (LinearNDInterpolator,
_ndim_coords_from_arrays)

vblocksize, hblocksize = blocksize, blocksize

vchunks = range(0, shape[0], vblocksize)
hchunks = range(0, shape[1], hblocksize)

points = _ndim_coords_from_arrays(np.vstack((np.asarray(ypoints),
np.asarray(xpoints))).T)

token = tokenize(blocksize, points, values, shape)
name = 'interpolate2-' + token

interpolator = LinearNDInterpolator(points, values)

def intp(slice_rows, slice_cols, interpolator):
grid_x, grid_y = np.mgrid[slice_rows, slice_cols]
return interpolator((grid_x, grid_y))

dskx = {(name, i, j): (intp,
slice(vcs, min(vcs + vblocksize, shape[0])),
slice(hcs, min(hcs + hblocksize, shape[1])),
interpolator)
for i, vcs in enumerate(vchunks)
for j, hcs in enumerate(hchunks)
}
def intp(grid_x, grid_y, interpolator):
return interpolator((grid_y, grid_x))

res = Array(dskx, name, shape=list(shape),
chunks=(vblocksize, hblocksize),
dtype=values.dtype)
grid_x, grid_y = da.meshgrid(da.arange(shape[1], chunks=CHUNK_SIZE),
da.arange(shape[0], chunks=CHUNK_SIZE))
# workaround for non-thread-safe first call of the interpolator:
interpolator((0, 0))
res = da.map_blocks(intp, grid_x, grid_y, interpolator=interpolator)

return DataArray(res, dims=('y', 'x'))

Expand Down Expand Up @@ -316,9 +299,9 @@ def read_band(self, blocksize=CHUNK_SIZE):
for j, hcs in enumerate(hchunks)
}

res = Array(dskx, name, shape=list(shape),
chunks=(blocksize, blocksize),
dtype=np.uint16)
res = da.Array(dskx, name, shape=list(shape),
chunks=(blocksize, blocksize),
dtype=np.uint16)
return DataArray(res, dims=('y', 'x'))

def get_lonlats(self):
Expand Down

0 comments on commit b6e13e9

Please sign in to comment.