Skip to content

Commit

Permalink
Use getitem_with_mask in reindex_variables (#1847)
Browse files Browse the repository at this point in the history
* WIP: use getitem_with_mask in reindex_variables

* Fix dtype promotion for where

* Add whats new

* Fix flake8

* Fix test_align_dtype and bool+str promotion

* tests and docstring for dtypes.result_type

* More dtype promotion fixes, including for concat
  • Loading branch information
shoyer authored and fujiisoup committed Feb 14, 2018
1 parent 33660b7 commit 2aa5b8a
Show file tree
Hide file tree
Showing 11 changed files with 234 additions and 77 deletions.
2 changes: 1 addition & 1 deletion asv_bench/asv.conf.json
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
"netcdf4": [""],
"scipy": [""],
"bottleneck": ["", null],
"dask": ["", null],
"dask": [""],
},


Expand Down
45 changes: 45 additions & 0 deletions asv_bench/benchmarks/reindexing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import xarray as xr

from . import requires_dask


class Reindex(object):
def setup(self):
data = np.random.RandomState(0).randn(1000, 100, 100)
self.ds = xr.Dataset({'temperature': (('time', 'x', 'y'), data)},
coords={'time': np.arange(1000),
'x': np.arange(100),
'y': np.arange(100)})

def time_1d_coarse(self):
self.ds.reindex(time=np.arange(0, 1000, 5)).load()

def time_1d_fine_all_found(self):
self.ds.reindex(time=np.arange(0, 1000, 0.5), method='nearest').load()

def time_1d_fine_some_missing(self):
self.ds.reindex(time=np.arange(0, 1000, 0.5), method='nearest',
tolerance=0.1).load()

def time_2d_coarse(self):
self.ds.reindex(x=np.arange(0, 100, 2), y=np.arange(0, 100, 2)).load()

def time_2d_fine_all_found(self):
self.ds.reindex(x=np.arange(0, 100, 0.5), y=np.arange(0, 100, 0.5),
method='nearest').load()

def time_2d_fine_some_missing(self):
self.ds.reindex(x=np.arange(0, 100, 0.5), y=np.arange(0, 100, 0.5),
method='nearest', tolerance=0.1).load()


class ReindexDask(Reindex):
def setup(self):
requires_dask()
super(ReindexDask, self).setup()
self.ds = self.ds.chunk({'time': 100})
7 changes: 7 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ Enhancements
By `Zac Hatfield-Dodds <https://github.com/Zac-HD>`_.
- `.dt` accessor can now ceil, floor and round timestamps to specified frequency.
By `Deepak Cherian <https://github.com/dcherian>`_.
- Speed of reindexing/alignment with dask array is orders of magnitude faster
when inserting missing values (:issue:`1847`).
By `Stephan Hoyer <https://github.com/shoyer>`_.

.. _Zarr: http://zarr.readthedocs.io/

Expand Down Expand Up @@ -140,6 +143,10 @@ Bug fixes
``parse_coordinates`` kwarg has beed added to :py:func:`~open_rasterio`
(set to ``True`` per default).
By `Fabien Maussion <https://github.com/fmaussion>`_.
- Fixed dtype promotion rules in :py:func:`where` and :py:func:`concat` to
match pandas (:issue:`1847`). A combination of strings/numbers or
unicode/bytes now promote to object dtype, instead of strings or unicode.
By `Stephan Hoyer <https://github.com/shoyer>`_.

.. _whats-new.0.10.0:

Expand Down
101 changes: 34 additions & 67 deletions xarray/core/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,11 @@

import numpy as np

from . import duck_array_ops
from . import dtypes
from . import utils
from .indexing import get_indexer_nd
from .pycompat import iteritems, OrderedDict, suppress
from .utils import is_full_slice, is_dict_like
from .variable import Variable, IndexVariable
from .variable import IndexVariable


