diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 06fd3c1eae0067..31563e4bccbb76 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -394,35 +394,39 @@ def _aggregate_named(self, func, *args, **kwargs): def transform(self, func, *args, **kwargs): func = self._get_cython_func(func) or func - if isinstance(func, str): - if not (func in base.transform_kernel_whitelist): - msg = "'{func}' is not a valid function name for transform(name)" - raise ValueError(msg.format(func=func)) - if func in base.cythonized_kernels: - # cythonized transform or canned "agg+broadcast" - return getattr(self, func)(*args, **kwargs) - else: - # If func is a reduction, we need to broadcast the - # result to the whole group. Compute func result - # and deal with possible broadcasting below. - return self._transform_fast( - lambda: getattr(self, func)(*args, **kwargs), func - ) + if not isinstance(func, str): + return self._transform_general(func, *args, **kwargs) + + elif func not in base.transform_kernel_whitelist: + msg = f"'{func}' is not a valid function name for transform(name)" + raise ValueError(msg) + elif func in base.cythonized_kernels: + # cythonized transform or canned "agg+broadcast" + return getattr(self, func)(*args, **kwargs) - # reg transform + # If func is a reduction, we need to broadcast the + # result to the whole group. Compute func result + # and deal with possible broadcasting below. + result = getattr(self, func)(*args, **kwargs) + return self._transform_fast(result, func) + + def _transform_general(self, func, *args, **kwargs): + """ + Transform with a non-str `func`. + """ klass = self._selected_obj.__class__ + results = [] - wrapper = lambda x: func(x, *args, **kwargs) for name, group in self: object.__setattr__(group, "name", name) - res = wrapper(group) + res = func(group, *args, **kwargs) if isinstance(res, (ABCDataFrame, ABCSeries)): res = res._values indexer = self._get_index(name) - s = klass(res, indexer) - results.append(s) + ser = klass(res, indexer) + results.append(ser) # check for empty "results" to avoid concat ValueError if results: @@ -433,7 +437,7 @@ def transform(self, func, *args, **kwargs): result = Series() # we will only try to coerce the result type if - # we have a numeric dtype, as these are *always* udfs + # we have a numeric dtype, as these are *always* user-defined funcs # the cython take a different path (and casting) dtype = self._selected_obj.dtype if is_numeric_dtype(dtype): @@ -443,17 +447,14 @@ def transform(self, func, *args, **kwargs): result.index = self._selected_obj.index return result - def _transform_fast(self, func, func_nm) -> Series: + def _transform_fast(self, result, func_nm: str) -> Series: """ fast version of transform, only applicable to builtin/cythonizable functions """ - if isinstance(func, str): - func = getattr(self, func) - ids, _, ngroup = self.grouper.group_info cast = self._transform_should_cast(func_nm) - out = algorithms.take_1d(func()._values, ids) + out = algorithms.take_1d(result._values, ids) if cast: out = self._try_cast(out, self.obj) return Series(out, index=self.obj.index, name=self.obj.name) @@ -1333,21 +1334,21 @@ def transform(self, func, *args, **kwargs): # optimized transforms func = self._get_cython_func(func) or func - if isinstance(func, str): - if not (func in base.transform_kernel_whitelist): - msg = "'{func}' is not a valid function name for transform(name)" - raise ValueError(msg.format(func=func)) - if func in base.cythonized_kernels: - # cythonized transformation or canned "reduction+broadcast" - return getattr(self, func)(*args, **kwargs) - else: - # If func is a reduction, we need to broadcast the - # result to the whole group. Compute func result - # and deal with possible broadcasting below. - result = getattr(self, func)(*args, **kwargs) - else: + if not isinstance(func, str): return self._transform_general(func, *args, **kwargs) + elif func not in base.transform_kernel_whitelist: + msg = f"'{func}' is not a valid function name for transform(name)" + raise ValueError(msg) + elif func in base.cythonized_kernels: + # cythonized transformation or canned "reduction+broadcast" + return getattr(self, func)(*args, **kwargs) + + # If func is a reduction, we need to broadcast the + # result to the whole group. Compute func result + # and deal with possible broadcasting below. + result = getattr(self, func)(*args, **kwargs) + # a reduction transform if not isinstance(result, DataFrame): return self._transform_general(func, *args, **kwargs) @@ -1358,9 +1359,9 @@ def transform(self, func, *args, **kwargs): if not result.columns.equals(obj.columns): return self._transform_general(func, *args, **kwargs) - return self._transform_fast(result, obj, func) + return self._transform_fast(result, func) - def _transform_fast(self, result: DataFrame, obj: DataFrame, func_nm) -> DataFrame: + def _transform_fast(self, result: DataFrame, func_nm: str) -> DataFrame: """ Fast transform path for aggregations """ @@ -1368,6 +1369,8 @@ def _transform_fast(self, result: DataFrame, obj: DataFrame, func_nm) -> DataFra # try casting data to original dtype cast = self._transform_should_cast(func_nm) + obj = self._obj_with_exclusions + # for each col, reshape to to size of original frame # by take operation ids, _, ngroup = self.grouper.group_info