Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Data] support batch_format for Sort and Aggregate #48287

Merged
merged 8 commits into from
Nov 13, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def __init__(
self,
input_op: LogicalOperator,
sort_key: SortKey,
batch_format: Optional[str] = "default",
):
super().__init__(
"Sort",
Expand All @@ -131,6 +132,7 @@ def __init__(
],
)
self._sort_key = sort_key
self._batch_format = batch_format

def aggregate_output_metadata(self) -> BlockMetadata:
assert len(self._input_dependencies) == 1, len(self._input_dependencies)
Expand All @@ -145,6 +147,7 @@ def __init__(
input_op: LogicalOperator,
key: Optional[str],
aggs: List[AggregateFn],
batch_format: Optional[str] = "default",
):
super().__init__(
"Aggregate",
Expand All @@ -157,3 +160,4 @@ def __init__(
)
self._key = key
self._aggs = aggs
self._batch_format = batch_format
2 changes: 2 additions & 0 deletions python/ray/data/_internal/logical/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
PhysicalPlan,
Rule,
)
from ray.data._internal.logical.rules.inherit_batch_format import InheritBatchFormatRule
from ray.data._internal.logical.rules.inherit_target_max_block_size import (
InheritTargetMaxBlockSizeRule,
)
Expand All @@ -20,6 +21,7 @@

_LOGICAL_RULES = [
ReorderRandomizeBlocksRule,
InheritBatchFormatRule,
]

