Skip to content

Commit

Permalink
cache class
Browse files Browse the repository at this point in the history
  • Loading branch information
adevday committed Mar 27, 2023
1 parent 93badd5 commit c86c319
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 21 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/java_binding/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ futures = { version = "0.3", default-features = false, features = ["alloc"] }
itertools = "0.10"
jni = "0.20.0"
prost = "0.11"
once_cell = "1"
risingwave_common = { path = "../common" }
risingwave_hummock_sdk = { path = "../storage/hummock_sdk" }
risingwave_object_store = { path = "../object_store" }
Expand Down
7 changes: 6 additions & 1 deletion src/java_binding/src/hummock_iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ fn select_all_vnode_stream(
pub struct HummockJavaBindingIterator {
row_serde: EitherSerde,
stream: SelectAllIterStream,
pub class_cache: Arc<crate::JavaBindingRowCache>,
}

pub struct KeyedRow {
Expand Down Expand Up @@ -137,7 +138,11 @@ impl HummockJavaBindingIterator {
BasicSerde::new(&column_ids, schema.into()).into()
};

Ok(Self { row_serde, stream })
Ok(Self {
row_serde,
stream,
class_cache: Default::default(),
})
}

pub async fn next(&mut self) -> StorageResult<Option<KeyedRow>> {
Expand Down
94 changes: 74 additions & 20 deletions src/java_binding/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@ use std::marker::PhantomData;
use std::ops::Deref;
use std::panic::catch_unwind;
use std::slice::from_raw_parts;
use std::sync::LazyLock;
use std::sync::{Arc, LazyLock};

use hummock_iterator::{HummockJavaBindingIterator, KeyedRow};
use jni::objects::{AutoArray, JClass, JObject, JString, ReleaseMode};
use jni::objects::{AutoArray, GlobalRef, JClass, JMethodID, JObject, JString, ReleaseMode};
use jni::sys::{jboolean, jbyte, jbyteArray, jdouble, jfloat, jint, jlong, jshort};
use jni::JNIEnv;
use once_cell::sync::OnceCell;
use prost::{DecodeError, Message};
use risingwave_common::array::{ArrayError, StreamChunk};
use risingwave_common::hash::VirtualNode;
Expand Down Expand Up @@ -221,22 +222,46 @@ where
}
}

