Skip to content

Commit

Permalink
TST: add test case for user-defined function taking correct path in g…
Browse files Browse the repository at this point in the history
…roupby transform (pandas-dev#29631)
  • Loading branch information
jbrockmendel authored and jacobaustin123 committed Nov 20, 2019
1 parent 0939732 commit 00cfdf5
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 3 deletions.
6 changes: 3 additions & 3 deletions pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1382,7 +1382,7 @@ def _define_paths(self, func, *args, **kwargs):
)
return fast_path, slow_path

def _choose_path(self, fast_path, slow_path, group):
def _choose_path(self, fast_path: Callable, slow_path: Callable, group: DataFrame):
path = slow_path
res = slow_path(group)

Expand All @@ -1392,8 +1392,8 @@ def _choose_path(self, fast_path, slow_path, group):
except AssertionError:
raise
except Exception:
# Hard to know ex-ante what exceptions `fast_path` might raise
# TODO: no test cases get here
# GH#29631 For user-defined function, we cant predict what may be
# raised; see test_transform.test_transform_fastpath_raises
return path, res

# verify fast path does not change columns (and names), otherwise
Expand Down
30 changes: 30 additions & 0 deletions pandas/tests/groupby/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,3 +1073,33 @@ def test_transform_lambda_with_datetimetz():
name="time",
)
tm.assert_series_equal(result, expected)


def test_transform_fastpath_raises():
# GH#29631 case where fastpath defined in groupby.generic _choose_path
# raises, but slow_path does not

df = pd.DataFrame({"A": [1, 1, 2, 2], "B": [1, -1, 1, 2]})
gb = df.groupby("A")

def func(grp):
# we want a function such that func(frame) fails but func.apply(frame)
# works
if grp.ndim == 2:
# Ensure that fast_path fails
raise NotImplementedError("Don't cross the streams")
return grp * 2

# Check that the fastpath raises, see _transform_general
obj = gb._obj_with_exclusions
gen = gb.grouper.get_iterator(obj, axis=gb.axis)
fast_path, slow_path = gb._define_paths(func)
_, group = next(gen)

with pytest.raises(NotImplementedError, match="Don't cross the streams"):
fast_path(group)

result = gb.transform(func)

expected = pd.DataFrame([2, -2, 2, 4], columns=["B"])
tm.assert_frame_equal(result, expected)

0 comments on commit 00cfdf5

Please sign in to comment.