From e2c0178586e532380b0085e6a24a8fc0f6be7aeb Mon Sep 17 00:00:00 2001 From: wforget <643348094@qq.com> Date: Thu, 23 Jan 2025 11:24:33 +0800 Subject: [PATCH 1/4] Improve performance of update metrics --- native/core/src/execution/jni_api.rs | 15 ++- native/core/src/execution/metrics/utils.rs | 96 +++++++------------ .../core/src/jvm_bridge/comet_metric_node.rs | 4 + native/proto/build.rs | 1 + native/proto/src/lib.rs | 6 ++ native/proto/src/proto/metric.proto | 29 ++++++ .../spark/sql/comet/CometMetricNode.scala | 18 ++++ 7 files changed, 99 insertions(+), 70 deletions(-) create mode 100644 native/proto/src/proto/metric.proto diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index fe29d8da14..1c006c04ad 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -99,8 +99,6 @@ struct ExecutionContext { pub debug_native: bool, /// Whether to write native plans with metrics to stdout pub explain_native: bool, - /// Map of metrics name -> jstring object to cache jni_NewStringUTF calls. - pub metrics_jstrings: HashMap>, /// Memory pool config pub memory_pool_config: MemoryPoolConfig, } @@ -237,7 +235,6 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( session_ctx: Arc::new(session), debug_native: debug_native == 1, explain_native: explain_native == 1, - metrics_jstrings: HashMap::new(), memory_pool_config, }); @@ -508,9 +505,6 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( let next_item = exec_context.stream.as_mut().unwrap().next(); let poll_output = exec_context.runtime.block_on(async { poll!(next_item) }); - // Update metrics - update_metrics(&mut env, exec_context)?; - match poll_output { Poll::Ready(Some(output)) => { // prepare output for FFI transfer @@ -561,8 +555,12 @@ pub extern "system" fn Java_org_apache_comet_Native_releasePlan( _class: JClass, exec_context: jlong, ) { - try_unwrap_or_throw(&e, |_| unsafe { + try_unwrap_or_throw(&e, |mut env| unsafe { let execution_context = get_execution_context(exec_context); + + // Update metrics + update_metrics(&mut env, execution_context)?; + if execution_context.memory_pool_config.pool_type == MemoryPoolType::FairSpillTaskShared || execution_context.memory_pool_config.pool_type == MemoryPoolType::GreedyTaskShared { @@ -588,8 +586,7 @@ pub extern "system" fn Java_org_apache_comet_Native_releasePlan( fn update_metrics(env: &mut JNIEnv, exec_context: &mut ExecutionContext) -> CometResult<()> { let native_query = exec_context.root_op.as_ref().unwrap(); let metrics = exec_context.metrics.as_obj(); - let metrics_jstrings = &mut exec_context.metrics_jstrings; - update_comet_metric(env, metrics, native_query, metrics_jstrings) + update_comet_metric(env, metrics, native_query) } fn convert_datatype_arrays( diff --git a/native/core/src/execution/metrics/utils.rs b/native/core/src/execution/metrics/utils.rs index 0eb4b631dd..cff5a5ad36 100644 --- a/native/core/src/execution/metrics/utils.rs +++ b/native/core/src/execution/metrics/utils.rs @@ -16,16 +16,16 @@ // under the License. use crate::execution::spark_plan::SparkPlan; -use crate::jvm_bridge::jni_new_global_ref; use crate::{ errors::CometError, - jvm_bridge::{jni_call, jni_new_string}, + jvm_bridge::jni_call, }; use datafusion::physical_plan::metrics::MetricValue; -use jni::objects::{GlobalRef, JString}; +use datafusion_comet_proto::spark_metric::NativeMetricNode; use jni::{objects::JObject, JNIEnv}; use std::collections::HashMap; use std::sync::Arc; +use prost::Message; /// Updates the metrics of a CometMetricNode. This function is called recursively to /// update the metrics of all the children nodes. The metrics are pulled from the @@ -33,11 +33,23 @@ use std::sync::Arc; pub fn update_comet_metric( env: &mut JNIEnv, metric_node: &JObject, - spark_plan: &Arc, - metrics_jstrings: &mut HashMap>, + spark_plan: &Arc ) -> Result<(), CometError> { - // combine all metrics from all native plans for this SparkPlan - let metrics = if spark_plan.additional_native_plans.is_empty() { + unsafe { + let native_metric = to_native_metric_node(spark_plan); + let jbytes = env.byte_array_from_slice(&*native_metric?.encode_to_vec())?; + jni_call!(env, comet_metric_node(metric_node).set_all_from_bytes(&jbytes) -> ())?; + } + Ok(()) +} + +pub fn to_native_metric_node(spark_plan: &Arc) -> Result { + let mut native_metric_node = NativeMetricNode { + metrics: HashMap::new(), + children: Vec::new(), + }; + + let node_metrics = if spark_plan.additional_native_plans.is_empty() { spark_plan.native_plan.metrics() } else { let mut metrics = spark_plan.native_plan.metrics().unwrap_or_default(); @@ -55,60 +67,22 @@ pub fn update_comet_metric( Some(metrics.aggregate_by_name()) }; - update_metrics( - env, - metric_node, - &metrics - .unwrap_or_default() - .iter() - .map(|m| m.value()) - .map(|m| (m.name(), m.as_usize() as i64)) - .collect::>(), - metrics_jstrings, - )?; + // add metrics + node_metrics + .unwrap_or_default() + .iter() + .map(|m| m.value()) + .map(|m| (m.name(), m.as_usize() as i64)) + .for_each(|(name, value)| { + native_metric_node.metrics.insert(name.to_string(), value); + }); - unsafe { - for (i, child_plan) in spark_plan.children().iter().enumerate() { - let child_metric_node: JObject = jni_call!(env, - comet_metric_node(metric_node).get_child_node(i as i32) -> JObject - )?; - if child_metric_node.is_null() { - continue; - } - update_comet_metric(env, &child_metric_node, child_plan, metrics_jstrings)?; - } - } - Ok(()) -} -#[inline] -fn update_metrics( - env: &mut JNIEnv, - metric_node: &JObject, - metric_values: &[(&str, i64)], - metrics_jstrings: &mut HashMap>, -) -> Result<(), CometError> { - unsafe { - for &(name, value) in metric_values { - // Perform a lookup in the jstrings cache. - if let Some(map_global_ref) = metrics_jstrings.get(name) { - // Cache hit. Extract the jstring from the global ref. - let jobject = map_global_ref.as_obj(); - let jstring = JString::from_raw(**jobject); - // Update the metrics using the jstring as a key. - jni_call!(env, comet_metric_node(metric_node).set(&jstring, value) -> ())?; - } else { - // Cache miss. Allocate a new string, promote to global ref, and insert into cache. - let local_jstring = jni_new_string!(env, &name)?; - let global_ref = jni_new_global_ref!(env, local_jstring)?; - let arc_global_ref = Arc::new(global_ref); - metrics_jstrings.insert(name.to_string(), Arc::clone(&arc_global_ref)); - let jobject = arc_global_ref.as_obj(); - let jstring = JString::from_raw(**jobject); - // Update the metrics using the jstring as a key. - jni_call!(env, comet_metric_node(metric_node).set(&jstring, value) -> ())?; - } - } - } - Ok(()) + // add children + spark_plan.children().iter().for_each(|child_plan| { + let child_node = to_native_metric_node(child_plan).unwrap(); + native_metric_node.children.push(child_node); + }); + + Ok(native_metric_node) } diff --git a/native/core/src/jvm_bridge/comet_metric_node.rs b/native/core/src/jvm_bridge/comet_metric_node.rs index 85386d9b0d..89a28fbf0e 100644 --- a/native/core/src/jvm_bridge/comet_metric_node.rs +++ b/native/core/src/jvm_bridge/comet_metric_node.rs @@ -30,6 +30,8 @@ pub struct CometMetricNode<'a> { pub method_get_child_node_ret: ReturnType, pub method_set: JMethodID, pub method_set_ret: ReturnType, + pub method_set_all_from_bytes: JMethodID, + pub method_set_all_from_bytes_ret: ReturnType, } impl<'a> CometMetricNode<'a> { @@ -47,6 +49,8 @@ impl<'a> CometMetricNode<'a> { method_get_child_node_ret: ReturnType::Object, method_set: env.get_method_id(Self::JVM_CLASS, "set", "(Ljava/lang/String;J)V")?, method_set_ret: ReturnType::Primitive(Primitive::Void), + method_set_all_from_bytes: env.get_method_id(Self::JVM_CLASS, "set_all_from_bytes", "([B)V")?, + method_set_all_from_bytes_ret: ReturnType::Primitive(Primitive::Void), class, }) } diff --git a/native/proto/build.rs b/native/proto/build.rs index e707f0c3b9..ba3d12b382 100644 --- a/native/proto/build.rs +++ b/native/proto/build.rs @@ -30,6 +30,7 @@ fn main() -> Result<()> { prost_build::Config::new().out_dir(out_dir).compile_protos( &[ "src/proto/expr.proto", + "src/proto/metric.proto", "src/proto/partitioning.proto", "src/proto/operator.proto", ], diff --git a/native/proto/src/lib.rs b/native/proto/src/lib.rs index 266bf62dbc..ed24440360 100644 --- a/native/proto/src/lib.rs +++ b/native/proto/src/lib.rs @@ -36,3 +36,9 @@ pub mod spark_partitioning { pub mod spark_operator { include!(concat!("generated", "/spark.spark_operator.rs")); } + +// Include generated modules from .proto files. +#[allow(missing_docs)] +pub mod spark_metric { + include!(concat!("generated", "/spark.spark_metric.rs")); +} diff --git a/native/proto/src/proto/metric.proto b/native/proto/src/proto/metric.proto new file mode 100644 index 0000000000..f026e505ae --- /dev/null +++ b/native/proto/src/proto/metric.proto @@ -0,0 +1,29 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + + + +syntax = "proto3"; + +package spark.spark_metric; + +option java_package = "org.apache.comet.serde"; + +message NativeMetricNode { + map metrics = 1; + repeated NativeMetricNode children = 2; +} diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala index 53370a03b7..41490607d6 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala @@ -19,11 +19,15 @@ package org.apache.spark.sql.comet +import scala.collection.JavaConverters._ + import org.apache.spark.SparkContext import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.comet.serde.Metric + /** * A node carrying SQL metrics from SparkPlan, and metrics of its children. Native code will call * [[getChildNode]] and [[set]] to update the metrics. @@ -65,6 +69,20 @@ case class CometMetricNode(metrics: Map[String, SQLMetric], children: Seq[CometM logDebug(s"Non-existing metric: $metricName. Ignored") } } + + private def set_all(metricNode: Metric.NativeMetricNode): Unit = { + metricNode.getMetricsMap.forEach((name, value) => { + set(name, value) + }) + metricNode.getChildrenList.asScala.zip(children).foreach { case (child, childNode) => + childNode.set_all(child) + } + } + + def set_all_from_bytes(bytes: Array[Byte]): Unit = { + val metricNode = Metric.NativeMetricNode.parseFrom(bytes) + set_all(metricNode) + } } object CometMetricNode { From 958476b8f2c27601a5051d4897d164214db63a3c Mon Sep 17 00:00:00 2001 From: wforget <643348094@qq.com> Date: Thu, 23 Jan 2025 12:08:26 +0800 Subject: [PATCH 2/4] fix style --- native/core/src/execution/metrics/utils.rs | 10 +++------- native/core/src/jvm_bridge/comet_metric_node.rs | 6 +++++- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/native/core/src/execution/metrics/utils.rs b/native/core/src/execution/metrics/utils.rs index cff5a5ad36..2beb5b80c4 100644 --- a/native/core/src/execution/metrics/utils.rs +++ b/native/core/src/execution/metrics/utils.rs @@ -16,16 +16,13 @@ // under the License. use crate::execution::spark_plan::SparkPlan; -use crate::{ - errors::CometError, - jvm_bridge::jni_call, -}; +use crate::{errors::CometError, jvm_bridge::jni_call}; use datafusion::physical_plan::metrics::MetricValue; use datafusion_comet_proto::spark_metric::NativeMetricNode; use jni::{objects::JObject, JNIEnv}; +use prost::Message; use std::collections::HashMap; use std::sync::Arc; -use prost::Message; /// Updates the metrics of a CometMetricNode. This function is called recursively to /// update the metrics of all the children nodes. The metrics are pulled from the @@ -33,7 +30,7 @@ use prost::Message; pub fn update_comet_metric( env: &mut JNIEnv, metric_node: &JObject, - spark_plan: &Arc + spark_plan: &Arc, ) -> Result<(), CometError> { unsafe { let native_metric = to_native_metric_node(spark_plan); @@ -77,7 +74,6 @@ pub fn to_native_metric_node(spark_plan: &Arc) -> Result CometMetricNode<'a> { method_get_child_node_ret: ReturnType::Object, method_set: env.get_method_id(Self::JVM_CLASS, "set", "(Ljava/lang/String;J)V")?, method_set_ret: ReturnType::Primitive(Primitive::Void), - method_set_all_from_bytes: env.get_method_id(Self::JVM_CLASS, "set_all_from_bytes", "([B)V")?, + method_set_all_from_bytes: env.get_method_id( + Self::JVM_CLASS, + "set_all_from_bytes", + "([B)V", + )?, method_set_all_from_bytes_ret: ReturnType::Primitive(Primitive::Void), class, }) From 8c5724d62607e5cf3f1b1a8552dfcc404bb29a2b Mon Sep 17 00:00:00 2001 From: wforget <643348094@qq.com> Date: Thu, 23 Jan 2025 13:23:12 +0800 Subject: [PATCH 3/4] fix --- native/core/src/execution/metrics/utils.rs | 2 +- native/core/src/jvm_bridge/mod.rs | 8 -------- 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/native/core/src/execution/metrics/utils.rs b/native/core/src/execution/metrics/utils.rs index 2beb5b80c4..0836a2b57b 100644 --- a/native/core/src/execution/metrics/utils.rs +++ b/native/core/src/execution/metrics/utils.rs @@ -34,7 +34,7 @@ pub fn update_comet_metric( ) -> Result<(), CometError> { unsafe { let native_metric = to_native_metric_node(spark_plan); - let jbytes = env.byte_array_from_slice(&*native_metric?.encode_to_vec())?; + let jbytes = env.byte_array_from_slice(&native_metric?.encode_to_vec())?; jni_call!(env, comet_metric_node(metric_node).set_all_from_bytes(&jbytes) -> ())?; } Ok(()) diff --git a/native/core/src/jvm_bridge/mod.rs b/native/core/src/jvm_bridge/mod.rs index 5fc0a55e3e..b863268945 100644 --- a/native/core/src/jvm_bridge/mod.rs +++ b/native/core/src/jvm_bridge/mod.rs @@ -46,13 +46,6 @@ macro_rules! jvalues { }} } -/// Macro for create a new JNI string. -macro_rules! jni_new_string { - ($env:expr, $value:expr) => {{ - $crate::jvm_bridge::jni_map_error!($env, $env.new_string($value)) - }}; -} - /// Macro for calling a JNI method. /// The syntax is: /// jni_call!(env, comet_metric_node(metric_node).add(jname, value) -> ())?; @@ -173,7 +166,6 @@ macro_rules! jni_new_global_ref { pub(crate) use jni_call; pub(crate) use jni_map_error; pub(crate) use jni_new_global_ref; -pub(crate) use jni_new_string; pub(crate) use jni_static_call; pub(crate) use jvalues; From 642c7376c4749fc31b2fdffa67e2919fc47886ae Mon Sep 17 00:00:00 2001 From: wforget <643348094@qq.com> Date: Thu, 23 Jan 2025 14:52:24 +0800 Subject: [PATCH 4/4] fix --- native/core/src/execution/jni_api.rs | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 1c006c04ad..6fc5cfedb2 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -584,9 +584,13 @@ pub extern "system" fn Java_org_apache_comet_Native_releasePlan( /// Updates the metrics of the query plan. fn update_metrics(env: &mut JNIEnv, exec_context: &mut ExecutionContext) -> CometResult<()> { - let native_query = exec_context.root_op.as_ref().unwrap(); - let metrics = exec_context.metrics.as_obj(); - update_comet_metric(env, metrics, native_query) + if exec_context.root_op.is_some() { + let native_query = exec_context.root_op.as_ref().unwrap(); + let metrics = exec_context.metrics.as_obj(); + update_comet_metric(env, metrics, native_query) + } else { + Ok(()) + } } fn convert_datatype_arrays(