Skip to content

Commit

Permalink
Refactor fast groups (#310)
Browse files Browse the repository at this point in the history
* refactor: clean up fast grouped architecture

* docs(DRAFT): start module level docs for fast grouped ops

* chore: clean up copying dispatchers for fast groups

* fix: use proper dispatch lookup in fast_filter
  • Loading branch information
machow authored Feb 18, 2021
1 parent 030eead commit 6a050ac
Show file tree
Hide file tree
Showing 7 changed files with 204 additions and 133 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ examples/%.ipynb:
jupytext --sync $@

docs/api_extra/%.rst: siuba/dply/%.py $(AUTODOC_SCRIPT)
python3 docs/generate_autodoc.py . $< > $@
python3 $(AUTODOC_SCRIPT) . $< > $@

docs-watch: $(AUTODOC_PAGES)
cd docs && sphinx-autobuild . ./_build/html
Expand Down
1 change: 1 addition & 0 deletions docs/developer/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ Developer docs
call_trees.Rmd
sql-translators.ipynb
pandas-group-ops.Rmd
api_pd_groups.rst

8 changes: 4 additions & 4 deletions siuba/dply/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pandas.core.frame import NDFrame
from pandas import Series

from siuba.experimental.pd_groups.groupby import GroupByAgg, _regroup
from siuba.experimental.pd_groups.groupby import GroupByAgg, regroup
from siuba.experimental.pd_groups.translate import method_agg_op

__ALL__ = [
Expand Down Expand Up @@ -224,7 +224,7 @@ def _row_number_grouped(g: GroupBy) -> GroupBy:
for g_key, inds in indices.items():
out[inds] = np.arange(1, len(inds) + 1, dtype = int)

return _regroup(out, g)
return regroup(g, out)


# ntile -----------------------------------------------------------------------
Expand Down Expand Up @@ -333,7 +333,7 @@ def lead(x, n = 1, default = None):
def _lead_grouped(x, n = 1, default = None):
res = x.shift(-1*n, fill_value = default)

return _regroup(res, x)
return regroup(x, res)


# lag -------------------------------------------------------------------------
Expand Down Expand Up @@ -374,7 +374,7 @@ def lag(x, n = 1, default = None):
def _lag_grouped(x, n = 1, default = None):
res = x.shift(n, fill_value = default)

return _regroup(res, x)
return regroup(x, res)

# n ---------------------------------------------------------------------------

Expand Down
78 changes: 28 additions & 50 deletions siuba/experimental/pd_groups/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from siuba.siu import CallTreeLocal, FunctionLookupError

from siuba.experimental.pd_groups.translate import SeriesGroupBy, GroupByAgg, GROUP_METHODS
from siuba.experimental.pd_groups.groupby import broadcast_agg, is_compatible


# TODO: make into CallTreeLocal factory function
Expand Down Expand Up @@ -44,7 +45,7 @@
# Fast group by verbs =========================================================

from siuba.siu import Call
from siuba.dply.verbs import mutate, filter, summarize, singledispatch2, DataFrameGroupBy, _regroup
from siuba.dply.verbs import mutate, filter, summarize, singledispatch2, DataFrameGroupBy
from pandas.core.dtypes.inference import is_scalar
import warnings

Expand All @@ -71,19 +72,18 @@ def grouped_eval(__data, expr, require_agg = False):
#
grouped_res = call(__data)

if isinstance(grouped_res, GroupByAgg):
# TODO: may want to validate its grouper
if isinstance(grouped_res, SeriesGroupBy):
if not is_compatible(grouped_res, __data):
raise ValueError("Incompatible groupers")

# TODO: may want to validate result is correct length / index?
# e.g. a SeriesGroupBy could be compatible and not an agg
if require_agg:
# need an agg, got an agg. we are done.
if not grouped_res._orig_grouper is __data.grouper:
raise ValueError("Incompatible groupers")
return grouped_res
return grouped_res.obj
else:
# broadcast from aggregate to original length (like transform)
return grouped_res._broadcast_agg_result()
elif isinstance(grouped_res, SeriesGroupBy) and not require_agg:
# TODO: may want to validate its grouper
return grouped_res.obj
return broadcast_agg(grouped_res)

else:
# can happen right now if user selects, e.g., a property of the
# groupby object, like .dtype, which returns a single value
Expand Down Expand Up @@ -113,7 +113,19 @@ def _transform_args(args):

return out

@singledispatch2(DataFrameGroupBy)
def _copy_dispatch(dispatcher, cls, func = None):
if func is None:
return lambda f: _copy_dispatch(dispatcher, cls, f)

# Note stripping symbolics may occur twice. Once in the original, and once
# in this dispatcher.
new_dispatch = singledispatch2(cls, func)
new_dispatch.register(object, dispatcher)

return new_dispatch


@_copy_dispatch(mutate, DataFrameGroupBy)
def fast_mutate(__data, **kwargs):
"""Warning: this function is experimental"""

Expand All @@ -136,18 +148,9 @@ def fast_mutate(__data, **kwargs):
return out.groupby(groupings)


@fast_mutate.register(object)
def _fast_mutate_default(__data, **kwargs):
# TODO: had to register object second, since singledispatch2 sets object dispatch
# to be a pipe (e.g. unknown types become a pipe by default)
# by default dispatch to regular mutate
f = mutate.registry[type(__data)]
return f(__data, **kwargs)


# Fast filter ----

@singledispatch2(DataFrameGroupBy)
@_copy_dispatch(filter, DataFrameGroupBy)
def fast_filter(__data, *args):
"""Warning: this function is experimental"""

Expand All @@ -165,7 +168,7 @@ def fast_filter(__data, *args):
res = grouped_eval(__data, expr)
out.append(res)

filter_df = filter.registry[__data.obj.__class__]
filter_df = filter.dispatch(__data.obj.__class__)

df_result = filter_df(__data.obj, *out)

Expand All @@ -174,18 +177,9 @@ def fast_filter(__data, *args):
return df_result.groupby(group_names)


@fast_filter.register(object)
def _fast_filter_default(__data, *args, **kwargs):
# TODO: had to register object second, since singledispatch2 sets object dispatch
# to be a pipe (e.g. unknown types become a pipe by default)
# by default dispatch to regular mutate
f = filter.registry[type(__data)]
return f(__data, *args, **kwargs)


# Fast summarize ----

@singledispatch2(DataFrameGroupBy)
@_copy_dispatch(summarize, DataFrameGroupBy)
def fast_summarize(__data, **kwargs):
"""Warning: this function is experimental"""

Expand All @@ -205,24 +199,8 @@ def fast_summarize(__data, **kwargs):
# special case: set scalars directly
res = grouped_eval(__data, expr, require_agg = True)

if isinstance(res, GroupByAgg):
# TODO: would be faster to check that res has matching grouper, since
# here it goes through the work of matching up indexes (which if
# the groupers match are identical)
out[name] = res.obj

# otherwise, assign like a scalar
else:
out[name] = res
out[name] = res

return out.reset_index(drop = True)


@fast_summarize.register(object)
def _fast_summarize_default(__data, **kwargs):
# TODO: had to register object second, since singledispatch2 sets object dispatch
# to be a pipe (e.g. unknown types become a pipe by default)
# by default dispatch to regular mutate
f = summarize.registry[type(__data)]
return f(__data, **kwargs)

Loading

0 comments on commit 6a050ac

Please sign in to comment.