Skip to content

Commit

Permalink
Refactor ReadBuilder#with_projection to accept field names for better…
Browse files Browse the repository at this point in the history
… using (#27)
  • Loading branch information
yuzelin authored Nov 25, 2024
1 parent c4bbf32 commit f09dc58
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 24 deletions.
2 changes: 1 addition & 1 deletion paimon_python_api/read_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def with_filter(self, predicate: Predicate):
"""

@abstractmethod
def with_projection(self, projection: List[List[int]]) -> 'ReadBuilder':
def with_projection(self, projection: List[str]) -> 'ReadBuilder':
"""Push nested projection."""

@abstractmethod
Expand Down
43 changes: 20 additions & 23 deletions paimon_python_java/pypaimon.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,37 +61,36 @@ class Table(table.Table):
def __init__(self, j_table, catalog_options: dict):
self._j_table = j_table
self._catalog_options = catalog_options
# init arrow schema
schema_bytes = get_gateway().jvm.SchemaUtil.getArrowSchema(j_table.rowType())
schema_reader = pa.RecordBatchStreamReader(pa.BufferReader(schema_bytes))
self._arrow_schema = schema_reader.schema
schema_reader.close()

def new_read_builder(self) -> 'ReadBuilder':
j_read_builder = get_gateway().jvm.InvocationUtil.getReadBuilder(self._j_table)
return ReadBuilder(
j_read_builder, self._j_table.rowType(), self._catalog_options, self._arrow_schema)
return ReadBuilder(j_read_builder, self._j_table.rowType(), self._catalog_options)

def new_batch_write_builder(self) -> 'BatchWriteBuilder':
java_utils.check_batch_write(self._j_table)
j_batch_write_builder = get_gateway().jvm.InvocationUtil.getBatchWriteBuilder(self._j_table)
return BatchWriteBuilder(j_batch_write_builder, self._j_table.rowType(), self._arrow_schema)
return BatchWriteBuilder(j_batch_write_builder)


class ReadBuilder(read_builder.ReadBuilder):

def __init__(self, j_read_builder, j_row_type, catalog_options: dict, arrow_schema: pa.Schema):
def __init__(self, j_read_builder, j_row_type, catalog_options: dict):
self._j_read_builder = j_read_builder
self._j_row_type = j_row_type
self._catalog_options = catalog_options
self._arrow_schema = arrow_schema

def with_filter(self, predicate: 'Predicate'):
self._j_read_builder.withFilter(predicate.to_j_predicate())
return self

def with_projection(self, projection: List[List[int]]) -> 'ReadBuilder':
self._j_read_builder.withProjection(projection)
def with_projection(self, projection: List[str]) -> 'ReadBuilder':
field_names = list(map(lambda field: field.name(), self._j_row_type.getFields()))
int_projection = list(map(lambda p: field_names.index(p), projection))
gateway = get_gateway()
int_projection_arr = gateway.new_array(gateway.jvm.int, len(projection))
for i in range(len(projection)):
int_projection_arr[i] = int_projection[i]
self._j_read_builder.withProjection(int_projection_arr)
return self

def with_limit(self, limit: int) -> 'ReadBuilder':
Expand All @@ -104,7 +103,7 @@ def new_scan(self) -> 'TableScan':

def new_read(self) -> 'TableRead':
j_table_read = self._j_read_builder.newRead().executeFilter()
return TableRead(j_table_read, self._j_row_type, self._catalog_options, self._arrow_schema)
return TableRead(j_table_read, self._j_read_builder.readType(), self._catalog_options)

def new_predicate_builder(self) -> 'PredicateBuilder':
return PredicateBuilder(self._j_row_type)
Expand Down Expand Up @@ -141,12 +140,12 @@ def to_j_split(self):

class TableRead(table_read.TableRead):

def __init__(self, j_table_read, j_row_type, catalog_options, arrow_schema):
def __init__(self, j_table_read, j_read_type, catalog_options):
self._j_table_read = j_table_read
self._j_row_type = j_row_type
self._j_read_type = j_read_type
self._catalog_options = catalog_options
self._j_bytes_reader = None
self._arrow_schema = arrow_schema
self._arrow_schema = java_utils.to_arrow_schema(j_read_type)

def to_arrow(self, splits):
record_batch_reader = self.to_arrow_batch_reader(splits)
Expand Down Expand Up @@ -174,7 +173,7 @@ def _init(self):
if max_workers <= 0:
raise ValueError("max_workers must be greater than 0")
self._j_bytes_reader = get_gateway().jvm.InvocationUtil.createParallelBytesReader(
self._j_table_read, self._j_row_type, max_workers)
self._j_table_read, self._j_read_type, max_workers)

def _batch_generator(self) -> Iterator[pa.RecordBatch]:
while True:
Expand All @@ -188,10 +187,8 @@ def _batch_generator(self) -> Iterator[pa.RecordBatch]:

class BatchWriteBuilder(write_builder.BatchWriteBuilder):

def __init__(self, j_batch_write_builder, j_row_type, arrow_schema: pa.Schema):
def __init__(self, j_batch_write_builder):
self._j_batch_write_builder = j_batch_write_builder
self._j_row_type = j_row_type
self._arrow_schema = arrow_schema

def overwrite(self, static_partition: Optional[dict] = None) -> 'BatchWriteBuilder':
if static_partition is None:
Expand All @@ -201,7 +198,7 @@ def overwrite(self, static_partition: Optional[dict] = None) -> 'BatchWriteBuild

def new_write(self) -> 'BatchTableWrite':
j_batch_table_write = self._j_batch_write_builder.newWrite()
return BatchTableWrite(j_batch_table_write, self._j_row_type, self._arrow_schema)
return BatchTableWrite(j_batch_table_write, self._j_batch_write_builder.rowType())

def new_commit(self) -> 'BatchTableCommit':
j_batch_table_commit = self._j_batch_write_builder.newCommit()
Expand All @@ -210,11 +207,11 @@ def new_commit(self) -> 'BatchTableCommit':

class BatchTableWrite(table_write.BatchTableWrite):

def __init__(self, j_batch_table_write, j_row_type, arrow_schema: pa.Schema):
def __init__(self, j_batch_table_write, j_row_type):
self._j_batch_table_write = j_batch_table_write
self._j_bytes_writer = get_gateway().jvm.InvocationUtil.createBytesWriter(
j_batch_table_write, j_row_type)
self._arrow_schema = arrow_schema
self._arrow_schema = java_utils.to_arrow_schema(j_row_type)

def write_arrow(self, table):
for record_batch in table.to_reader():
Expand Down
62 changes: 62 additions & 0 deletions paimon_python_java/tests/test_write_and_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,3 +445,65 @@ def _testIgnoreNullableImpl(self, table_name, table_schema, data_schema):
df['f0'] = df['f0'].astype('int32')
pd.testing.assert_frame_equal(
actual_df.reset_index(drop=True), df.reset_index(drop=True))

def testProjection(self):
pa_schema = pa.schema([
('f0', pa.int64()),
('f1', pa.string()),
('f2', pa.bool_()),
('f3', pa.string())
])
schema = Schema(pa_schema)
self.catalog.create_table('default.test_projection', schema, False)
table = self.catalog.get_table('default.test_projection')

# prepare data
data = {
'f0': [1, 2, 3],
'f1': ['a', 'b', 'c'],
'f2': [True, True, False],
'f3': ['A', 'B', 'C']
}
df = pd.DataFrame(data)

# write and commit data
write_builder = table.new_batch_write_builder()
table_write = write_builder.new_write()
table_commit = write_builder.new_commit()

table_write.write_pandas(df)
commit_messages = table_write.prepare_commit()
table_commit.commit(commit_messages)

table_write.close()
table_commit.close()

# case 1: read empty
read_builder = table.new_read_builder().with_projection([])
table_scan = read_builder.new_scan()
table_read = read_builder.new_read()
splits = table_scan.plan().splits()
result1 = table_read.to_pandas(splits)
self.assertTrue(result1.empty)

# case 2: read fully
read_builder = table.new_read_builder().with_projection(['f0', 'f1', 'f2', 'f3'])
table_scan = read_builder.new_scan()
table_read = read_builder.new_read()
splits = table_scan.plan().splits()
result2 = table_read.to_pandas(splits)
pd.testing.assert_frame_equal(
result2.reset_index(drop=True), df.reset_index(drop=True))

# case 3: read partially
read_builder = table.new_read_builder().with_projection(['f3', 'f2'])
table_scan = read_builder.new_scan()
table_read = read_builder.new_read()
splits = table_scan.plan().splits()
result3 = table_read.to_pandas(splits)
expected_df = pd.DataFrame({
'f3': ['A', 'B', 'C'],
'f2': [True, True, False]
})
pd.testing.assert_frame_equal(
result3.reset_index(drop=True), expected_df.reset_index(drop=True))
9 changes: 9 additions & 0 deletions paimon_python_java/util/java_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,12 @@ def _to_j_type(name, pa_type):
return jvm.DataTypes.STRING()
else:
raise ValueError(f'Found unsupported data type {str(pa_type)} for field {name}.')


def to_arrow_schema(j_row_type):
# init arrow schema
schema_bytes = get_gateway().jvm.SchemaUtil.getArrowSchema(j_row_type)
schema_reader = pa.RecordBatchStreamReader(pa.BufferReader(schema_bytes))
arrow_schema = schema_reader.schema
schema_reader.close()
return arrow_schema

0 comments on commit f09dc58

Please sign in to comment.