Skip to content

Commit

Permalink
feat: add timestamp & decimal type support to java bindings (#8740)
Browse files Browse the repository at this point in the history
  • Loading branch information
adevday authored Mar 30, 2023
1 parent 5d93d99 commit 6d65352
Show file tree
Hide file tree
Showing 13 changed files with 185 additions and 23 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,24 @@ static ValueGetter[] buildValueGetter(TableSchema tableSchema) {
return row.getString(index);
};
break;
case TIMESTAMP:
ret[i] =
row -> {
if (row.isNull(index)) {
return null;
}
return row.getTimestamp(index);
};
break;
case DECIMAL:
ret[i] =
row -> {
if (row.isNull(index)) {
return null;
}
return row.getDecimal(index);
};
break;
default:
throw io.grpc.Status.INVALID_ARGUMENT
.withDescription("unsupported type " + typeName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,18 @@ public static void validateRow(BaseRow row) {
row.getString(6),
((Short) rowIndex).toString().repeat((rowIndex % 10) + 1)));
}
if (row.isNull(7) != (rowIndex % 5 == 0)) {

if (row.getTimestamp(7).getTime() != rowIndex * 1000) {
throw new RuntimeException(
String.format("invalid Timestamp value: %s %s", row.getTimestamp(7), rowIndex));
}

if (row.getDecimal(8).intValue() != rowIndex) {
throw new RuntimeException(
String.format("invalid decimal value: %s %s", row.getDecimal(8), rowIndex));
}

if (row.isNull(9) != (rowIndex % 5 == 0)) {
throw new RuntimeException(
String.format(
"invalid isNull value: %s %s", row.isNull(7), (rowIndex % 5 == 0)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,14 @@ public String getString(int index) {
return Binding.rowGetStringValue(pointer, index);
}

public java.sql.Timestamp getTimestamp(int index) {
return Binding.rowGetTimestampValue(pointer, index);
}

public java.math.BigDecimal getDecimal(int index) {
return Binding.rowGetDecimalValue(pointer, index);
}

@Override
public void close() {
if (!isClosed) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ public class Binding {

static native String rowGetStringValue(long pointer, int index);

static native java.sql.Timestamp rowGetTimestampValue(long pointer, int index);

static native java.math.BigDecimal rowGetDecimalValue(long pointer, int index);

// Since the underlying rust does not have garbage collection, we will have to manually call
// close on the row to release the row instance pointed by the pointer.
static native void rowClose(long pointer);
Expand Down
14 changes: 14 additions & 0 deletions src/common/src/row/owned_row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,20 @@ impl OwnedRow {
_ => unreachable!("type is not utf8 at index: {}", idx),
}
}

pub fn get_datetime(&self, idx: usize) -> &Timestamp {
match self[idx].as_ref().unwrap() {
ScalarImpl::Timestamp(dt) => dt,
_ => unreachable!("type is not NaiveDateTime at index: {}", idx),
}
}

pub fn get_decimal(&self, idx: usize) -> &Decimal {
match self[idx].as_ref().unwrap() {
ScalarImpl::Decimal(d) => d,
_ => unreachable!("type is not NaiveDateTime at index: {}", idx),
}
}
}

impl EstimateSize for OwnedRow {
Expand Down
1 change: 1 addition & 0 deletions src/java_binding/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ bytes = "1"
futures = { version = "0.3", default-features = false, features = ["alloc"] }
itertools = "0.10"
jni = "0.20.0"
once_cell = "1"
prost = "0.11"
risingwave_common = { path = "../common" }
risingwave_hummock_sdk = { path = "../storage/hummock_sdk" }
Expand Down
8 changes: 5 additions & 3 deletions src/java_binding/gen-demo-insert-data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ def gen_row(index):
v4 = float(index)
v5 = float(index)
v6 = index % 3 == 0
v7 = str(index) * ((index % 10) + 1)
v7 = '\'' + str(index) * ((index % 10) + 1) + '\''
v8 = "to_timestamp(" + str(index) + ")"
v9 = index
may_null = None if index % 5 == 0 else int(index)
row_data = [v1, v2, v3, v4, v5, v6, v7, may_null]
repr = [o.__repr__() if o is not None else 'null' for o in row_data]
row_data = [v1, v2, v3, v4, v5, v6, v7, v8, v9, may_null]
repr = [str(o) if o is not None else 'null' for o in row_data]
return '(' + ', '.join(repr) + ')'


Expand Down
2 changes: 1 addition & 1 deletion src/java_binding/run_demo.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ INSERT_DATA=$(python3 ${RISINGWAVE_ROOT}/src/java_binding/gen-demo-insert-data.p

psql -d ${DB_NAME} -h localhost -p 4566 -U root << EOF
DROP TABLE IF EXISTS ${TABLE_NAME};
CREATE TABLE ${TABLE_NAME} (v1 smallint, v2 int, v3 bigint, v4 float4, v5 float8, v6 bool, v7 varchar, may_null bigint);
CREATE TABLE ${TABLE_NAME} (v1 smallint, v2 int, v3 bigint, v4 float4, v5 float8, v6 bool, v7 varchar, v8 timestamp, v9 decimal, may_null bigint);
INSERT INTO ${TABLE_NAME} values ${INSERT_DATA};
FLUSH;
EOF
Expand Down
10 changes: 8 additions & 2 deletions src/java_binding/src/bin/data-chunk-payload-generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ use std::io::Write;
use prost::Message;
use risingwave_common::array::{Op, StreamChunk};
use risingwave_common::row::OwnedRow;
use risingwave_common::types::{DataType, ScalarImpl, F32, F64};
use risingwave_common::types::{DataType, ScalarImpl, Timestamp, F32, F64};
use risingwave_common::util::chunk_coalesce::DataChunkBuilder;

fn build_row(index: usize) -> OwnedRow {
let mut row_value = Vec::with_capacity(8);
let mut row_value = Vec::with_capacity(10);
row_value.push(Some(ScalarImpl::Int16(index as i16)));
row_value.push(Some(ScalarImpl::Int32(index as i32)));
row_value.push(Some(ScalarImpl::Int64(index as i64)));
Expand All @@ -31,6 +31,10 @@ fn build_row(index: usize) -> OwnedRow {
row_value.push(Some(ScalarImpl::Utf8(
format!("{}", index).repeat((index % 10) + 1).into(),
)));
row_value.push(Some(ScalarImpl::Timestamp(
Timestamp::from_timestamp_uncheck(index as _, 0),
)));
row_value.push(Some(ScalarImpl::Decimal(index.into())));
row_value.push(if index % 5 == 0 {
None
} else {
Expand All @@ -50,6 +54,8 @@ fn main() {
DataType::Float64,
DataType::Boolean,
DataType::Varchar,
DataType::Timestamp,
DataType::Decimal,
DataType::Int64,
];
let mut ops = Vec::with_capacity(row_count);
Expand Down
7 changes: 6 additions & 1 deletion src/java_binding/src/hummock_iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ fn select_all_vnode_stream(
pub struct HummockJavaBindingIterator {
row_serde: EitherSerde,
stream: SelectAllIterStream,
pub class_cache: Arc<crate::JavaClassMethodCache>,
}

pub struct KeyedRow {
Expand Down Expand Up @@ -139,7 +140,11 @@ impl HummockJavaBindingIterator {
BasicSerde::new(&column_ids, schema.into()).into()
};

Ok(Self { row_serde, stream })
Ok(Self {
row_serde,
stream,
class_cache: Default::default(),
})
}

pub async fn next(&mut self) -> StorageResult<Option<KeyedRow>> {
Expand Down
118 changes: 103 additions & 15 deletions src/java_binding/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@ use std::marker::PhantomData;
use std::ops::Deref;
use std::panic::catch_unwind;
use std::slice::from_raw_parts;
use std::sync::LazyLock;
use std::sync::{Arc, LazyLock};

use hummock_iterator::{HummockJavaBindingIterator, KeyedRow};
use jni::objects::{AutoArray, JClass, JObject, JString, ReleaseMode};
use jni::objects::{AutoArray, GlobalRef, JClass, JMethodID, JObject, JString, ReleaseMode};
use jni::sys::{jboolean, jbyte, jbyteArray, jdouble, jfloat, jint, jlong, jshort};
use jni::JNIEnv;
use once_cell::sync::OnceCell;
use prost::{DecodeError, Message};
use risingwave_common::array::{ArrayError, StreamChunk};
use risingwave_common::hash::VirtualNode;
Expand Down Expand Up @@ -221,22 +222,49 @@ where
}
}

pub enum JavaBindingRow {
pub enum JavaBindingRowInner {
Keyed(KeyedRow),
StreamChunk(StreamChunkRow),
}
#[derive(Default)]
pub struct JavaClassMethodCache {
big_decimal_ctor: OnceCell<(GlobalRef, JMethodID)>,
timestamp_ctor: OnceCell<(GlobalRef, JMethodID)>,
}

pub struct JavaBindingRow {
inner: JavaBindingRowInner,
class_cache: Arc<JavaClassMethodCache>,
}

impl JavaBindingRow {
fn with_stream_chunk(
underlying: StreamChunkRow,
class_cache: Arc<JavaClassMethodCache>,
) -> Self {
Self {
inner: JavaBindingRowInner::StreamChunk(underlying),
class_cache,
}
}

fn with_keyed(underlying: KeyedRow, class_cache: Arc<JavaClassMethodCache>) -> Self {
Self {
inner: JavaBindingRowInner::Keyed(underlying),
class_cache,
}
}

fn as_keyed(&self) -> &KeyedRow {
match &self {
JavaBindingRow::Keyed(r) => r,
match &self.inner {
JavaBindingRowInner::Keyed(r) => r,
_ => unreachable!("can only call as_keyed for KeyedRow"),
}
}

fn as_stream_chunk(&self) -> &StreamChunkRow {
match &self {
JavaBindingRow::StreamChunk(r) => r,
match &self.inner {
JavaBindingRowInner::StreamChunk(r) => r,
_ => unreachable!("can only call as_stream_chunk for StreamChunkRow"),
}
}
Expand All @@ -246,9 +274,9 @@ impl Deref for JavaBindingRow {
type Target = OwnedRow;

fn deref(&self) -> &Self::Target {
match &self {
JavaBindingRow::Keyed(r) => r.row(),
JavaBindingRow::StreamChunk(r) => r.row(),
match &self.inner {
JavaBindingRowInner::Keyed(r) => r.row(),
JavaBindingRowInner::StreamChunk(r) => r.row(),
}
}
}
Expand Down Expand Up @@ -278,9 +306,10 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_hummockIteratorN
mut pointer: Pointer<'a, HummockJavaBindingIterator>,
) -> Pointer<'static, JavaBindingRow> {
execute_and_catch(env, move || {
match RUNTIME.block_on(pointer.as_mut().next())? {
let iter = pointer.as_mut();
match RUNTIME.block_on(iter.next())? {
None => Ok(Pointer::null()),
Some(row) => Ok(JavaBindingRow::Keyed(row).into()),
Some(row) => Ok(JavaBindingRow::with_keyed(row, iter.class_cache.clone()).into()),
}
})
}
Expand Down Expand Up @@ -311,9 +340,14 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_streamChunkItera
env: EnvParam<'a>,
mut pointer: Pointer<'a, StreamChunkIterator>,
) -> Pointer<'static, JavaBindingRow> {
execute_and_catch(env, move || match pointer.as_mut().next() {
None => Ok(Pointer::null()),
Some(row) => Ok(JavaBindingRow::StreamChunk(row).into()),
execute_and_catch(env, move || {
let iter = pointer.as_mut();
match iter.next() {
None => Ok(Pointer::null()),
Some(row) => {
Ok(JavaBindingRow::with_stream_chunk(row, iter.class_cache.clone()).into())
}
}
})
}

Expand Down Expand Up @@ -425,6 +459,60 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetStringValu
})
}

#[no_mangle]
pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetTimestampValue<'a>(
env: EnvParam<'a>,
pointer: Pointer<'a, JavaBindingRow>,
idx: jint,
) -> JObject<'a> {
execute_and_catch(env, move || {
let millis = pointer
.as_ref()
.get_datetime(idx as usize)
.0
.timestamp_millis();
let (ts_class_ref, constructor) = pointer
.as_ref()
.class_cache
.timestamp_ctor
.get_or_try_init(|| {
let cls = env.find_class("java/sql/Timestamp")?;
let init_method = env.get_method_id(cls, "<init>", "(J)V")?;
Ok::<_, jni::errors::Error>((env.new_global_ref(cls)?, init_method))
})?;
let ts_class = JClass::from(ts_class_ref.as_obj());
let date_obj = env.new_object_unchecked(ts_class, *constructor, &[millis.into()])?;

Ok(date_obj)
})
}

#[no_mangle]
pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetDecimalValue<'a>(
env: EnvParam<'a>,
pointer: Pointer<'a, JavaBindingRow>,
idx: jint,
) -> JObject<'a> {
execute_and_catch(env, move || {
let value = pointer.as_ref().get_decimal(idx as usize).to_string();
let string_value = env.new_string(value)?;
let (decimal_class_ref, constructor) = pointer
.as_ref()
.class_cache
.big_decimal_ctor
.get_or_try_init(|| {
let cls = env.find_class("java/math/BigDecimal")?;
let init_method = env.get_method_id(cls, "<init>", "(Ljava/lang/String;)V")?;
Ok::<_, jni::errors::Error>((env.new_global_ref(cls)?, init_method))
})?;
let decimal_class = JClass::from(decimal_class_ref.as_obj());
let date_obj =
env.new_object_unchecked(decimal_class, *constructor, &[string_value.into()])?;

Ok(date_obj)
})
}

#[no_mangle]
pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowClose<'a>(
_env: EnvParam<'a>,
Expand Down
4 changes: 4 additions & 0 deletions src/java_binding/src/stream_chunk_iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::sync::Arc;

use itertools::Itertools;
use risingwave_common::array::StreamChunk;
use risingwave_common::row::{OwnedRow, Row};
Expand All @@ -36,6 +38,7 @@ type StreamChunkRowIterator = impl Iterator<Item = StreamChunkRow> + 'static;

pub struct StreamChunkIterator {
iter: StreamChunkRowIterator,
pub class_cache: Arc<crate::JavaClassMethodCache>,
}

impl StreamChunkIterator {
Expand All @@ -49,6 +52,7 @@ impl StreamChunkIterator {
})
.collect_vec()
.into_iter(),
class_cache: Default::default(),
}
}

Expand Down

0 comments on commit 6d65352

Please sign in to comment.