_PHYSICAL_RULES = [
Expand Down
42 changes: 42 additions & 0 deletions python/ray/data/_internal/logical/rules/inherit_batch_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from collections import deque
from typing import Iterable

from ray.data._internal.logical.interfaces import LogicalOperator, LogicalPlan, Rule
from ray.data._internal.logical.operators.all_to_all_operator import AbstractAllToAll
from ray.data._internal.logical.operators.map_operator import MapBatches


class InheritBatchFormatRule(Rule):
"""For AbstractAllToAll based operator, apply this rule
to inherit batch_format from upstream operator by traversing
the entire DAG."""

def apply(self, plan: LogicalPlan) -> LogicalPlan:
optimized_dag: LogicalOperator = self._apply(plan.dag)
new_plan = LogicalPlan(dag=optimized_dag, context=plan.context)
return new_plan

def _apply(self, op: LogicalOperator):
# Post-order traversal.
nodes: Iterable[LogicalOperator] = deque()
for node in op.post_order_iter():
nodes.appendleft(node)

while len(nodes) > 0:
Comment on lines +22 to +25
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's combine these 2 loops

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking through the codebase, I was trying to use same coding style for this matter. it seems not worth it to combine? Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't just blindly follow the patterns, these need to make sense, right?

What's the motivation for first collecting the nodes into a queue, instead of traversing the iterator directly?

current_op = nodes.pop()

if isinstance(current_op, AbstractAllToAll):
# traversal up the DAG until we find MapBatches with batch_format
# or we reach to source op and do nothing
upstream_op = current_op.input_dependencies[0]
while upstream_op.input_dependencies:
if (
isinstance(upstream_op, MapBatches)
and upstream_op._batch_format
):
current_op._batch_format = upstream_op._batch_format
break
upstream_op = upstream_op.input_dependencies[0]

# just return the default op
return op
2 changes: 2 additions & 0 deletions python/ray/data/_internal/planner/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
def generate_aggregate_fn(
key: Optional[str],
aggs: List[AggregateFn],
batch_format: str,
_debug_limit_shuffle_execution_to_num_blocks: Optional[int] = None,
) -> AllToAllTransformFn:
"""Generate function to aggregate blocks by the specified key column or key
Expand Down Expand Up @@ -67,6 +68,7 @@ def fn(
boundaries=boundaries,
key=key,
aggs=aggs,
batch_format=batch_format,
)
if DataContext.get_current().use_push_based_shuffle:
scheduler = PushBasedShuffleTaskScheduler(agg_spec)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@ def __init__(
boundaries: List[KeyType],
key: Optional[str],
aggs: List[AggregateFn],
batch_format: str,
):
super().__init__(
map_args=[boundaries, key, aggs],
reduce_args=[key, aggs],
reduce_args=[key, aggs, batch_format],
)

@staticmethod
Expand Down Expand Up @@ -62,11 +63,15 @@ def map(
def reduce(
key: Optional[str],
aggs: List[AggregateFn],
batch_format: str,
*mapper_outputs: List[Block],
partial_reduce: bool = False,
) -> Tuple[Block, BlockMetadata]:
return BlockAccessor.for_block(mapper_outputs[0]).aggregate_combined_blocks(
list(mapper_outputs), key, aggs, finalize=not partial_reduce
normalized_blocks = TableBlockAccessor.normalize_block_types(
mapper_outputs, normalize_type=batch_format
)
return BlockAccessor.for_block(normalized_blocks[0]).aggregate_combined_blocks(
list(normalized_blocks), key, aggs, finalize=not partial_reduce
)

@staticmethod
Expand Down
12 changes: 9 additions & 3 deletions python/ray/data/_internal/planner/exchange/sort_task_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ray.data._internal.planner.exchange.interfaces import ExchangeTaskSpec
from ray.data._internal.progress_bar import ProgressBar
from ray.data._internal.remote_fn import cached_remote_fn
from ray.data._internal.table_block import TableBlockAccessor
from ray.data.block import Block, BlockAccessor, BlockExecStats, BlockMetadata
from ray.types import ObjectRef

Expand Down Expand Up @@ -116,10 +117,11 @@ def __init__(
self,
boundaries: List[T],
sort_key: SortKey,
batch_format: str,
):
super().__init__(
map_args=[boundaries, sort_key],
reduce_args=[sort_key],
reduce_args=[sort_key, batch_format],
)

@staticmethod
Expand All @@ -138,11 +140,15 @@ def map(
@staticmethod
def reduce(
sort_key: SortKey,
batch_format: str,
*mapper_outputs: List[Block],
partial_reduce: bool = False,
) -> Tuple[Block, BlockMetadata]:
return BlockAccessor.for_block(mapper_outputs[0]).merge_sorted_blocks(
mapper_outputs, sort_key
normalized_blocks = TableBlockAccessor.normalize_block_types(
mapper_outputs, normalize_type=batch_format
)
return BlockAccessor.for_block(normalized_blocks[0]).merge_sorted_blocks(
normalized_blocks, sort_key
)

@staticmethod
Expand Down
9 changes: 7 additions & 2 deletions python/ray/data/_internal/planner/plan_all_to_all_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ def plan_all_to_all_op(
"debug_limit_shuffle_execution_to_num_blocks", None
)
)
fn = generate_sort_fn(op._sort_key, debug_limit_shuffle_execution_to_num_blocks)
fn = generate_sort_fn(
op._sort_key, op._batch_format, debug_limit_shuffle_execution_to_num_blocks
)
target_max_block_size = DataContext.get_current().target_shuffle_max_block_size
elif isinstance(op, Aggregate):
debug_limit_shuffle_execution_to_num_blocks = (
Expand All @@ -80,7 +82,10 @@ def plan_all_to_all_op(
)
)
fn = generate_aggregate_fn(
op._key, op._aggs, debug_limit_shuffle_execution_to_num_blocks
op._key,
op._aggs,
op._batch_format,
debug_limit_shuffle_execution_to_num_blocks,
)
target_max_block_size = DataContext.get_current().target_shuffle_max_block_size
else:
Expand Down
5 changes: 4 additions & 1 deletion python/ray/data/_internal/planner/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

def generate_sort_fn(
sort_key: SortKey,
batch_format: str,
_debug_limit_shuffle_execution_to_num_blocks: Optional[int] = None,
) -> AllToAllTransformFn:
"""Generate function to sort blocks by the specified key column or key function."""
Expand Down Expand Up @@ -56,7 +57,9 @@ def fn(
_, ascending = sort_key.to_pandas_sort_args()
if not ascending:
boundaries.reverse()
sort_spec = SortTaskSpec(boundaries=boundaries, sort_key=sort_key)
sort_spec = SortTaskSpec(
boundaries=boundaries, sort_key=sort_key, batch_format=batch_format
)

if DataContext.get_current().use_push_based_shuffle:
scheduler = PushBasedShuffleTaskScheduler(sort_spec)
Expand Down
69 changes: 69 additions & 0 deletions python/ray/data/tests/test_execution_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1172,6 +1172,75 @@ def test_sort_validate_keys(ray_start_regular_shared):
ds_named.sort(invalid_col_name).take_all()


def test_inherit_batch_format_rule():
from ray.data._internal.logical.rules.inherit_batch_format import (
InheritBatchFormatRule,
)

ctx = DataContext.get_current()

operator1 = get_parquet_read_logical_op()
operator2 = MapBatches(operator1, fn=lambda g: g, batch_format="pandas")
sort_key = SortKey("number", descending=True)
operator3 = Sort(operator2, sort_key)
original_plan = LogicalPlan(dag=operator3, context=ctx)

rule = InheritBatchFormatRule()
optimized_plan = rule.apply(original_plan)
assert optimized_plan.dag._batch_format == "pandas"


def test_batch_format_on_sort(ray_start_regular_shared):
"""Checks that the Sort op can inherit batch_format from upstream ops correctly."""
ds = ray.data.from_items(
[
{"col1": 1, "col2": 2},
{"col1": 1, "col2": 4},
{"col1": 5, "col2": 6},
{"col1": 7, "col2": 8},
]
)
df_expected = pd.DataFrame(
{
"col1": [7, 5, 1, 1],
"col2": [8, 6, 4, 2],
}
)
df_actual = (
ds.groupby("col1")
.map_groups(lambda g: g, batch_format="pandas")
.sort("col2", descending=True)
.to_pandas()
)
pd.testing.assert_frame_equal(df_actual, df_expected)


def test_batch_format_on_aggregate(ray_start_regular_shared):
"""Checks that the Aggregate op can inherit batch_format
from upstream ops correctly."""
from ray.data.aggregate import AggregateFn

ds = ray.data.from_items(
[
{"col1": 1, "col2": 2},
{"col1": 1, "col2": 4},
{"col1": 5, "col2": 6},
{"col1": 7, "col2": 8},
]
)
aggregation = AggregateFn(
init=lambda column: 1,
accumulate_row=lambda a, row: a * row["col2"],
merge=lambda a1, a2: a1 * a2,
name="prod",
)
assert (
ds.groupby("col1")
.map_groups(lambda g: g, batch_format="pandas")
.aggregate(aggregation)
) == {"prod": 384}


def test_aggregate_operator(ray_start_regular_shared):
ctx = DataContext.get_current()

Expand Down
Loading