Skip to content

Commit

Permalink
Add 'out' keyword to argmin/argmax methods - allow numpy call signature
Browse files Browse the repository at this point in the history
When np.argmin(da) is called, numpy passes an 'out' keyword argument to
argmin/argmax. Need to allow this argument to avoid errors (but an
exception is thrown if out is not None).
  • Loading branch information
johnomotani committed Apr 10, 2020
1 parent cb6742d commit ab480b5
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 6 deletions.
12 changes: 10 additions & 2 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3728,6 +3728,7 @@ def argmin(
axis: Union[int, None] = None,
keep_attrs: bool = None,
skipna: bool = None,
out=None,
) -> Union["DataArray", Dict[Hashable, "DataArray"]]:
"""Indices of the minimum of the DataArray over one or more dimensions. Result
returned as dict of DataArrays, which can be passed directly to isel().
Expand All @@ -3752,6 +3753,9 @@ def argmin(
skips missing values for float dtypes; other dtypes either do not
have a sentinel missing value (int) or skipna=True has not been
implemented (object, datetime64 or timedelta64).
out : None
'out' should not be passed - provided for compatibility with numpy function
signature
Returns
-------
Expand Down Expand Up @@ -3812,7 +3816,7 @@ def argmin(
array([ 1, -5, 1])
Dimensions without coordinates: y
"""
result = self.variable.argmin(dim, axis, keep_attrs, skipna)
result = self.variable.argmin(dim, axis, keep_attrs, skipna, out)
if isinstance(result, dict):
return {k: self._replace_maybe_drop_dims(v) for k, v in result.items()}
else:
Expand All @@ -3824,6 +3828,7 @@ def argmax(
axis: Union[int, None] = None,
keep_attrs: bool = None,
skipna: bool = None,
out=None,
) -> Union["DataArray", Dict[Hashable, "DataArray"]]:
"""Indices of the maximum of the DataArray over one or more dimensions. Result
returned as dict of DataArrays, which can be passed directly to isel().
Expand All @@ -3848,6 +3853,9 @@ def argmax(
skips missing values for float dtypes; other dtypes either do not
have a sentinel missing value (int) or skipna=True has not been
implemented (object, datetime64 or timedelta64).
out : None
'out' should not be passed - provided for compatibility with numpy function
signature
Returns
-------
Expand Down Expand Up @@ -3909,7 +3917,7 @@ def argmax(
array([3, 5, 3])
Dimensions without coordinates: y
"""
result = self.variable.argmax(dim, axis, keep_attrs, skipna)
result = self.variable.argmax(dim, axis, keep_attrs, skipna, out)
if isinstance(result, dict):
return {k: self._replace_maybe_drop_dims(v) for k, v in result.items()}
else:
Expand Down
17 changes: 13 additions & 4 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -2086,6 +2086,7 @@ def _unravel_argminmax(
axis: Union[int, None],
keep_attrs: Optional[bool],
skipna: Optional[bool],
out,
) -> Union["Variable", Dict[Hashable, "Variable"]]:
"""Apply argmin or argmax over one or more dimensions, returning the result as a
dict of DataArray that can be passed directly to isel.
Expand All @@ -2110,7 +2111,7 @@ def _unravel_argminmax(
# Return int index if single dimension is passed, and is not part of a
# sequence
return getattr(self, "_injected_" + str(argminmax))(
dim=dim, axis=axis, keep_attrs=keep_attrs, skipna=skipna
dim=dim, axis=axis, keep_attrs=keep_attrs, skipna=skipna, out=out
)

# Get a name for the new dimension that does not conflict with any existing
Expand All @@ -2127,7 +2128,7 @@ def _unravel_argminmax(
reduce_shape = tuple(self.sizes[d] for d in dim)

result_flat_indices = getattr(stacked, "_injected_" + str(argminmax))(
axis=-1, skipna=skipna
axis=-1, skipna=skipna, out=out
)

result_unravelled_indices = np.unravel_index(result_flat_indices, reduce_shape)
Expand All @@ -2151,6 +2152,7 @@ def argmin(
axis: Union[int, None] = None,
keep_attrs: bool = None,
skipna: bool = None,
out=None,
) -> Union["Variable", Dict[Hashable, "Variable"]]:
"""Indices of the minimum of the DataArray over one or more dimensions. Result
returned as dict of DataArrays, which can be passed directly to isel().
Expand All @@ -2175,6 +2177,9 @@ def argmin(
skips missing values for float dtypes; other dtypes either do not
have a sentinel missing value (int) or skipna=True has not been
implemented (object, datetime64 or timedelta64).
out : None
'out' should not be passed - provided for compatibility with numpy function
signature
Returns
-------
Expand All @@ -2184,14 +2189,15 @@ def argmin(
--------
DataArray.argmin, DataArray.idxmin
"""
return self._unravel_argminmax("argmin", dim, axis, keep_attrs, skipna)
return self._unravel_argminmax("argmin", dim, axis, keep_attrs, skipna, out)

def argmax(
self,
dim: Union[Hashable, Sequence[Hashable]] = None,
axis: Union[int, None] = None,
keep_attrs: bool = None,
skipna: bool = None,
out=None,
) -> Union["Variable", Dict[Hashable, "Variable"]]:
"""Indices of the maximum of the DataArray over one or more dimensions. Result
returned as dict of DataArrays, which can be passed directly to isel().
Expand All @@ -2216,6 +2222,9 @@ def argmax(
skips missing values for float dtypes; other dtypes either do not
have a sentinel missing value (int) or skipna=True has not been
implemented (object, datetime64 or timedelta64).
out : None
'out' should not be passed - provided for compatibility with numpy function
signature
Returns
-------
Expand All @@ -2225,7 +2234,7 @@ def argmax(
--------
DataArray.argmax, DataArray.idxmax
"""
return self._unravel_argminmax("argmax", dim, axis, keep_attrs, skipna)
return self._unravel_argminmax("argmax", dim, axis, keep_attrs, skipna, out)


ops.inject_all_ops_and_reduce_methods(Variable)
Expand Down

0 comments on commit ab480b5

Please sign in to comment.