Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix in vectorized item assignment #1746

Merged
merged 14 commits into from
Dec 9, 2017
Merged
11 changes: 5 additions & 6 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,20 +637,19 @@ def __setitem__(self, key, value):
"""
dims, index_tuple, new_order = self._broadcast_indexes(key)

if isinstance(value, Variable):
value = value.set_dims(dims).data
else: # first broadcast value
if not isinstance(value, Variable):
value = as_compatible_data(value)
if value.ndim > len(dims):
raise ValueError(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need this special case error message now that we call set_dims below?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think value = Variable(dims[-value.ndim:], value) fails if value.ndim > len(dims).

'shape mismatch: value array of shape %s could not be'
'broadcast to indexing result with %s dimensions'
% (value.shape, len(dims)))

if value.ndim == 0:
value = Variable((), value).set_dims(dims).data
value = Variable((), value)
else:
value = Variable(dims[-value.ndim:], value).set_dims(dims).data
value = Variable(dims[-value.ndim:], value)
# broadcast to become assignable
value = value.set_dims(dims).data
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I decided to revert nputils.NumpyVindexAdapter to the master version and to make a broadcasting here.


if new_order:
value = duck_array_ops.asarray(value)
Expand Down
12 changes: 0 additions & 12 deletions xarray/tests/test_nputils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,3 @@ def test_vindex():
vindex[[0, 1], [0, 1], :] = vindex[[0, 1], [0, 1], :]
vindex[[0, 1], :, [0, 1]] = vindex[[0, 1], :, [0, 1]]
vindex[:, [0, 1], [0, 1]] = vindex[:, [0, 1], [0, 1]]

def test_vindex_4d():
x = np.arange(3 * 4 * 5 * 6).reshape((3, 4, 5, 6))
vindex = NumpyVIndexAdapter(x)

# getitem
assert_array_equal(vindex[0], x[0])
assert_array_equal(vindex[[1, 2], [1, 2]], x[[1, 2], [1, 2]])
assert vindex[[0, 1], [0, 1], :].shape == (2, 5, 6)
assert vindex[[0, 1], :, [0, 1]].shape == (2, 4, 6)
assert vindex[:, [0, 1], [0, 1]].shape == (2, 3, 6)
assert vindex[:, [0, 1], :, [0, 1]].shape == (2, 3, 5)