Skip to content

Commit

Permalink
REF: align transform logic flow (pandas-dev#29672)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored and proost committed Dec 19, 2019
1 parent e94f583 commit 580a6de
Showing 1 changed file with 43 additions and 40 deletions.
83 changes: 43 additions & 40 deletions pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,35 +394,39 @@ def _aggregate_named(self, func, *args, **kwargs):
def transform(self, func, *args, **kwargs):
func = self._get_cython_func(func) or func

if isinstance(func, str):
if not (func in base.transform_kernel_whitelist):
msg = "'{func}' is not a valid function name for transform(name)"
raise ValueError(msg.format(func=func))
if func in base.cythonized_kernels:
# cythonized transform or canned "agg+broadcast"
return getattr(self, func)(*args, **kwargs)
else:
# If func is a reduction, we need to broadcast the
# result to the whole group. Compute func result
# and deal with possible broadcasting below.
return self._transform_fast(
lambda: getattr(self, func)(*args, **kwargs), func
)
if not isinstance(func, str):
return self._transform_general(func, *args, **kwargs)

elif func not in base.transform_kernel_whitelist:
msg = f"'{func}' is not a valid function name for transform(name)"
raise ValueError(msg)
elif func in base.cythonized_kernels:
# cythonized transform or canned "agg+broadcast"
return getattr(self, func)(*args, **kwargs)

# reg transform
# If func is a reduction, we need to broadcast the
# result to the whole group. Compute func result
# and deal with possible broadcasting below.
result = getattr(self, func)(*args, **kwargs)
return self._transform_fast(result, func)

def _transform_general(self, func, *args, **kwargs):
"""
Transform with a non-str `func`.
"""
klass = self._selected_obj.__class__

results = []
wrapper = lambda x: func(x, *args, **kwargs)
for name, group in self:
object.__setattr__(group, "name", name)
res = wrapper(group)
res = func(group, *args, **kwargs)

if isinstance(res, (ABCDataFrame, ABCSeries)):
res = res._values

indexer = self._get_index(name)
s = klass(res, indexer)
results.append(s)
ser = klass(res, indexer)
results.append(ser)

# check for empty "results" to avoid concat ValueError
if results:
Expand All @@ -433,7 +437,7 @@ def transform(self, func, *args, **kwargs):
result = Series()

# we will only try to coerce the result type if
# we have a numeric dtype, as these are *always* udfs
# we have a numeric dtype, as these are *always* user-defined funcs
# the cython take a different path (and casting)
dtype = self._selected_obj.dtype
if is_numeric_dtype(dtype):
Expand All @@ -443,17 +447,14 @@ def transform(self, func, *args, **kwargs):
result.index = self._selected_obj.index
return result

def _transform_fast(self, func, func_nm) -> Series:
def _transform_fast(self, result, func_nm: str) -> Series:
"""
fast version of transform, only applicable to
builtin/cythonizable functions
"""
if isinstance(func, str):
func = getattr(self, func)

ids, _, ngroup = self.grouper.group_info
cast = self._transform_should_cast(func_nm)
out = algorithms.take_1d(func()._values, ids)
out = algorithms.take_1d(result._values, ids)
if cast:
out = self._try_cast(out, self.obj)
return Series(out, index=self.obj.index, name=self.obj.name)
Expand Down Expand Up @@ -1333,21 +1334,21 @@ def transform(self, func, *args, **kwargs):
# optimized transforms
func = self._get_cython_func(func) or func

if isinstance(func, str):
if not (func in base.transform_kernel_whitelist):
msg = "'{func}' is not a valid function name for transform(name)"
raise ValueError(msg.format(func=func))
if func in base.cythonized_kernels:
# cythonized transformation or canned "reduction+broadcast"
return getattr(self, func)(*args, **kwargs)
else:
# If func is a reduction, we need to broadcast the
# result to the whole group. Compute func result
# and deal with possible broadcasting below.
result = getattr(self, func)(*args, **kwargs)
else:
if not isinstance(func, str):
return self._transform_general(func, *args, **kwargs)

elif func not in base.transform_kernel_whitelist:
msg = f"'{func}' is not a valid function name for transform(name)"
raise ValueError(msg)
elif func in base.cythonized_kernels:
# cythonized transformation or canned "reduction+broadcast"
return getattr(self, func)(*args, **kwargs)

# If func is a reduction, we need to broadcast the
# result to the whole group. Compute func result
# and deal with possible broadcasting below.
result = getattr(self, func)(*args, **kwargs)

# a reduction transform
if not isinstance(result, DataFrame):
return self._transform_general(func, *args, **kwargs)
Expand All @@ -1358,16 +1359,18 @@ def transform(self, func, *args, **kwargs):
if not result.columns.equals(obj.columns):
return self._transform_general(func, *args, **kwargs)

return self._transform_fast(result, obj, func)
return self._transform_fast(result, func)

def _transform_fast(self, result: DataFrame, obj: DataFrame, func_nm) -> DataFrame:
def _transform_fast(self, result: DataFrame, func_nm: str) -> DataFrame:
"""
Fast transform path for aggregations
"""
# if there were groups with no observations (Categorical only?)
# try casting data to original dtype
cast = self._transform_should_cast(func_nm)

obj = self._obj_with_exclusions

# for each col, reshape to to size of original frame
# by take operation
ids, _, ngroup = self.grouper.group_info
Expand Down

0 comments on commit 580a6de

Please sign in to comment.