Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Datasets] [Out-of-Band Serialization: 1/3] Refactor LazyBlockList. #23821

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
ParquetDatasource,
BlockWritePathProvider,
DefaultBlockWritePathProvider,
ReadTask,
WriteResult,
)
from ray.data.datasource.file_based_datasource import (
Expand Down Expand Up @@ -988,26 +989,26 @@ def union(self, *other: List["Dataset[T]"]) -> "Dataset[T]":

start_time = time.perf_counter()
context = DatasetContext.get_current()
calls: List[Callable[[], ObjectRef[BlockPartition]]] = []
metadata: List[BlockPartitionMetadata] = []
tasks: List[ReadTask] = []
block_partitions: List[ObjectRef[BlockPartition]] = []
block_partitions_meta: List[ObjectRef[BlockPartitionMetadata]] = []

datasets = [self] + list(other)
for ds in datasets:
bl = ds._plan.execute()
if isinstance(bl, LazyBlockList):
calls.extend(bl._calls)
metadata.extend(bl._metadata)
tasks.extend(bl._tasks)
block_partitions.extend(bl._block_partitions)
block_partitions_meta.extend(bl._block_partitions_meta)
else:
calls.extend([None] * bl.initial_num_blocks())
metadata.extend(bl._metadata)
tasks.extend([ReadTask(lambda: None, meta) for meta in bl._metadata])
if context.block_splitting_enabled:
block_partitions.extend(
[ray.put([(b, m)]) for b, m in bl.get_blocks_with_metadata()]
)
else:
block_partitions.extend(bl.get_blocks())
block_partitions_meta.extend([ray.put(meta) for meta in bl._metadata])

epochs = [ds._get_epoch() for ds in datasets]
max_epoch = max(*epochs)
Expand All @@ -1028,7 +1029,8 @@ def union(self, *other: List["Dataset[T]"]) -> "Dataset[T]":
dataset_stats.time_total_s = time.perf_counter() - start_time
return Dataset(
ExecutionPlan(
LazyBlockList(calls, metadata, block_partitions), dataset_stats
LazyBlockList(tasks, block_partitions, block_partitions_meta),
dataset_stats,
),
max_epoch,
self._lazy,
Expand Down Expand Up @@ -2548,6 +2550,7 @@ def repeat(self, times: Optional[int] = None) -> "DatasetPipeline[T]":
# to enable fusion with downstream map stages.
ctx = DatasetContext.get_current()
if self._plan._is_read_stage() and ctx.optimize_fuse_read_stages:
self._plan._in_blocks.clear()
blocks, read_stage = self._plan._rewrite_read_stage()
outer_stats = DatasetStats(stages={}, parent=None)
else:
Expand Down Expand Up @@ -2666,6 +2669,7 @@ def window(
# to enable fusion with downstream map stages.
ctx = DatasetContext.get_current()
if self._plan._is_read_stage() and ctx.optimize_fuse_read_stages:
self._plan._in_blocks.clear()
blocks, read_stage = self._plan._rewrite_read_stage()
outer_stats = DatasetStats(stages={}, parent=None)
else:
Expand Down Expand Up @@ -2749,12 +2753,13 @@ def fully_executed(self) -> "Dataset[T]":
Returns:
A Dataset with all blocks fully materialized in memory.
"""
blocks = self.get_internal_block_refs()
bar = ProgressBar("Force reads", len(blocks))
bar.block_until_complete(blocks)
blocks, metadata = [], []
for b, m in self._plan.execute().get_blocks_with_metadata():
blocks.append(b)
metadata.append(m)
ds = Dataset(
ExecutionPlan(
BlockList(blocks, self._plan.execute().get_metadata()),
BlockList(blocks, metadata),
self._plan.stats(),
dataset_uuid=self._get_uuid(),
),
Expand Down
31 changes: 12 additions & 19 deletions python/ray/data/impl/block_list.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
import math
from typing import List, Iterator, Tuple, Any, Union, Optional, TYPE_CHECKING

if TYPE_CHECKING:
import pyarrow
from typing import List, Iterator, Tuple, Optional

import numpy as np

Expand All @@ -26,11 +23,7 @@ def __init__(self, blocks: List[ObjectRef[Block]], metadata: List[BlockMetadata]
self._num_blocks = len(self._blocks)
self._metadata: List[BlockMetadata] = metadata

def set_metadata(self, i: int, metadata: BlockMetadata) -> None:
"""Set the metadata for a given block."""
self._metadata[i] = metadata

def get_metadata(self) -> List[BlockMetadata]:
def get_metadata(self, fetch_if_missing: bool = False) -> List[BlockMetadata]:
"""Get the metadata for all blocks."""
return self._metadata.copy()

Expand Down Expand Up @@ -183,22 +176,22 @@ def executed_num_blocks(self) -> int:
"""
return len(self.get_blocks())

def ensure_schema_for_first_block(self) -> Optional[Union["pyarrow.Schema", type]]:
"""Ensure that the schema is set for the first block.
def ensure_metadata_for_first_block(self) -> BlockMetadata:
"""Ensure that the metadata is fetched and set for the first block.

Returns None if the block list is empty.
"""
get_schema = cached_remote_fn(_get_schema)
get_metadata = cached_remote_fn(_get_metadata)
try:
block = next(self.iter_blocks())
block, metadata = next(self.iter_blocks_with_metadata())
except (StopIteration, ValueError):
# Dataset is empty (no blocks) or was manually cleared.
return None
schema = ray.get(get_schema.remote(block))
# Set the schema.
self._metadata[0].schema = schema
return schema
input_files = metadata.input_files
metadata = ray.get(get_metadata.remote(block, input_files))
self._metadata[0] = metadata
return metadata


def _get_schema(block: Block) -> Any:
return BlockAccessor.for_block(block).schema()
def _get_metadata(block: Block, input_files=Optional[List[str]]) -> BlockMetadata:
return BlockAccessor.for_block(block).get_metadata(input_files=input_files)
Loading