diff --git a/python/ray/data/_internal/logical/rules/inherit_batch_format.py b/python/ray/data/_internal/logical/rules/inherit_batch_format.py index 264da93b6c41..38009a1225c9 100644 --- a/python/ray/data/_internal/logical/rules/inherit_batch_format.py +++ b/python/ray/data/_internal/logical/rules/inherit_batch_format.py @@ -3,6 +3,7 @@ from ray.data._internal.logical.interfaces import LogicalOperator, LogicalPlan, Rule from ray.data._internal.logical.operators.all_to_all_operator import Aggregate, Sort +from ray.data._internal.logical.operators.map_operator import MapBatches class InheritBatchFormatRule(Rule): @@ -25,16 +26,17 @@ def _apply(self, op: LogicalOperator): current_op = nodes.pop() if isinstance(current_op, (Sort, Aggregate)): - # traversal up the DAG until we find first operator with batch_format + # traversal up the DAG until we find MapBatches with batch_format # or we reach to source op and do nothing upstream_op = current_op.input_dependencies[0] - while ( - upstream_op.input_dependencies - and getattr(upstream_op, "_batch_format", None) is None - ): + while upstream_op.input_dependencies: + if ( + isinstance(upstream_op, MapBatches) + and upstream_op._batch_format + ): + current_op._batch_format = upstream_op._batch_format + break upstream_op = upstream_op.input_dependencies[0] - if getattr(upstream_op, "_batch_format", None): - current_op._batch_format = upstream_op._batch_format # just return the default op return op diff --git a/python/ray/data/_internal/planner/exchange/sort_task_spec.py b/python/ray/data/_internal/planner/exchange/sort_task_spec.py index 5bef0a9db3f7..299e8793774f 100644 --- a/python/ray/data/_internal/planner/exchange/sort_task_spec.py +++ b/python/ray/data/_internal/planner/exchange/sort_task_spec.py @@ -117,7 +117,7 @@ def __init__( self, boundaries: List[T], sort_key: SortKey, - batch_format: Optional[str], + batch_format: str, ): super().__init__( map_args=[boundaries, sort_key],