Skip to content

Commit

Permalink
multi-column explode
Browse files Browse the repository at this point in the history
  • Loading branch information
iynehz committed Apr 4, 2021
1 parent 4554635 commit 34a099d
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 23 deletions.
75 changes: 53 additions & 22 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -7925,7 +7925,9 @@ def stack(self, level: Level = -1, dropna: bool = True):
return result.__finalize__(self, method="stack")

def explode(
self, column: Union[str, Tuple], ignore_index: bool = False
self,
column: Union[str, Tuple, List[Union[str, Tuple]]],
ignore_index: bool = False,
) -> DataFrame:
"""
Transform each element of a list-like to a row, replicating index values.
Expand All @@ -7934,8 +7936,8 @@ def explode(
Parameters
----------
column : str or tuple
Column to explode.
column : str or tuple or list thereof
Column(s) to explode.
ignore_index : bool, default False
If True, the resulting index will be labeled 0, 1, …, n - 1.
Expand Down Expand Up @@ -7969,32 +7971,61 @@ def explode(
Examples
--------
>>> df = pd.DataFrame({'A': [[1, 2, 3], 'foo', [], [3, 4]], 'B': 1})
>>> df = pd.DataFrame({'A': [[0, 1, 2], 'foo', [], [3, 4]],
... 'B': 1,
... 'C': [['a', 'b', 'c'], np.nan, [], ['d', 'e']]})
>>> df
A B
0 [1, 2, 3] 1
1 foo 1
2 [] 1
3 [3, 4] 1
A B C
0 [0, 1, 2] 1 [a, b, c]
1 foo 1 NaN
2 [] 1 []
3 [3, 4] 1 [d, e]
>>> df.explode('A')
A B
0 1 1
0 2 1
0 3 1
1 foo 1
2 NaN 1
3 3 1
3 4 1
"""
if not (is_scalar(column) or isinstance(column, tuple)):
raise ValueError("column must be a scalar")
A B C
0 0 1 [a, b, c]
0 1 1 [a, b, c]
0 2 1 [a, b, c]
1 foo 1 NaN
2 NaN 1 []
3 3 1 [d, e]
3 4 1 [d, e]
>>> df.explode(list('AC'))
A B C
0 0 1 a
0 1 1 b
0 2 1 c
1 foo 1 NaN
2 NaN 1 NaN
3 3 1 d
3 4 1 e
"""
if not self.columns.is_unique:
raise ValueError("columns must be unique")
if is_scalar(column) or isinstance(column, tuple):
columns = [column]
elif isinstance(column, list) and all(
map(lambda c: is_scalar(c) or isinstance(c, tuple), column)
):
if len(column) > len(set(column)):
raise ValueError("column must be unique")
# mypy: Incompatible types in assignment (expression has type
# "List[Union[str, Tuple[Any, ...]]]", variable has type
# "List[Union[str, Tuple[Any, ...], List[Union[str, Tuple[Any, ...]]]]]")
columns = column # type: ignore[assignment]
else:
raise ValueError("column must be a scalar, tuple, or list thereof")

mylen = lambda x: len(x) if is_list_like(x) else -1
counts0 = self[columns[0]].apply(mylen)
for c in columns[1:]:
if not all(counts0 == self[c].apply(mylen)):
raise ValueError("columns must have matching element counts")

df = self.reset_index(drop=True)
result = df[column].explode()
result = df.drop([column], axis=1).join(result)
result = DataFrame({c: df[c].explode() for c in columns})
result = df.drop(columns, axis=1).join(result)
if ignore_index:
result.index = ibase.default_index(len(result))
else:
Expand Down
33 changes: 32 additions & 1 deletion pandas/tests/frame/methods/test_explode.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,23 @@ def test_error():
df = pd.DataFrame(
{"A": pd.Series([[0, 1, 2], np.nan, [], (3, 4)], index=list("abcd")), "B": 1}
)
with pytest.raises(ValueError, match="column must be a scalar"):
with pytest.raises(
ValueError, match="column must be a scalar, tuple, or list thereof"
):
df.explode([list("AA")])

with pytest.raises(ValueError, match="column must be unique"):
df.explode(list("AA"))

df.columns = list("AA")
with pytest.raises(ValueError, match="columns must be unique"):
df.explode("A")

df1 = df.assign(C=[["a", "b", "c"], "foo", [], ["d", "e", "f"]])
df1.columns = list("ABC")
with pytest.raises(ValueError, match="columns must have matching element counts"):
df1.explode(list("AC"))


def test_basic():
df = pd.DataFrame(
Expand Down Expand Up @@ -180,3 +190,24 @@ def test_explode_sets():
result = df.explode(column="a").sort_values(by="a")
expected = pd.DataFrame({"a": ["x", "y"], "b": [1, 1]}, index=[1, 1])
tm.assert_frame_equal(result, expected)


def test_multi_columns():
df = pd.DataFrame(
{
"A": pd.Series([[0, 1, 2], np.nan, [], (3, 4)], index=list("abcd")),
"B": 1,
"C": [["a", "b", "c"], "foo", [], ["d", "e"]],
}
)
result = df.explode(list("AC"))
expected = pd.DataFrame(
{
"A": pd.Series(
[0, 1, 2, np.nan, np.nan, 3, 4], index=list("aaabcdd"), dtype=object
),
"B": 1,
"C": ["a", "b", "c", "foo", np.nan, "d", "e"],
}
)
tm.assert_frame_equal(result, expected)

0 comments on commit 34a099d

Please sign in to comment.