Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

REF: align transform logic flow #29672

Merged
merged 1 commit into from
Nov 19, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 43 additions & 40 deletions pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,35 +394,39 @@ def _aggregate_named(self, func, *args, **kwargs):
def transform(self, func, *args, **kwargs):
func = self._get_cython_func(func) or func

if isinstance(func, str):
if not (func in base.transform_kernel_whitelist):
msg = "'{func}' is not a valid function name for transform(name)"
raise ValueError(msg.format(func=func))
if func in base.cythonized_kernels:
# cythonized transform or canned "agg+broadcast"
return getattr(self, func)(*args, **kwargs)
else:
# If func is a reduction, we need to broadcast the
# result to the whole group. Compute func result
# and deal with possible broadcasting below.
return self._transform_fast(
lambda: getattr(self, func)(*args, **kwargs), func
)
if not isinstance(func, str):
return self._transform_general(func, *args, **kwargs)

elif func not in base.transform_kernel_whitelist:
msg = f"'{func}' is not a valid function name for transform(name)"
raise ValueError(msg)
elif func in base.cythonized_kernels:
# cythonized transform or canned "agg+broadcast"
return getattr(self, func)(*args, **kwargs)

# reg transform
# If func is a reduction, we need to broadcast the
# result to the whole group. Compute func result
# and deal with possible broadcasting below.
result = getattr(self, func)(*args, **kwargs)
return self._transform_fast(result, func)

def _transform_general(self, func, *args, **kwargs):
"""
Transform with a non-str `func`.
"""
klass = self._selected_obj.__class__

results = []
wrapper = lambda x: func(x, *args, **kwargs)
for name, group in self:
object.__setattr__(group, "name", name)
res = wrapper(group)
res = func(group, *args, **kwargs)

if isinstance(res, (ABCDataFrame, ABCSeries)):
res = res._values

indexer = self._get_index(name)
s = klass(res, indexer)
results.append(s)
ser = klass(res, indexer)
results.append(ser)

# check for empty "results" to avoid concat ValueError
if results:
Expand All @@ -433,7 +437,7 @@ def transform(self, func, *args, **kwargs):
result = Series()

# we will only try to coerce the result type if
# we have a numeric dtype, as these are *always* udfs
# we have a numeric dtype, as these are *always* user-defined funcs
# the cython take a different path (and casting)
dtype = self._selected_obj.dtype
if is_numeric_dtype(dtype):
Expand All @@ -443,17 +447,14 @@ def transform(self, func, *args, **kwargs):
result.index = self._selected_obj.index
return result

def _transform_fast(self, func, func_nm) -> Series:
def _transform_fast(self, result, func_nm: str) -> Series:
"""
fast version of transform, only applicable to
builtin/cythonizable functions
"""
if isinstance(func, str):
func = getattr(self, func)

ids, _, ngroup = self.grouper.group_info
cast = self._transform_should_cast(func_nm)
out = algorithms.take_1d(func()._values, ids)
out = algorithms.take_1d(result._values, ids)
if cast:
out = self._try_cast(out, self.obj)
return Series(out, index=self.obj.index, name=self.obj.name)
Expand Down Expand Up @@ -1340,21 +1341,21 @@ def transform(self, func, *args, **kwargs):
# optimized transforms
func = self._get_cython_func(func) or func

if isinstance(func, str):
if not (func in base.transform_kernel_whitelist):
msg = "'{func}' is not a valid function name for transform(name)"
raise ValueError(msg.format(func=func))
if func in base.cythonized_kernels:
# cythonized transformation or canned "reduction+broadcast"
return getattr(self, func)(*args, **kwargs)
else:
# If func is a reduction, we need to broadcast the
# result to the whole group. Compute func result
# and deal with possible broadcasting below.
result = getattr(self, func)(*args, **kwargs)
else:
if not isinstance(func, str):
return self._transform_general(func, *args, **kwargs)

elif func not in base.transform_kernel_whitelist:
Copy link
Contributor

@jreback jreback Nov 17, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be possible to consolidate this logic into the base class to keep this DRY?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

im optimistic about this

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you have a strong preference for doing this in this PR? im trying to keep the scope narrow to avoid conflicting with #29124.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no can do later

msg = f"'{func}' is not a valid function name for transform(name)"
raise ValueError(msg)
elif func in base.cythonized_kernels:
# cythonized transformation or canned "reduction+broadcast"
return getattr(self, func)(*args, **kwargs)

# If func is a reduction, we need to broadcast the
# result to the whole group. Compute func result
# and deal with possible broadcasting below.
result = getattr(self, func)(*args, **kwargs)

# a reduction transform
if not isinstance(result, DataFrame):
return self._transform_general(func, *args, **kwargs)
Expand All @@ -1365,16 +1366,18 @@ def transform(self, func, *args, **kwargs):
if not result.columns.equals(obj.columns):
return self._transform_general(func, *args, **kwargs)

return self._transform_fast(result, obj, func)
return self._transform_fast(result, func)

def _transform_fast(self, result: DataFrame, obj: DataFrame, func_nm) -> DataFrame:
def _transform_fast(self, result: DataFrame, func_nm: str) -> DataFrame:
"""
Fast transform path for aggregations
"""
# if there were groups with no observations (Categorical only?)
# try casting data to original dtype
cast = self._transform_should_cast(func_nm)

obj = self._obj_with_exclusions

# for each col, reshape to to size of original frame
# by take operation
ids, _, ngroup = self.grouper.group_info
Expand Down