pub enum JavaBindingRow {
pub enum JavaBindingRowKind {
Keyed(KeyedRow),
StreamChunk(StreamChunkRow),
}
#[derive(Default)]
pub struct JavaBindingRowCache {
big_decimal_class: OnceCell<GlobalRef>,
timestamp_class: OnceCell<GlobalRef>,
}

pub struct JavaBindingRow {
underlying: JavaBindingRowKind,
class_cache: Arc<JavaBindingRowCache>,
}

impl JavaBindingRow {
fn new_stream_chunk(underlying: StreamChunkRow, class_cache: Arc<JavaBindingRowCache>) -> Self {
Self {
underlying: JavaBindingRowKind::StreamChunk(underlying),
class_cache,
}
}

fn new_keyed(underlying: KeyedRow, class_cache: Arc<JavaBindingRowCache>) -> Self {
Self {
underlying: JavaBindingRowKind::Keyed(underlying),
class_cache,
}
}

fn as_keyed(&self) -> &KeyedRow {
match &self {
JavaBindingRow::Keyed(r) => r,
match &self.underlying {
JavaBindingRowKind::Keyed(r) => r,
_ => unreachable!("can only call as_keyed for KeyedRow"),
}
}

fn as_stream_chunk(&self) -> &StreamChunkRow {
match &self {
JavaBindingRow::StreamChunk(r) => r,
match &self.underlying {
JavaBindingRowKind::StreamChunk(r) => r,
_ => unreachable!("can only call as_stream_chunk for StreamChunkRow"),
}
}
Expand All @@ -246,9 +271,9 @@ impl Deref for JavaBindingRow {
type Target = OwnedRow;

fn deref(&self) -> &Self::Target {
match &self {
JavaBindingRow::Keyed(r) => r.row(),
JavaBindingRow::StreamChunk(r) => r.row(),
match &self.underlying {
JavaBindingRowKind::Keyed(r) => r.row(),
JavaBindingRowKind::StreamChunk(r) => r.row(),
}
}
}
Expand Down Expand Up @@ -278,9 +303,10 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_hummockIteratorN
mut pointer: Pointer<'a, HummockJavaBindingIterator>,
) -> Pointer<'static, JavaBindingRow> {
execute_and_catch(env, move || {
match RUNTIME.block_on(pointer.as_mut().next())? {
let iter = pointer.as_mut();
match RUNTIME.block_on(iter.next())? {
None => Ok(Pointer::null()),
Some(row) => Ok(JavaBindingRow::Keyed(row).into()),
Some(row) => Ok(JavaBindingRow::new_keyed(row, iter.class_cache.clone()).into()),
}
})
}
Expand Down Expand Up @@ -311,9 +337,12 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_streamChunkItera
env: EnvParam<'a>,
mut pointer: Pointer<'a, StreamChunkIterator>,
) -> Pointer<'static, JavaBindingRow> {
execute_and_catch(env, move || match pointer.as_mut().next() {
None => Ok(Pointer::null()),
Some(row) => Ok(JavaBindingRow::StreamChunk(row).into()),
execute_and_catch(env, move || {
let iter = pointer.as_mut();
match iter.next() {
None => Ok(Pointer::null()),
Some(row) => Ok(JavaBindingRow::new_stream_chunk(row, iter.class_cache.clone()).into()),
}
})
}

Expand Down Expand Up @@ -431,15 +460,27 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetTimestampV
pointer: Pointer<'a, JavaBindingRow>,
idx: jint,
) -> JObject<'a> {
// since JMethodID is always validate until the belonging class is unload.
static INIT_METHOD: OnceCell<JMethodID> = OnceCell::new();
execute_and_catch(env, move || {
let millis = pointer
.as_ref()
.get_datetime(idx as usize)
.0
.timestamp_millis();
let date_class = env.find_class("java/sql/Timestamp")?;
let constructor = env.get_method_id(date_class, "<init>", "(J)V")?;
let date_obj = env.new_object_unchecked(date_class, constructor, &[millis.into()])?;
let ts_class_ref = pointer
.as_ref()
.class_cache
.timestamp_class
.get_or_try_init(|| {
let cls = env.find_class("java/sql/Timestamp")?;
Ok::<_, jni::errors::Error>(env.new_global_ref(cls)?)
})?;
let ts_class = JClass::from(ts_class_ref.as_obj());
let constructor = INIT_METHOD
.get_or_try_init(|| env.get_method_id(ts_class, "<init>", "(J)V"))?
.clone();
let date_obj = env.new_object_unchecked(ts_class, constructor, &[millis.into()])?;

Ok(date_obj)
})
Expand All @@ -451,11 +492,24 @@ pub extern "system" fn Java_com_risingwave_java_binding_Binding_rowGetDecimalVal
pointer: Pointer<'a, JavaBindingRow>,
idx: jint,
) -> JObject<'a> {
static INIT_METHOD: OnceCell<JMethodID> = OnceCell::new();
execute_and_catch(env, move || {
let value = pointer.as_ref().get_decimal(idx as usize).to_string();
let string_value = env.new_string(value)?;
let decimal_class = env.find_class("java/math/BigDecimal")?;
let constructor = env.get_method_id(decimal_class, "<init>", "(Ljava/lang/String;)V")?;
let ts_class_ref = pointer
.as_ref()
.class_cache
.big_decimal_class
.get_or_try_init(|| {
let cls = env.find_class("java/math/BigDecimal")?;
Ok::<_, jni::errors::Error>(env.new_global_ref(cls)?)
})?;
let decimal_class = JClass::from(ts_class_ref.as_obj());
let constructor = INIT_METHOD
.get_or_try_init(|| {
env.get_method_id(decimal_class, "<init>", "(Ljava/lang/String;)V")
})?
.clone();
let date_obj =
env.new_object_unchecked(decimal_class, constructor, &[string_value.into()])?;

Expand Down
4 changes: 4 additions & 0 deletions src/java_binding/src/stream_chunk_iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::sync::Arc;

use itertools::Itertools;
use risingwave_common::array::StreamChunk;
use risingwave_common::row::{OwnedRow, Row};
Expand All @@ -36,6 +38,7 @@ type StreamChunkRowIterator = impl Iterator<Item = StreamChunkRow> + 'static;

pub struct StreamChunkIterator {
iter: StreamChunkRowIterator,
pub class_cache: Arc<crate::JavaBindingRowCache>,
}

impl StreamChunkIterator {
Expand All @@ -49,6 +52,7 @@ impl StreamChunkIterator {
})
.collect_vec()
.into_iter(),
class_cache: Default::default(),
}
}

Expand Down

0 comments on commit c86c319

Please sign in to comment.