From f09dc5887fdcab90b7bcac361e54414fab089fc9 Mon Sep 17 00:00:00 2001 From: yuzelin <33053040+yuzelin@users.noreply.github.com> Date: Mon, 25 Nov 2024 20:40:20 +0800 Subject: [PATCH] Refactor ReadBuilder#with_projection to accept field names for better using (#27) --- paimon_python_api/read_builder.py | 2 +- paimon_python_java/pypaimon.py | 43 ++++++------- .../tests/test_write_and_read.py | 62 +++++++++++++++++++ paimon_python_java/util/java_utils.py | 9 +++ 4 files changed, 92 insertions(+), 24 deletions(-) diff --git a/paimon_python_api/read_builder.py b/paimon_python_api/read_builder.py index ad5e6d6..a031a05 100644 --- a/paimon_python_api/read_builder.py +++ b/paimon_python_api/read_builder.py @@ -32,7 +32,7 @@ def with_filter(self, predicate: Predicate): """ @abstractmethod - def with_projection(self, projection: List[List[int]]) -> 'ReadBuilder': + def with_projection(self, projection: List[str]) -> 'ReadBuilder': """Push nested projection.""" @abstractmethod diff --git a/paimon_python_java/pypaimon.py b/paimon_python_java/pypaimon.py index 16c7a69..b884fa4 100644 --- a/paimon_python_java/pypaimon.py +++ b/paimon_python_java/pypaimon.py @@ -61,37 +61,36 @@ class Table(table.Table): def __init__(self, j_table, catalog_options: dict): self._j_table = j_table self._catalog_options = catalog_options - # init arrow schema - schema_bytes = get_gateway().jvm.SchemaUtil.getArrowSchema(j_table.rowType()) - schema_reader = pa.RecordBatchStreamReader(pa.BufferReader(schema_bytes)) - self._arrow_schema = schema_reader.schema - schema_reader.close() def new_read_builder(self) -> 'ReadBuilder': j_read_builder = get_gateway().jvm.InvocationUtil.getReadBuilder(self._j_table) - return ReadBuilder( - j_read_builder, self._j_table.rowType(), self._catalog_options, self._arrow_schema) + return ReadBuilder(j_read_builder, self._j_table.rowType(), self._catalog_options) def new_batch_write_builder(self) -> 'BatchWriteBuilder': java_utils.check_batch_write(self._j_table) j_batch_write_builder = get_gateway().jvm.InvocationUtil.getBatchWriteBuilder(self._j_table) - return BatchWriteBuilder(j_batch_write_builder, self._j_table.rowType(), self._arrow_schema) + return BatchWriteBuilder(j_batch_write_builder) class ReadBuilder(read_builder.ReadBuilder): - def __init__(self, j_read_builder, j_row_type, catalog_options: dict, arrow_schema: pa.Schema): + def __init__(self, j_read_builder, j_row_type, catalog_options: dict): self._j_read_builder = j_read_builder self._j_row_type = j_row_type self._catalog_options = catalog_options - self._arrow_schema = arrow_schema def with_filter(self, predicate: 'Predicate'): self._j_read_builder.withFilter(predicate.to_j_predicate()) return self - def with_projection(self, projection: List[List[int]]) -> 'ReadBuilder': - self._j_read_builder.withProjection(projection) + def with_projection(self, projection: List[str]) -> 'ReadBuilder': + field_names = list(map(lambda field: field.name(), self._j_row_type.getFields())) + int_projection = list(map(lambda p: field_names.index(p), projection)) + gateway = get_gateway() + int_projection_arr = gateway.new_array(gateway.jvm.int, len(projection)) + for i in range(len(projection)): + int_projection_arr[i] = int_projection[i] + self._j_read_builder.withProjection(int_projection_arr) return self def with_limit(self, limit: int) -> 'ReadBuilder': @@ -104,7 +103,7 @@ def new_scan(self) -> 'TableScan': def new_read(self) -> 'TableRead': j_table_read = self._j_read_builder.newRead().executeFilter() - return TableRead(j_table_read, self._j_row_type, self._catalog_options, self._arrow_schema) + return TableRead(j_table_read, self._j_read_builder.readType(), self._catalog_options) def new_predicate_builder(self) -> 'PredicateBuilder': return PredicateBuilder(self._j_row_type) @@ -141,12 +140,12 @@ def to_j_split(self): class TableRead(table_read.TableRead): - def __init__(self, j_table_read, j_row_type, catalog_options, arrow_schema): + def __init__(self, j_table_read, j_read_type, catalog_options): self._j_table_read = j_table_read - self._j_row_type = j_row_type + self._j_read_type = j_read_type self._catalog_options = catalog_options self._j_bytes_reader = None - self._arrow_schema = arrow_schema + self._arrow_schema = java_utils.to_arrow_schema(j_read_type) def to_arrow(self, splits): record_batch_reader = self.to_arrow_batch_reader(splits) @@ -174,7 +173,7 @@ def _init(self): if max_workers <= 0: raise ValueError("max_workers must be greater than 0") self._j_bytes_reader = get_gateway().jvm.InvocationUtil.createParallelBytesReader( - self._j_table_read, self._j_row_type, max_workers) + self._j_table_read, self._j_read_type, max_workers) def _batch_generator(self) -> Iterator[pa.RecordBatch]: while True: @@ -188,10 +187,8 @@ def _batch_generator(self) -> Iterator[pa.RecordBatch]: class BatchWriteBuilder(write_builder.BatchWriteBuilder): - def __init__(self, j_batch_write_builder, j_row_type, arrow_schema: pa.Schema): + def __init__(self, j_batch_write_builder): self._j_batch_write_builder = j_batch_write_builder - self._j_row_type = j_row_type - self._arrow_schema = arrow_schema def overwrite(self, static_partition: Optional[dict] = None) -> 'BatchWriteBuilder': if static_partition is None: @@ -201,7 +198,7 @@ def overwrite(self, static_partition: Optional[dict] = None) -> 'BatchWriteBuild def new_write(self) -> 'BatchTableWrite': j_batch_table_write = self._j_batch_write_builder.newWrite() - return BatchTableWrite(j_batch_table_write, self._j_row_type, self._arrow_schema) + return BatchTableWrite(j_batch_table_write, self._j_batch_write_builder.rowType()) def new_commit(self) -> 'BatchTableCommit': j_batch_table_commit = self._j_batch_write_builder.newCommit() @@ -210,11 +207,11 @@ def new_commit(self) -> 'BatchTableCommit': class BatchTableWrite(table_write.BatchTableWrite): - def __init__(self, j_batch_table_write, j_row_type, arrow_schema: pa.Schema): + def __init__(self, j_batch_table_write, j_row_type): self._j_batch_table_write = j_batch_table_write self._j_bytes_writer = get_gateway().jvm.InvocationUtil.createBytesWriter( j_batch_table_write, j_row_type) - self._arrow_schema = arrow_schema + self._arrow_schema = java_utils.to_arrow_schema(j_row_type) def write_arrow(self, table): for record_batch in table.to_reader(): diff --git a/paimon_python_java/tests/test_write_and_read.py b/paimon_python_java/tests/test_write_and_read.py index b468e9f..337b9f5 100644 --- a/paimon_python_java/tests/test_write_and_read.py +++ b/paimon_python_java/tests/test_write_and_read.py @@ -445,3 +445,65 @@ def _testIgnoreNullableImpl(self, table_name, table_schema, data_schema): df['f0'] = df['f0'].astype('int32') pd.testing.assert_frame_equal( actual_df.reset_index(drop=True), df.reset_index(drop=True)) + + def testProjection(self): + pa_schema = pa.schema([ + ('f0', pa.int64()), + ('f1', pa.string()), + ('f2', pa.bool_()), + ('f3', pa.string()) + ]) + schema = Schema(pa_schema) + self.catalog.create_table('default.test_projection', schema, False) + table = self.catalog.get_table('default.test_projection') + + # prepare data + data = { + 'f0': [1, 2, 3], + 'f1': ['a', 'b', 'c'], + 'f2': [True, True, False], + 'f3': ['A', 'B', 'C'] + } + df = pd.DataFrame(data) + + # write and commit data + write_builder = table.new_batch_write_builder() + table_write = write_builder.new_write() + table_commit = write_builder.new_commit() + + table_write.write_pandas(df) + commit_messages = table_write.prepare_commit() + table_commit.commit(commit_messages) + + table_write.close() + table_commit.close() + + # case 1: read empty + read_builder = table.new_read_builder().with_projection([]) + table_scan = read_builder.new_scan() + table_read = read_builder.new_read() + splits = table_scan.plan().splits() + result1 = table_read.to_pandas(splits) + self.assertTrue(result1.empty) + + # case 2: read fully + read_builder = table.new_read_builder().with_projection(['f0', 'f1', 'f2', 'f3']) + table_scan = read_builder.new_scan() + table_read = read_builder.new_read() + splits = table_scan.plan().splits() + result2 = table_read.to_pandas(splits) + pd.testing.assert_frame_equal( + result2.reset_index(drop=True), df.reset_index(drop=True)) + + # case 3: read partially + read_builder = table.new_read_builder().with_projection(['f3', 'f2']) + table_scan = read_builder.new_scan() + table_read = read_builder.new_read() + splits = table_scan.plan().splits() + result3 = table_read.to_pandas(splits) + expected_df = pd.DataFrame({ + 'f3': ['A', 'B', 'C'], + 'f2': [True, True, False] + }) + pd.testing.assert_frame_equal( + result3.reset_index(drop=True), expected_df.reset_index(drop=True)) diff --git a/paimon_python_java/util/java_utils.py b/paimon_python_java/util/java_utils.py index 8c4f276..ce0404a 100644 --- a/paimon_python_java/util/java_utils.py +++ b/paimon_python_java/util/java_utils.py @@ -91,3 +91,12 @@ def _to_j_type(name, pa_type): return jvm.DataTypes.STRING() else: raise ValueError(f'Found unsupported data type {str(pa_type)} for field {name}.') + + +def to_arrow_schema(j_row_type): + # init arrow schema + schema_bytes = get_gateway().jvm.SchemaUtil.getArrowSchema(j_row_type) + schema_reader = pa.RecordBatchStreamReader(pa.BufferReader(schema_bytes)) + arrow_schema = schema_reader.schema + schema_reader.close() + return arrow_schema