Skip to content

Commit

Permalink
refactor sort
Browse files Browse the repository at this point in the history
  • Loading branch information
richardliaw committed Nov 12, 2024
1 parent 55cb9c2 commit 47c0589
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 122 deletions.
76 changes: 14 additions & 62 deletions python/ray/data/_internal/arrow_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,82 +615,34 @@ def aggregate_combined_blocks(
the ith given aggregation.
If key is None then the k column is omitted.
"""

stats = BlockExecStats.builder()
keys = sort_key.get_columns()

def key_fn(r):
if sort_key is not None:
return tuple(r[k] for k in keys if k in r)
else:
return (0,)

# Handle blocks of different types.
blocks = TableBlockAccessor.normalize_block_types(blocks, "arrow")
combined = pyarrow.concat_tables(blocks, promote_options="default")

sort_args = sort_key.to_arrow_sort_args()
concat_and_sort = get_concat_and_sort_transform(DataContext.get_current())
combined = concat_and_sort(blocks, sort_key)

# if any of the sort_args are not in the combined schema, remove
# them from the sort_args
sort_args = [arg for arg in sort_args if arg[0] in combined.schema.names]
builder = ArrowBlockBuilder()

if not sort_args:
if combined.num_rows == 0:
return combined, ArrowBlockAccessor(combined).get_metadata(exec_stats=stats.build())

combined = combined.sort_by(sort_args)
iter = ArrowBlockAccessor(combined).iter_rows(public_row_format=False)

next_row = None
builder = ArrowBlockBuilder()

resolved_agg_names = TableBlockAccessor._resolve_aggregate_name_conflicts(
[agg.name for agg in aggs])

while True:
try:
if next_row is None:
next_row = next(iter)
next_keys = key_fn(next_row)
next_key_names = keys

def gen():
nonlocal iter
nonlocal next_row
while key_fn(next_row) == next_keys:
yield next_row
try:
next_row = next(iter)
except StopIteration:
next_row = None
break

# Merge.
accumulators = dict(
zip(resolved_agg_names, (agg.init(next_keys) for agg in aggs)))
keys = sort_key.get_columns()

for r in gen():
for name, accumulated in accumulators.items():
accumulators[name] = aggs[i].merge(accumulated, r[name])
def key_fn(r):
if keys:
return tuple(r[k] for k in keys if k in r)
else:
return (0,)

# Build the row.
row = {}
if keys:
for next_key, next_key_name in zip(next_keys, next_key_names):
row[next_key_name] = next_key

for agg_name in resolved_agg_names:
if finalize:
row[agg_name] = agg.finalize(accumulators[agg_name])
else:
row[agg_name] = accumulators[agg_name]
ret = TableBlockAccessor._aggregate_combined_blocks(
iter=iter, builder=builder, keys=keys, key_fn=key_fn, aggs=aggs, finalize=finalize
)
return ret, ArrowBlockAccessor(ret).get_metadata(exec_stats=stats.build())

builder.add(row)
except StopIteration:
break

ret = builder.build()
return ret, ArrowBlockAccessor(ret).get_metadata(exec_stats=stats.build())

def block_type(self) -> BlockType:
return BlockType.ARROW
Expand Down
7 changes: 6 additions & 1 deletion python/ray/data/_internal/arrow_ops/transform_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,12 @@ def concat_and_sort(
import pyarrow.compute as pac

ret = concat(blocks)
indices = pac.sort_indices(ret, sort_keys=sort_key.to_arrow_sort_args())
sort_args = sort_key.to_arrow_sort_args()
sort_args = [arg for arg in sort_args if arg[0] in ret.schema.names]
if not sort_args:
return pyarrow.Table.from_pydict({})

indices = pac.sort_indices(ret, sort_keys=sort_args)
return take_table(ret, indices)


Expand Down
54 changes: 5 additions & 49 deletions python/ray/data/_internal/pandas_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,64 +544,20 @@ def key_fn(r):
# Handle blocks of different types.
blocks = TableBlockAccessor.normalize_block_types(blocks, "pandas")

resolved_agg_names = TableBlockAccessor._resolve_aggregate_name_conflicts(
[agg.name for agg in aggs])

iter = heapq.merge(
*[
PandasBlockAccessor(block).iter_rows(public_row_format=False)
for block in blocks
],
key=key_fn,
)
next_row = None

builder = PandasBlockBuilder()
while True:
try:
if next_row is None:
next_row = next(iter)
next_keys = key_fn(next_row)
next_key_names = keys

def gen():
nonlocal iter
nonlocal next_row
while key_fn(next_row) == next_keys:
yield next_row
try:
next_row = next(iter)
except StopIteration:
next_row = None
break

# Merge.
accumulators = dict(
zip(resolved_agg_names, (agg.init(next_keys) for agg in aggs)))

for r in gen():
for name, accumulated in accumulators.items():
accumulators[name] = aggs[i].merge(accumulated, r[name])

# Build the row.
row = {}
if keys:
for next_key, next_key_name in zip(next_keys, next_key_names):
row[next_key_name] = next_key

for agg, agg_name, accumulator in zip(
aggs, resolved_agg_names, accumulators
):
if finalize:
row[agg_name] = agg.finalize(accumulator)
else:
row[agg_name] = accumulator

builder.add(row)
except StopIteration:
break

ret = builder.build()
return ret, PandasBlockAccessor(ret).get_metadata(exec_stats=stats.build())
ret = TableBlockAccessor._aggregate_combined_blocks(
iter=iter, builder=builder, keys=keys, key_fn=key_fn, aggs=aggs, finalize=finalize
)
return ret, PandasBlockBuilder(ret).get_metadata(exec_stats=stats.build())

def block_type(self) -> BlockType:
return BlockType.PANDAS
82 changes: 72 additions & 10 deletions python/ray/data/_internal/table_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterator,
List,
Mapping,
Optional,
Tuple,
TypeVar,
Union,
)
Expand All @@ -22,6 +24,7 @@

if TYPE_CHECKING:
from ray.data._internal.planner.exchange.sort_task_spec import SortKey
from ray.data.aggregate import AggregateFn


T = TypeVar("T")
Expand Down Expand Up @@ -243,16 +246,75 @@ def zip(self, other: "Block") -> "Block":
return self._zip(acc)

@staticmethod
def _resolve_aggregate_name_conflicts(aggregate_names: List[str]) -> List[str]:
count = collections.defaultdict(int)
resolved_agg_names = []
for name in aggregate_names:
# Check for conflicts with existing aggregation name.
if name in count:
name = f"{name}_{count+1}"
count[name] += 1
resolved_agg_names.append(name)
return resolved_agg_names
def _aggregate_combined_blocks(
iter: Iterator[T],
builder: TableBlockBuilder,
keys: List[str],
key_fn: Callable[[T], Tuple],
aggs: Tuple["AggregateFn"],
finalize: bool,
) -> Block:

def _resolve_aggregate_name_conflicts(aggregate_names: List[str]) -> List[str]:
count = collections.defaultdict(int)
resolved_agg_names = []
for name in aggregate_names:
# Check for conflicts with existing aggregation name.
if name in count:
name = f"{name}_{count+1}"
count[name] += 1
resolved_agg_names.append(name)
return resolved_agg_names

resolved_agg_names = _resolve_aggregate_name_conflicts(
[agg.name for agg in aggs])

next_row = None
named_aggs = dict(zip(resolved_agg_names, aggs))
while True:
try:
if next_row is None:
next_row = next(iter)
next_keys = key_fn(next_row)
next_key_names = keys

def gen():
nonlocal iter
nonlocal next_row
while key_fn(next_row) == next_keys:
yield next_row
try:
next_row = next(iter)
except StopIteration:
next_row = None
break

# Merge.
accumulators = dict(
zip(resolved_agg_names, (agg.init(next_keys) for agg in aggs)))

for r in gen():
for name, accumulated in accumulators.items():
accumulators[name] = named_aggs[name].merge(accumulated, r[name])

# Build the row.
row = {}
if keys:
for next_key, next_key_name in zip(next_keys, next_key_names):
row[next_key_name] = next_key

for agg_name in resolved_agg_names:
if finalize:
row[agg_name] = named_aggs[agg_name].finalize(accumulators[agg_name])
else:
row[agg_name] = accumulators[agg_name]

builder.add(row)
except StopIteration:
break

return builder.build()


@staticmethod
def _empty_table() -> Any:
Expand Down

0 comments on commit 47c0589

Please sign in to comment.