Skip to content

Commit

Permalink
fix reading
Browse files Browse the repository at this point in the history
  • Loading branch information
yuzelin committed Sep 3, 2024
1 parent c505c28 commit 4b9e44b
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 24 deletions.
27 changes: 8 additions & 19 deletions java_based_implementation/api_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
# limitations under the License.
################################################################################

import itertools

from java_based_implementation.java_gateway import get_gateway
from java_based_implementation.util.java_utils import to_j_catalog_context, check_batch_write
from paimon_python_api import (catalog, table, read_builder, table_scan, split, table_read,
Expand Down Expand Up @@ -123,15 +121,15 @@ def __init__(self, j_table_read, j_row_type):

def create_reader(self, split: Split):
self._j_bytes_reader.setSplit(split.to_j_split())
batch_iterator = self._batch_generator()
# to init arrow schema
try:
first_batch = next(batch_iterator)
except StopIteration:
return self._empty_batch_reader()
# get schema
if self._arrow_schema is None:
schema_bytes = self._j_bytes_reader.serializeSchema()
schema_reader = RecordBatchStreamReader(BufferReader(schema_bytes))
self._arrow_schema = schema_reader.schema
schema_reader.close()

batches = itertools.chain((b for b in [first_batch]), batch_iterator)
return RecordBatchReader.from_batches(self._arrow_schema, batches)
batch_iterator = self._batch_generator()
return RecordBatchReader.from_batches(self._arrow_schema, batch_iterator)

def _batch_generator(self) -> Iterator[RecordBatch]:
while True:
Expand All @@ -140,17 +138,8 @@ def _batch_generator(self) -> Iterator[RecordBatch]:
break
else:
stream_reader = RecordBatchStreamReader(BufferReader(next_bytes))
if self._arrow_schema is None:
self._arrow_schema = stream_reader.schema
yield from stream_reader

def _empty_batch_reader(self):
import pyarrow as pa
schema = pa.schema([])
empty_batch = pa.RecordBatch.from_arrays([], schema=schema)
empty_reader = pa.RecordBatchReader.from_batches(schema, [empty_batch])
return empty_reader


class BatchWriteBuilder(write_builder.BatchWriteBuilder):

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public class BytesReader {

public BytesReader(TableRead tableRead, RowType rowType) {
this.tableRead = tableRead;
this.arrowFormatWriter = new ArrowFormatWriter(rowType, DEFAULT_WRITE_BATCH_SIZE);
this.arrowFormatWriter = new ArrowFormatWriter(rowType, DEFAULT_WRITE_BATCH_SIZE, true);
}

public void setSplit(Split split) throws IOException {
Expand All @@ -56,6 +56,13 @@ public void setSplit(Split split) throws IOException {
nextRow();
}

public byte[] serializeSchema() {
VectorSchemaRoot vsr = arrowFormatWriter.getVectorSchemaRoot();
ByteArrayOutputStream out = new ByteArrayOutputStream();
ArrowUtils.serializeToIpc(vsr, out);
return out.toByteArray();
}

@Nullable
public byte[] next() throws Exception {
if (nextRow == null) {
Expand All @@ -68,6 +75,7 @@ public byte[] next() throws Exception {
rowCount++;
}

arrowFormatWriter.flush();
VectorSchemaRoot vsr = arrowFormatWriter.getVectorSchemaRoot();
vsr.setRowCount(rowCount);
ByteArrayOutputStream out = new ByteArrayOutputStream();
Expand Down
7 changes: 3 additions & 4 deletions java_based_implementation/tests/test_write_and_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,7 @@ def testReadEmptyPkTable(self):
for split in splits
for batch in table_read.create_reader(split)
]
result = pd.concat(data_frames)
self.assertEqual(result.shape, (0, 0))
self.assertEqual(len(data_frames), 0)

def testWriteReadAppendTable(self):
create_simple_table(self.warehouse, 'default', 'simple_append_table', False)
Expand Down Expand Up @@ -135,8 +134,8 @@ def testWriteReadAppendTable(self):
]
result = pd.concat(data_frames)

# check data
pd.testing.assert_frame_equal(result, df)
# check data (ignore index)
pd.testing.assert_frame_equal(result.reset_index(drop=True), df.reset_index(drop=True))

def testWriteWrongSchema(self):
create_simple_table(self.warehouse, 'default', 'test_wrong_schema', False)
Expand Down

0 comments on commit 4b9e44b

Please sign in to comment.