Skip to content

Commit 5512bc4

Browse files
Backport PR #43172: BUG: Pass index data correctly in groupby.transform/agg w/ engine=numba (#43250)
Co-authored-by: Matthew Roeschke <[email protected]>
1 parent 5c2c116 commit 5512bc4

File tree

4 files changed

+34
-2
lines changed

4 files changed

+34
-2
lines changed

doc/source/whatsnew/v1.3.3.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ Fixed regressions
2525

2626
Bug fixes
2727
~~~~~~~~~
28-
-
28+
- Bug in :meth:`.DataFrameGroupBy.agg` and :meth:`.DataFrameGroupBy.transform` with ``engine="numba"`` where ``index`` data was not being correctly passed into ``func`` (:issue:`43133`)
2929
-
3030

3131
.. ---------------------------------------------------------------------------

pandas/core/groupby/groupby.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -1143,9 +1143,15 @@ def _numba_prep(self, func, data):
11431143
sorted_ids = algorithms.take_nd(ids, sorted_index, allow_fill=False)
11441144

11451145
sorted_data = data.take(sorted_index, axis=self.axis).to_numpy()
1146+
sorted_index_data = data.index.take(sorted_index).to_numpy()
11461147

11471148
starts, ends = lib.generate_slices(sorted_ids, ngroups)
1148-
return starts, ends, sorted_index, sorted_data
1149+
return (
1150+
starts,
1151+
ends,
1152+
sorted_index_data,
1153+
sorted_data,
1154+
)
11491155

11501156
@final
11511157
def _transform_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs):

pandas/tests/groupby/aggregate/test_numba.py

+14
Original file line numberDiff line numberDiff line change
@@ -173,3 +173,17 @@ def sum_last(values, index, n):
173173
result = grouped_x.agg(sum_last, 2, engine="numba")
174174
expected = Series([2.0] * 2, name="x", index=Index([0, 1], name="id"))
175175
tm.assert_series_equal(result, expected)
176+
177+
178+
@td.skip_if_no("numba", "0.46.0")
179+
def test_index_data_correctly_passed():
180+
# GH 43133
181+
def f(values, index):
182+
return np.mean(index)
183+
184+
df = DataFrame({"group": ["A", "A", "B"], "v": [4, 5, 6]}, index=[-1, -2, -3])
185+
result = df.groupby("group").aggregate(f, engine="numba")
186+
expected = DataFrame(
187+
[-1.5, -3.0], columns=["v"], index=Index(["A", "B"], name="group")
188+
)
189+
tm.assert_frame_equal(result, expected)

pandas/tests/groupby/transform/test_numba.py

+12
Original file line numberDiff line numberDiff line change
@@ -164,3 +164,15 @@ def sum_last(values, index, n):
164164
result = grouped_x.transform(sum_last, 2, engine="numba")
165165
expected = Series([2.0] * 4, name="x")
166166
tm.assert_series_equal(result, expected)
167+
168+
169+
@td.skip_if_no("numba", "0.46.0")
170+
def test_index_data_correctly_passed():
171+
# GH 43133
172+
def f(values, index):
173+
return index - 1
174+
175+
df = DataFrame({"group": ["A", "A", "B"], "v": [4, 5, 6]}, index=[-1, -2, -3])
176+
result = df.groupby("group").transform(f, engine="numba")
177+
expected = DataFrame([-4.0, -3.0, -2.0], columns=["v"], index=[-1, -2, -3])
178+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)