From 5f2bff93b66b018b8afb8687057cd5ca3f5c7002 Mon Sep 17 00:00:00 2001 From: Stephan Loyd Date: Sat, 3 Apr 2021 22:03:52 +0800 Subject: [PATCH] EHN: multi-column explode (#39240) --- doc/source/whatsnew/v1.3.0.rst | 1 + pandas/core/frame.py | 97 +++++++++++++++++----- pandas/tests/frame/methods/test_explode.py | 93 ++++++++++++++++++++- 3 files changed, 167 insertions(+), 24 deletions(-) diff --git a/doc/source/whatsnew/v1.3.0.rst b/doc/source/whatsnew/v1.3.0.rst index 1a5a9980e5e96..46ecce77798ef 100644 --- a/doc/source/whatsnew/v1.3.0.rst +++ b/doc/source/whatsnew/v1.3.0.rst @@ -274,6 +274,7 @@ Other enhancements - Add keyword ``dropna`` to :meth:`DataFrame.value_counts` to allow counting rows that include ``NA`` values (:issue:`41325`) - :meth:`Series.replace` will now cast results to ``PeriodDtype`` where possible instead of ``object`` dtype (:issue:`41526`) - Improved error message in ``corr`` and ``cov`` methods on :class:`.Rolling`, :class:`.Expanding`, and :class:`.ExponentialMovingWindow` when ``other`` is not a :class:`DataFrame` or :class:`Series` (:issue:`41741`) +- :meth:`DataFrame.explode` now supports exploding multiple columns. Its ``column`` argument now also accepts a list of str or tuples for exploding on multiple columns at the same time (:issue:`39240`) .. --------------------------------------------------------------------------- diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 2edad9f6626bb..b3c090a918b24 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -8151,7 +8151,11 @@ def stack(self, level: Level = -1, dropna: bool = True): return result.__finalize__(self, method="stack") - def explode(self, column: str | tuple, ignore_index: bool = False) -> DataFrame: + def explode( + self, + column: str | tuple | list[str | tuple], + ignore_index: bool = False, + ) -> DataFrame: """ Transform each element of a list-like to a row, replicating index values. @@ -8159,8 +8163,15 @@ def explode(self, column: str | tuple, ignore_index: bool = False) -> DataFrame: Parameters ---------- - column : str or tuple - Column to explode. + column : str or tuple or list thereof + Column(s) to explode. + For multiple columns, specify a non-empty list with each element + be str or tuple, and all specified columns their list-like data + on same row of the frame must have matching length. + + .. versionadded:: 1.3.0 + Multi-column explode + ignore_index : bool, default False If True, the resulting index will be labeled 0, 1, …, n - 1. @@ -8175,7 +8186,10 @@ def explode(self, column: str | tuple, ignore_index: bool = False) -> DataFrame: Raises ------ ValueError : - if columns of the frame are not unique. + * If columns of the frame are not unique. + * If specified columns to explode is empty list. + * If specified columns to explode have not matching count of + elements rowwise in the frame. See Also -------- @@ -8194,32 +8208,69 @@ def explode(self, column: str | tuple, ignore_index: bool = False) -> DataFrame: Examples -------- - >>> df = pd.DataFrame({'A': [[1, 2, 3], 'foo', [], [3, 4]], 'B': 1}) + >>> df = pd.DataFrame({'A': [[0, 1, 2], 'foo', [], [3, 4]], + ... 'B': 1, + ... 'C': [['a', 'b', 'c'], np.nan, [], ['d', 'e']]}) >>> df - A B - 0 [1, 2, 3] 1 - 1 foo 1 - 2 [] 1 - 3 [3, 4] 1 + A B C + 0 [0, 1, 2] 1 [a, b, c] + 1 foo 1 NaN + 2 [] 1 [] + 3 [3, 4] 1 [d, e] + + Single-column explode. >>> df.explode('A') - A B - 0 1 1 - 0 2 1 - 0 3 1 - 1 foo 1 - 2 NaN 1 - 3 3 1 - 3 4 1 - """ - if not (is_scalar(column) or isinstance(column, tuple)): - raise ValueError("column must be a scalar") + A B C + 0 0 1 [a, b, c] + 0 1 1 [a, b, c] + 0 2 1 [a, b, c] + 1 foo 1 NaN + 2 NaN 1 [] + 3 3 1 [d, e] + 3 4 1 [d, e] + + Multi-column explode. + + >>> df.explode(list('AC')) + A B C + 0 0 1 a + 0 1 1 b + 0 2 1 c + 1 foo 1 NaN + 2 NaN 1 NaN + 3 3 1 d + 3 4 1 e + """ if not self.columns.is_unique: raise ValueError("columns must be unique") + columns: list[str | tuple] + if is_scalar(column) or isinstance(column, tuple): + assert isinstance(column, (str, tuple)) + columns = [column] + elif isinstance(column, list) and all( + map(lambda c: is_scalar(c) or isinstance(c, tuple), column) + ): + if not column: + raise ValueError("column must be nonempty") + if len(column) > len(set(column)): + raise ValueError("column must be unique") + columns = column + else: + raise ValueError("column must be a scalar, tuple, or list thereof") + df = self.reset_index(drop=True) - result = df[column].explode() - result = df.drop([column], axis=1).join(result) + if len(columns) == 1: + result = df[columns[0]].explode() + else: + mylen = lambda x: len(x) if is_list_like(x) else -1 + counts0 = self[columns[0]].apply(mylen) + for c in columns[1:]: + if not all(counts0 == self[c].apply(mylen)): + raise ValueError("columns must have matching element counts") + result = DataFrame({c: df[c].explode() for c in columns}) + result = df.drop(columns, axis=1).join(result) if ignore_index: result.index = ibase.default_index(len(result)) else: diff --git a/pandas/tests/frame/methods/test_explode.py b/pandas/tests/frame/methods/test_explode.py index bd0901387eeed..6fdf5d806ac6b 100644 --- a/pandas/tests/frame/methods/test_explode.py +++ b/pandas/tests/frame/methods/test_explode.py @@ -9,7 +9,12 @@ def test_error(): df = pd.DataFrame( {"A": pd.Series([[0, 1, 2], np.nan, [], (3, 4)], index=list("abcd")), "B": 1} ) - with pytest.raises(ValueError, match="column must be a scalar"): + with pytest.raises( + ValueError, match="column must be a scalar, tuple, or list thereof" + ): + df.explode([list("AA")]) + + with pytest.raises(ValueError, match="column must be unique"): df.explode(list("AA")) df.columns = list("AA") @@ -17,6 +22,37 @@ def test_error(): df.explode("A") +@pytest.mark.parametrize( + "input_subset, error_message", + [ + ( + list("AC"), + "columns must have matching element counts", + ), + ( + [], + "column must be nonempty", + ), + ( + list("AC"), + "columns must have matching element counts", + ), + ], +) +def test_error_multi_columns(input_subset, error_message): + # GH 39240 + df = pd.DataFrame( + { + "A": [[0, 1, 2], np.nan, [], (3, 4)], + "B": 1, + "C": [["a", "b", "c"], "foo", [], ["d", "e", "f"]], + }, + index=list("abcd"), + ) + with pytest.raises(ValueError, match=error_message): + df.explode(input_subset) + + def test_basic(): df = pd.DataFrame( {"A": pd.Series([[0, 1, 2], np.nan, [], (3, 4)], index=list("abcd")), "B": 1} @@ -180,3 +216,58 @@ def test_explode_sets(): result = df.explode(column="a").sort_values(by="a") expected = pd.DataFrame({"a": ["x", "y"], "b": [1, 1]}, index=[1, 1]) tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "input_subset, expected_dict, expected_index", + [ + ( + list("AC"), + { + "A": pd.Series( + [0, 1, 2, np.nan, np.nan, 3, 4, np.nan], + index=list("aaabcdde"), + dtype=object, + ), + "B": 1, + "C": ["a", "b", "c", "foo", np.nan, "d", "e", np.nan], + }, + list("aaabcdde"), + ), + ( + list("A"), + { + "A": pd.Series( + [0, 1, 2, np.nan, np.nan, 3, 4, np.nan], + index=list("aaabcdde"), + dtype=object, + ), + "B": 1, + "C": [ + ["a", "b", "c"], + ["a", "b", "c"], + ["a", "b", "c"], + "foo", + [], + ["d", "e"], + ["d", "e"], + np.nan, + ], + }, + list("aaabcdde"), + ), + ], +) +def test_multi_columns(input_subset, expected_dict, expected_index): + # GH 39240 + df = pd.DataFrame( + { + "A": [[0, 1, 2], np.nan, [], (3, 4), np.nan], + "B": 1, + "C": [["a", "b", "c"], "foo", [], ["d", "e"], np.nan], + }, + index=list("abcde"), + ) + result = df.explode(input_subset) + expected = pd.DataFrame(expected_dict, expected_index) + tm.assert_frame_equal(result, expected)