diff --git a/java_based_implementation/api_impl.py b/java_based_implementation/api_impl.py index 9a06156..eeccca3 100644 --- a/java_based_implementation/api_impl.py +++ b/java_based_implementation/api_impl.py @@ -21,8 +21,8 @@ from paimon_python_api import (catalog, table, read_builder, table_scan, split, table_read, write_builder, table_write, commit_message, table_commit) from pyarrow import (RecordBatch, BufferOutputStream, RecordBatchStreamWriter, - RecordBatchStreamReader, BufferReader) -from typing import List + RecordBatchStreamReader, BufferReader, RecordBatchReader) +from typing import List, Iterator class Catalog(catalog.Catalog): @@ -119,42 +119,27 @@ def __init__(self, j_table_read, j_row_type): def create_reader(self, split: Split): self._j_bytes_reader.setSplit(split.to_j_split()) - return BatchReader(self._j_bytes_reader) + batches = [] + schema = None + for arrow_bytes in self._bytes_generator(): + stream_reader = RecordBatchStreamReader(BufferReader(arrow_bytes)) + if schema is None: + schema = stream_reader.schema + batches.extend(batch for batch in stream_reader) + return RecordBatchReader.from_batches(schema, batches) + + def _bytes_generator(self) -> Iterator[bytes]: + while True: + next_bytes = self._j_bytes_reader.next() + if next_bytes is None: + break + else: + yield next_bytes def close(self): self._j_bytes_reader.close() -class BatchReader(table_read.BatchReader): - - def __init__(self, j_bytes_reader): - self._j_bytes_reader = j_bytes_reader - self._inited = False - self._has_next = True - self._next_arrow_reader() - - def next_batch(self): - if not self._has_next: - return None - - try: - return self._current_arrow_reader.read_next_batch() - except StopIteration: - self._current_arrow_reader.close() - self._next_arrow_reader() - if not self._has_next: - return None - else: - return self._current_arrow_reader.read_next_batch() - - def _next_arrow_reader(self): - byte_array = self._j_bytes_reader.next() - if byte_array is None: - self._has_next = False - else: - self._current_arrow_reader = RecordBatchStreamReader(BufferReader(byte_array)) - - class BatchWriteBuilder(write_builder.BatchWriteBuilder): def __init__(self, j_batch_write_builder, j_row_type): diff --git a/java_based_implementation/tests/test_write_and_read.py b/java_based_implementation/tests/test_write_and_read.py index 9163482..acfdba0 100644 --- a/java_based_implementation/tests/test_write_and_read.py +++ b/java_based_implementation/tests/test_write_and_read.py @@ -66,19 +66,14 @@ def testWriteReadAppendTable(self): read_builder = table.new_read_builder() table_scan = read_builder.new_scan() table_read = read_builder.new_read() - splits = table_scan.plan().splits() - batches = [] - for split in splits: - batch_reader = table_read.create_reader(split) - while True: - batch = batch_reader.next_batch() - if batch is None: - break - else: - batches.append(batch.to_pandas()) - result = pd.concat(batches) + data_frames = [ + batch.to_pandas() + for split in splits + for batch in table_read.create_reader(split) + ] + result = pd.concat(data_frames) # check data pd.testing.assert_frame_equal(result, df) diff --git a/paimon_python_api/table_read.py b/paimon_python_api/table_read.py index 4555dc0..bb6000d 100644 --- a/paimon_python_api/table_read.py +++ b/paimon_python_api/table_read.py @@ -17,26 +17,17 @@ ################################################################################# from abc import ABC, abstractmethod -from pyarrow import RecordBatch +from pyarrow import RecordBatchReader from paimon_python_api.split import Split -from typing import Union class TableRead(ABC): """To read data from data splits.""" @abstractmethod - def create_reader(self, split: Split) -> 'BatchReader': + def create_reader(self, split: Split) -> RecordBatchReader: """Return a reader containing batches of pyarrow format.""" @abstractmethod def close(self): """Close this resource.""" - - -class BatchReader(ABC): - """Reader to get RecordBatch.""" - - @abstractmethod - def next_batch(self) -> Union[RecordBatch, None]: - """Get next RecordBatch. Return NONE if there is no more data."""