From 051938c0b12e8e1ec9bd8f401202ac4adca0b398 Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Wed, 24 Jan 2024 08:36:36 -0800 Subject: [PATCH 1/2] feat: Upgrade to `jni-rs` 0.21 This PR upgrades `jni-rs` from 0.19 to 0.21. There are many API changes between the two versions and hence many places need to be updated as well in the repo. Most notably, many method calls in `jni-rs` now requires a `&mut JNIEnv` as parameter, while it used to be `&JNIEnv`. --- core/Cargo.lock | 104 ++++++++- core/Cargo.toml | 4 +- core/src/errors.rs | 169 +++++++------- .../datafusion/expressions/subquery.rs | 204 ++++++++-------- core/src/execution/jni_api.rs | 219 +++++++++--------- core/src/execution/metrics/utils.rs | 32 +-- core/src/jvm_bridge/comet_exec.rs | 102 ++++---- core/src/jvm_bridge/comet_metric_node.rs | 22 +- core/src/jvm_bridge/mod.rs | 24 +- core/src/lib.rs | 8 +- core/src/parquet/mod.rs | 166 +++++++------ core/src/parquet/util/jni.rs | 28 +-- core/src/parquet/util/mod.rs | 2 - 13 files changed, 601 insertions(+), 483 deletions(-) diff --git a/core/Cargo.lock b/core/Cargo.lock index 0585d7ec7..9c40b9153 100644 --- a/core/Cargo.lock +++ b/core/Cargo.lock @@ -1087,7 +1087,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a258e46cdc063eb8519c00b9fc845fc47bcfca4130e2f08e88665ceda8474245" dependencies = [ "libc", - "windows-sys", + "windows-sys 0.52.0", ] [[package]] @@ -1336,7 +1336,7 @@ version = "0.5.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5" dependencies = [ - "windows-sys", + "windows-sys 0.52.0", ] [[package]] @@ -1436,7 +1436,7 @@ checksum = "0bad00257d07be169d870ab665980b06cdb366d792ad690bf2e76876dc503455" dependencies = [ "hermit-abi", "rustix", - "windows-sys", + "windows-sys 0.52.0", ] [[package]] @@ -1472,18 +1472,32 @@ version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" +[[package]] +name = "java-locator" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90003f2fd9c52f212c21d8520f1128da0080bad6fff16b68fe6e7f2f0c3780c2" +dependencies = [ + "glob", + "lazy_static", +] + [[package]] name = "jni" -version = "0.19.0" +version = "0.21.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c6df18c2e3db7e453d3c6ac5b3e9d5182664d28788126d39b91f2d1e22b017ec" +checksum = "1a87aa2bb7d2af34197c04845522473242e1aa17c12f4935d5856491a7fb8c97" dependencies = [ "cesu8", + "cfg-if", "combine", + "java-locator", "jni-sys", + "libloading", "log", "thiserror", "walkdir", + "windows-sys 0.45.0", ] [[package]] @@ -1586,6 +1600,16 @@ version = "0.2.151" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "302d7ab3130588088d277783b1e2d2e10c9e9e4a16dd9050e6ec93fb3e7048f4" +[[package]] +name = "libloading" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b67380fd3b2fbe7527a606e18729d21c6f3951633d0500574c4dc22d2d638b9f" +dependencies = [ + "cfg-if", + "winapi", +] + [[package]] name = "libm" version = "0.2.8" @@ -2319,7 +2343,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys", - "windows-sys", + "windows-sys 0.52.0", ] [[package]] @@ -2602,7 +2626,7 @@ dependencies = [ "fastrand", "redox_syscall", "rustix", - "windows-sys", + "windows-sys 0.52.0", ] [[package]] @@ -3009,6 +3033,15 @@ dependencies = [ "windows-targets 0.52.0", ] +[[package]] +name = "windows-sys" +version = "0.45.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" +dependencies = [ + "windows-targets 0.42.2", +] + [[package]] name = "windows-sys" version = "0.52.0" @@ -3018,6 +3051,21 @@ dependencies = [ "windows-targets 0.52.0", ] +[[package]] +name = "windows-targets" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071" +dependencies = [ + "windows_aarch64_gnullvm 0.42.2", + "windows_aarch64_msvc 0.42.2", + "windows_i686_gnu 0.42.2", + "windows_i686_msvc 0.42.2", + "windows_x86_64_gnu 0.42.2", + "windows_x86_64_gnullvm 0.42.2", + "windows_x86_64_msvc 0.42.2", +] + [[package]] name = "windows-targets" version = "0.48.5" @@ -3048,6 +3096,12 @@ dependencies = [ "windows_x86_64_msvc 0.52.0", ] +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" + [[package]] name = "windows_aarch64_gnullvm" version = "0.48.5" @@ -3060,6 +3114,12 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cb7764e35d4db8a7921e09562a0304bf2f93e0a51bfccee0bd0bb0b666b015ea" +[[package]] +name = "windows_aarch64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" + [[package]] name = "windows_aarch64_msvc" version = "0.48.5" @@ -3072,6 +3132,12 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbaa0368d4f1d2aaefc55b6fcfee13f41544ddf36801e793edbbfd7d7df075ef" +[[package]] +name = "windows_i686_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" + [[package]] name = "windows_i686_gnu" version = "0.48.5" @@ -3084,6 +3150,12 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a28637cb1fa3560a16915793afb20081aba2c92ee8af57b4d5f28e4b3e7df313" +[[package]] +name = "windows_i686_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" + [[package]] name = "windows_i686_msvc" version = "0.48.5" @@ -3096,6 +3168,12 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ffe5e8e31046ce6230cc7215707b816e339ff4d4d67c65dffa206fd0f7aa7b9a" +[[package]] +name = "windows_x86_64_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" + [[package]] name = "windows_x86_64_gnu" version = "0.48.5" @@ -3108,6 +3186,12 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d6fa32db2bc4a2f5abeacf2b69f7992cd09dca97498da74a151a3132c26befd" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" + [[package]] name = "windows_x86_64_gnullvm" version = "0.48.5" @@ -3120,6 +3204,12 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1a657e1e9d3f514745a572a6846d3c7aa7dbe1658c056ed9c3344c4109a6949e" +[[package]] +name = "windows_x86_64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" + [[package]] name = "windows_x86_64_msvc" version = "0.48.5" diff --git a/core/Cargo.toml b/core/Cargo.toml index d27b83366..b4df34d0c 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -48,7 +48,7 @@ serde = { version = "1", features = ["derive"] } lazy_static = "1.4.0" prost = "0.12.1" thrift = "0.17" -jni = "0.19" +jni = "0.21" byteorder = "1.4.3" snap = "1.1" brotli = "3.3" @@ -81,7 +81,7 @@ prost-build = "0.9.0" [dev-dependencies] pprof = { version = "0.13.0", features = ["flamegraph"] } criterion = "0.5.1" -jni = { version = "0.19", features = ["invocation"] } +jni = { version = "0.21", features = ["invocation"] } lazy_static = "1.4" assertables = "7" diff --git a/core/src/errors.rs b/core/src/errors.rs index a5f52d377..7188ebd1d 100644 --- a/core/src/errors.rs +++ b/core/src/errors.rs @@ -298,7 +298,7 @@ impl JNIDefault for () { // `RuntimeException` back to the calling Java. Since a return result is required, use `JNIDefault` // to create a reasonable result. This returned default value will be ignored due to the exception. pub fn unwrap_or_throw_default( - env: &JNIEnv, + env: &mut JNIEnv, result: std::result::Result, ) -> T { match result { @@ -314,7 +314,7 @@ pub fn unwrap_or_throw_default( } } -fn throw_exception(env: &JNIEnv, error: &E, backtrace: Option) { +fn throw_exception(env: &mut JNIEnv, error: &E, backtrace: Option) { // If there isn't already an exception? if env.exception_check().is_ok() { // ... then throw new exception @@ -380,37 +380,46 @@ fn flatten(result: Result, E>) -> Result { result.and_then(convert::identity) } -// It is currently undefined behavior to unwind from Rust code into foreign code, so we can wrap -// our JNI functions and turn these panics into a `RuntimeException`. -pub fn try_or_throw(env: JNIEnv, f: F) -> T +// Implements "currying" from `FnOnce(T) -> R` to `FnOnce() -> R`, given +// an instance of T. Curring is not supported in Rust so we have to use this +// custom function to achieve something similar here. +fn curry<'a, T: 'a, F, R>(f: F, t: T) -> impl FnOnce() -> R + 'a where - T: JNIDefault, - F: FnOnce() -> T + UnwindSafe, + F: FnOnce(T) -> R + 'a, { - unwrap_or_throw_default(&env, catch_unwind(f).map_err(CometError::from)) + || f(t) } // This is a duplicate of `try_unwrap_or_throw`, which is used to work around Arrow's lack of // `UnwindSafe` handling. -pub fn try_assert_unwind_safe_or_throw(env: JNIEnv, f: F) -> T +pub fn try_assert_unwind_safe_or_throw(env: &JNIEnv, f: F) -> T where T: JNIDefault, - F: FnOnce() -> Result, + F: FnOnce(JNIEnv) -> Result, { + let mut env1 = unsafe { JNIEnv::from_raw(env.get_raw()).unwrap() }; + let env2 = unsafe { JNIEnv::from_raw(env.get_raw()).unwrap() }; unwrap_or_throw_default( - &env, - flatten(catch_unwind(std::panic::AssertUnwindSafe(f)).map_err(CometError::from)), + &mut env1, + flatten( + catch_unwind(std::panic::AssertUnwindSafe(curry(f, env2))).map_err(CometError::from), + ), ) } // It is currently undefined behavior to unwind from Rust code into foreign code, so we can wrap // our JNI functions and turn these panics into a `RuntimeException`. -pub fn try_unwrap_or_throw(env: JNIEnv, f: F) -> T +pub fn try_unwrap_or_throw(env: &JNIEnv, f: F) -> T where T: JNIDefault, - F: FnOnce() -> Result + UnwindSafe, + F: FnOnce(JNIEnv) -> Result + UnwindSafe, { - unwrap_or_throw_default(&env, flatten(catch_unwind(f).map_err(CometError::from))) + let mut env1 = unsafe { JNIEnv::from_raw(env.get_raw()).unwrap() }; + let env2 = unsafe { JNIEnv::from_raw(env.get_raw()).unwrap() }; + unwrap_or_throw_default( + &mut env1, + flatten(catch_unwind(curry(f, env2)).map_err(CometError::from)), + ) } #[cfg(test)] @@ -425,7 +434,7 @@ mod tests { }; use jni::{ - objects::{JClass, JObject, JString, JThrowable}, + objects::{JClass, JIntArray, JString, JThrowable}, sys::{jintArray, jstring}, AttachGuard, InitArgsBuilder, JNIEnv, JNIVersion, JavaVM, }; @@ -482,14 +491,14 @@ mod tests { #[test] pub fn error_from_panic() { let _guard = attach_current_thread(); - let env = jvm().get_env().unwrap(); + let mut env = jvm().get_env().unwrap(); - try_or_throw(env, || { + try_unwrap_or_throw(&env, |_| -> CometResult<()> { panic!("oops!"); }); assert_pending_java_exception_detailed( - &env, + &mut env, Some("java/lang/RuntimeException"), Some("oops!"), ); @@ -500,38 +509,16 @@ mod tests { #[test] pub fn object_result() { let _guard = attach_current_thread(); - let env = jvm().get_env().unwrap(); + let mut env = jvm().get_env().unwrap(); let clazz = env.find_class("java/lang/Object").unwrap(); let input = env.new_string("World".to_string()).unwrap(); - let actual = Java_Errors_hello(env, clazz, input); - - let actual_string = String::from(env.get_string(actual.into()).unwrap().to_str().unwrap()); - assert_eq!("Hello, World!", actual_string); - } - - // Verify that functions that return an object can handle throwing exceptions. The test - // causes an exception by passing a `null` where a string value is expected. - #[test] - pub fn object_panic_exception() { - let _guard = attach_current_thread(); - let env = jvm().get_env().unwrap(); - // Class java.lang.object is just a stand-in - let class = env.find_class("java/lang/Object").unwrap(); - let input = JString::from(JObject::null()); - let _actual = Java_Errors_hello(env, class, input); - - assert!(env.exception_check().unwrap()); - let exception = env.exception_occurred().expect("Unable to get exception"); - env.exception_clear().unwrap(); + let actual = Java_Errors_hello(&env, clazz, input); + let actual_s = unsafe { JString::from_raw(actual) }; - assert_exception_message_with_stacktrace( - &env, - exception, - "Couldn't get java string!: NullPtr(\"get_string obj argument\")", - "at Java_Errors_hello(", - ); + let actual_string = String::from(env.get_string(&actual_s).unwrap().to_str().unwrap()); + assert_eq!("Hello, World!", actual_string); } // Verify that functions that return an native time are handled correctly. This is basically @@ -539,13 +526,13 @@ mod tests { #[test] pub fn jlong_result() { let _guard = attach_current_thread(); - let env = jvm().get_env().unwrap(); + let mut env = jvm().get_env().unwrap(); // Class java.lang.object is just a stand-in let class = env.find_class("java/lang/Object").unwrap(); let a: jlong = 6; let b: jlong = 3; - let actual = Java_Errors_div(env, class, a, b); + let actual = Java_Errors_div(&env, class, a, b); assert_eq!(2, actual); } @@ -555,16 +542,16 @@ mod tests { #[test] pub fn jlong_panic_exception() { let _guard = attach_current_thread(); - let env = jvm().get_env().unwrap(); + let mut env = jvm().get_env().unwrap(); // Class java.lang.object is just a stand-in let class = env.find_class("java/lang/Object").unwrap(); let a: jlong = 6; let b: jlong = 0; - let _actual = Java_Errors_div(env, class, a, b); + let _actual = Java_Errors_div(&env, class, a, b); assert_pending_java_exception_detailed( - &env, + &mut env, Some("java/lang/RuntimeException"), Some("attempt to divide by zero"), ); @@ -575,13 +562,13 @@ mod tests { #[test] pub fn jlong_result_ok() { let _guard = attach_current_thread(); - let env = jvm().get_env().unwrap(); + let mut env = jvm().get_env().unwrap(); // Class java.lang.object is just a stand-in let class = env.find_class("java/lang/Object").unwrap(); let a: JString = env.new_string("9".to_string()).unwrap(); let b: JString = env.new_string("3".to_string()).unwrap(); - let actual = Java_Errors_div_with_parse(env, class, a, b); + let actual = Java_Errors_div_with_parse(&env, class, a, b); assert_eq!(3, actual); } @@ -591,16 +578,16 @@ mod tests { #[test] pub fn jlong_result_err() { let _guard = attach_current_thread(); - let env = jvm().get_env().unwrap(); + let mut env = jvm().get_env().unwrap(); // Class java.lang.object is just a stand-in let class = env.find_class("java/lang/Object").unwrap(); let a: JString = env.new_string("NaN".to_string()).unwrap(); let b: JString = env.new_string("3".to_string()).unwrap(); - let _actual = Java_Errors_div_with_parse(env, class, a, b); + let _actual = Java_Errors_div_with_parse(&env, class, a, b); assert_pending_java_exception_detailed( - &env, + &mut env, Some("java/lang/NumberFormatException"), Some("invalid digit found in string"), ); @@ -611,17 +598,18 @@ mod tests { #[test] pub fn jint_array_result() { let _guard = attach_current_thread(); - let env = jvm().get_env().unwrap(); + let mut env = jvm().get_env().unwrap(); // Class java.lang.object is just a stand-in let class = env.find_class("java/lang/Object").unwrap(); let buf = [2, 4, 6]; let input = env.new_int_array(3).unwrap(); - env.set_int_array_region(input, 0, &buf).unwrap(); - let actual = Java_Errors_array_div(env, class, input, 2); + env.set_int_array_region(&input, 0, &buf).unwrap(); + let actual = Java_Errors_array_div(&env, class, &input, 2); + let actual_s = unsafe { JIntArray::from_raw(actual) }; let mut buf: [i32; 3] = [0; 3]; - env.get_int_array_region(actual, 0, &mut buf).unwrap(); + env.get_int_array_region(&actual_s, 0, &mut buf).unwrap(); assert_eq!([1, 2, 3], buf); } @@ -630,17 +618,17 @@ mod tests { #[test] pub fn jint_array_panic_exception() { let _guard = attach_current_thread(); - let env = jvm().get_env().unwrap(); + let mut env = jvm().get_env().unwrap(); // Class java.lang.object is just a stand-in let class = env.find_class("java/lang/Object").unwrap(); let buf = [2, 4, 6]; let input = env.new_int_array(3).unwrap(); - env.set_int_array_region(input, 0, &buf).unwrap(); - let _actual = Java_Errors_array_div(env, class, input, 0); + env.set_int_array_region(&input, 0, &buf).unwrap(); + let _actual = Java_Errors_array_div(&env, class, &input, 0); assert_pending_java_exception_detailed( - &env, + &mut env, Some("java/lang/RuntimeException"), Some("attempt to divide by zero"), ); @@ -683,13 +671,13 @@ mod tests { // * throwing an exception from `.expect()` #[no_mangle] pub extern "system" fn Java_Errors_hello( - env: JNIEnv, + e: &JNIEnv, _class: JClass, input: JString, ) -> jstring { - try_or_throw(env, || { + try_unwrap_or_throw(&e, |mut env| { let input: String = env - .get_string(input) + .get_string(&input) .expect("Couldn't get java string!") .into(); @@ -697,7 +685,7 @@ mod tests { .new_string(format!("Hello, {}!", input)) .expect("Couldn't create java string!"); - output.into_inner() + Ok(output.into_raw()) }) } @@ -706,24 +694,24 @@ mod tests { // * throwing an exception when dividing by zero #[no_mangle] pub extern "system" fn Java_Errors_div( - env: JNIEnv, + env: &JNIEnv, _class: JClass, a: jlong, b: jlong, ) -> jlong { - try_or_throw(env, || a / b) + try_unwrap_or_throw(env, |_| Ok(a / b)) } #[no_mangle] pub extern "system" fn Java_Errors_div_with_parse( - env: JNIEnv, + e: &JNIEnv, _class: JClass, a: JString, b: JString, ) -> jlong { - try_unwrap_or_throw(env, || { - let a_value: i64 = env.get_string(a)?.to_str()?.parse()?; - let b_value: i64 = env.get_string(b)?.to_str()?.parse()?; + try_unwrap_or_throw(e, |mut env| { + let a_value: i64 = env.get_string(&a)?.to_str()?.parse()?; + let b_value: i64 = env.get_string(&b)?.to_str()?.parse()?; Ok(a_value / b_value) }) } @@ -733,27 +721,27 @@ mod tests { // * throwing an exception when dividing by zero #[no_mangle] pub extern "system" fn Java_Errors_array_div( - env: JNIEnv, + e: &JNIEnv, _class: JClass, - input: jintArray, + input: &JIntArray, divisor: jint, ) -> jintArray { - try_or_throw(env, || { + try_unwrap_or_throw(e, |env| { let mut input_buf: [jint; 3] = [0; 3]; - env.get_int_array_region(input, 0, &mut input_buf).unwrap(); + env.get_int_array_region(input, 0, &mut input_buf)?; let buf = input_buf.map(|v| -> jint { v / divisor }); - let result = env.new_int_array(3).unwrap(); - env.set_int_array_region(result, 0, &buf).unwrap(); - result + let result = env.new_int_array(3)?; + env.set_int_array_region(&result, 0, &buf)?; + Ok(result.into_raw()) }) } // Helper method that asserts there is a pending Java exception which is an `instance_of` // `expected_type` with a message matching `expected_message` and clears it if any. fn assert_pending_java_exception_detailed( - env: &JNIEnv, + env: &mut JNIEnv, expected_type: Option<&str>, expected_message: Option<&str>, ) { @@ -762,7 +750,7 @@ mod tests { env.exception_clear().unwrap(); if let Some(expected_type) = expected_type { - assert_exception_type(env, exception, expected_type); + assert_exception_type(env, &exception, expected_type); } if let Some(expected_message) = expected_message { @@ -771,7 +759,7 @@ mod tests { } // Asserts that exception is an `instance_of` `expected_type` type. - fn assert_exception_type(env: &JNIEnv, exception: JThrowable, expected_type: &str) { + fn assert_exception_type(env: &mut JNIEnv, exception: &JThrowable, expected_type: &str) { if !env.is_instance_of(exception, expected_type).unwrap() { let class: JClass = env.get_object_class(exception).unwrap(); let name = env @@ -779,19 +767,21 @@ mod tests { .unwrap() .l() .unwrap(); - let class_name: String = env.get_string(name.into()).unwrap().into(); + let name_string = name.into(); + let class_name: String = env.get_string(&name_string).unwrap().into(); assert_eq!(class_name.replace('.', "/"), expected_type); }; } // Asserts that exception's message matches `expected_message`. - fn assert_exception_message(env: &JNIEnv, exception: JThrowable, expected_message: &str) { + fn assert_exception_message(env: &mut JNIEnv, exception: JThrowable, expected_message: &str) { let message = env .call_method(exception, "getMessage", "()Ljava/lang/String;", &[]) .unwrap() .l() .unwrap(); - let msg_rust: String = env.get_string(message.into()).unwrap().into(); + let message_string = message.into(); + let msg_rust: String = env.get_string(&message_string).unwrap().into(); println!("{}", msg_rust); // Since panics result in multi-line messages which include the backtrace, just use the // first line. @@ -800,7 +790,7 @@ mod tests { // Asserts that exception's message matches `expected_message`. fn assert_exception_message_with_stacktrace( - env: &JNIEnv, + env: &mut JNIEnv, exception: JThrowable, expected_message: &str, stacktrace_contains: &str, @@ -810,7 +800,8 @@ mod tests { .unwrap() .l() .unwrap(); - let msg_rust: String = env.get_string(message.into()).unwrap().into(); + let message_string = message.into(); + let msg_rust: String = env.get_string(&message_string).unwrap().into(); // Since panics result in multi-line messages which include the backtrace, just use the // first line. assert_starts_with!(msg_rust, expected_message); diff --git a/core/src/execution/datafusion/expressions/subquery.rs b/core/src/execution/datafusion/expressions/subquery.rs index a82fb357c..a4b32ba16 100644 --- a/core/src/execution/datafusion/expressions/subquery.rs +++ b/core/src/execution/datafusion/expressions/subquery.rs @@ -20,7 +20,10 @@ use arrow_schema::{DataType, Schema, TimeUnit}; use datafusion::logical_expr::ColumnarValue; use datafusion_common::{internal_err, DataFusionError, ScalarValue}; use datafusion_physical_expr::PhysicalExpr; -use jni::sys::{jboolean, jbyte, jint, jlong, jshort}; +use jni::{ + objects::JByteArray, + sys::{jboolean, jbyte, jint, jlong, jshort}, +}; use std::{ any::Any, fmt::{Display, Formatter}, @@ -87,109 +90,112 @@ impl PhysicalExpr for Subquery { } fn evaluate(&self, _: &RecordBatch) -> datafusion_common::Result { - let env = JVMClasses::get_env(); - - let is_null = - jni_static_call!(env, comet_exec.is_null(self.exec_context_id, self.id) -> jboolean)?; + let mut env = JVMClasses::get_env(); - if is_null > 0 { - return Ok(ColumnarValue::Scalar(ScalarValue::try_from( - &self.data_type, - )?)); - } + unsafe { + let is_null = jni_static_call!(env, + comet_exec.is_null(self.exec_context_id, self.id) -> jboolean + )?; - match &self.data_type { - DataType::Boolean => { - let r = jni_static_call!(env, - comet_exec.get_bool(self.exec_context_id, self.id) -> jboolean - )?; - Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(r > 0)))) - } - DataType::Int8 => { - let r = jni_static_call!(env, - comet_exec.get_byte(self.exec_context_id, self.id) -> jbyte - )?; - Ok(ColumnarValue::Scalar(ScalarValue::Int8(Some(r)))) - } - DataType::Int16 => { - let r = jni_static_call!(env, - comet_exec.get_short(self.exec_context_id, self.id) -> jshort - )?; - Ok(ColumnarValue::Scalar(ScalarValue::Int16(Some(r)))) + if is_null > 0 { + return Ok(ColumnarValue::Scalar(ScalarValue::try_from( + &self.data_type, + )?)); } - DataType::Int32 => { - let r = jni_static_call!(env, - comet_exec.get_int(self.exec_context_id, self.id) -> jint - )?; - Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(r)))) - } - DataType::Int64 => { - let r = jni_static_call!(env, - comet_exec.get_long(self.exec_context_id, self.id) -> jlong - )?; - Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(r)))) - } - DataType::Float32 => { - let r = jni_static_call!(env, - comet_exec.get_float(self.exec_context_id, self.id) -> f32 - )?; - Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some(r)))) - } - DataType::Float64 => { - let r = jni_static_call!(env, - comet_exec.get_double(self.exec_context_id, self.id) -> f64 - )?; - - Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(r)))) - } - DataType::Decimal128(p, s) => { - let bytes = jni_static_call!(env, - comet_exec.get_decimal(self.exec_context_id, self.id) -> BinaryWrapper - )?; - - let slice = env.convert_byte_array((*bytes.get()).into_inner()).unwrap(); - - Ok(ColumnarValue::Scalar(ScalarValue::Decimal128( - Some(bytes_to_i128(&slice)), - *p, - *s, - ))) - } - DataType::Date32 => { - let r = jni_static_call!(env, - comet_exec.get_int(self.exec_context_id, self.id) -> jint - )?; - - Ok(ColumnarValue::Scalar(ScalarValue::Date32(Some(r)))) - } - DataType::Timestamp(TimeUnit::Microsecond, timezone) => { - let r = jni_static_call!(env, - comet_exec.get_long(self.exec_context_id, self.id) -> jlong - )?; - - Ok(ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond( - Some(r), - timezone.clone(), - ))) - } - DataType::Utf8 => { - let string = jni_static_call!(env, - comet_exec.get_string(self.exec_context_id, self.id) -> StringWrapper - )?; - - let string = env.get_string(*string.get()).unwrap().into(); - Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(string)))) - } - DataType::Binary => { - let bytes = jni_static_call!(env, - comet_exec.get_binary(self.exec_context_id, self.id) -> BinaryWrapper - )?; - - let slice = env.convert_byte_array((*bytes.get()).into_inner()).unwrap(); - Ok(ColumnarValue::Scalar(ScalarValue::Binary(Some(slice)))) + match &self.data_type { + DataType::Boolean => { + let r = jni_static_call!(env, + comet_exec.get_bool(self.exec_context_id, self.id) -> jboolean + )?; + Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some(r > 0)))) + } + DataType::Int8 => { + let r = jni_static_call!(env, + comet_exec.get_byte(self.exec_context_id, self.id) -> jbyte + )?; + Ok(ColumnarValue::Scalar(ScalarValue::Int8(Some(r)))) + } + DataType::Int16 => { + let r = jni_static_call!(env, + comet_exec.get_short(self.exec_context_id, self.id) -> jshort + )?; + Ok(ColumnarValue::Scalar(ScalarValue::Int16(Some(r)))) + } + DataType::Int32 => { + let r = jni_static_call!(env, + comet_exec.get_int(self.exec_context_id, self.id) -> jint + )?; + Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(r)))) + } + DataType::Int64 => { + let r = jni_static_call!(env, + comet_exec.get_long(self.exec_context_id, self.id) -> jlong + )?; + Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some(r)))) + } + DataType::Float32 => { + let r = jni_static_call!(env, + comet_exec.get_float(self.exec_context_id, self.id) -> f32 + )?; + Ok(ColumnarValue::Scalar(ScalarValue::Float32(Some(r)))) + } + DataType::Float64 => { + let r = jni_static_call!(env, + comet_exec.get_double(self.exec_context_id, self.id) -> f64 + )?; + + Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(r)))) + } + DataType::Decimal128(p, s) => { + let bytes = jni_static_call!(env, + comet_exec.get_decimal(self.exec_context_id, self.id) -> BinaryWrapper + )?; + let bytes: &JByteArray = bytes.get().into(); + let slice = env.convert_byte_array(bytes).unwrap(); + + Ok(ColumnarValue::Scalar(ScalarValue::Decimal128( + Some(bytes_to_i128(&slice)), + *p, + *s, + ))) + } + DataType::Date32 => { + let r = jni_static_call!(env, + comet_exec.get_int(self.exec_context_id, self.id) -> jint + )?; + + Ok(ColumnarValue::Scalar(ScalarValue::Date32(Some(r)))) + } + DataType::Timestamp(TimeUnit::Microsecond, timezone) => { + let r = jni_static_call!(env, + comet_exec.get_long(self.exec_context_id, self.id) -> jlong + )?; + + Ok(ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond( + Some(r), + timezone.clone(), + ))) + } + DataType::Utf8 => { + let string = jni_static_call!(env, + comet_exec.get_string(self.exec_context_id, self.id) -> StringWrapper + )?; + + let string = env.get_string(string.get()).unwrap().into(); + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(string)))) + } + DataType::Binary => { + let bytes = jni_static_call!(env, + comet_exec.get_binary(self.exec_context_id, self.id) -> BinaryWrapper + )?; + let bytes: &JByteArray = bytes.get().into(); + let slice = env.convert_byte_array(bytes).unwrap(); + + Ok(ColumnarValue::Scalar(ScalarValue::Binary(Some(slice)))) + } + _ => internal_err!("Unsupported scalar subquery data type {:?}", self.data_type), } - _ => internal_err!("Unsupported scalar subquery data type {:?}", self.data_type), } } diff --git a/core/src/execution/jni_api.rs b/core/src/execution/jni_api.rs index 9981cece3..831f78838 100644 --- a/core/src/execution/jni_api.rs +++ b/core/src/execution/jni_api.rs @@ -36,7 +36,10 @@ use datafusion_common::DataFusionError; use futures::poll; use jni::{ errors::Result as JNIResult, - objects::{JClass, JMap, JObject, JString, ReleaseMode}, + objects::{ + AutoElements, JBooleanArray, JByteArray, JClass, JIntArray, JLongArray, JMap, JObject, + JObjectArray, JPrimitiveArray, JString, ReleaseMode, + }, sys::{jbyteArray, jint, jlong, jlongArray}, JNIEnv, }; @@ -45,7 +48,7 @@ use std::{collections::HashMap, sync::Arc, task::Poll}; use super::{serde, utils::SparkArrowConvert}; use crate::{ - errors::{try_unwrap_or_throw, CometError}, + errors::{try_unwrap_or_throw, CometError, CometResult}, execution::{ datafusion::planner::PhysicalPlanner, metrics::utils::update_comet_metric, serde::to_arrow_datatype, shuffle::row::process_sorted_row_partition, sort::RdxSort, @@ -55,7 +58,7 @@ use crate::{ }; use futures::stream::StreamExt; use jni::{ - objects::{AutoArray, GlobalRef}, + objects::GlobalRef, sys::{jboolean, jbooleanArray, jdouble, jintArray, jobjectArray, jstring}, }; use tokio::runtime::Runtime; @@ -88,21 +91,24 @@ struct ExecutionContext { pub debug_native: bool, } -#[no_mangle] /// Accept serialized query plan and return the address of the native query plan. -pub extern "system" fn Java_org_apache_comet_Native_createPlan( - env: JNIEnv, +/// # Safety +/// This function is inheritly unsafe since it deals with raw pointers passed from JNI. +#[no_mangle] +pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( + e: JNIEnv, _class: JClass, id: jlong, config_object: JObject, serialized_query: jbyteArray, metrics_node: JObject, ) -> jlong { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&e, |mut env| { // Init JVM classes - JVMClasses::init(&env); + JVMClasses::init(&mut env); - let bytes = env.convert_byte_array(serialized_query)?; + let array = unsafe { JPrimitiveArray::from_raw(serialized_query) }; + let bytes = env.convert_byte_array(array)?; // Deserialize query plan let spark_plan = serde::deserialize_op(bytes.as_slice())?; @@ -110,13 +116,13 @@ pub extern "system" fn Java_org_apache_comet_Native_createPlan( // Sets up context let mut configs = HashMap::new(); - let config_map = JMap::from_env(&env, config_object)?; - config_map.iter()?.for_each(|config| { - let key: String = env.get_string(JString::from(config.0)).unwrap().into(); - let value: String = env.get_string(JString::from(config.1)).unwrap().into(); - + let config_map = JMap::from_env(&mut env, &config_object)?; + let mut map_iter = config_map.iter(&mut env)?; + while let Some((key, value)) = map_iter.next(&mut env)? { + let key: String = env.get_string(&JString::from(key)).unwrap().into(); + let value: String = env.get_string(&JString::from(value)).unwrap().into(); configs.insert(key, value); - }); + } // Whether we've enabled additional debugging on the native side let debug_native = configs @@ -157,8 +163,8 @@ pub extern "system" fn Java_org_apache_comet_Native_createPlan( /// Parse Comet configs and configure DataFusion session context. fn prepare_datafusion_session_context( conf: &HashMap, -) -> Result { - // Get the batch size from Comet JVM side +) -> CometResult { + // Get the batch size from Boson JVM side let batch_size = conf .get("batch_size") .ok_or(CometError::Internal( @@ -205,10 +211,10 @@ fn prepare_datafusion_session_context( /// Prepares arrow arrays for output. fn prepare_output( + env: &mut JNIEnv, output: Result, - env: JNIEnv, exec_context: &mut ExecutionContext, -) -> Result { +) -> CometResult { let output_batch = output?; let results = output_batch.columns(); let num_rows = output_batch.num_rows(); @@ -226,7 +232,7 @@ fn prepare_output( let return_flag = 1; let long_array = env.new_long_array((results.len() * 2) as i32 + 2)?; - env.set_long_array_region(long_array, 0, &[return_flag, num_rows as jlong])?; + env.set_long_array_region(&long_array, 0, &[return_flag, num_rows as jlong])?; let mut arrays = vec![]; @@ -241,48 +247,61 @@ fn prepare_output( arrays.push((arrow_array, arrow_schema)); } - env.set_long_array_region(long_array, (i * 2) as i32 + 2, &[array, schema])?; + env.set_long_array_region(&long_array, (i * 2) as i32 + 2, &[array, schema])?; i += 1; } // Update metrics - update_metrics(&env, exec_context)?; + update_metrics(env, exec_context)?; // Record the pointer to allocated Arrow Arrays exec_context.ffi_arrays = arrays; - Ok(long_array) + Ok(long_array.into_raw()) } -#[no_mangle] /// Accept serialized query plan and the addresses of Arrow Arrays from Spark, /// then execute the query. Return addresses of arrow vector. -pub extern "system" fn Java_org_apache_comet_Native_executePlan( - env: JNIEnv, +/// # Safety +/// This function is inheritly unsafe since it deals with raw pointers passed from JNI. +#[no_mangle] +pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( + e: JNIEnv, _class: JClass, exec_context: jlong, addresses_array: jobjectArray, finishes: jbooleanArray, batch_rows: jint, ) -> jlongArray { - try_unwrap_or_throw(env, || { - let addresses_vec = convert_addresses_arrays(&env, addresses_array)?; - let mut all_inputs: Vec> = Vec::with_capacity(addresses_vec.len()); - + try_unwrap_or_throw(&e, |mut env| unsafe { let exec_context = get_execution_context(exec_context); - for addresses in addresses_vec.iter() { + + let addresses = JObjectArray::from_raw(addresses_array); + let num_addresses = env.get_array_length(&addresses)? as usize; + + let mut all_inputs: Vec> = Vec::with_capacity(num_addresses); + + for i in 0..num_addresses { let mut inputs: Vec = vec![]; - let array_num = addresses.size()? as usize; - assert_eq!(array_num % 2, 0, "Arrow Array addresses are invalid!"); + let inner_addresses = env.get_object_array_element(&addresses, i as i32)?.into(); + let inner_address_array: AutoElements = + env.get_array_elements(&inner_addresses, ReleaseMode::NoCopyBack)?; - let num_arrays = array_num / 2; - let array_elements = addresses.as_ptr(); + let num_inner_address = inner_address_array.len(); + assert_eq!( + num_inner_address % 2, + 0, + "Arrow Array addresses are invalid!" + ); + + let num_arrays = num_inner_address / 2; + let array_elements = inner_address_array.as_ptr(); let mut i: usize = 0; while i < num_arrays { - let array_ptr = unsafe { *(array_elements.add(i * 2)) }; - let schema_ptr = unsafe { *(array_elements.add(i * 2 + 1)) }; + let array_ptr = *(array_elements.add(i * 2)); + let schema_ptr = *(array_elements.add(i * 2 + 1)); let array_data = ArrayData::from_spark((array_ptr, schema_ptr))?; if exec_context.debug_native { @@ -298,7 +317,8 @@ pub extern "system" fn Java_org_apache_comet_Native_executePlan( } // Prepares the input batches. - let eofs = env.get_boolean_array_elements(finishes, ReleaseMode::NoCopyBack)?; + let array = JBooleanArray::from_raw(finishes); + let eofs = env.get_array_elements(&array, ReleaseMode::NoCopyBack)?; let eof_flags = eofs.as_ptr(); // Whether reaching the end of input batches. @@ -306,7 +326,7 @@ pub extern "system" fn Java_org_apache_comet_Native_executePlan( let mut input_batches = all_inputs .into_iter() .enumerate() - .map(|(idx, inputs)| unsafe { + .map(|(idx, inputs)| { let eof = eof_flags.add(idx); if *eof == 1 { @@ -364,25 +384,25 @@ pub extern "system" fn Java_org_apache_comet_Native_executePlan( match poll_output { Poll::Ready(Some(output)) => { - return prepare_output(output, env, exec_context); + return prepare_output(&mut env, output, exec_context); } Poll::Ready(None) => { // Reaches EOF of output. // Update metrics - update_metrics(&env, exec_context)?; + update_metrics(&mut env, exec_context)?; let long_array = env.new_long_array(1)?; - env.set_long_array_region(long_array, 0, &[-1])?; + env.set_long_array_region(&long_array, 0, &[-1])?; - return Ok(long_array); + return Ok(long_array.into_raw()); } - // After reaching the end of any input, a poll pending means there are more than one - // blocking operators, we don't need go back-forth between JVM/Native. Just - // keeping polling. + // After reaching the end of any input, a poll pending means there are more than + // one blocking operators, we don't need go back-forth + // between JVM/Native. Just keeping polling. Poll::Pending if finished => { // Update metrics - update_metrics(&env, exec_context)?; + update_metrics(&mut env, exec_context)?; // Output not ready yet continue; @@ -391,7 +411,7 @@ pub extern "system" fn Java_org_apache_comet_Native_executePlan( // operators. Just returning to keep reading next input. Poll::Pending => { // Update metrics - update_metrics(&env, exec_context)?; + update_metrics(&mut env, exec_context)?; return return_pending(env); } } @@ -401,19 +421,18 @@ pub extern "system" fn Java_org_apache_comet_Native_executePlan( fn return_pending(env: JNIEnv) -> Result { let long_array = env.new_long_array(1)?; - env.set_long_array_region(long_array, 0, &[0])?; - - Ok(long_array) + env.set_long_array_region(&long_array, 0, &[0])?; + Ok(long_array.into_raw()) } #[no_mangle] /// Peeks into next output if any. pub extern "system" fn Java_org_apache_comet_Native_peekNext( - env: JNIEnv, + e: JNIEnv, _class: JClass, exec_context: jlong, ) -> jlongArray { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&e, |mut env| { // Retrieve the query let exec_context = get_execution_context(exec_context); @@ -427,10 +446,10 @@ pub extern "system" fn Java_org_apache_comet_Native_peekNext( let poll_output = exec_context.runtime.block_on(async { poll!(next_item) }); match poll_output { - Poll::Ready(Some(output)) => prepare_output(output, env, exec_context), + Poll::Ready(Some(output)) => prepare_output(&mut env, output, exec_context), _ => { // Update metrics - update_metrics(&env, exec_context)?; + update_metrics(&mut env, exec_context)?; return_pending(env) } } @@ -440,11 +459,11 @@ pub extern "system" fn Java_org_apache_comet_Native_peekNext( #[no_mangle] /// Drop the native query plan object and context object. pub extern "system" fn Java_org_apache_comet_Native_releasePlan( - env: JNIEnv, + e: JNIEnv, _class: JClass, exec_context: jlong, ) { - try_unwrap_or_throw(env, || unsafe { + try_unwrap_or_throw(&e, |_| unsafe { let execution_context = get_execution_context(exec_context); let _: Box = Box::from_raw(execution_context); Ok(()) @@ -452,51 +471,32 @@ pub extern "system" fn Java_org_apache_comet_Native_releasePlan( } /// Updates the metrics of the query plan. -fn update_metrics(env: &JNIEnv, exec_context: &ExecutionContext) -> Result<(), CometError> { +fn update_metrics(env: &mut JNIEnv, exec_context: &ExecutionContext) -> CometResult<()> { let native_query = exec_context.root_op.as_ref().unwrap(); let metrics = exec_context.metrics.as_obj(); update_comet_metric(env, metrics, native_query) } -/// Converts a Java array of address arrays to a Rust vector of address arrays. -fn convert_addresses_arrays<'a>( - env: &'a JNIEnv<'a>, - addresses_array: jobjectArray, -) -> JNIResult>> { - let array_len = env.get_array_length(addresses_array)?; - let mut res: Vec> = Vec::new(); - - for i in 0..array_len { - let array: AutoArray = env.get_array_elements( - env.get_object_array_element(addresses_array, i)? - .into_inner() as jlongArray, - ReleaseMode::NoCopyBack, - )?; - res.push(array); - } - - Ok(res) -} - fn convert_datatype_arrays( - env: &'_ JNIEnv<'_>, + env: &'_ mut JNIEnv<'_>, serialized_datatypes: jobjectArray, ) -> JNIResult> { - let array_len = env.get_array_length(serialized_datatypes)?; - let mut res: Vec = Vec::new(); - - for i in 0..array_len { - let array = env - .get_object_array_element(serialized_datatypes, i)? - .into_inner() as jbyteArray; + unsafe { + let obj_array = JObjectArray::from_raw(serialized_datatypes); + let array_len = env.get_array_length(&obj_array)?; + let mut res: Vec = Vec::new(); + + for i in 0..array_len { + let inner_array = env.get_object_array_element(&obj_array, i)?; + let inner_array: JByteArray = inner_array.into(); + let bytes = env.convert_byte_array(inner_array)?; + let data_type = serde::deserialize_data_type(bytes.as_slice()).unwrap(); + let arrow_dt = to_arrow_datatype(&data_type); + res.push(arrow_dt); + } - let bytes = env.convert_byte_array(array)?; - let data_type = serde::deserialize_data_type(bytes.as_slice()).unwrap(); - let arrow_dt = to_arrow_datatype(&data_type); - res.push(arrow_dt); + Ok(res) } - - Ok(res) } fn get_execution_context<'a>(id: i64) -> &'a mut ExecutionContext { @@ -507,10 +507,12 @@ fn get_execution_context<'a>(id: i64) -> &'a mut ExecutionContext { } } +/// Used by Boson shuffle external sorter to write sorted records to disk. +/// # Safety +/// This function is inheritly unsafe since it deals with raw pointers passed from JNI. #[no_mangle] -/// Used by Comet shuffle external sorter to write sorted records to disk. -pub extern "system" fn Java_org_apache_comet_Native_writeSortedFileNative( - env: JNIEnv, +pub unsafe extern "system" fn Java_org_apache_comet_Native_writeSortedFileNative( + e: JNIEnv, _class: JClass, row_addresses: jlongArray, row_sizes: jintArray, @@ -521,18 +523,23 @@ pub extern "system" fn Java_org_apache_comet_Native_writeSortedFileNative( checksum_algo: jint, current_checksum: jlong, ) -> jlongArray { - try_unwrap_or_throw(env, || { - let row_num = env.get_array_length(row_addresses)? as usize; + try_unwrap_or_throw(&e, |mut env| unsafe { + let data_types = convert_datatype_arrays(&mut env, serialized_datatypes)?; - let data_types = convert_datatype_arrays(&env, serialized_datatypes)?; + let row_address_array = JLongArray::from_raw(row_addresses); + let row_num = env.get_array_length(&row_address_array)? as usize; + let row_addresses = env.get_array_elements(&row_address_array, ReleaseMode::NoCopyBack)?; - let row_addresses = env.get_long_array_elements(row_addresses, ReleaseMode::NoCopyBack)?; - let row_sizes = env.get_int_array_elements(row_sizes, ReleaseMode::NoCopyBack)?; + let row_size_array = JIntArray::from_raw(row_sizes); + let row_sizes = env.get_array_elements(&row_size_array, ReleaseMode::NoCopyBack)?; let row_addresses_ptr = row_addresses.as_ptr(); let row_sizes_ptr = row_sizes.as_ptr(); - let output_path: String = env.get_string(JString::from(file_path)).unwrap().into(); + let output_path: String = env + .get_string(&JString::from_raw(file_path)) + .unwrap() + .into(); let checksum_enabled = checksum_enabled == 1; let current_checksum = if current_checksum == i64::MIN { @@ -563,21 +570,21 @@ pub extern "system" fn Java_org_apache_comet_Native_writeSortedFileNative( }; let long_array = env.new_long_array(2)?; - env.set_long_array_region(long_array, 0, &[written_bytes, checksum])?; + env.set_long_array_region(&long_array, 0, &[written_bytes, checksum])?; - Ok(long_array) + Ok(long_array.into_raw()) }) } #[no_mangle] -/// Used by Comet shuffle external sorter to sort in-memory row partition ids. +/// Used by Boson shuffle external sorter to sort in-memory row partition ids. pub extern "system" fn Java_org_apache_comet_Native_sortRowPartitionsNative( - env: JNIEnv, + e: JNIEnv, _class: JClass, address: jlong, size: jlong, ) { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&e, |_| { // SAFETY: JVM unsafe memory allocation is aligned with long. let array = unsafe { std::slice::from_raw_parts_mut(address as *mut i64, size as usize) }; array.rdxsort(); diff --git a/core/src/execution/metrics/utils.rs b/core/src/execution/metrics/utils.rs index eb36a5562..6990aa54f 100644 --- a/core/src/execution/metrics/utils.rs +++ b/core/src/execution/metrics/utils.rs @@ -27,8 +27,8 @@ use std::sync::Arc; /// update the metrics of all the children nodes. The metrics are pulled from the /// DataFusion execution plan and pushed to the Java side through JNI. pub fn update_comet_metric( - env: &JNIEnv, - metric_node: JObject, + env: &mut JNIEnv, + metric_node: &JObject, execution_plan: &Arc, ) -> Result<(), CometError> { update_metrics( @@ -43,27 +43,31 @@ pub fn update_comet_metric( .collect::>(), )?; - for (i, child_plan) in execution_plan.children().iter().enumerate() { - let child_metric_node: JObject = jni_call!(env, - comet_metric_node(metric_node).get_child_node(i as i32) -> JObject - )?; - if child_metric_node.is_null() { - continue; + unsafe { + for (i, child_plan) in execution_plan.children().iter().enumerate() { + let child_metric_node: JObject = jni_call!(env, + comet_metric_node(metric_node).get_child_node(i as i32) -> JObject + )?; + if child_metric_node.is_null() { + continue; + } + update_comet_metric(env, &child_metric_node, child_plan)?; } - update_comet_metric(env, child_metric_node, child_plan)?; } Ok(()) } #[inline] fn update_metrics( - env: &JNIEnv, - metric_node: JObject, + env: &mut JNIEnv, + metric_node: &JObject, metric_values: &[(&str, i64)], ) -> Result<(), CometError> { - for &(name, value) in metric_values { - let jname = jni_new_string!(env, &name)?; - jni_call!(env, comet_metric_node(metric_node).add(jname, value) -> ())?; + unsafe { + for &(name, value) in metric_values { + let jname = jni_new_string!(env, &name)?; + jni_call!(env, comet_metric_node(metric_node).add(&jname, value) -> ())?; + } } Ok(()) } diff --git a/core/src/jvm_bridge/comet_exec.rs b/core/src/jvm_bridge/comet_exec.rs index e28fc080f..6b6652eb4 100644 --- a/core/src/jvm_bridge/comet_exec.rs +++ b/core/src/jvm_bridge/comet_exec.rs @@ -18,7 +18,7 @@ use jni::{ errors::Result as JniResult, objects::{JClass, JStaticMethodID}, - signature::{JavaType, Primitive}, + signature::{Primitive, ReturnType}, JNIEnv, }; @@ -27,75 +27,83 @@ use super::get_global_jclass; /// A struct that holds all the JNI methods and fields for JVM CometExec object. pub struct CometExec<'a> { pub class: JClass<'a>, - pub method_get_bool: JStaticMethodID<'a>, - pub method_get_bool_ret: JavaType, - pub method_get_byte: JStaticMethodID<'a>, - pub method_get_byte_ret: JavaType, - pub method_get_short: JStaticMethodID<'a>, - pub method_get_short_ret: JavaType, - pub method_get_int: JStaticMethodID<'a>, - pub method_get_int_ret: JavaType, - pub method_get_long: JStaticMethodID<'a>, - pub method_get_long_ret: JavaType, - pub method_get_float: JStaticMethodID<'a>, - pub method_get_float_ret: JavaType, - pub method_get_double: JStaticMethodID<'a>, - pub method_get_double_ret: JavaType, - pub method_get_decimal: JStaticMethodID<'a>, - pub method_get_decimal_ret: JavaType, - pub method_get_string: JStaticMethodID<'a>, - pub method_get_string_ret: JavaType, - pub method_get_binary: JStaticMethodID<'a>, - pub method_get_binary_ret: JavaType, - pub method_is_null: JStaticMethodID<'a>, - pub method_is_null_ret: JavaType, + pub method_get_bool: JStaticMethodID, + pub method_get_bool_ret: ReturnType, + pub method_get_byte: JStaticMethodID, + pub method_get_byte_ret: ReturnType, + pub method_get_short: JStaticMethodID, + pub method_get_short_ret: ReturnType, + pub method_get_int: JStaticMethodID, + pub method_get_int_ret: ReturnType, + pub method_get_long: JStaticMethodID, + pub method_get_long_ret: ReturnType, + pub method_get_float: JStaticMethodID, + pub method_get_float_ret: ReturnType, + pub method_get_double: JStaticMethodID, + pub method_get_double_ret: ReturnType, + pub method_get_decimal: JStaticMethodID, + pub method_get_decimal_ret: ReturnType, + pub method_get_string: JStaticMethodID, + pub method_get_string_ret: ReturnType, + pub method_get_binary: JStaticMethodID, + pub method_get_binary_ret: ReturnType, + pub method_is_null: JStaticMethodID, + pub method_is_null_ret: ReturnType, } impl<'a> CometExec<'a> { pub const JVM_CLASS: &'static str = "org/apache/spark/sql/comet/CometScalarSubquery"; - pub fn new(env: &JNIEnv<'a>) -> JniResult> { + pub fn new(env: &mut JNIEnv<'a>) -> JniResult> { // Get the global class reference let class = get_global_jclass(env, Self::JVM_CLASS)?; Ok(CometExec { - class, method_get_bool: env - .get_static_method_id(class, "getBoolean", "(JJ)Z") + .get_static_method_id(Self::JVM_CLASS, "getBoolean", "(JJ)Z") + .unwrap(), + method_get_bool_ret: ReturnType::Primitive(Primitive::Boolean), + method_get_byte: env + .get_static_method_id(Self::JVM_CLASS, "getByte", "(JJ)B") .unwrap(), - method_get_bool_ret: JavaType::Primitive(Primitive::Boolean), - method_get_byte: env.get_static_method_id(class, "getByte", "(JJ)B").unwrap(), - method_get_byte_ret: JavaType::Primitive(Primitive::Byte), + method_get_byte_ret: ReturnType::Primitive(Primitive::Byte), method_get_short: env - .get_static_method_id(class, "getShort", "(JJ)S") + .get_static_method_id(Self::JVM_CLASS, "getShort", "(JJ)S") + .unwrap(), + method_get_short_ret: ReturnType::Primitive(Primitive::Short), + method_get_int: env + .get_static_method_id(Self::JVM_CLASS, "getInt", "(JJ)I") .unwrap(), - method_get_short_ret: JavaType::Primitive(Primitive::Short), - method_get_int: env.get_static_method_id(class, "getInt", "(JJ)I").unwrap(), - method_get_int_ret: JavaType::Primitive(Primitive::Int), - method_get_long: env.get_static_method_id(class, "getLong", "(JJ)J").unwrap(), - method_get_long_ret: JavaType::Primitive(Primitive::Long), + method_get_int_ret: ReturnType::Primitive(Primitive::Int), + method_get_long: env + .get_static_method_id(Self::JVM_CLASS, "getLong", "(JJ)J") + .unwrap(), + method_get_long_ret: ReturnType::Primitive(Primitive::Long), method_get_float: env - .get_static_method_id(class, "getFloat", "(JJ)F") + .get_static_method_id(Self::JVM_CLASS, "getFloat", "(JJ)F") .unwrap(), - method_get_float_ret: JavaType::Primitive(Primitive::Float), + method_get_float_ret: ReturnType::Primitive(Primitive::Float), method_get_double: env - .get_static_method_id(class, "getDouble", "(JJ)D") + .get_static_method_id(Self::JVM_CLASS, "getDouble", "(JJ)D") .unwrap(), - method_get_double_ret: JavaType::Primitive(Primitive::Double), + method_get_double_ret: ReturnType::Primitive(Primitive::Double), method_get_decimal: env - .get_static_method_id(class, "getDecimal", "(JJ)[B") + .get_static_method_id(Self::JVM_CLASS, "getDecimal", "(JJ)[B") .unwrap(), - method_get_decimal_ret: JavaType::Array(Box::new(JavaType::Primitive(Primitive::Byte))), + method_get_decimal_ret: ReturnType::Array, method_get_string: env - .get_static_method_id(class, "getString", "(JJ)Ljava/lang/String;") + .get_static_method_id(Self::JVM_CLASS, "getString", "(JJ)Ljava/lang/String;") .unwrap(), - method_get_string_ret: JavaType::Object("java/lang/String".to_owned()), + method_get_string_ret: ReturnType::Object, method_get_binary: env - .get_static_method_id(class, "getBinary", "(JJ)[B") + .get_static_method_id(Self::JVM_CLASS, "getBinary", "(JJ)[B") + .unwrap(), + method_get_binary_ret: ReturnType::Array, + method_is_null: env + .get_static_method_id(Self::JVM_CLASS, "isNull", "(JJ)Z") .unwrap(), - method_get_binary_ret: JavaType::Array(Box::new(JavaType::Primitive(Primitive::Byte))), - method_is_null: env.get_static_method_id(class, "isNull", "(JJ)Z").unwrap(), - method_is_null_ret: JavaType::Primitive(Primitive::Boolean), + method_is_null_ret: ReturnType::Primitive(Primitive::Boolean), + class, }) } } diff --git a/core/src/jvm_bridge/comet_metric_node.rs b/core/src/jvm_bridge/comet_metric_node.rs index 1d4928a09..d0176f427 100644 --- a/core/src/jvm_bridge/comet_metric_node.rs +++ b/core/src/jvm_bridge/comet_metric_node.rs @@ -18,7 +18,7 @@ use jni::{ errors::Result as JniResult, objects::{JClass, JMethodID}, - signature::{JavaType, Primitive}, + signature::{Primitive, ReturnType}, JNIEnv, }; @@ -27,33 +27,33 @@ use super::get_global_jclass; /// A struct that holds all the JNI methods and fields for JVM CometMetricNode class. pub struct CometMetricNode<'a> { pub class: JClass<'a>, - pub method_get_child_node: JMethodID<'a>, - pub method_get_child_node_ret: JavaType, - pub method_add: JMethodID<'a>, - pub method_add_ret: JavaType, + pub method_get_child_node: JMethodID, + pub method_get_child_node_ret: ReturnType, + pub method_add: JMethodID, + pub method_add_ret: ReturnType, } impl<'a> CometMetricNode<'a> { pub const JVM_CLASS: &'static str = "org/apache/spark/sql/comet/CometMetricNode"; - pub fn new(env: &JNIEnv<'a>) -> JniResult> { + pub fn new(env: &mut JNIEnv<'a>) -> JniResult> { // Get the global class reference let class = get_global_jclass(env, Self::JVM_CLASS)?; Ok(CometMetricNode { - class, method_get_child_node: env .get_method_id( - class, + Self::JVM_CLASS, "getChildNode", format!("(I)L{:};", Self::JVM_CLASS).as_str(), ) .unwrap(), - method_get_child_node_ret: JavaType::Object(Self::JVM_CLASS.to_owned()), + method_get_child_node_ret: ReturnType::Object, method_add: env - .get_method_id(class, "add", "(Ljava/lang/String;J)V") + .get_method_id(Self::JVM_CLASS, "add", "(Ljava/lang/String;J)V") .unwrap(), - method_add_ret: JavaType::Primitive(Primitive::Void), + method_add_ret: ReturnType::Primitive(Primitive::Void), + class, }) } } diff --git a/core/src/jvm_bridge/mod.rs b/core/src/jvm_bridge/mod.rs index 6f162a0ea..331e7768d 100644 --- a/core/src/jvm_bridge/mod.rs +++ b/core/src/jvm_bridge/mod.rs @@ -19,7 +19,7 @@ use jni::{ errors::{Error, Result as JniResult}, - objects::{JClass, JObject, JString, JValue}, + objects::{JClass, JObject, JString, JValueGen, JValueOwned}, AttachGuard, JNIEnv, }; use once_cell::sync::OnceCell; @@ -38,7 +38,7 @@ macro_rules! jni_map_error { /// Macro for converting Rust types to JNI types. macro_rules! jvalues { ($($args:expr,)* $(,)?) => {{ - &[$(jni::objects::JValue::from($args)),*] as &[jni::objects::JValue] + &[$(jni::objects::JValue::from($args).as_jni()),*] as &[jni::sys::jvalue] }} } @@ -75,7 +75,7 @@ macro_rules! jni_static_call { $crate::jvm_bridge::jni_map_error!( $env, $env.call_static_method_unchecked( - paste::paste! {$crate::jvm_bridge::JVMClasses::get().[<$clsname>].[]}, + &paste::paste! {$crate::jvm_bridge::JVMClasses::get().[<$clsname>].[]}, paste::paste! {$crate::jvm_bridge::JVMClasses::get().[<$clsname>].[]}, paste::paste! {$crate::jvm_bridge::JVMClasses::get().[<$clsname>].[]}.clone(), $crate::jvm_bridge::jvalues!($($args,)*) @@ -114,23 +114,23 @@ impl<'a> BinaryWrapper<'a> { } } -impl<'a> TryFrom> for StringWrapper<'a> { +impl<'a> TryFrom> for StringWrapper<'a> { type Error = Error; - fn try_from(value: JValue<'a>) -> Result, Error> { + fn try_from(value: JValueOwned<'a>) -> Result, Error> { match value { - JValue::Object(b) => Ok(StringWrapper::new(JString::from(b))), + JValueGen::Object(b) => Ok(StringWrapper::new(JString::from(b))), _ => Err(Error::WrongJValueType("object", value.type_name())), } } } -impl<'a> TryFrom> for BinaryWrapper<'a> { +impl<'a> TryFrom> for BinaryWrapper<'a> { type Error = Error; - fn try_from(value: JValue<'a>) -> Result, Error> { + fn try_from(value: JValueOwned<'a>) -> Result, Error> { match value { - JValue::Object(b) => Ok(BinaryWrapper::new(b)), + JValueGen::Object(b) => Ok(BinaryWrapper::new(b)), _ => Err(Error::WrongJValueType("object", value.type_name())), } } @@ -151,7 +151,7 @@ pub(crate) use jni_static_call; pub(crate) use jvalues; /// Gets a global reference to a Java class. -pub fn get_global_jclass(env: &JNIEnv<'_>, cls: &str) -> JniResult> { +pub fn get_global_jclass(env: &mut JNIEnv, cls: &str) -> JniResult> { let local_jclass = env.find_class(cls)?; let global = env.new_global_ref::(local_jclass.into())?; @@ -186,11 +186,11 @@ static JVM_CLASSES: OnceCell = OnceCell::new(); impl JVMClasses<'_> { /// Creates a new JVMClasses struct. - pub fn init(env: &JNIEnv) { + pub fn init(env: &mut JNIEnv) { JVM_CLASSES.get_or_init(|| { // A hack to make the `JNIEnv` static. It is not safe but we don't really use the // `JNIEnv` except for creating the global references of the classes. - let env = unsafe { std::mem::transmute::<_, &'static JNIEnv>(env) }; + let env = unsafe { std::mem::transmute::<_, &'static mut JNIEnv>(env) }; JVMClasses { comet_metric_node: CometMetricNode::new(env).unwrap(), diff --git a/core/src/lib.rs b/core/src/lib.rs index c85263f4f..d10478885 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -45,7 +45,7 @@ use once_cell::sync::OnceCell; pub use data_type::*; -use crate::errors::{try_unwrap_or_throw, CometError, CometResult}; +use errors::{try_unwrap_or_throw, CometError, CometResult}; #[macro_use] mod errors; @@ -64,15 +64,15 @@ static JAVA_VM: OnceCell = OnceCell::new(); #[no_mangle] pub extern "system" fn Java_org_apache_comet_NativeBase_init( - env: JNIEnv, + e: JNIEnv, _: JClass, log_conf_path: JString, ) { // Initialize the error handling to capture panic backtraces errors::init(); - try_unwrap_or_throw(env, || { - let path: String = env.get_string(log_conf_path)?.into(); + try_unwrap_or_throw(&e, |mut env| { + let path: String = env.get_string(&log_conf_path)?.into(); // empty path means there is no custom log4rs config file provided, so fallback to use // the default configuration diff --git a/core/src/parquet/mod.rs b/core/src/parquet/mod.rs index b1a7b939c..f4c417790 100644 --- a/core/src/parquet/mod.rs +++ b/core/src/parquet/mod.rs @@ -41,7 +41,7 @@ use jni::{ use crate::execution::utils::SparkArrowConvert; use arrow::buffer::{Buffer, MutableBuffer}; -use jni::objects::ReleaseMode; +use jni::objects::{JBooleanArray, JLongArray, JPrimitiveArray, ReleaseMode}; use read::ColumnReader; use util::jni::{convert_column_descriptor, convert_encoding}; @@ -58,7 +58,7 @@ struct Context { #[no_mangle] pub extern "system" fn Java_org_apache_comet_parquet_Native_initColumnReader( - env: JNIEnv, + e: JNIEnv, _jclass: JClass, primitive_type: jint, logical_type: jint, @@ -78,9 +78,9 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_initColumnReader( use_decimal_128: jboolean, use_legacy_date_timestamp: jboolean, ) -> jlong { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&e, |mut env| { let desc = convert_column_descriptor( - &env, + &mut env, primitive_type, logical_type, max_dl, @@ -111,66 +111,74 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_initColumnReader( }) } +/// # Safety +/// This function is inheritly unsafe since it deals with raw pointers passed from JNI. #[no_mangle] pub extern "system" fn Java_org_apache_comet_parquet_Native_setDictionaryPage( - env: JNIEnv, + e: JNIEnv, _jclass: JClass, handle: jlong, page_value_count: jint, page_data: jbyteArray, encoding: jint, ) { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&e, |env| { let reader = get_reader(handle)?; // convert value encoding ordinal to the native encoding definition let encoding = convert_encoding(encoding); // copy the input on-heap buffer to native - let page_len = env.get_array_length(page_data)?; + let page_data_array = unsafe { JPrimitiveArray::from_raw(page_data) }; + let page_len = env.get_array_length(&page_data_array)?; let mut buffer = MutableBuffer::from_len_zeroed(page_len as usize); - env.get_byte_array_region(page_data, 0, from_u8_slice(buffer.as_slice_mut()))?; + env.get_byte_array_region(&page_data_array, 0, from_u8_slice(buffer.as_slice_mut()))?; reader.set_dictionary_page(page_value_count as usize, buffer.into(), encoding); Ok(()) }) } +/// # Safety +/// This function is inheritly unsafe since it deals with raw pointers passed from JNI. #[no_mangle] -pub extern "system" fn Java_org_apache_comet_parquet_Native_setPageV1( - env: JNIEnv, +pub unsafe extern "system" fn Java_org_apache_comet_parquet_Native_setPageV1( + e: JNIEnv, _jclass: JClass, handle: jlong, page_value_count: jint, page_data: jbyteArray, value_encoding: jint, ) { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&e, |env| { let reader = get_reader(handle)?; // convert value encoding ordinal to the native encoding definition let encoding = convert_encoding(value_encoding); // copy the input on-heap buffer to native - let page_len = env.get_array_length(page_data)?; + let page_data_array = unsafe { JPrimitiveArray::from_raw(page_data) }; + let page_len = env.get_array_length(&page_data_array)?; let mut buffer = MutableBuffer::from_len_zeroed(page_len as usize); - env.get_byte_array_region(page_data, 0, from_u8_slice(buffer.as_slice_mut()))?; + env.get_byte_array_region(&page_data_array, 0, from_u8_slice(buffer.as_slice_mut()))?; reader.set_page_v1(page_value_count as usize, buffer.into(), encoding); Ok(()) }) } +/// # Safety +/// This function is inheritly unsafe since it deals with raw pointers passed from JNI. #[no_mangle] -pub extern "system" fn Java_org_apache_comet_parquet_Native_setPageBufferV1( - env: JNIEnv, +pub unsafe extern "system" fn Java_org_apache_comet_parquet_Native_setPageBufferV1( + e: JNIEnv, _jclass: JClass, handle: jlong, page_value_count: jint, buffer: jobject, value_encoding: jint, ) { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&e, |env| { let ctx = get_context(handle)?; let reader = &mut ctx.column_reader; @@ -178,19 +186,20 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_setPageBufferV1( let encoding = convert_encoding(value_encoding); // Get slices from Java DirectByteBuffer - let jbuffer = JByteBuffer::from(buffer); + let jbuffer = unsafe { JByteBuffer::from_raw(buffer) }; // Convert the page to global reference so it won't get GC'd by Java. Also free the last // page if there is any. - ctx.last_data_page = Some(env.new_global_ref(jbuffer)?); + ctx.last_data_page = Some(env.new_global_ref(&jbuffer)?); - let buf_slice = env.get_direct_buffer_address(jbuffer)?; + let buf_slice = env.get_direct_buffer_address(&jbuffer)?; + let buf_capacity = env.get_direct_buffer_capacity(&jbuffer)?; unsafe { - let page_ptr = NonNull::new_unchecked(buf_slice.as_ptr() as *mut u8); + let page_ptr = NonNull::new_unchecked(buf_slice); let buffer = Buffer::from_custom_allocation( page_ptr, - buf_slice.len(), + buf_capacity, Arc::new(FFI_ArrowArray::empty()), ); reader.set_page_v1(page_value_count as usize, buffer, encoding); @@ -199,9 +208,11 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_setPageBufferV1( }) } +/// # Safety +/// This function is inheritly unsafe since it deals with raw pointers passed from JNI. #[no_mangle] -pub extern "system" fn Java_org_apache_comet_parquet_Native_setPageV2( - env: JNIEnv, +pub unsafe extern "system" fn Java_org_apache_comet_parquet_Native_setPageV2( + e: JNIEnv, _jclass: JClass, handle: jlong, page_value_count: jint, @@ -210,24 +221,27 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_setPageV2( value_data: jbyteArray, value_encoding: jint, ) { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&e, |env| { let reader = get_reader(handle)?; // convert value encoding ordinal to the native encoding definition let encoding = convert_encoding(value_encoding); // copy the input on-heap buffer to native - let dl_len = env.get_array_length(def_level_data)?; + let def_level_array = unsafe { JPrimitiveArray::from_raw(def_level_data) }; + let dl_len = env.get_array_length(&def_level_array)?; let mut dl_buffer = MutableBuffer::from_len_zeroed(dl_len as usize); - env.get_byte_array_region(def_level_data, 0, from_u8_slice(dl_buffer.as_slice_mut()))?; + env.get_byte_array_region(&def_level_array, 0, from_u8_slice(dl_buffer.as_slice_mut()))?; - let rl_len = env.get_array_length(rep_level_data)?; + let rep_level_array = unsafe { JPrimitiveArray::from_raw(rep_level_data) }; + let rl_len = env.get_array_length(&rep_level_array)?; let mut rl_buffer = MutableBuffer::from_len_zeroed(rl_len as usize); - env.get_byte_array_region(rep_level_data, 0, from_u8_slice(rl_buffer.as_slice_mut()))?; + env.get_byte_array_region(&rep_level_array, 0, from_u8_slice(rl_buffer.as_slice_mut()))?; - let v_len = env.get_array_length(value_data)?; + let value_array = unsafe { JPrimitiveArray::from_raw(value_data) }; + let v_len = env.get_array_length(&value_array)?; let mut v_buffer = MutableBuffer::from_len_zeroed(v_len as usize); - env.get_byte_array_region(value_data, 0, from_u8_slice(v_buffer.as_slice_mut()))?; + env.get_byte_array_region(&value_array, 0, from_u8_slice(v_buffer.as_slice_mut()))?; reader.set_page_v2( page_value_count as usize, @@ -246,7 +260,7 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_setNull( _jclass: JClass, handle: jlong, ) { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&env, |_| { let reader = get_reader(handle)?; reader.set_null(); Ok(()) @@ -260,7 +274,7 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_setBoolean( handle: jlong, value: jboolean, ) { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&env, |_| { let reader = get_reader(handle)?; reader.set_boolean(value != 0); Ok(()) @@ -274,7 +288,7 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_setByte( handle: jlong, value: jbyte, ) { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&env, |_| { let reader = get_reader(handle)?; reader.set_fixed::(value); Ok(()) @@ -288,7 +302,7 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_setShort( handle: jlong, value: jshort, ) { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&env, |_| { let reader = get_reader(handle)?; reader.set_fixed::(value); Ok(()) @@ -302,7 +316,7 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_setInt( handle: jlong, value: jint, ) { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&env, |_| { let reader = get_reader(handle)?; reader.set_fixed::(value); Ok(()) @@ -316,7 +330,7 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_setLong( handle: jlong, value: jlong, ) { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&env, |_| { let reader = get_reader(handle)?; reader.set_fixed::(value); Ok(()) @@ -330,7 +344,7 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_setFloat( handle: jlong, value: jfloat, ) { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&env, |_| { let reader = get_reader(handle)?; reader.set_fixed::(value); Ok(()) @@ -344,44 +358,50 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_setDouble( handle: jlong, value: jdouble, ) { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&env, |_| { let reader = get_reader(handle)?; reader.set_fixed::(value); Ok(()) }) } +/// # Safety +/// This function is inheritly unsafe since it deals with raw pointers passed from JNI. #[no_mangle] -pub extern "system" fn Java_org_apache_comet_parquet_Native_setBinary( - env: JNIEnv, +pub unsafe extern "system" fn Java_org_apache_comet_parquet_Native_setBinary( + e: JNIEnv, _jclass: JClass, handle: jlong, value: jbyteArray, ) { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&e, |env| { let reader = get_reader(handle)?; - let len = env.get_array_length(value)?; + let value_array = unsafe { JPrimitiveArray::from_raw(value) }; + let len = env.get_array_length(&value_array)?; let mut buffer = MutableBuffer::from_len_zeroed(len as usize); - env.get_byte_array_region(value, 0, from_u8_slice(buffer.as_slice_mut()))?; + env.get_byte_array_region(&value_array, 0, from_u8_slice(buffer.as_slice_mut()))?; reader.set_binary(buffer); Ok(()) }) } +/// # Safety +/// This function is inheritly unsafe since it deals with raw pointers passed from JNI. #[no_mangle] -pub extern "system" fn Java_org_apache_comet_parquet_Native_setDecimal( - env: JNIEnv, +pub unsafe extern "system" fn Java_org_apache_comet_parquet_Native_setDecimal( + e: JNIEnv, _jclass: JClass, handle: jlong, value: jbyteArray, ) { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&e, |env| { let reader = get_reader(handle)?; - let len = env.get_array_length(value)?; + let value_array = unsafe { JPrimitiveArray::from_raw(value) }; + let len = env.get_array_length(&value_array)?; let mut buffer = MutableBuffer::from_len_zeroed(len as usize); - env.get_byte_array_region(value, 0, from_u8_slice(buffer.as_slice_mut()))?; + env.get_byte_array_region(&value_array, 0, from_u8_slice(buffer.as_slice_mut()))?; reader.set_decimal_flba(buffer); Ok(()) }) @@ -395,26 +415,29 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_setPosition( value: jlong, size: jint, ) { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&env, |_| { let reader = get_reader(handle)?; reader.set_position(value, size as usize); Ok(()) }) } +/// # Safety +/// This function is inheritly unsafe since it deals with raw pointers passed from JNI. #[no_mangle] -pub extern "system" fn Java_org_apache_comet_parquet_Native_setIndices( - env: JNIEnv, +pub unsafe extern "system" fn Java_org_apache_comet_parquet_Native_setIndices( + e: JNIEnv, _jclass: JClass, handle: jlong, offset: jlong, batch_size: jint, indices: jlongArray, ) -> jlong { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&e, |mut env| { let reader = get_reader(handle)?; - let indices = env.get_long_array_elements(indices, ReleaseMode::NoCopyBack)?; - let len = indices.size()? as usize; + let indice_array = unsafe { JLongArray::from_raw(indices) }; + let indices = unsafe { env.get_array_elements(&indice_array, ReleaseMode::NoCopyBack)? }; + let len = indices.len(); // paris alternately contains start index and length of continuous indices let pairs = unsafe { core::slice::from_raw_parts_mut(indices.as_ptr(), len) }; let mut skipped = 0; @@ -437,19 +460,22 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_setIndices( }) } +/// # Safety +/// This function is inheritly unsafe since it deals with raw pointers passed from JNI. #[no_mangle] -pub extern "system" fn Java_org_apache_comet_parquet_Native_setIsDeleted( - env: JNIEnv, +pub unsafe extern "system" fn Java_org_apache_comet_parquet_Native_setIsDeleted( + e: JNIEnv, _jclass: JClass, handle: jlong, is_deleted: jbooleanArray, ) { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&e, |env| { let reader = get_reader(handle)?; - let len = env.get_array_length(is_deleted)?; + let is_deleted_array = unsafe { JBooleanArray::from_raw(is_deleted) }; + let len = env.get_array_length(&is_deleted_array)?; let mut buffer = MutableBuffer::from_len_zeroed(len as usize); - env.get_boolean_array_region(is_deleted, 0, buffer.as_slice_mut())?; + env.get_boolean_array_region(&is_deleted_array, 0, buffer.as_slice_mut())?; reader.set_is_deleted(buffer); Ok(()) }) @@ -461,7 +487,7 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_resetBatch( _jclass: JClass, handle: jlong, ) { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&env, |_| { let reader = get_reader(handle)?; reader.reset_batch(); Ok(()) @@ -470,20 +496,20 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_resetBatch( #[no_mangle] pub extern "system" fn Java_org_apache_comet_parquet_Native_readBatch( - env: JNIEnv, + e: JNIEnv, _jclass: JClass, handle: jlong, batch_size: jint, null_pad_size: jint, ) -> jintArray { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&e, |env| { let reader = get_reader(handle)?; let (num_values, num_nulls) = reader.read_batch(batch_size as usize, null_pad_size as usize); let res = env.new_int_array(2)?; let buf: [i32; 2] = [num_values as i32, num_nulls as i32]; - env.set_int_array_region(res, 0, &buf)?; - Ok(res) + env.set_int_array_region(&res, 0, &buf)?; + Ok(res.into_raw()) }) } @@ -495,7 +521,7 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_skipBatch( batch_size: jint, discard: jboolean, ) -> jint { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&env, |_| { let reader = get_reader(handle)?; Ok(reader.skip_batch(batch_size as usize, discard == 0) as jint) }) @@ -503,11 +529,11 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_skipBatch( #[no_mangle] pub extern "system" fn Java_org_apache_comet_parquet_Native_currentBatch( - env: JNIEnv, + e: JNIEnv, _jclass: JClass, handle: jlong, ) -> jlongArray { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&e, |env| { let ctx = get_context(handle)?; let reader = &mut ctx.column_reader; let data = reader.current_batch(); @@ -520,9 +546,9 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_currentBatch( let res = env.new_long_array(2)?; let buf: [i64; 2] = [array, schema]; - env.set_long_array_region(res, 0, &buf) + env.set_long_array_region(&res, 0, &buf) .expect("set long array region failed"); - Ok(res) + Ok(res.into_raw()) } }) } @@ -547,7 +573,7 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_closeColumnReader( _jclass: JClass, handle: jlong, ) { - try_unwrap_or_throw(env, || { + try_unwrap_or_throw(&env, |_| { unsafe { let ctx = handle as *mut Context; let _ = Box::from_raw(ctx); diff --git a/core/src/parquet/util/jni.rs b/core/src/parquet/util/jni.rs index 000eeee0b..225abfc03 100644 --- a/core/src/parquet/util/jni.rs +++ b/core/src/parquet/util/jni.rs @@ -19,8 +19,8 @@ use std::sync::Arc; use jni::{ errors::Result as JNIResult, - objects::{JMethodID, JString}, - sys::{jboolean, jint, jobjectArray, jstring}, + objects::{JObjectArray, JString}, + sys::{jboolean, jint, jobjectArray}, JNIEnv, }; @@ -33,7 +33,7 @@ use parquet::{ /// Convert primitives from Spark side into a `ColumnDescriptor`. #[allow(clippy::too_many_arguments)] pub fn convert_column_descriptor( - env: &JNIEnv, + env: &mut JNIEnv, physical_type_id: jint, logical_type_id: jint, max_dl: jint, @@ -114,12 +114,13 @@ impl TypePromotionInfo { } } -fn convert_column_path(env: &JNIEnv, path: jobjectArray) -> JNIResult { - let array_len = env.get_array_length(path)?; +fn convert_column_path(env: &mut JNIEnv, path: jobjectArray) -> JNIResult { + let path_array = unsafe { JObjectArray::from_raw(path) }; + let array_len = env.get_array_length(&path_array)?; let mut res: Vec = Vec::new(); for i in 0..array_len { - let p: JString = (env.get_object_array_element(path, i)?.into_inner() as jstring).into(); - res.push(env.get_string(p)?.into()); + let p: JString = env.get_object_array_element(&path_array, i)?.into(); + res.push(env.get_string(&p)?.into()); } Ok(ColumnPath::new(res)) } @@ -184,16 +185,3 @@ fn fix_type_length(t: &PhysicalType, type_length: i32) -> i32 { _ => type_length, } } - -fn get_method_id<'a>(env: &'a JNIEnv, class: &'a str, method: &str, sig: &str) -> JMethodID<'a> { - // first verify the class exists - let _ = env - .find_class(class) - .unwrap_or_else(|_| panic!("Class '{}' not found", class)); - env.get_method_id(class, method, sig).unwrap_or_else(|_| { - panic!( - "Method '{}' with signature '{}' of class '{}' not found", - method, sig, class - ) - }) -} diff --git a/core/src/parquet/util/mod.rs b/core/src/parquet/util/mod.rs index 6a8c731d4..7a37b786d 100644 --- a/core/src/parquet/util/mod.rs +++ b/core/src/parquet/util/mod.rs @@ -22,7 +22,5 @@ pub mod memory; mod buffer; pub use buffer::*; -mod jni_buffer; -pub use jni_buffer::*; pub mod test_common; From 8795709b213374ea21f80cacf1c489dbcd13d752 Mon Sep 17 00:00:00 2001 From: Chao Sun Date: Mon, 19 Feb 2024 21:06:38 -0800 Subject: [PATCH 2/2] add unsafe --- core/src/parquet/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/parquet/mod.rs b/core/src/parquet/mod.rs index f4c417790..4f87d15de 100644 --- a/core/src/parquet/mod.rs +++ b/core/src/parquet/mod.rs @@ -114,7 +114,7 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_initColumnReader( /// # Safety /// This function is inheritly unsafe since it deals with raw pointers passed from JNI. #[no_mangle] -pub extern "system" fn Java_org_apache_comet_parquet_Native_setDictionaryPage( +pub unsafe extern "system" fn Java_org_apache_comet_parquet_Native_setDictionaryPage( e: JNIEnv, _jclass: JClass, handle: jlong,