Skip to content

Commit

Permalink
Complete table read and write (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
yuzelin authored Aug 20, 2024
1 parent 974f8aa commit 73366fb
Show file tree
Hide file tree
Showing 12 changed files with 499 additions and 44 deletions.
4 changes: 4 additions & 0 deletions dev/dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,8 @@ setuptools>=18.0
wheel
py4j==0.10.9.7
pyarrow>=5.0.0
pandas>=1.3.0
numpy>=1.22.4
python-dateutil>=2.8.0,<3
pytz>=2018.3
pytest~=7.0
84 changes: 67 additions & 17 deletions java_based_implementation/api_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@
# limitations under the License.
################################################################################

import itertools

from java_based_implementation.java_gateway import get_gateway
from java_based_implementation.util.java_utils import to_j_catalog_context
from paimon_python_api import (catalog, table, read_builder, table_scan, split, table_read,
write_builder, table_write, commit_message, table_commit)
from pyarrow import RecordBatchReader, RecordBatch
from typing import List
from pyarrow import (RecordBatch, BufferOutputStream, RecordBatchStreamWriter,
RecordBatchStreamReader, BufferReader, RecordBatchReader)
from typing import List, Iterator


class Catalog(catalog.Catalog):
Expand Down Expand Up @@ -49,18 +52,19 @@ def __init__(self, j_table):
self._j_table = j_table

def new_read_builder(self) -> 'ReadBuilder':
j_read_builder = self._j_table.newReadBuilder()
return ReadBuilder(j_read_builder)
j_read_builder = get_gateway().jvm.InvocationUtil.getReadBuilder(self._j_table)
return ReadBuilder(j_read_builder, self._j_table.rowType())

def new_batch_write_builder(self) -> 'BatchWriteBuilder':
j_batch_write_builder = self._j_table.newBatchWriteBuilder()
return BatchWriteBuilder(j_batch_write_builder)
j_batch_write_builder = get_gateway().jvm.InvocationUtil.getBatchWriteBuilder(self._j_table)
return BatchWriteBuilder(j_batch_write_builder, self._j_table.rowType())


class ReadBuilder(read_builder.ReadBuilder):

def __init__(self, j_read_builder):
def __init__(self, j_read_builder, j_row_type):
self._j_read_builder = j_read_builder
self._j_row_type = j_row_type

def with_projection(self, projection: List[List[int]]) -> 'ReadBuilder':
self._j_read_builder.withProjection(projection)
Expand All @@ -75,8 +79,8 @@ def new_scan(self) -> 'TableScan':
return TableScan(j_table_scan)

def new_read(self) -> 'TableRead':
# TODO
pass
j_table_read = self._j_read_builder.newRead()
return TableRead(j_table_read, self._j_row_type)


class TableScan(table_scan.TableScan):
Expand Down Expand Up @@ -110,23 +114,56 @@ def to_j_split(self):

class TableRead(table_read.TableRead):

def create_reader(self, split: Split) -> RecordBatchReader:
# TODO
pass
def __init__(self, j_table_read, j_row_type):
self._j_table_read = j_table_read
self._j_bytes_reader = get_gateway().jvm.InvocationUtil.createBytesReader(
j_table_read, j_row_type)
self._arrow_schema = None

def create_reader(self, split: Split):
self._j_bytes_reader.setSplit(split.to_j_split())
batch_iterator = self._batch_generator()
# to init arrow schema
try:
first_batch = next(batch_iterator)
except StopIteration:
return self._empty_batch_reader()

batches = itertools.chain((b for b in [first_batch]), batch_iterator)
return RecordBatchReader.from_batches(self._arrow_schema, batches)

def _batch_generator(self) -> Iterator[RecordBatch]:
while True:
next_bytes = self._j_bytes_reader.next()
if next_bytes is None:
break
else:
stream_reader = RecordBatchStreamReader(BufferReader(next_bytes))
if self._arrow_schema is None:
self._arrow_schema = stream_reader.schema
yield from stream_reader

def _empty_batch_reader(self):
import pyarrow as pa
schema = pa.schema([])
empty_batch = pa.RecordBatch.from_arrays([], schema=schema)
empty_reader = pa.RecordBatchReader.from_batches(schema, [empty_batch])
return empty_reader


class BatchWriteBuilder(write_builder.BatchWriteBuilder):

def __init__(self, j_batch_write_builder):
def __init__(self, j_batch_write_builder, j_row_type):
self._j_batch_write_builder = j_batch_write_builder
self._j_row_type = j_row_type

def with_overwrite(self, static_partition: dict) -> 'BatchWriteBuilder':
self._j_batch_write_builder.withOverwrite(static_partition)
return self

def new_write(self) -> 'BatchTableWrite':
j_batch_table_write = self._j_batch_write_builder.newWrite()
return BatchTableWrite(j_batch_table_write)
return BatchTableWrite(j_batch_table_write, self._j_row_type)

def new_commit(self) -> 'BatchTableCommit':
j_batch_table_commit = self._j_batch_write_builder.newCommit()
Expand All @@ -135,17 +172,27 @@ def new_commit(self) -> 'BatchTableCommit':

class BatchTableWrite(table_write.BatchTableWrite):

def __init__(self, j_batch_table_write):
def __init__(self, j_batch_table_write, j_row_type):
self._j_batch_table_write = j_batch_table_write
self._j_bytes_writer = get_gateway().jvm.InvocationUtil.createBytesWriter(
j_batch_table_write, j_row_type)

def write(self, record_batch: RecordBatch):
# TODO
pass
stream = BufferOutputStream()
with RecordBatchStreamWriter(stream, record_batch.schema) as writer:
writer.write(record_batch)
writer.close()
arrow_bytes = stream.getvalue().to_pybytes()
self._j_bytes_writer.write(arrow_bytes)

def prepare_commit(self) -> List['CommitMessage']:
j_commit_messages = self._j_batch_table_write.prepareCommit()
return list(map(lambda cm: CommitMessage(cm), j_commit_messages))

def close(self):
self._j_batch_table_write.close()
self._j_bytes_writer.close()


class CommitMessage(commit_message.CommitMessage):

Expand All @@ -164,3 +211,6 @@ def __init__(self, j_batch_table_commit):
def commit(self, commit_messages: List[CommitMessage]):
j_commit_messages = list(map(lambda cm: cm.to_j_commit_message(), commit_messages))
self._j_batch_table_commit.commit(j_commit_messages)

def close(self):
self._j_batch_table_commit.close()
7 changes: 6 additions & 1 deletion java_based_implementation/java_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,14 @@ def launch_gateway():
return gateway


# TODO: import more
def import_paimon_view(gateway):
java_import(gateway.jvm, "org.apache.paimon.table.*")
java_import(gateway.jvm, "org.apache.paimon.options.Options")
java_import(gateway.jvm, "org.apache.paimon.catalog.*")
java_import(gateway.jvm, "org.apache.paimon.schema.Schema*")
java_import(gateway.jvm, 'org.apache.paimon.types.*')
java_import(gateway.jvm, 'org.apache.paimon.python.InvocationUtil')
java_import(gateway.jvm, "org.apache.paimon.data.*")


class Watchdog(object):
Expand Down
46 changes: 46 additions & 0 deletions java_based_implementation/paimon-python-java-bridge/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@
<flink.shaded.hadoop.version>2.8.3-10.0</flink.shaded.hadoop.version>
<py4j.version>0.10.9.7</py4j.version>
<slf4j.version>1.7.32</slf4j.version>
<log4j.version>2.17.1</log4j.version>
<spotless.version>2.13.0</spotless.version>
<spotless.delimiter>package</spotless.delimiter>
<arrow.version>14.0.0</arrow.version>
</properties>

<dependencies>
Expand All @@ -47,18 +49,48 @@
<version>${paimon.version}</version>
</dependency>

<dependency>
<groupId>org.apache.paimon</groupId>
<artifactId>paimon-arrow</artifactId>
<version>${paimon.version}</version>
</dependency>

<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<version>${slf4j.version}</version>
</dependency>

<dependency>
<groupId>org.apache.logging.log4j</groupId>
<artifactId>log4j-1.2-api</artifactId>
<version>${log4j.version}</version>
</dependency>

<dependency>
<groupId>org.apache.flink</groupId>
<artifactId>flink-shaded-hadoop-2-uber</artifactId>
<version>${flink.shaded.hadoop.version}</version>
</dependency>

<dependency>
<groupId>org.apache.arrow</groupId>
<artifactId>arrow-vector</artifactId>
<version>${arrow.version}</version>
</dependency>

<dependency>
<groupId>org.apache.arrow</groupId>
<artifactId>arrow-memory-unsafe</artifactId>
<version>${arrow.version}</version>
</dependency>

<dependency>
<groupId>org.apache.arrow</groupId>
<artifactId>arrow-c-data</artifactId>
<version>${arrow.version}</version>
</dependency>

<!-- Python API dependencies -->

<dependency>
Expand Down Expand Up @@ -119,11 +151,25 @@
<artifactSet>
<includes combine.children="append">
<include>org.apache.paimon:paimon-bundle</include>
<include>org.apache.paimon:paimon-arrow</include>
<include>org.apache.arrow:arrow-vector</include>
<include>org.apache.arrow:arrow-memory-core</include>
<include>org.apache.arrow:arrow-memory-unsafe</include>
<include>org.apache.arrow:arrow-c-data</include>
<include>org.apache.arrow:arrow-format</include>
<include>com.google.flatbuffers:flatbuffers-java</include>
<include>org.slf4j:slf4j-api</include>
<include>org.apache.logging.log4j:log4j-1.2-api</include>
<include>org.apache.flink:flink-shaded-hadoop-2-uber</include>
<include>net.sf.py4j:py4j</include>
</includes>
</artifactSet>
<relocations>
<relocation>
<pattern>com.fasterxml.jackson</pattern>
<shadedPattern>org.apache.paimon.shade.jackson2.com.fasterxml.jackson</shadedPattern>
</relocation>
</relocations>
</configuration>
</execution>
</executions>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.paimon.python;

import org.apache.paimon.arrow.ArrowUtils;
import org.apache.paimon.arrow.vector.ArrowFormatWriter;
import org.apache.paimon.data.InternalRow;
import org.apache.paimon.reader.RecordReader;
import org.apache.paimon.reader.RecordReaderIterator;
import org.apache.paimon.table.source.Split;
import org.apache.paimon.table.source.TableRead;
import org.apache.paimon.types.RowType;

import org.apache.arrow.vector.VectorSchemaRoot;

import javax.annotation.Nullable;

import java.io.ByteArrayOutputStream;
import java.io.IOException;

/** Read Arrow bytes from split. */
public class BytesReader {

private static final int DEFAULT_WRITE_BATCH_SIZE = 2048;

private final TableRead tableRead;
private final ArrowFormatWriter arrowFormatWriter;

private RecordReaderIterator<InternalRow> iterator;
private InternalRow nextRow;

public BytesReader(TableRead tableRead, RowType rowType) {
this.tableRead = tableRead;
this.arrowFormatWriter = new ArrowFormatWriter(rowType, DEFAULT_WRITE_BATCH_SIZE);
}

public void setSplit(Split split) throws IOException {
RecordReader<InternalRow> recordReader = tableRead.createReader(split);
iterator = new RecordReaderIterator<InternalRow>(recordReader);
nextRow();
}

@Nullable
public byte[] next() throws Exception {
if (nextRow == null) {
return null;
}

int rowCount = 0;
while (nextRow != null && arrowFormatWriter.write(nextRow)) {
nextRow();
rowCount++;
}

VectorSchemaRoot vsr = arrowFormatWriter.getVectorSchemaRoot();
vsr.setRowCount(rowCount);
ByteArrayOutputStream out = new ByteArrayOutputStream();
ArrowUtils.serializeToIpc(vsr, out);
if (nextRow == null) {
// close resource
arrowFormatWriter.close();
iterator.close();
}
return out.toByteArray();
}

private void nextRow() {
if (iterator.hasNext()) {
nextRow = iterator.next();
} else {
nextRow = null;
}
}
}
Loading

0 comments on commit 73366fb

Please sign in to comment.