diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 484b01f2c04f0f..46a0bd3483884e 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -7925,7 +7925,9 @@ def stack(self, level: Level = -1, dropna: bool = True): return result.__finalize__(self, method="stack") def explode( - self, column: Union[str, Tuple], ignore_index: bool = False + self, + column: Union[str, Tuple, List[Union[str, Tuple]]], + ignore_index: bool = False, ) -> DataFrame: """ Transform each element of a list-like to a row, replicating index values. @@ -7934,8 +7936,8 @@ def explode( Parameters ---------- - column : str or tuple - Column to explode. + column : str or tuple or list thereof + Column(s) to explode. ignore_index : bool, default False If True, the resulting index will be labeled 0, 1, …, n - 1. @@ -7969,32 +7971,61 @@ def explode( Examples -------- - >>> df = pd.DataFrame({'A': [[1, 2, 3], 'foo', [], [3, 4]], 'B': 1}) + >>> df = pd.DataFrame({'A': [[0, 1, 2], 'foo', [], [3, 4]], + ... 'B': 1, + ... 'C': [['a', 'b', 'c'], np.nan, [], ['d', 'e']]}) >>> df - A B - 0 [1, 2, 3] 1 - 1 foo 1 - 2 [] 1 - 3 [3, 4] 1 + A B C + 0 [0, 1, 2] 1 [a, b, c] + 1 foo 1 NaN + 2 [] 1 [] + 3 [3, 4] 1 [d, e] >>> df.explode('A') - A B - 0 1 1 - 0 2 1 - 0 3 1 - 1 foo 1 - 2 NaN 1 - 3 3 1 - 3 4 1 - """ - if not (is_scalar(column) or isinstance(column, tuple)): - raise ValueError("column must be a scalar") + A B C + 0 0 1 [a, b, c] + 0 1 1 [a, b, c] + 0 2 1 [a, b, c] + 1 foo 1 NaN + 2 NaN 1 [] + 3 3 1 [d, e] + 3 4 1 [d, e] + + >>> df.explode(list('AC')) + A B C + 0 0 1 a + 0 1 1 b + 0 2 1 c + 1 foo 1 NaN + 2 NaN 1 NaN + 3 3 1 d + 3 4 1 e + """ if not self.columns.is_unique: raise ValueError("columns must be unique") + if is_scalar(column) or isinstance(column, tuple): + columns = [column] + elif isinstance(column, list) and all( + map(lambda c: is_scalar(c) or isinstance(c, tuple), column) + ): + if len(column) > len(set(column)): + raise ValueError("column must be unique") + # mypy: Incompatible types in assignment (expression has type + # "List[Union[str, Tuple[Any, ...]]]", variable has type + # "List[Union[str, Tuple[Any, ...], List[Union[str, Tuple[Any, ...]]]]]") + columns = column # type: ignore[assignment] + else: + raise ValueError("column must be a scalar, tuple, or list thereof") + + mylen = lambda x: len(x) if is_list_like(x) else -1 + counts0 = self[columns[0]].apply(mylen) + for c in columns[1:]: + if not all(counts0 == self[c].apply(mylen)): + raise ValueError("columns must have matching element counts") df = self.reset_index(drop=True) - result = df[column].explode() - result = df.drop([column], axis=1).join(result) + result = DataFrame({c: df[c].explode() for c in columns}) + result = df.drop(columns, axis=1).join(result) if ignore_index: result.index = ibase.default_index(len(result)) else: diff --git a/pandas/tests/frame/methods/test_explode.py b/pandas/tests/frame/methods/test_explode.py index bd0901387eeedc..56c3038f472afc 100644 --- a/pandas/tests/frame/methods/test_explode.py +++ b/pandas/tests/frame/methods/test_explode.py @@ -9,13 +9,23 @@ def test_error(): df = pd.DataFrame( {"A": pd.Series([[0, 1, 2], np.nan, [], (3, 4)], index=list("abcd")), "B": 1} ) - with pytest.raises(ValueError, match="column must be a scalar"): + with pytest.raises( + ValueError, match="column must be a scalar, tuple, or list thereof" + ): + df.explode([list("AA")]) + + with pytest.raises(ValueError, match="column must be unique"): df.explode(list("AA")) df.columns = list("AA") with pytest.raises(ValueError, match="columns must be unique"): df.explode("A") + df1 = df.assign(C=[["a", "b", "c"], "foo", [], ["d", "e", "f"]]) + df1.columns = list("ABC") + with pytest.raises(ValueError, match="columns must have matching element counts"): + df1.explode(list("AC")) + def test_basic(): df = pd.DataFrame( @@ -180,3 +190,24 @@ def test_explode_sets(): result = df.explode(column="a").sort_values(by="a") expected = pd.DataFrame({"a": ["x", "y"], "b": [1, 1]}, index=[1, 1]) tm.assert_frame_equal(result, expected) + + +def test_multi_columns(): + df = pd.DataFrame( + { + "A": pd.Series([[0, 1, 2], np.nan, [], (3, 4)], index=list("abcd")), + "B": 1, + "C": [["a", "b", "c"], "foo", [], ["d", "e"]], + } + ) + result = df.explode(list("AC")) + expected = pd.DataFrame( + { + "A": pd.Series( + [0, 1, 2, np.nan, np.nan, 3, 4], index=list("aaabcdd"), dtype=object + ), + "B": 1, + "C": ["a", "b", "c", "foo", np.nan, "d", "e"], + } + ) + tm.assert_frame_equal(result, expected)