diff --git a/Cargo.lock b/Cargo.lock index bfe986a7f656d..8faf6dd0ecc0c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6299,6 +6299,7 @@ dependencies = [ "itertools", "jni", "madsim-tokio", + "once_cell", "prost 0.11.8", "risingwave_common", "risingwave_hummock_sdk", diff --git a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/deserializer/StreamChunkDeserializer.java b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/deserializer/StreamChunkDeserializer.java index b6be0af1d9394..c2d755a9ccc0f 100644 --- a/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/deserializer/StreamChunkDeserializer.java +++ b/java/connector-node/risingwave-connector-service/src/main/java/com/risingwave/connector/deserializer/StreamChunkDeserializer.java @@ -107,6 +107,24 @@ static ValueGetter[] buildValueGetter(TableSchema tableSchema) { return row.getString(index); }; break; + case TIMESTAMP: + ret[i] = + row -> { + if (row.isNull(index)) { + return null; + } + return row.getTimestamp(index); + }; + break; + case DECIMAL: + ret[i] = + row -> { + if (row.isNull(index)) { + return null; + } + return row.getDecimal(index); + }; + break; default: throw io.grpc.Status.INVALID_ARGUMENT .withDescription("unsupported type " + typeName) diff --git a/java/java-binding-integration-test/src/main/java/com/risingwave/java/binding/Utils.java b/java/java-binding-integration-test/src/main/java/com/risingwave/java/binding/Utils.java index 193ba4811bdc1..ac1dfbf210ba4 100644 --- a/java/java-binding-integration-test/src/main/java/com/risingwave/java/binding/Utils.java +++ b/java/java-binding-integration-test/src/main/java/com/risingwave/java/binding/Utils.java @@ -47,7 +47,18 @@ public static void validateRow(BaseRow row) { row.getString(6), ((Short) rowIndex).toString().repeat((rowIndex % 10) + 1))); } - if (row.isNull(7) != (rowIndex % 5 == 0)) { + + if (row.getTimestamp(7).getTime() != rowIndex * 1000) { + throw new RuntimeException( + String.format("invalid Timestamp value: %s %s", row.getTimestamp(7), rowIndex)); + } + + if (row.getDecimal(8).intValue() != rowIndex) { + throw new RuntimeException( + String.format("invalid decimal value: %s %s", row.getDecimal(8), rowIndex)); + } + + if (row.isNull(9) != (rowIndex % 5 == 0)) { throw new RuntimeException( String.format( "invalid isNull value: %s %s", row.isNull(7), (rowIndex % 5 == 0))); diff --git a/java/java-binding/src/main/java/com/risingwave/java/binding/BaseRow.java b/java/java-binding/src/main/java/com/risingwave/java/binding/BaseRow.java index 22d55a145deaa..2e493691ef801 100644 --- a/java/java-binding/src/main/java/com/risingwave/java/binding/BaseRow.java +++ b/java/java-binding/src/main/java/com/risingwave/java/binding/BaseRow.java @@ -55,6 +55,14 @@ public String getString(int index) { return Binding.rowGetStringValue(pointer, index); } + public java.sql.Timestamp getTimestamp(int index) { + return Binding.rowGetTimestampValue(pointer, index); + } + + public java.math.BigDecimal getDecimal(int index) { + return Binding.rowGetDecimalValue(pointer, index); + } + @Override public void close() { if (!isClosed) { diff --git a/java/java-binding/src/main/java/com/risingwave/java/binding/Binding.java b/java/java-binding/src/main/java/com/risingwave/java/binding/Binding.java index f4dec3eecb426..1fcaa659db577 100644 --- a/java/java-binding/src/main/java/com/risingwave/java/binding/Binding.java +++ b/java/java-binding/src/main/java/com/risingwave/java/binding/Binding.java @@ -53,6 +53,10 @@ public class Binding { static native String rowGetStringValue(long pointer, int index); + static native java.sql.Timestamp rowGetTimestampValue(long pointer, int index); + + static native java.math.BigDecimal rowGetDecimalValue(long pointer, int index); + // Since the underlying rust does not have garbage collection, we will have to manually call // close on the row to release the row instance pointed by the pointer. static native void rowClose(long pointer); diff --git a/src/common/src/row/owned_row.rs b/src/common/src/row/owned_row.rs index 97373e630d8b1..a52ec5617c394 100644 --- a/src/common/src/row/owned_row.rs +++ b/src/common/src/row/owned_row.rs @@ -144,6 +144,20 @@ impl OwnedRow { _ => unreachable!("type is not utf8 at index: {}", idx), } } + + pub fn get_datetime(&self, idx: usize) -> &Timestamp { + match self[idx].as_ref().unwrap() { + ScalarImpl::Timestamp(dt) => dt, + _ => unreachable!("type is not NaiveDateTime at index: {}", idx), + } + } + + pub fn get_decimal(&self, idx: usize) -> &Decimal { + match self[idx].as_ref().unwrap() { + ScalarImpl::Decimal(d) => d, + _ => unreachable!("type is not NaiveDateTime at index: {}", idx), + } + } } impl EstimateSize for OwnedRow { diff --git a/src/java_binding/Cargo.toml b/src/java_binding/Cargo.toml index 43f1aae283a8b..4e47a0121d93c 100644 --- a/src/java_binding/Cargo.toml +++ b/src/java_binding/Cargo.toml @@ -14,6 +14,7 @@ bytes = "1" futures = { version = "0.3", default-features = false, features = ["alloc"] } itertools = "0.10" jni = "0.20.0" +once_cell = "1" prost = "0.11" risingwave_common = { path = "../common" } risingwave_hummock_sdk = { path = "../storage/hummock_sdk" } diff --git a/src/java_binding/gen-demo-insert-data.py b/src/java_binding/gen-demo-insert-data.py index 56be589763ab2..6ffc79077eb82 100644 --- a/src/java_binding/gen-demo-insert-data.py +++ b/src/java_binding/gen-demo-insert-data.py @@ -7,10 +7,12 @@ def gen_row(index): v4 = float(index) v5 = float(index) v6 = index % 3 == 0 - v7 = str(index) * ((index % 10) + 1) + v7 = '\'' + str(index) * ((index % 10) + 1) + '\'' + v8 = "to_timestamp(" + str(index) + ")" + v9 = index may_null = None if index % 5 == 0 else int(index) - row_data = [v1, v2, v3, v4, v5, v6, v7, may_null] - repr = [o.__repr__() if o is not None else 'null' for o in row_data] + row_data = [v1, v2, v3, v4, v5, v6, v7, v8, v9, may_null] + repr = [str(o) if o is not None else 'null' for o in row_data] return '(' + ', '.join(repr) + ')' diff --git a/src/java_binding/run_demo.sh b/src/java_binding/run_demo.sh index fc49d96fd5678..9c7fa0fd8158f 100644 --- a/src/java_binding/run_demo.sh +++ b/src/java_binding/run_demo.sh @@ -10,7 +10,7 @@ INSERT_DATA=$(python3 ${RISINGWAVE_ROOT}/src/java_binding/gen-demo-insert-data.p psql -d ${DB_NAME} -h localhost -p 4566 -U root << EOF DROP TABLE IF EXISTS ${TABLE_NAME}; -CREATE TABLE ${TABLE_NAME} (v1 smallint, v2 int, v3 bigint, v4 float4, v5 float8, v6 bool, v7 varchar, may_null bigint); +CREATE TABLE ${TABLE_NAME} (v1 smallint, v2 int, v3 bigint, v4 float4, v5 float8, v6 bool, v7 varchar, v8 timestamp, v9 decimal, may_null bigint); INSERT INTO ${TABLE_NAME} values ${INSERT_DATA}; FLUSH; EOF diff --git a/src/java_binding/src/bin/data-chunk-payload-generator.rs b/src/java_binding/src/bin/data-chunk-payload-generator.rs index 179d1baab4e2d..20ab2b65148e3 100644 --- a/src/java_binding/src/bin/data-chunk-payload-generator.rs +++ b/src/java_binding/src/bin/data-chunk-payload-generator.rs @@ -17,11 +17,11 @@ use std::io::Write; use prost::Message; use risingwave_common::array::{Op, StreamChunk}; use risingwave_common::row::OwnedRow; -use risingwave_common::types::{DataType, ScalarImpl, F32, F64}; +use risingwave_common::types::{DataType, ScalarImpl, Timestamp, F32, F64}; use risingwave_common::util::chunk_coalesce::DataChunkBuilder; fn build_row(index: usize) -> OwnedRow { - let mut row_value = Vec::with_capacity(8); + let mut row_value = Vec::with_capacity(10); row_value.push(Some(ScalarImpl::Int16(index as i16))); row_value.push(Some(ScalarImpl::Int32(index as i32))); row_value.push(Some(ScalarImpl::Int64(index as i64))); @@ -31,6 +31,10 @@ fn build_row(index: usize) -> OwnedRow { row_value.push(Some(ScalarImpl::Utf8( format!("{}", index).repeat((index % 10) + 1).into(), ))); + row_value.push(Some(ScalarImpl::Timestamp( + Timestamp::from_timestamp_uncheck(index as _, 0), + ))); + row_value.push(Some(ScalarImpl::Decimal(index.into()))); row_value.push(if index % 5 == 0 { None } else { @@ -50,6 +54,8 @@ fn main() { DataType::Float64, DataType::Boolean, DataType::Varchar, + DataType::Timestamp, + DataType::Decimal, DataType::Int64, ]; let mut ops = Vec::with_capacity(row_count); diff --git a/src/java_binding/src/hummock_iterator.rs b/src/java_binding/src/hummock_iterator.rs index 4f1ef0d47d0ed..cb1dc388606fc 100644 --- a/src/java_binding/src/hummock_iterator.rs +++ b/src/java_binding/src/hummock_iterator.rs @@ -50,6 +50,7 @@ fn select_all_vnode_stream( pub struct HummockJavaBindingIterator { row_serde: EitherSerde, stream: SelectAllIterStream, + pub class_cache: Arc, } pub struct KeyedRow { @@ -139,7 +140,11 @@ impl HummockJavaBindingIterator { BasicSerde::new(&column_ids, schema.into()).into() }; - Ok(Self { row_serde, stream }) + Ok(Self { + row_serde, + stream, + class_cache: Default::default(), + }) } pub async fn next(&mut self) -> StorageResult> { diff --git a/src/java_binding/src/lib.rs b/src/java_binding/src/lib.rs index 77c752fdb231e..955f2f6e17494 100644 --- a/src/java_binding/src/lib.rs +++ b/src/java_binding/src/lib.rs @@ -25,12 +25,13 @@ use std::marker::PhantomData; use std::ops::Deref; use std::panic::catch_unwind; use std::slice::from_raw_parts; -use std::sync::LazyLock; +use std::sync::{Arc, LazyLock}; use hummock_iterator::{HummockJavaBindingIterator, KeyedRow}; -use jni::objects::{AutoArray, JClass, JObject, JString, ReleaseMode}; +use jni::objects::{AutoArray, GlobalRef, JClass, JMethodID, JObject, JString, ReleaseMode}; use jni::sys::{jboolean, jbyte, jbyteArray, jdouble, jfloat, jint, jlong, jshort}; use jni::JNIEnv; +use once_cell::sync::OnceCell; use prost::{DecodeError, Message}; use risingwave_common::array::{ArrayError, StreamChunk}; use risingwave_common::hash::VirtualNode; @@ -221,22 +222,49 @@ where } } -pub enum JavaBindingRow { +pub enum JavaBindingRowInner { Keyed(KeyedRow), StreamChunk(StreamChunkRow), } +#[derive(Default)] +pub struct JavaClassMethodCache { + big_decimal_ctor: OnceCell<(GlobalRef, JMethodID)>, + timestamp_ctor: OnceCell<(GlobalRef, JMethodID)>, +} + +pub struct JavaBindingRow { + inner: JavaBindingRowInner, + class_cache: Arc, +} impl JavaBindingRow { + fn with_stream_chunk( + underlying: StreamChunkRow, + class_cache: Arc, + ) -> Self { + Self { + inner: JavaBindingRowInner::StreamChunk(underlying), + class_cache, + } + } + + fn with_keyed(underlying: KeyedRow, class_cache: Arc) -> Self { + Self { + inner: JavaBindingRowInner::Keyed(underlying), + class_cache, + } + } + fn as_keyed(&self) -> &KeyedRow { - match &self { - JavaBindingRow::Keyed(r) => r, + match &self.inner { + JavaBindingRowInner::Keyed(r) => r, _ => unreachable!("can only call as_keyed for KeyedRow"), } } fn as_stream_chunk(&self) -> &StreamChunkRow { - match &self { - JavaBindingRow::StreamChunk(r) => r, + match &self.inner { + JavaBindingRowInner::StreamChunk(r) => r, _ => unreachable!("can only call as_stream_chunk for StreamChunkRow"), } } @@ -246,9 +274,9 @@ impl Deref for JavaBindingRow { type Target = OwnedRow; fn deref(&self) -> &Self::Target { - match &self { - JavaBindingRow::Keyed(r) => r.row(), - JavaBindingRow::StreamChunk(r) => r.row(), + match &self.inner { + JavaBindingRowInner::Keyed(r) => r.row(), + JavaBindingRowInner::StreamChunk(r) => r.row(), } } } @@ -278,9 +306,10 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_hummockIteratorN mut pointer: Pointer<'a, HummockJavaBindingIterator>, ) -> Pointer<'static, JavaBindingRow> { execute_and_catch(env, move || { - match RUNTIME.block_on(pointer.as_mut().next())? { + let iter = pointer.as_mut(); + match RUNTIME.block_on(iter.next())? { None => Ok(Pointer::null()), - Some(row) => Ok(JavaBindingRow::Keyed(row).into()), + Some(row) => Ok(JavaBindingRow::with_keyed(row, iter.class_cache.clone()).into()), } }) } @@ -311,9 +340,14 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_streamChunkItera env: EnvParam<'a>, mut pointer: Pointer<'a, StreamChunkIterator>, ) -> Pointer<'static, JavaBindingRow> { - execute_and_catch(env, move || match pointer.as_mut().next() { - None => Ok(Pointer::null()), - Some(row) => Ok(JavaBindingRow::StreamChunk(row).into()), + execute_and_catch(env, move || { + let iter = pointer.as_mut(); + match iter.next() { + None => Ok(Pointer::null()), + Some(row) => { + Ok(JavaBindingRow::with_stream_chunk(row, iter.class_cache.clone()).into()) + } + } }) } @@ -425,6 +459,60 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetStringValu }) } +#[no_mangle] +pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetTimestampValue<'a>( + env: EnvParam<'a>, + pointer: Pointer<'a, JavaBindingRow>, + idx: jint, +) -> JObject<'a> { + execute_and_catch(env, move || { + let millis = pointer + .as_ref() + .get_datetime(idx as usize) + .0 + .timestamp_millis(); + let (ts_class_ref, constructor) = pointer + .as_ref() + .class_cache + .timestamp_ctor + .get_or_try_init(|| { + let cls = env.find_class("java/sql/Timestamp")?; + let init_method = env.get_method_id(cls, "", "(J)V")?; + Ok::<_, jni::errors::Error>((env.new_global_ref(cls)?, init_method)) + })?; + let ts_class = JClass::from(ts_class_ref.as_obj()); + let date_obj = env.new_object_unchecked(ts_class, *constructor, &[millis.into()])?; + + Ok(date_obj) + }) +} + +#[no_mangle] +pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetDecimalValue<'a>( + env: EnvParam<'a>, + pointer: Pointer<'a, JavaBindingRow>, + idx: jint, +) -> JObject<'a> { + execute_and_catch(env, move || { + let value = pointer.as_ref().get_decimal(idx as usize).to_string(); + let string_value = env.new_string(value)?; + let (decimal_class_ref, constructor) = pointer + .as_ref() + .class_cache + .big_decimal_ctor + .get_or_try_init(|| { + let cls = env.find_class("java/math/BigDecimal")?; + let init_method = env.get_method_id(cls, "", "(Ljava/lang/String;)V")?; + Ok::<_, jni::errors::Error>((env.new_global_ref(cls)?, init_method)) + })?; + let decimal_class = JClass::from(decimal_class_ref.as_obj()); + let date_obj = + env.new_object_unchecked(decimal_class, *constructor, &[string_value.into()])?; + + Ok(date_obj) + }) +} + #[no_mangle] pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowClose<'a>( _env: EnvParam<'a>, diff --git a/src/java_binding/src/stream_chunk_iterator.rs b/src/java_binding/src/stream_chunk_iterator.rs index bf6b3e8acc710..d62117a0aa108 100644 --- a/src/java_binding/src/stream_chunk_iterator.rs +++ b/src/java_binding/src/stream_chunk_iterator.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::sync::Arc; + use itertools::Itertools; use risingwave_common::array::StreamChunk; use risingwave_common::row::{OwnedRow, Row}; @@ -36,6 +38,7 @@ type StreamChunkRowIterator = impl Iterator + 'static; pub struct StreamChunkIterator { iter: StreamChunkRowIterator, + pub class_cache: Arc, } impl StreamChunkIterator { @@ -49,6 +52,7 @@ impl StreamChunkIterator { }) .collect_vec() .into_iter(), + class_cache: Default::default(), } }