Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: DataFrame reductions losing EA dtypes #52261

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pandas/core/array_algos/masked_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def mean(
skipna: bool = True,
axis: AxisInt | None = None,
):
if not values.size or mask.all():
if (not values.size or mask.all()) and (values.ndim == 1 or axis is None):
return libmissing.NA
return _reductions(np.mean, values=values, mask=mask, skipna=skipna, axis=axis)

Expand All @@ -168,7 +168,7 @@ def var(
axis: AxisInt | None = None,
ddof: int = 1,
):
if not values.size or mask.all():
if (not values.size or mask.all()) and (values.ndim == 1 or axis is None):
return libmissing.NA

return _reductions(
Expand All @@ -184,7 +184,7 @@ def std(
axis: AxisInt | None = None,
ddof: int = 1,
):
if not values.size or mask.all():
if (not values.size or mask.all()) and (values.ndim == 1 or axis is None):
return libmissing.NA

return _reductions(
Expand Down
7 changes: 7 additions & 0 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1222,6 +1222,7 @@ def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
------
TypeError : subclass does not define reductions
"""
keepdims = kwargs.pop("keepdims", False)
pa_type = self._pa_array.type

data_to_reduce = self._pa_array
Expand Down Expand Up @@ -1289,6 +1290,12 @@ def pyarrow_meth(data, skip_nulls, **kwargs):
f"upgrading pyarrow."
)
raise TypeError(msg) from err

if keepdims:
# TODO: is there a way to do this without .as_py()
result = pa.array([result.as_py()], type=result.type)
return type(self)(result)

if pc.is_null(result).as_py():
return self.dtype.na_value

Expand Down
10 changes: 10 additions & 0 deletions pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2049,6 +2049,11 @@ def min(self, *, skipna: bool = True, **kwargs):
-------
min : the minimum of this `Categorical`, NA value if empty
"""
keepdims = kwargs.pop("keepdims", False)
if keepdims:
result = self.min(skipna=skipna, **kwargs)
return type(self)([result], dtype=self.dtype)

nv.validate_minmax_axis(kwargs.get("axis", 0))
nv.validate_min((), kwargs)
self.check_for_ordered("min")
Expand Down Expand Up @@ -2081,6 +2086,11 @@ def max(self, *, skipna: bool = True, **kwargs):
-------
max : the maximum of this `Categorical`, NA if array is empty
"""
keepdims = kwargs.pop("keepdims", False)
if keepdims:
result = self.max(skipna=skipna, **kwargs)
return type(self)([result], dtype=self.dtype)

nv.validate_minmax_axis(kwargs.get("axis", 0))
nv.validate_max((), kwargs)
self.check_for_ordered("max")
Expand Down
14 changes: 11 additions & 3 deletions pandas/core/arrays/masked.py
Original file line number Diff line number Diff line change
Expand Up @@ -1074,20 +1074,27 @@ def _quantile(
# Reductions

def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
keepdims = kwargs.pop("keepdims", False)
if keepdims and "axis" not in kwargs:
return self.reshape(-1, 1)._reduce(name=name, skipna=skipna, **kwargs)

if name in {"any", "all", "min", "max", "sum", "prod", "mean", "var", "std"}:
return getattr(self, name)(skipna=skipna, **kwargs)

data = self._data
mask = self._mask

# median, skew, kurt, sem
axis = kwargs.pop("axis", 0)
op = getattr(nanops, f"nan{name}")
result = op(data, axis=0, skipna=skipna, mask=mask, **kwargs)
result = op(data, axis=axis, skipna=skipna, mask=mask, **kwargs)

if np.isnan(result):
return libmissing.NA

return result
return self._wrap_reduction_result(
name=name, result=result, skipna=skipna, axis=axis
)

def _wrap_reduction_result(self, name: str, result, skipna, **kwargs):
if isinstance(result, np.ndarray):
Expand All @@ -1098,7 +1105,8 @@ def _wrap_reduction_result(self, name: str, result, skipna, **kwargs):
else:
mask = self._mask.any(axis=axis)

return self._maybe_mask_result(result, mask)
if name not in ["argmin", "argmax"]:
return self._maybe_mask_result(result, mask)
return result

def sum(
Expand Down
22 changes: 17 additions & 5 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -10906,13 +10906,25 @@ def func(values: np.ndarray):
# We only use this in the case that operates on self.values
return op(values, axis=axis, skipna=skipna, **kwds)

is_am = isinstance(self._mgr, ArrayManager)

def blk_func(values, axis: Axis = 1):
if isinstance(values, ExtensionArray):
if not is_1d_only_ea_dtype(values.dtype) and not isinstance(
self._mgr, ArrayManager
):
return values._reduce(name, axis=1, skipna=skipna, **kwds)
return values._reduce(name, skipna=skipna, **kwds)
if not is_1d_only_ea_dtype(values.dtype):
if is_am:
# error: "ExtensionArray" has no attribute "reshape";
# maybe "shape"?
vals2d = values.reshape(1, -1) # type: ignore[attr-defined]
return vals2d._reduce(name, axis=1, skipna=skipna, **kwds)
else:
return values._reduce(name, axis=1, skipna=skipna, **kwds)

try:
return values._reduce(name, skipna=skipna, keepdims=True, **kwds)
except (TypeError, ValueError):
# no keepdims keyword yet; ValueError gets raised by
# util validator functions
return values._reduce(name, skipna=skipna, **kwds)
Comment on lines +10922 to +10927
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will just be the try portion when fully implemented, yea?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right. there would need to be a deprecation cycle to allow 3rd party EAs to catch up

else:
return op(values, axis=axis, skipna=skipna, **kwds)

Expand Down
22 changes: 11 additions & 11 deletions pandas/core/internals/array_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -976,24 +976,24 @@ def reduce(self, func: Callable) -> Self:
-------
ArrayManager
"""
result_arrays: list[np.ndarray] = []
result_arrays: list[ArrayLike] = []
for i, arr in enumerate(self.arrays):
res = func(arr, axis=0)

# TODO NaT doesn't preserve dtype, so we need to ensure to create
# a timedelta result array if original was timedelta
# what if datetime results in timedelta? (eg std)
dtype = arr.dtype if res is NaT else None
result_arrays.append(
sanitize_array([res], None, dtype=dtype) # type: ignore[arg-type]
)
if isinstance(res, (np.ndarray, ExtensionArray)):
# keepdims worked!
result_arrays.append(res)
else:
# TODO NaT doesn't preserve dtype, so we need to ensure to create
# a timedelta result array if original was timedelta
# what if datetime results in timedelta? (eg std)
dtype = arr.dtype if res is NaT else None
result_arrays.append(sanitize_array([res], None, dtype=dtype))
Comment on lines +983 to +991
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar - is the plan to be able to remove this else?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes


index = Index._simple_new(np.array([None], dtype=object)) # placeholder
columns = self.items

# error: Argument 1 to "ArrayManager" has incompatible type "List[ndarray]";
# expected "List[Union[ndarray, ExtensionArray]]"
new_mgr = type(self)(result_arrays, [index, columns]) # type: ignore[arg-type]
new_mgr = type(self)(result_arrays, [index, columns])
return new_mgr

def operate_blockwise(self, other: ArrayManager, array_op) -> ArrayManager:
Expand Down
6 changes: 5 additions & 1 deletion pandas/core/internals/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,11 @@ def reduce(self, func) -> list[Block]:

if self.values.ndim == 1:
# TODO(EA2D): special case not needed with 2D EAs
res_values = np.array([[result]])
if isinstance(result, (np.ndarray, ExtensionArray)):
# keepdims=True worked
res_values = result
else:
res_values = np.array([[result]])
else:
res_values = result.reshape(-1, 1)

Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/arrays/categorical/test_analytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def test_numpy_min_max_raises(self, method):
with pytest.raises(TypeError, match=re.escape(msg)):
method(cat)

@pytest.mark.parametrize("kwarg", ["axis", "out", "keepdims"])
@pytest.mark.parametrize("kwarg", ["axis", "out"])
@pytest.mark.parametrize("method", ["min", "max"])
def test_numpy_min_max_unsupported_kwargs_raises(self, method, kwarg):
cat = Categorical(["a", "b", "c", "b"], ordered=True)
Expand Down
27 changes: 18 additions & 9 deletions pandas/tests/frame/test_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,14 +693,7 @@ def test_std_timedelta64_skipna_false(self):
def test_std_datetime64_with_nat(
self, values, skipna, using_array_manager, request
):
# GH#51335
if using_array_manager and (
not skipna or all(value is pd.NaT for value in values)
):
mark = pytest.mark.xfail(
reason="GH#51446: Incorrect type inference on NaT in reduction result"
)
request.node.add_marker(mark)
# GH#51335, GH#51446
df = DataFrame({"a": to_datetime(values)})
result = df.std(skipna=skipna)
if not skipna or all(value is pd.NaT for value in values):
Expand Down Expand Up @@ -918,7 +911,7 @@ def test_mean_extensionarray_numeric_only_true(self):
arr = np.random.randint(1000, size=(10, 5))
df = DataFrame(arr, dtype="Int64")
result = df.mean(numeric_only=True)
expected = DataFrame(arr).mean()
expected = DataFrame(arr).mean().astype("Float64")
tm.assert_series_equal(result, expected)

def test_stats_mixed_type(self, float_string_frame):
Expand Down Expand Up @@ -1726,3 +1719,19 @@ def test_fails_on_non_numeric(kernel):
)
with pytest.raises(TypeError, match=msg):
getattr(df, kernel)(*args)


@pytest.mark.parametrize(
"dtype", ["Int64", pytest.param("int64[pyarrow]", marks=td.skip_if_no("pyarrow"))]
)
def test_Int64_mean_preserves_dtype(dtype):
# GH#42895
arr = np.random.randn(4, 3).astype("int64")
df = DataFrame(arr).astype(dtype)
df.iloc[:, 1] = pd.NA
assert (df.dtypes == dtype).all()

res = df.mean()
exp_dtype = "Float64" if dtype == "Int64" else "float64[pyarrow]"
expected = Series([arr[:, 0].mean(), pd.NA, arr[:, 2].mean()], dtype=exp_dtype)
tm.assert_series_equal(res, expected)