Skip to content

Commit

Permalink
address the comments
Browse files Browse the repository at this point in the history
Signed-off-by: Xingyu Long <[email protected]>
  • Loading branch information
xingyu-long committed Nov 8, 2024
1 parent e3f14e7 commit 81097ac
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
16 changes: 9 additions & 7 deletions python/ray/data/_internal/logical/rules/inherit_batch_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from ray.data._internal.logical.interfaces import LogicalOperator, LogicalPlan, Rule
from ray.data._internal.logical.operators.all_to_all_operator import Aggregate, Sort
from ray.data._internal.logical.operators.map_operator import MapBatches


class InheritBatchFormatRule(Rule):
Expand All @@ -25,16 +26,17 @@ def _apply(self, op: LogicalOperator):
current_op = nodes.pop()

if isinstance(current_op, (Sort, Aggregate)):
# traversal up the DAG until we find first operator with batch_format
# traversal up the DAG until we find MapBatches with batch_format
# or we reach to source op and do nothing
upstream_op = current_op.input_dependencies[0]
while (
upstream_op.input_dependencies
and getattr(upstream_op, "_batch_format", None) is None
):
while upstream_op.input_dependencies:
if (
isinstance(upstream_op, MapBatches)
and upstream_op._batch_format
):
current_op._batch_format = upstream_op._batch_format
break
upstream_op = upstream_op.input_dependencies[0]
if getattr(upstream_op, "_batch_format", None):
current_op._batch_format = upstream_op._batch_format

# just return the default op
return op
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def __init__(
self,
boundaries: List[T],
sort_key: SortKey,
batch_format: Optional[str],
batch_format: str,
):
super().__init__(
map_args=[boundaries, sort_key],
Expand Down

0 comments on commit 81097ac

Please sign in to comment.