def _get_joiner(join):
Expand Down Expand Up @@ -306,59 +304,51 @@ def reindex_variables(variables, sizes, indexes, indexers, method=None,
from .dataarray import DataArray

# build up indexers for assignment along each dimension
to_indexers = {}
from_indexers = {}
int_indexers = {}
targets = {}
masked_dims = set()
unchanged_dims = set()

# size of reindexed dimensions
new_sizes = {}

for name, index in iteritems(indexes):
if name in indexers:
target = utils.safe_cast_to_index(indexers[name])
if not index.is_unique:
raise ValueError(
'cannot reindex or align along dimension %r because the '
'index has duplicate values' % name)
indexer = get_indexer_nd(index, target, method, tolerance)

target = utils.safe_cast_to_index(indexers[name])
new_sizes[name] = len(target)
# Note pandas uses negative values from get_indexer_nd to signify
# values that are missing in the index
# The non-negative values thus indicate the non-missing values
to_indexers[name] = indexer >= 0
if to_indexers[name].all():
# If an indexer includes no negative values, then the
# assignment can be to a full-slice (which is much faster,
# and means we won't need to fill in any missing values)
to_indexers[name] = slice(None)

from_indexers[name] = indexer[to_indexers[name]]
if np.array_equal(from_indexers[name], np.arange(len(index))):
# If the indexer is equal to the original index, use a full
# slice object to speed up selection and so we can avoid
# unnecessary copies
from_indexers[name] = slice(None)

int_indexer = get_indexer_nd(index, target, method, tolerance)

# We uses negative values from get_indexer_nd to signify
# values that are missing in the index.
if (int_indexer < 0).any():
masked_dims.add(name)
elif np.array_equal(int_indexer, np.arange(len(index))):
unchanged_dims.add(name)

int_indexers[name] = int_indexer
targets[name] = target

for dim in sizes:
if dim not in indexes and dim in indexers:
existing_size = sizes[dim]
new_size = utils.safe_cast_to_index(indexers[dim]).size
new_size = indexers[dim].size
if existing_size != new_size:
raise ValueError(
'cannot reindex or align along dimension %r without an '
'index because its size %r is different from the size of '
'the new index %r' % (dim, existing_size, new_size))

def any_not_full_slices(indexers):
return any(not is_full_slice(idx) for idx in indexers)

def var_indexers(var, indexers):
return tuple(indexers.get(d, slice(None)) for d in var.dims)

# create variables for the new dataset
reindexed = OrderedDict()

for dim, indexer in indexers.items():
if isinstance(indexer, DataArray) and indexer.dims != (dim, ):
if isinstance(indexer, DataArray) and indexer.dims != (dim,):
warnings.warn(
"Indexer has dimensions {0:s} that are different "
"from that to be indexed along {1:s}. "
Expand All @@ -375,47 +365,24 @@ def var_indexers(var, indexers):

for name, var in iteritems(variables):
if name not in indexers:
assign_to = var_indexers(var, to_indexers)
assign_from = var_indexers(var, from_indexers)

if any_not_full_slices(assign_to):
# there are missing values to in-fill
data = var[assign_from].data
dtype, fill_value = dtypes.maybe_promote(var.dtype)

if isinstance(data, np.ndarray):
shape = tuple(new_sizes.get(dim, size)
for dim, size in zip(var.dims, var.shape))
new_data = np.empty(shape, dtype=dtype)
new_data[...] = fill_value
# create a new Variable so we can use orthogonal indexing
# use fastpath=True to avoid dtype inference
new_var = Variable(var.dims, new_data, var.attrs,
fastpath=True)
new_var[assign_to] = data

else: # dask array
data = data.astype(dtype, copy=False)
for axis, indexer in enumerate(assign_to):
if not is_full_slice(indexer):
indices = np.cumsum(indexer)[~indexer]
data = duck_array_ops.insert(
data, indices, fill_value, axis=axis)
new_var = Variable(var.dims, data, var.attrs,
fastpath=True)

elif any_not_full_slices(assign_from):
# type coercion is not necessary as there are no missing
# values
new_var = var[assign_from]

else:
# no reindexing is necessary
key = tuple(slice(None)
if d in unchanged_dims
else int_indexers.get(d, slice(None))
for d in var.dims)
needs_masking = any(d in masked_dims for d in var.dims)

if needs_masking:
new_var = var._getitem_with_mask(key)
elif all(is_full_slice(k) for k in key):
# no reindexing necessary
# here we need to manually deal with copying data, since
# we neither created a new ndarray nor used fancy indexing
new_var = var.copy(deep=copy)
else:
new_var = var[key]

reindexed[name] = new_var

return reindexed


Expand Down
37 changes: 37 additions & 0 deletions xarray/core/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,17 @@
NA = utils.ReprObject('<NA>')


# Pairs of types that, if both found, should be promoted to object dtype
# instead of following NumPy's own type-promotion rules. These type promotion
# rules match pandas instead. For reference, see the NumPy type hierarchy:
# https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html
PROMOTE_TO_OBJECT = [
{np.number, np.character}, # numpy promotes to character
{np.bool_, np.character}, # numpy promotes to character
{np.bytes_, np.unicode_}, # numpy promotes to unicode
]


def maybe_promote(dtype):
"""Simpler equivalent of pandas.core.common._maybe_promote
Expand Down Expand Up @@ -60,3 +71,29 @@ def is_datetime_like(dtype):
"""
return (np.issubdtype(dtype, np.datetime64) or
np.issubdtype(dtype, np.timedelta64))


def result_type(*arrays_and_dtypes):
"""Like np.result_type, but with type promotion rules matching pandas.
Examples of changed behavior:
number + string -> object (not string)
bytes + unicode -> object (not unicode)
Parameters
----------
*arrays_and_dtypes : list of arrays and dtypes
The dtype is extracted from both numpy and dask arrays.
Returns
-------
numpy.dtype for the result.
"""
types = {np.result_type(t).type for t in arrays_and_dtypes}

for left, right in PROMOTE_TO_OBJECT:
if (any(issubclass(t, left) for t in types) and
any(issubclass(t, right) for t in types)):
return np.dtype(object)

return np.result_type(*arrays_and_dtypes)
32 changes: 29 additions & 3 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,13 @@ def isnull(data):


transpose = _dask_or_eager_func('transpose')
where = _dask_or_eager_func('where', n_array_args=3)
_where = _dask_or_eager_func('where', n_array_args=3)
insert = _dask_or_eager_func('insert')
take = _dask_or_eager_func('take')
broadcast_to = _dask_or_eager_func('broadcast_to')

concatenate = _dask_or_eager_func('concatenate', list_of_args=True)
stack = _dask_or_eager_func('stack', list_of_args=True)
_concatenate = _dask_or_eager_func('concatenate', list_of_args=True)
_stack = _dask_or_eager_func('stack', list_of_args=True)

array_all = _dask_or_eager_func('all')
array_any = _dask_or_eager_func('any')
Expand All @@ -100,6 +100,17 @@ def asarray(data):
return data if isinstance(data, dask_array_type) else np.asarray(data)


def as_shared_dtype(scalars_or_arrays):
"""Cast a arrays to a shared dtype using xarray's type promotion rules."""
arrays = [asarray(x) for x in scalars_or_arrays]
# Pass arrays directly instead of dtypes to result_type so scalars
# get handled properly.
# Note that result_type() safely gets the dtype from dask arrays without
# evaluating them.
out_type = dtypes.result_type(*arrays)
return [x.astype(out_type, copy=False) for x in arrays]


def as_like_arrays(*data):
if all(isinstance(d, dask_array_type) for d in data):
return data
Expand Down Expand Up @@ -151,6 +162,11 @@ def count(data, axis=None):
return sum(~isnull(data), axis=axis)


def where(condition, x, y):
"""Three argument where() with better dtype promotion rules."""
return _where(condition, *as_shared_dtype([x, y]))


def where_method(data, cond, other=dtypes.NA):
if other is dtypes.NA:
other = dtypes.get_fill_value(data.dtype)
Expand All @@ -161,6 +177,16 @@ def fillna(data, other):
return where(isnull(data), other, data)


def concatenate(arrays, axis=0):
"""concatenate() with better dtype promotion rules."""
return _concatenate(as_shared_dtype(arrays), axis=axis)


def stack(arrays, axis=0):
"""stack() with better dtype promotion rules."""
return _stack(as_shared_dtype(arrays), axis=axis)


@contextlib.contextmanager
def _ignore_warnings_if(condition):
if condition:
Expand Down
2 changes: 0 additions & 2 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1273,8 +1273,6 @@ def concat(cls, variables, dim='concat_dim', positions=None,

arrays = [v.data for v in variables]

# TODO: use our own type promotion rules to ensure that
# [str, float] -> object, not str like numpy
if dim in first_var.dims:
axis = first_var.get_axis_num(dim)
dims = first_var.dims
Expand Down
6 changes: 6 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1717,6 +1717,12 @@ def test_where(self):
actual = arr.where(arr.x < 2, drop=True)
assert_identical(actual, expected)

def test_where_string(self):
array = DataArray(['a', 'b'])
expected = DataArray(np.array(['a', np.nan], dtype=object))
actual = array.where([True, False])
assert_identical(actual, expected)

def test_cumops(self):
coords = {'x': [-1, -2], 'y': ['ab', 'cd', 'ef'],
'lat': (['x', 'y'], [[1, 2, 3], [-1, -2, -3]]),
Expand Down
Loading

0 comments on commit 2aa5b8a

Please sign in to comment.