Skip to content

Commit

Permalink
Backport PR pandas-dev#54341: PERF: axis=1 reductions with EA dtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
lukemanley authored and meeseeksmachine committed Aug 13, 2023
1 parent b215522 commit 7213ad4
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 1 deletion.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,7 @@ Performance improvements
- :class:`Period`'s default formatter (`period_format`) is now significantly (~twice) faster. This improves performance of ``str(Period)``, ``repr(Period)``, and :meth:`Period.strftime(fmt=None)`, as well as ``PeriodArray.strftime(fmt=None)``, ``PeriodIndex.strftime(fmt=None)`` and ``PeriodIndex.format(fmt=None)``. Finally, ``to_csv`` operations involving :class:`PeriodArray` or :class:`PeriodIndex` with default ``date_format`` are also significantly accelerated. (:issue:`51459`)
- Performance improvement accessing :attr:`arrays.IntegerArrays.dtype` & :attr:`arrays.FloatingArray.dtype` (:issue:`52998`)
- Performance improvement for :class:`DataFrameGroupBy`/:class:`SeriesGroupBy` aggregations (e.g. :meth:`DataFrameGroupBy.sum`) with ``engine="numba"`` (:issue:`53731`)
- Performance improvement in :class:`DataFrame` reductions with ``axis=1`` and extension dtypes (:issue:`54341`)
- Performance improvement in :class:`DataFrame` reductions with ``axis=None`` and extension dtypes (:issue:`54308`)
- Performance improvement in :class:`MultiIndex` and multi-column operations (e.g. :meth:`DataFrame.sort_values`, :meth:`DataFrame.groupby`, :meth:`Series.unstack`) when index/column values are already sorted (:issue:`53806`)
- Performance improvement in :class:`Series` reductions (:issue:`52341`)
Expand Down
26 changes: 26 additions & 0 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -11172,6 +11172,32 @@ def _get_data() -> DataFrame:
).iloc[:0]
result.index = df.index
return result

# kurtosis excluded since groupby does not implement it
if df.shape[1] and name != "kurt":
dtype = find_common_type([arr.dtype for arr in df._mgr.arrays])
if isinstance(dtype, ExtensionDtype):
# GH 54341: fastpath for EA-backed axis=1 reductions
# This flattens the frame into a single 1D array while keeping
# track of the row and column indices of the original frame. Once
# flattened, grouping by the row indices and aggregating should
# be equivalent to transposing the original frame and aggregating
# with axis=0.
name = {"argmax": "idxmax", "argmin": "idxmin"}.get(name, name)
df = df.astype(dtype, copy=False)
arr = concat_compat(list(df._iter_column_arrays()))
nrows, ncols = df.shape
row_index = np.tile(np.arange(nrows), ncols)
col_index = np.repeat(np.arange(ncols), nrows)
ser = Series(arr, index=col_index, copy=False)
result = ser.groupby(row_index).agg(name, **kwds)
result.index = df.index
if not skipna and name not in ("any", "all"):
mask = df.isna().to_numpy(dtype=np.bool_).any(axis=1)
other = -1 if name in ("idxmax", "idxmin") else lib.no_default
result = result.mask(mask, other)
return result

df = df.T

# After possibly _get_data and transposing, we are now in the
Expand Down
6 changes: 5 additions & 1 deletion pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ class providing the base-class of operations.
ExtensionArray,
FloatingArray,
IntegerArray,
SparseArray,
)
from pandas.core.base import (
PandasObject,
Expand Down Expand Up @@ -1909,7 +1910,10 @@ def array_func(values: ArrayLike) -> ArrayLike:
# and non-applicable functions
# try to python agg
# TODO: shouldn't min_count matter?
if how in ["any", "all", "std", "sem"]:
# TODO: avoid special casing SparseArray here
if how in ["any", "all"] and isinstance(values, SparseArray):
pass
elif how in ["any", "all", "std", "sem"]:
raise # TODO: re-raise as TypeError? should not be reached
else:
return result
Expand Down
77 changes: 77 additions & 0 deletions pandas/tests/frame/test_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1938,3 +1938,80 @@ def test_fails_on_non_numeric(kernel):
msg = "|".join([msg1, msg2])
with pytest.raises(TypeError, match=msg):
getattr(df, kernel)(*args)


@pytest.mark.parametrize(
"method",
[
"all",
"any",
"count",
"idxmax",
"idxmin",
"kurt",
"kurtosis",
"max",
"mean",
"median",
"min",
"nunique",
"prod",
"product",
"sem",
"skew",
"std",
"sum",
"var",
],
)
@pytest.mark.parametrize("min_count", [0, 2])
def test_numeric_ea_axis_1(method, skipna, min_count, any_numeric_ea_dtype):
# GH 54341
df = DataFrame(
{
"a": Series([0, 1, 2, 3], dtype=any_numeric_ea_dtype),
"b": Series([0, 1, pd.NA, 3], dtype=any_numeric_ea_dtype),
},
)
expected_df = DataFrame(
{
"a": [0.0, 1.0, 2.0, 3.0],
"b": [0.0, 1.0, np.nan, 3.0],
},
)
if method in ("count", "nunique"):
expected_dtype = "int64"
elif method in ("all", "any"):
expected_dtype = "boolean"
elif method in (
"kurt",
"kurtosis",
"mean",
"median",
"sem",
"skew",
"std",
"var",
) and not any_numeric_ea_dtype.startswith("Float"):
expected_dtype = "Float64"
else:
expected_dtype = any_numeric_ea_dtype

kwargs = {}
if method not in ("count", "nunique", "quantile"):
kwargs["skipna"] = skipna
if method in ("prod", "product", "sum"):
kwargs["min_count"] = min_count

warn = None
msg = None
if not skipna and method in ("idxmax", "idxmin"):
warn = FutureWarning
msg = f"The behavior of DataFrame.{method} with all-NA values"
with tm.assert_produces_warning(warn, match=msg):
result = getattr(df, method)(axis=1, **kwargs)
with tm.assert_produces_warning(warn, match=msg):
expected = getattr(expected_df, method)(axis=1, **kwargs)
if method not in ("idxmax", "idxmin"):
expected = expected.astype(expected_dtype)
tm.assert_series_equal(result, expected)

0 comments on commit 7213ad4

Please sign in to comment.