Skip to content

Commit

Permalink
fixup slow transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
jreback committed Feb 27, 2017
1 parent cc43503 commit 2f48549
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 24 deletions.
41 changes: 17 additions & 24 deletions pandas/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -2890,39 +2890,32 @@ def transform(self, func, *args, **kwargs):
lambda: getattr(self, func)(*args, **kwargs))

# reg transform
dtype = self._selected_obj.dtype
result = self._selected_obj.values.copy()

klass = self._selected_obj.__class__
results = []
wrapper = lambda x: func(x, *args, **kwargs)
for i, (name, group) in enumerate(self):
for name, group in self:
object.__setattr__(group, 'name', name)
res = wrapper(group)

if hasattr(res, 'values'):
res = res.values

# may need to astype
try:
common_type = np.common_type(np.array(res), result)
if common_type != result.dtype:
result = result.astype(common_type)
except Exception as exc:
# date math can cause type of result to change
if i == 0 and (is_datetime64_dtype(result.dtype) or
is_timedelta64_dtype(result.dtype)):
try:
dtype = res.dtype
except Exception as exc:
dtype = type(res)
result = np.empty_like(result, dtype)

indexer = self._get_index(name)
result[indexer] = res
s = klass(res, indexer)
results.append(s)

result = _possibly_downcast_to_dtype(result, dtype)
return self._selected_obj.__class__(result,
index=self._selected_obj.index,
name=self._selected_obj.name)
from pandas.tools.concat import concat
result = concat(results).sort_index()

# we will only try to coerce the result type if
# we have a numeric dtype
dtype = self._selected_obj.dtype
if is_numeric_dtype(dtype):
result = _possibly_downcast_to_dtype(result, dtype)

result.name = self._selected_obj.name
result.index = self._selected_obj.index
return result

def _transform_fast(self, func):
"""
Expand Down
1 change: 1 addition & 0 deletions pandas/tests/groupby/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ def test_filter_against_workaround(self):
grouper = s.apply(lambda x: np.round(x, -1))
grouped = s.groupby(grouper)
f = lambda x: x.mean() > 10

old_way = s[grouped.transform(f).astype('bool')]
new_way = grouped.filter(f)
assert_series_equal(new_way.sort_values(), old_way.sort_values())
Expand Down

0 comments on commit 2f48549

Please sign in to comment.