From 4b9e44bdeb6ddabfadd522f9756afd51116be4a5 Mon Sep 17 00:00:00 2001 From: yuzelin Date: Wed, 28 Aug 2024 14:57:10 +0800 Subject: [PATCH] fix reading --- java_based_implementation/api_impl.py | 27 ++++++------------- .../org/apache/paimon/python/BytesReader.java | 10 ++++++- .../tests/test_write_and_read.py | 7 +++-- 3 files changed, 20 insertions(+), 24 deletions(-) diff --git a/java_based_implementation/api_impl.py b/java_based_implementation/api_impl.py index 6d53289..2605cc3 100644 --- a/java_based_implementation/api_impl.py +++ b/java_based_implementation/api_impl.py @@ -16,8 +16,6 @@ # limitations under the License. ################################################################################ -import itertools - from java_based_implementation.java_gateway import get_gateway from java_based_implementation.util.java_utils import to_j_catalog_context, check_batch_write from paimon_python_api import (catalog, table, read_builder, table_scan, split, table_read, @@ -123,15 +121,15 @@ def __init__(self, j_table_read, j_row_type): def create_reader(self, split: Split): self._j_bytes_reader.setSplit(split.to_j_split()) - batch_iterator = self._batch_generator() - # to init arrow schema - try: - first_batch = next(batch_iterator) - except StopIteration: - return self._empty_batch_reader() + # get schema + if self._arrow_schema is None: + schema_bytes = self._j_bytes_reader.serializeSchema() + schema_reader = RecordBatchStreamReader(BufferReader(schema_bytes)) + self._arrow_schema = schema_reader.schema + schema_reader.close() - batches = itertools.chain((b for b in [first_batch]), batch_iterator) - return RecordBatchReader.from_batches(self._arrow_schema, batches) + batch_iterator = self._batch_generator() + return RecordBatchReader.from_batches(self._arrow_schema, batch_iterator) def _batch_generator(self) -> Iterator[RecordBatch]: while True: @@ -140,17 +138,8 @@ def _batch_generator(self) -> Iterator[RecordBatch]: break else: stream_reader = RecordBatchStreamReader(BufferReader(next_bytes)) - if self._arrow_schema is None: - self._arrow_schema = stream_reader.schema yield from stream_reader - def _empty_batch_reader(self): - import pyarrow as pa - schema = pa.schema([]) - empty_batch = pa.RecordBatch.from_arrays([], schema=schema) - empty_reader = pa.RecordBatchReader.from_batches(schema, [empty_batch]) - return empty_reader - class BatchWriteBuilder(write_builder.BatchWriteBuilder): diff --git a/java_based_implementation/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/BytesReader.java b/java_based_implementation/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/BytesReader.java index 45be1d5..9272f98 100644 --- a/java_based_implementation/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/BytesReader.java +++ b/java_based_implementation/paimon-python-java-bridge/src/main/java/org/apache/paimon/python/BytesReader.java @@ -47,7 +47,7 @@ public class BytesReader { public BytesReader(TableRead tableRead, RowType rowType) { this.tableRead = tableRead; - this.arrowFormatWriter = new ArrowFormatWriter(rowType, DEFAULT_WRITE_BATCH_SIZE); + this.arrowFormatWriter = new ArrowFormatWriter(rowType, DEFAULT_WRITE_BATCH_SIZE, true); } public void setSplit(Split split) throws IOException { @@ -56,6 +56,13 @@ public void setSplit(Split split) throws IOException { nextRow(); } + public byte[] serializeSchema() { + VectorSchemaRoot vsr = arrowFormatWriter.getVectorSchemaRoot(); + ByteArrayOutputStream out = new ByteArrayOutputStream(); + ArrowUtils.serializeToIpc(vsr, out); + return out.toByteArray(); + } + @Nullable public byte[] next() throws Exception { if (nextRow == null) { @@ -68,6 +75,7 @@ public byte[] next() throws Exception { rowCount++; } + arrowFormatWriter.flush(); VectorSchemaRoot vsr = arrowFormatWriter.getVectorSchemaRoot(); vsr.setRowCount(rowCount); ByteArrayOutputStream out = new ByteArrayOutputStream(); diff --git a/java_based_implementation/tests/test_write_and_read.py b/java_based_implementation/tests/test_write_and_read.py index 907741d..02ead23 100644 --- a/java_based_implementation/tests/test_write_and_read.py +++ b/java_based_implementation/tests/test_write_and_read.py @@ -92,8 +92,7 @@ def testReadEmptyPkTable(self): for split in splits for batch in table_read.create_reader(split) ] - result = pd.concat(data_frames) - self.assertEqual(result.shape, (0, 0)) + self.assertEqual(len(data_frames), 0) def testWriteReadAppendTable(self): create_simple_table(self.warehouse, 'default', 'simple_append_table', False) @@ -135,8 +134,8 @@ def testWriteReadAppendTable(self): ] result = pd.concat(data_frames) - # check data - pd.testing.assert_frame_equal(result, df) + # check data (ignore index) + pd.testing.assert_frame_equal(result.reset_index(drop=True), df.reset_index(drop=True)) def testWriteWrongSchema(self): create_simple_table(self.warehouse, 'default', 'test_wrong_schema', False)