Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TST: add test case for user-defined function taking correct path in groupby transform #29631

Merged
merged 7 commits into from
Nov 20, 2019
6 changes: 3 additions & 3 deletions pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1380,7 +1380,7 @@ def _define_paths(self, func, *args, **kwargs):
)
return fast_path, slow_path

def _choose_path(self, fast_path, slow_path, group):
def _choose_path(self, fast_path: Callable, slow_path: Callable, group: DataFrame):
path = slow_path
res = slow_path(group)

Expand All @@ -1390,8 +1390,8 @@ def _choose_path(self, fast_path, slow_path, group):
except AssertionError:
raise
except Exception:
# Hard to know ex-ante what exceptions `fast_path` might raise
# TODO: no test cases get here
# GH#29631 For user-defined function, we cant predict what may be
# raised; see test_transform.test_transform_fastpath_raises
return path, res

# verify fast path does not change columns (and names), otherwise
Expand Down
30 changes: 30 additions & 0 deletions pandas/tests/groupby/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1073,3 +1073,33 @@ def test_transform_lambda_with_datetimetz():
name="time",
)
tm.assert_series_equal(result, expected)


def test_transform_fastpath_raises():
# GH#29631 case where fastpath defined in groupby.generic _choose_path
# raises, but slow_path does not

df = pd.DataFrame({"A": [1, 1, 2, 2], "B": [1, -1, 1, 2]})
gb = df.groupby("A")

def func(grp):
# we want a function such that func(frame) fails but func.apply(frame)
# works
if grp.ndim == 2:
# Ensure that fast_path fails
raise NotImplementedError("Don't cross the streams")
return grp * 2

# Check that the fastpath raises, see _transform_general
obj = gb._obj_with_exclusions
gen = gb.grouper.get_iterator(obj, axis=gb.axis)
fast_path, slow_path = gb._define_paths(func)
_, group = next(gen)

with pytest.raises(NotImplementedError, match="Don't cross the streams"):
fast_path(group)

result = gb.transform(func)

expected = pd.DataFrame([2, -2, 2, 4], columns=["B"])
tm.assert_frame_equal(result, expected)