Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix in vectorized item assignment #1746

Merged
merged 14 commits into from
Dec 9, 2017
Merged
18 changes: 18 additions & 0 deletions xarray/core/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,3 +301,21 @@ def __getitem__(self, key):

def __unicode__(self):
return formatting.indexes_repr(self)


def assert_coordinate_consistent(obj, coords):
""" Maeke sure the dimension coordinate of obj is
consistent with coords.

obj: DataArray or Dataset
coords: Dict-like of variables
"""
for k in obj.dims:
# make sure there are no conflict in dimension coordinates
if (k in coords and k in obj.coords):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you can drop the extra parentheses here inside if

coord = getattr(coords[k], 'variable', coords[k]) # Variable
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better to insist that coords always has the same type (e.g., a dict of with Variable values).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's really reasonable. Updated.

if not coord.equals(obj[k].variable):
raise IndexError(
'dimension coordinate {!r} conflicts between '
'indexed and indexing objects:\n{}\nvs.\n{}'
.format(k, obj[k], coords[k]))
6 changes: 5 additions & 1 deletion xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .alignment import align, reindex_like_indexers
from .common import AbstractArray, BaseDataObject
from .coordinates import (DataArrayCoordinates, LevelCoordinatesSource,
Indexes)
Indexes, assert_coordinate_consistent)
from .dataset import Dataset, merge_indexes, split_indexes
from .pycompat import iteritems, basestring, OrderedDict, zip, range
from .variable import (as_variable, Variable, as_compatible_data,
Expand Down Expand Up @@ -484,6 +484,10 @@ def __setitem__(self, key, value):
if isinstance(key, basestring):
self.coords[key] = value
else:
# Coordinates in key, value and self[key] should be consistent.
obj = self[key]
if isinstance(value, DataArray):
assert_coordinate_consistent(value, obj.coords)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was actually thinking of checking the consistency of coords on each DataArray argument in key instead. I guess we should probably check both!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if you use obj.coords.variables here we can skip the awkward getattr(coords[k], 'variable', coords[k]) above.

# DataArray key -> Variable key
key = {k: v.variable if isinstance(v, DataArray) else v
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to enforce consistency for coordinates here? My inclination would be that we should support exactly the same keys in setitem as are valid in getitem. Ideally we should also reuse the same code. That means we should raise errors if there are multiple indexers with inconsistent alignment.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My inclination would be that we should support exactly the same keys in setitem as are valid in getitem.

Reasonable. I will add a validation.

for k, v in self._item_key_to_dict(key).items()}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to check that coordinates are consistent on the key?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is done a few lines above, obj = self[key], where in .isel we check the coordinates in the key.

But I am wondering this unnecessary indexing, though I think this implementation is the simplest.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. This could indeed be a significant performance hit. That said, I'm OK leaving this for now, with a TODO note to optimize it later.

Expand Down
13 changes: 3 additions & 10 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
from . import duck_array_ops
from .. import conventions
from .alignment import align
from .coordinates import DatasetCoordinates, LevelCoordinatesSource, Indexes
from .coordinates import (DatasetCoordinates, LevelCoordinatesSource, Indexes,
assert_coordinate_consistent)
from .common import ImplementsDatasetReduce, BaseDataObject
from .dtypes import is_datetime_like
from .merge import (dataset_update_method, dataset_merge_method,
Expand Down Expand Up @@ -1305,15 +1306,7 @@ def _get_indexers_coordinates(self, indexers):
# we don't need to call align() explicitly, because merge_variables
# already checks for exact alignment between dimension coordinates
coords = merge_variables(coord_list)

for k in self.dims:
# make sure there are not conflict in dimension coordinates
if (k in coords and k in self._variables and
not coords[k].equals(self._variables[k])):
raise IndexError(
'dimension coordinate {!r} conflicts between '
'indexed and indexing objects:\n{}\nvs.\n{}'
.format(k, self._variables[k], coords[k]))
assert_coordinate_consistent(self, coords)

attached_coords = OrderedDict()
for k, v in coords.items(): # silently drop the conflicted variables.
Expand Down
55 changes: 55 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,61 @@ def test_setitem_fancy(self):
expected = DataArray([[0, 0], [0, 0], [1, 1]], dims=['x', 'y'])
self.assertVariableIdentical(expected, da)

def test_setitem_dataarray(self):
def get_data():
return DataArray(np.ones((4, 3, 2)), dims=['x', 'y', 'z'],
coords={'x': np.arange(4), 'y': ['a', 'b', 'c'],
'non-dim': ('x', [1, 3, 4, 2])})

da = get_data()
# indexer with inconsistent coordinates.
ind = DataArray(np.arange(1, 4), dims=['x'],
coords={'x': np.random.randn(3)})
with raises_regex(IndexError, "dimension coordinate 'x'"):
da[dict(x=ind)] = 0

# indexer with consistent coordinates.
ind = DataArray(np.arange(1, 4), dims=['x'],
coords={'x': np.arange(1, 4)})
da[dict(x=ind)] = 0 # should not raise
assert np.allclose(da[dict(x=ind)].values, 0)
self.assertDataArrayIdentical(da['x'], get_data()['x'])
self.assertDataArrayIdentical(da['non-dim'], get_data()['non-dim'])

da = get_data()
# conflict in the assigning values
value = xr.DataArray(np.zeros((3, 3, 2)), dims=['x', 'y', 'z'],
coords={'x': [0, 1, 2],
'non-dim': ('x', [0, 2, 4])})
with raises_regex(IndexError, "dimension coordinate 'x'"):
da[dict(x=ind)] = value

# consistent coordinate in the assigning values
value = xr.DataArray(np.zeros((3, 3, 2)), dims=['x', 'y', 'z'],
coords={'x': [1, 2, 3],
'non-dim': ('x', [0, 2, 4])})
da[dict(x=ind)] = value
assert np.allclose(da[dict(x=ind)].values, 0)
self.assertDataArrayIdentical(da['x'], get_data()['x'])
self.assertDataArrayIdentical(da['non-dim'], get_data()['non-dim'])

# Conflict in the non-dimension coordinate
value = xr.DataArray(np.zeros((3, 3, 2)), dims=['x', 'y', 'z'],
coords={'x': [1, 2, 3],
'non-dim': ('x', [0, 2, 4])})
# conflict in the assigning values
value = xr.DataArray(np.zeros((3, 3, 2)), dims=['x', 'y', 'z'],
coords={'x': [0, 1, 2],
'non-dim': ('x', [0, 2, 4])})
with raises_regex(IndexError, "dimension coordinate 'x'"):
da[dict(x=ind)] = value

# consistent coordinate in the assigning values
value = xr.DataArray(np.zeros((3, 3, 2)), dims=['x', 'y', 'z'],
coords={'x': [1, 2, 3],
'non-dim': ('x', [0, 2, 4])})
da[dict(x=ind)] = value # should not raise

def test_contains(self):
data_array = DataArray(1, coords={'x': 2})
with pytest.warns(FutureWarning):
Expand Down
15 changes: 6 additions & 9 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1001,15 +1001,12 @@ def test_isel_dataarray(self):
# Conflict in the dimension coordinate
indexing_da = DataArray(np.arange(1, 4), dims=['dim2'],
coords={'dim2': np.random.randn(3)})
with raises_regex(
IndexError, "dimension coordinate 'dim2'"):
with raises_regex(IndexError, "dimension coordinate 'dim2'"):
actual = data.isel(dim2=indexing_da)
# Also the case for DataArray
with raises_regex(
IndexError, "dimension coordinate 'dim2'"):
with raises_regex(IndexError, "dimension coordinate 'dim2'"):
actual = data['var2'].isel(dim2=indexing_da)
with raises_regex(
IndexError, "dimension coordinate 'dim2'"):
with raises_regex(IndexError, "dimension coordinate 'dim2'"):
data['dim2'].isel(dim2=indexing_da)

# same name coordinate which does not conflict
Expand Down Expand Up @@ -1502,7 +1499,7 @@ def test_reindex_like(self):

expected = data.copy(deep=True)
expected['dim3'] = ('dim3', list('cdefghijkl'))
expected['var3'][:-2] = expected['var3'][2:]
expected['var3'][:-2] = expected['var3'][2:].values
expected['var3'][-2:] = np.nan
expected['letters'] = expected['letters'].astype(object)
expected['letters'][-2:] = np.nan
Expand Down Expand Up @@ -1614,9 +1611,9 @@ def test_align(self):
left = create_test_data()
right = left.copy(deep=True)
right['dim3'] = ('dim3', list('cdefghijkl'))
right['var3'][:-2] = right['var3'][2:]
right['var3'][:-2] = right['var3'][2:].values
right['var3'][-2:] = np.random.randn(*right['var3'][-2:].shape)
right['numbers'][:-2] = right['numbers'][2:]
right['numbers'][:-2] = right['numbers'][2:].values
right['numbers'][-2:] = -10

intersection = list('cdefghij')
Expand Down