From bd9d1563c3942dff859ca5531b9aa34c699f0bdf Mon Sep 17 00:00:00 2001 From: Bohan Zhang Date: Thu, 16 Mar 2023 17:28:09 +0800 Subject: [PATCH 1/4] fix: load pk from the downstream instead of Risingwave (#8457) Signed-off-by: tabVersion --- .github/workflows/intergration_tests.yml | 2 +- .../mysql-sink/mysql_prepare.sql | 2 +- .../risingwave/connector/api/TableSchema.java | 2 + .../com/risingwave/connector/JDBCSink.java | 123 +++++++++++++++--- 4 files changed, 106 insertions(+), 23 deletions(-) diff --git a/.github/workflows/intergration_tests.yml b/.github/workflows/intergration_tests.yml index 74b5d09a81191..106783e792d39 100644 --- a/.github/workflows/intergration_tests.yml +++ b/.github/workflows/intergration_tests.yml @@ -39,7 +39,7 @@ jobs: - schema-registry - mysql-cdc - postgres-cdc - #- mysql-sink + - mysql-sink - postgres-sink - iceberg-sink format: ["json", "protobuf"] diff --git a/integration_tests/mysql-sink/mysql_prepare.sql b/integration_tests/mysql-sink/mysql_prepare.sql index cac57c699a154..ded9b4cec97cd 100644 --- a/integration_tests/mysql-sink/mysql_prepare.sql +++ b/integration_tests/mysql-sink/mysql_prepare.sql @@ -1,4 +1,4 @@ CREATE TABLE target_count ( - target_id VARCHAR(128), + target_id VARCHAR(128) primary key, target_count BIGINT ); \ No newline at end of file diff --git a/java/connector-node/connector-api/src/main/java/com/risingwave/connector/api/TableSchema.java b/java/connector-node/connector-api/src/main/java/com/risingwave/connector/api/TableSchema.java index 053ba1e329920..5d1016c95ea49 100644 --- a/java/connector-node/connector-api/src/main/java/com/risingwave/connector/api/TableSchema.java +++ b/java/connector-node/connector-api/src/main/java/com/risingwave/connector/api/TableSchema.java @@ -102,6 +102,8 @@ public static TableSchema fromProto(ConnectorServiceProto.TableSchema tableSchem .collect(Collectors.toList())); } + /** @deprecated pk here is from Risingwave, it may not match the pk in the database */ + @Deprecated public List getPrimaryKeys() { return primaryKeys; } diff --git a/java/connector-node/risingwave-sink-jdbc/src/main/java/com/risingwave/connector/JDBCSink.java b/java/connector-node/risingwave-sink-jdbc/src/main/java/com/risingwave/connector/JDBCSink.java index 43f99e5119430..9c249842ca09e 100644 --- a/java/connector-node/risingwave-sink-jdbc/src/main/java/com/risingwave/connector/JDBCSink.java +++ b/java/connector-node/risingwave-sink-jdbc/src/main/java/com/risingwave/connector/JDBCSink.java @@ -20,7 +20,9 @@ import com.risingwave.proto.Data; import io.grpc.Status; import java.sql.*; +import java.util.ArrayList; import java.util.Iterator; +import java.util.List; import java.util.stream.Collectors; import java.util.stream.IntStream; import org.slf4j.Logger; @@ -30,10 +32,13 @@ public class JDBCSink extends SinkBase { public static final String INSERT_TEMPLATE = "INSERT INTO %s (%s) VALUES (%s)"; private static final String DELETE_TEMPLATE = "DELETE FROM %s WHERE %s"; private static final String UPDATE_TEMPLATE = "UPDATE %s SET %s WHERE %s"; + private static final String ERROR_REPORT_TEMPLATE = "Error when exec %s, message %s"; private final String tableName; private final Connection conn; private final String jdbcUrl; + private final List pkColumnNames; + public static final String JDBC_COLUMN_NAME_KEY = "COLUMN_NAME"; private String updateDeleteConditionBuffer; private Object[] updateDeleteValueBuffer; @@ -48,9 +53,30 @@ public JDBCSink(String tableName, String jdbcUrl, TableSchema tableSchema) { try { this.conn = DriverManager.getConnection(jdbcUrl); this.conn.setAutoCommit(false); + this.pkColumnNames = getPkColumnNames(conn, tableName); } catch (SQLException e) { - throw Status.INTERNAL.withCause(e).asRuntimeException(); + throw Status.INTERNAL + .withDescription( + String.format(ERROR_REPORT_TEMPLATE, e.getSQLState(), e.getMessage())) + .asRuntimeException(); + } + } + + private static List getPkColumnNames(Connection conn, String tableName) { + List pkColumnNames = new ArrayList<>(); + try { + var pks = conn.getMetaData().getPrimaryKeys(null, null, tableName); + while (pks.next()) { + pkColumnNames.add(pks.getString(JDBC_COLUMN_NAME_KEY)); + } + } catch (SQLException e) { + throw Status.INTERNAL + .withDescription( + String.format(ERROR_REPORT_TEMPLATE, e.getSQLState(), e.getMessage())) + .asRuntimeException(); } + LOG.info("detected pk {}", pkColumnNames); + return pkColumnNames; } public JDBCSink(Connection conn, TableSchema tableSchema, String tableName) { @@ -58,6 +84,7 @@ public JDBCSink(Connection conn, TableSchema tableSchema, String tableName) { this.tableName = tableName; this.jdbcUrl = null; this.conn = conn; + this.pkColumnNames = getPkColumnNames(conn, tableName); } private PreparedStatement prepareStatement(SinkRow row) { @@ -79,35 +106,75 @@ private PreparedStatement prepareStatement(SinkRow row) { } return stmt; } catch (SQLException e) { - throw io.grpc.Status.INTERNAL.withCause(e).asRuntimeException(); + throw io.grpc.Status.INTERNAL + .withDescription( + String.format( + ERROR_REPORT_TEMPLATE, e.getSQLState(), e.getMessage())) + .asRuntimeException(); } case DELETE: - String deleteCondition = - getTableSchema().getPrimaryKeys().stream() - .map(key -> key + " = ?") - .collect(Collectors.joining(" AND ")); + String deleteCondition; + if (this.pkColumnNames.isEmpty()) { + deleteCondition = + IntStream.range(0, getTableSchema().getNumColumns()) + .mapToObj( + index -> + getTableSchema().getColumnNames()[index] + + " = ?") + .collect(Collectors.joining(" AND ")); + } else { + deleteCondition = + this.pkColumnNames.stream() + .map(key -> key + " = ?") + .collect(Collectors.joining(" AND ")); + } String deleteStmt = String.format(DELETE_TEMPLATE, tableName, deleteCondition); try { int placeholderIdx = 1; PreparedStatement stmt = conn.prepareStatement(deleteStmt, Statement.RETURN_GENERATED_KEYS); - for (String primaryKey : getTableSchema().getPrimaryKeys()) { + for (String primaryKey : this.pkColumnNames) { Object fromRow = getTableSchema().getFromRow(primaryKey, row); stmt.setObject(placeholderIdx++, fromRow); } return stmt; } catch (SQLException e) { - throw Status.INTERNAL.withCause(e).asRuntimeException(); + throw Status.INTERNAL + .withDescription( + String.format( + ERROR_REPORT_TEMPLATE, e.getSQLState(), e.getMessage())) + .asRuntimeException(); } case UPDATE_DELETE: - updateDeleteConditionBuffer = - getTableSchema().getPrimaryKeys().stream() - .map(key -> key + " = ?") - .collect(Collectors.joining(" AND ")); - updateDeleteValueBuffer = - getTableSchema().getPrimaryKeys().stream() - .map(key -> getTableSchema().getFromRow(key, row)) - .toArray(); + if (this.pkColumnNames.isEmpty()) { + updateDeleteConditionBuffer = + IntStream.range(0, getTableSchema().getNumColumns()) + .mapToObj( + index -> + getTableSchema().getColumnNames()[index] + + " = ?") + .collect(Collectors.joining(" AND ")); + updateDeleteValueBuffer = + IntStream.range(0, getTableSchema().getNumColumns()) + .mapToObj( + index -> + getTableSchema() + .getFromRow( + getTableSchema() + .getColumnNames()[ + index], + row)) + .toArray(); + } else { + updateDeleteConditionBuffer = + this.pkColumnNames.stream() + .map(key -> key + " = ?") + .collect(Collectors.joining(" AND ")); + updateDeleteValueBuffer = + this.pkColumnNames.stream() + .map(key -> getTableSchema().getFromRow(key, row)) + .toArray(); + } LOG.debug( "update delete condition: {} on values {}", updateDeleteConditionBuffer, @@ -144,7 +211,11 @@ private PreparedStatement prepareStatement(SinkRow row) { updateDeleteValueBuffer = null; return stmt; } catch (SQLException e) { - throw Status.INTERNAL.withCause(e).asRuntimeException(); + throw Status.INTERNAL + .withDescription( + String.format( + ERROR_REPORT_TEMPLATE, e.getSQLState(), e.getMessage())) + .asRuntimeException(); } default: throw Status.INVALID_ARGUMENT @@ -163,10 +234,14 @@ public void write(Iterator rows) { } if (stmt != null) { try { - LOG.debug("Executing statement: " + stmt); + LOG.debug("Executing statement: {}", stmt); stmt.executeUpdate(); } catch (SQLException e) { - throw Status.INTERNAL.withCause(e).asRuntimeException(); + throw Status.INTERNAL + .withDescription( + String.format( + ERROR_REPORT_TEMPLATE, e.getSQLState(), e.getMessage())) + .asRuntimeException(); } } else { throw Status.INTERNAL @@ -187,7 +262,10 @@ public void sync() { try { conn.commit(); } catch (SQLException e) { - throw io.grpc.Status.INTERNAL.withCause(e).asRuntimeException(); + throw io.grpc.Status.INTERNAL + .withDescription( + String.format(ERROR_REPORT_TEMPLATE, e.getSQLState(), e.getMessage())) + .asRuntimeException(); } } @@ -196,7 +274,10 @@ public void drop() { try { conn.close(); } catch (SQLException e) { - throw io.grpc.Status.INTERNAL.withCause(e).asRuntimeException(); + throw io.grpc.Status.INTERNAL + .withDescription( + String.format(ERROR_REPORT_TEMPLATE, e.getSQLState(), e.getMessage())) + .asRuntimeException(); } } From 075d50a4938bab239b09fadb5b5e5f1963db6139 Mon Sep 17 00:00:00 2001 From: ZENOTME <43447882+ZENOTME@users.noreply.github.com> Date: Thu, 16 Mar 2023 17:29:52 +0800 Subject: [PATCH 2/4] feat(frontend): support BATCH_PARALLELISM (#8552) --- e2e_test/batch/basic/join.slt.part | 19 +++++++++++++++ src/common/src/session_config/mod.rs | 24 ++++++++++++++++++- src/frontend/src/handler/explain.rs | 1 + src/frontend/src/handler/query.rs | 1 + .../src/scheduler/distributed/query.rs | 1 + src/frontend/src/scheduler/plan_fragmenter.rs | 15 ++++++++++++ .../src/scheduler/worker_node_manager.rs | 10 ++++++++ 7 files changed, 70 insertions(+), 1 deletion(-) diff --git a/e2e_test/batch/basic/join.slt.part b/e2e_test/batch/basic/join.slt.part index 1f7daba0f46c5..feeb793ba9e36 100644 --- a/e2e_test/batch/basic/join.slt.part +++ b/e2e_test/batch/basic/join.slt.part @@ -32,6 +32,25 @@ select * from t1 join t2 using(v1) join t3 using(v2); ---- 2 1 3 3 +statement ok +set batch_parallelism = 1; + +query IIIIII +select * from t1 join t2 using(v1) join t3 using(v2); +---- +2 1 3 3 + +statement ok +set batch_parallelism = 1000; + +query IIIIII +select * from t1 join t2 using(v1) join t3 using(v2); +---- +2 1 3 3 + +statement ok +set batch_parallelism = 0; + statement ok create index i1 on t1(v1) include(v2); diff --git a/src/common/src/session_config/mod.rs b/src/common/src/session_config/mod.rs index f24818f3020c8..d715de8985091 100644 --- a/src/common/src/session_config/mod.rs +++ b/src/common/src/session_config/mod.rs @@ -17,6 +17,7 @@ mod search_path; mod transaction_isolation_level; mod visibility_mode; +use std::num::NonZeroU64; use std::ops::Deref; use chrono_tz::Tz; @@ -33,7 +34,7 @@ use crate::util::epoch::Epoch; // This is a hack, &'static str is not allowed as a const generics argument. // TODO: refine this using the adt_const_params feature. -const CONFIG_KEYS: [&str; 20] = [ +const CONFIG_KEYS: [&str; 21] = [ "RW_IMPLICIT_FLUSH", "CREATE_COMPACTION_GROUP_FOR_MV", "QUERY_MODE", @@ -54,6 +55,7 @@ const CONFIG_KEYS: [&str; 20] = [ "RW_FORCE_TWO_PHASE_AGG", "RW_ENABLE_SHARE_PLAN", "INTERVALSTYLE", + "BATCH_PARALLELISM", ]; // MUST HAVE 1v1 relationship to CONFIG_KEYS. e.g. CONFIG_KEYS[IMPLICIT_FLUSH] = @@ -78,6 +80,7 @@ const ENABLE_TWO_PHASE_AGG: usize = 16; const FORCE_TWO_PHASE_AGG: usize = 17; const RW_ENABLE_SHARE_PLAN: usize = 18; const INTERVAL_STYLE: usize = 19; +const BATCH_PARALLELISM: usize = 20; trait ConfigEntry: Default + for<'a> TryFrom<&'a [&'a str], Error = RwError> { fn entry_name() -> &'static str; @@ -278,6 +281,7 @@ type EnableTwoPhaseAgg = ConfigBool; type ForceTwoPhaseAgg = ConfigBool; type EnableSharePlan = ConfigBool; type IntervalStyle = ConfigString; +type BatchParallelism = ConfigU64; #[derive(Derivative)] #[derivative(Default)] @@ -354,6 +358,8 @@ pub struct ConfigMap { /// see interval_style: IntervalStyle, + + batch_parallelism: BatchParallelism, } impl ConfigMap { @@ -410,6 +416,8 @@ impl ConfigMap { self.enable_share_plan = val.as_slice().try_into()?; } else if key.eq_ignore_ascii_case(IntervalStyle::entry_name()) { self.interval_style = val.as_slice().try_into()?; + } else if key.eq_ignore_ascii_case(BatchParallelism::entry_name()) { + self.batch_parallelism = val.as_slice().try_into()?; } else { return Err(ErrorCode::UnrecognizedConfigurationParameter(key.to_string()).into()); } @@ -458,6 +466,8 @@ impl ConfigMap { Ok(self.enable_share_plan.to_string()) } else if key.eq_ignore_ascii_case(IntervalStyle::entry_name()) { Ok(self.interval_style.to_string()) + } else if key.eq_ignore_ascii_case(BatchParallelism::entry_name()) { + Ok(self.batch_parallelism.to_string()) } else { Err(ErrorCode::UnrecognizedConfigurationParameter(key.to_string()).into()) } @@ -560,6 +570,11 @@ impl ConfigMap { setting : self.interval_style.to_string(), description : String::from("It is typically set by an application upon connection to the server.") }, + VariableInfo{ + name : BatchParallelism::entry_name().to_lowercase(), + setting : self.batch_parallelism.to_string(), + description: String::from("Sets the parallelism for batch. If 0, use default value.") + }, ] } @@ -648,4 +663,11 @@ impl ConfigMap { pub fn get_interval_style(&self) -> &str { &self.interval_style } + + pub fn get_batch_parallelism(&self) -> Option { + if self.batch_parallelism.0 != 0 { + return Some(NonZeroU64::new(self.batch_parallelism.0).unwrap()); + } + None + } } diff --git a/src/frontend/src/handler/explain.rs b/src/frontend/src/handler/explain.rs index 5375a128b3749..c33e52c86d0e2 100644 --- a/src/frontend/src/handler/explain.rs +++ b/src/frontend/src/handler/explain.rs @@ -145,6 +145,7 @@ pub async fn handle_explain( plan_fragmenter = Some(BatchPlanFragmenter::new( session.env().worker_node_manager_ref(), session.env().catalog_reader().clone(), + session.config().get_batch_parallelism(), plan, )?); } diff --git a/src/frontend/src/handler/query.rs b/src/frontend/src/handler/query.rs index 04bc203846667..4d437b304515b 100644 --- a/src/frontend/src/handler/query.rs +++ b/src/frontend/src/handler/query.rs @@ -162,6 +162,7 @@ pub async fn handle_query( let plan_fragmenter = BatchPlanFragmenter::new( session.env().worker_node_manager_ref(), session.env().catalog_reader().clone(), + session.config().get_batch_parallelism(), plan, )?; context.append_notice(&mut notice); diff --git a/src/frontend/src/scheduler/distributed/query.rs b/src/frontend/src/scheduler/distributed/query.rs index 4689f2bdd8748..e4aa85c08a7eb 100644 --- a/src/frontend/src/scheduler/distributed/query.rs +++ b/src/frontend/src/scheduler/distributed/query.rs @@ -660,6 +660,7 @@ pub(crate) mod tests { let fragmenter = BatchPlanFragmenter::new( worker_node_manager, catalog_reader, + None, batch_exchange_node.clone(), ) .unwrap(); diff --git a/src/frontend/src/scheduler/plan_fragmenter.rs b/src/frontend/src/scheduler/plan_fragmenter.rs index 26f27d912d0ac..cb3ef5a02610b 100644 --- a/src/frontend/src/scheduler/plan_fragmenter.rs +++ b/src/frontend/src/scheduler/plan_fragmenter.rs @@ -12,8 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::cmp::min; use std::collections::{HashMap, HashSet}; use std::fmt::{Debug, Formatter}; +use std::num::NonZeroU64; use std::sync::Arc; use anyhow::anyhow; @@ -120,6 +122,12 @@ pub struct BatchPlanFragmenter { worker_node_manager: WorkerNodeManagerRef, catalog_reader: CatalogReader, + /// if batch_parallelism is None, it means no limit, we will use the available nodes count as + /// parallelism. + /// if batch_parallelism is Some(num), we will use the min(num, the available + /// nodes count) as parallelism. + batch_parallelism: Option, + stage_graph_builder: Option, stage_graph: Option, } @@ -136,6 +144,7 @@ impl BatchPlanFragmenter { pub fn new( worker_node_manager: WorkerNodeManagerRef, catalog_reader: CatalogReader, + batch_parallelism: Option, batch_node: PlanRef, ) -> SchedulerResult { let mut plan_fragmenter = Self { @@ -144,6 +153,7 @@ impl BatchPlanFragmenter { next_stage_id: 0, worker_node_manager, catalog_reader, + batch_parallelism, stage_graph: None, }; plan_fragmenter.split_into_stage(batch_node)?; @@ -751,6 +761,11 @@ impl BatchPlanFragmenter { lookup_join_parallelism } else if source_info.is_some() { 0 + } else if let Some(num) = self.batch_parallelism { + min( + num.get() as usize, + self.worker_node_manager.schedule_unit_count(), + ) } else { self.worker_node_manager.worker_node_count() } diff --git a/src/frontend/src/scheduler/worker_node_manager.rs b/src/frontend/src/scheduler/worker_node_manager.rs index 913e83678a64d..977a591dd6aec 100644 --- a/src/frontend/src/scheduler/worker_node_manager.rs +++ b/src/frontend/src/scheduler/worker_node_manager.rs @@ -105,6 +105,16 @@ impl WorkerNodeManager { self.inner.read().unwrap().worker_nodes.len() } + pub fn schedule_unit_count(&self) -> usize { + self.inner + .read() + .unwrap() + .worker_nodes + .iter() + .map(|node| node.parallel_units.len()) + .sum() + } + /// If parallel unit ids is empty, the scheduler may fail to schedule any task and stuck at /// schedule next stage. If we do not return error in this case, needs more complex control /// logic above. Report in this function makes the schedule root fail reason more clear. From 14bfc62b95d0a5f5dda4a7e3122829d5f3d31f99 Mon Sep 17 00:00:00 2001 From: Shanicky Chen Date: Thu, 16 Mar 2023 17:42:50 +0800 Subject: [PATCH 3/4] chore: add is_visible in column for connector (#8592) --- src/connector/src/source/base.rs | 1 + .../src/source/datagen/source/generator.rs | 101 +++++++++++------- .../src/source/datagen/source/reader.rs | 27 +++-- src/connector/src/source/manager.rs | 21 ++++ src/source/src/connector_source.rs | 6 +- 5 files changed, 107 insertions(+), 49 deletions(-) diff --git a/src/connector/src/source/base.rs b/src/connector/src/source/base.rs index e2ddd01832c86..14f68f386782a 100644 --- a/src/connector/src/source/base.rs +++ b/src/connector/src/source/base.rs @@ -363,6 +363,7 @@ pub type DataType = risingwave_common::types::DataType; pub struct Column { pub name: String, pub data_type: DataType, + pub is_visible: bool, } /// Split id resides in every source message, use `Arc` to avoid copying. diff --git a/src/connector/src/source/datagen/source/generator.rs b/src/connector/src/source/datagen/source/generator.rs index 3d747921aef02..8c4bf59a7cd05 100644 --- a/src/connector/src/source/datagen/source/generator.rs +++ b/src/connector/src/source/datagen/source/generator.rs @@ -26,10 +26,16 @@ use risingwave_common::util::iter_util::ZipEqFast; use crate::source::{SourceFormat, SourceMessage, SourceMeta, SplitId, StreamChunkWithState}; +pub enum FieldDesc { + // field is invisible, generate None + Invisible, + Visible(FieldGeneratorImpl), +} + pub struct DatagenEventGenerator { // fields_map: HashMap, field_names: Vec, - fields_vec: Vec, + fields_vec: Vec, source_format: SourceFormat, data_types: Vec, offset: u64, @@ -46,7 +52,7 @@ pub struct DatagenMeta { impl DatagenEventGenerator { #[allow(clippy::too_many_arguments)] pub fn new( - fields_vec: Vec, + fields_vec: Vec, field_names: Vec, source_format: SourceFormat, data_types: Vec, @@ -96,16 +102,23 @@ impl DatagenEventGenerator { .iter() .zip_eq_fast(self.fields_vec.iter_mut()) { - let value = field_generator.generate_json(self.offset); - if value.is_null() { - reach_end = true; - tracing::info!( - "datagen split {} stop generate, offset {}", - self.split_id, - self.offset - ); - break 'outer; - } + let value = match field_generator { + FieldDesc::Invisible => continue, + FieldDesc::Visible(field_generator) => { + let value = field_generator.generate_json(self.offset); + if value.is_null() { + reach_end = true; + tracing::info!( + "datagen split {} stop generate, offset {}", + self.split_id, + self.offset + ); + break 'outer; + } + value + } + }; + map.insert(name.clone(), value); } Bytes::from(serde_json::Value::from(map).to_string()) @@ -159,16 +172,24 @@ impl DatagenEventGenerator { 'outer: for _ in 0..num_rows_to_generate { let mut row = Vec::with_capacity(self.fields_vec.len()); for field_generator in &mut self.fields_vec { - let datum = field_generator.generate_datum(self.offset); - if datum.is_none() { - reach_end = true; - tracing::info!( - "datagen split {} stop generate, offset {}", - self.split_id, - self.offset - ); - break 'outer; - } + let datum = match field_generator { + FieldDesc::Invisible => None, + FieldDesc::Visible(field_generator) => { + let datum = field_generator.generate_datum(self.offset); + if datum.is_none() { + reach_end = true; + tracing::info!( + "datagen split {} stop generate, offset {}", + self.split_id, + self.offset + ); + break 'outer; + }; + + datum + } + }; + row.push(datum); } @@ -214,22 +235,26 @@ mod tests { let data_types = vec![DataType::Int32, DataType::Float32]; let fields_vec = vec![ - FieldGeneratorImpl::with_number_sequence( - data_types[0].clone(), - Some(start.to_string()), - Some(end.to_string()), - split_index, - split_num, - ) - .unwrap(), - FieldGeneratorImpl::with_number_sequence( - data_types[1].clone(), - Some(start.to_string()), - Some(end.to_string()), - split_index, - split_num, - ) - .unwrap(), + FieldDesc::Visible( + FieldGeneratorImpl::with_number_sequence( + data_types[0].clone(), + Some(start.to_string()), + Some(end.to_string()), + split_index, + split_num, + ) + .unwrap(), + ), + FieldDesc::Visible( + FieldGeneratorImpl::with_number_sequence( + data_types[1].clone(), + Some(start.to_string()), + Some(end.to_string()), + split_index, + split_num, + ) + .unwrap(), + ), ]; let generator = DatagenEventGenerator::new( diff --git a/src/connector/src/source/datagen/source/reader.rs b/src/connector/src/source/datagen/source/reader.rs index c0780c5604622..a3f17c1ca4b35 100644 --- a/src/connector/src/source/datagen/source/reader.rs +++ b/src/connector/src/source/datagen/source/reader.rs @@ -26,7 +26,7 @@ use crate::impl_common_split_reader_logic; use crate::parser::{ParserConfig, SpecificParserConfig}; use crate::source::data_gen_util::spawn_data_generation_stream; use crate::source::datagen::source::SEQUENCE_FIELD_KIND; -use crate::source::datagen::{DatagenProperties, DatagenSplit}; +use crate::source::datagen::{DatagenProperties, DatagenSplit, FieldDesc}; use crate::source::{ BoxSourceStream, BoxSourceWithStateStream, Column, DataType, SourceContextRef, SplitId, SplitImpl, SplitMetaData, SplitReader, @@ -106,13 +106,18 @@ impl SplitReader for DatagenSplitReader { for column in columns { // let name = column.name.clone(); let data_type = column.data_type.clone(); - let gen = generator_from_data_type( - column.data_type, - &fields_option_map, - &column.name, - split_index, - split_num, - )?; + + let gen = if column.is_visible { + FieldDesc::Visible(generator_from_data_type( + column.data_type, + &fields_option_map, + &column.name, + split_index, + split_num, + )?) + } else { + FieldDesc::Invisible + }; fields_vec.push(gen); data_types.push(data_type); field_names.push(column.name); @@ -284,14 +289,17 @@ mod tests { Column { name: "random_int".to_string(), data_type: DataType::Int32, + is_visible: true, }, Column { name: "random_float".to_string(), data_type: DataType::Float32, + is_visible: true, }, Column { name: "sequence_int".to_string(), data_type: DataType::Int32, + is_visible: true, }, Column { name: "struct".to_string(), @@ -299,6 +307,7 @@ mod tests { fields: vec![DataType::Int32], field_names: vec!["random_int".to_string()], })), + is_visible: true, }, ]; let state = vec![SplitImpl::Datagen(DatagenSplit { @@ -364,10 +373,12 @@ mod tests { Column { name: "_".to_string(), data_type: DataType::Int64, + is_visible: true, }, Column { name: "random_int".to_string(), data_type: DataType::Int32, + is_visible: true, }, ]; let state = vec![SplitImpl::Datagen(DatagenSplit { diff --git a/src/connector/src/source/manager.rs b/src/connector/src/source/manager.rs index 25afb4f55917c..79df615980a4e 100644 --- a/src/connector/src/source/manager.rs +++ b/src/connector/src/source/manager.rs @@ -48,6 +48,11 @@ impl SourceColumnDesc { is_meta: false, } } + + #[inline] + pub fn is_visible(&self) -> bool { + !self.is_row_id && !self.is_meta + } } impl From<&ColumnDesc> for SourceColumnDesc { @@ -75,3 +80,19 @@ impl From<&SourceColumnDesc> for ColumnDesc { } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_is_visible() { + let mut c = SourceColumnDesc::simple("a", DataType::Int32, ColumnId::new(0)); + assert!(c.is_visible()); + c.is_row_id = true; + assert!(!c.is_visible()); + c.is_row_id = false; + c.is_meta = true; + assert!(!c.is_visible()); + } +} diff --git a/src/source/src/connector_source.rs b/src/source/src/connector_source.rs index a131e4c9a3928..979e905b0e596 100644 --- a/src/source/src/connector_source.rs +++ b/src/source/src/connector_source.rs @@ -101,10 +101,10 @@ impl ConnectorSource { let data_gen_columns = Some( columns .iter() - .cloned() .map(|col| Column { - name: col.name, - data_type: col.data_type, + name: col.name.clone(), + data_type: col.data_type.clone(), + is_visible: col.is_visible(), }) .collect_vec(), ); From 9b89bb0946874ccb5553c9d82ec3e4ac5fb654da Mon Sep 17 00:00:00 2001 From: stonepage <40830455+st1page@users.noreply.github.com> Date: Thu, 16 Mar 2023 17:47:55 +0800 Subject: [PATCH 4/4] fix(optimizer): projectSet && overAgg should call input's predicate push down && prune col (#8588) --- .../tests/testdata/over_window_function.yaml | 2 +- .../tests/testdata/predicate_pushdown.yaml | 2 +- .../optimizer/plan_node/logical_over_agg.rs | 9 ++++++-- .../plan_node/logical_project_set.rs | 22 ++++++++++++------- 4 files changed, 23 insertions(+), 12 deletions(-) diff --git a/src/frontend/planner_test/tests/testdata/over_window_function.yaml b/src/frontend/planner_test/tests/testdata/over_window_function.yaml index f830ff4ad45bc..076c80d3eaccb 100644 --- a/src/frontend/planner_test/tests/testdata/over_window_function.yaml +++ b/src/frontend/planner_test/tests/testdata/over_window_function.yaml @@ -294,7 +294,7 @@ └─StreamHashAgg { group_key: [$expr1, $expr2, bid.supplier_id], aggs: [sum(bid.price), count] } └─StreamExchange { dist: HashShard($expr1, $expr2, bid.supplier_id) } └─StreamProject { exprs: [TumbleStart(bid.bidtime, '00:10:00':Interval) as $expr1, (TumbleStart(bid.bidtime, '00:10:00':Interval) + '00:10:00':Interval) as $expr2, bid.supplier_id, bid.price, bid._row_id] } - └─StreamTableScan { table: bid, columns: [bid.bidtime, bid.price, bid.item, bid.supplier_id, bid._row_id], pk: [bid._row_id], dist: UpstreamHashShard(bid._row_id) } + └─StreamTableScan { table: bid, columns: [bid.bidtime, bid.price, bid.supplier_id, bid._row_id], pk: [bid._row_id], dist: UpstreamHashShard(bid._row_id) } - before: - create_bid sql: | diff --git a/src/frontend/planner_test/tests/testdata/predicate_pushdown.yaml b/src/frontend/planner_test/tests/testdata/predicate_pushdown.yaml index 2c7c14c150eb5..585e555870eba 100644 --- a/src/frontend/planner_test/tests/testdata/predicate_pushdown.yaml +++ b/src/frontend/planner_test/tests/testdata/predicate_pushdown.yaml @@ -129,7 +129,7 @@ └─LogicalAgg { group_key: [t.v1, t.v2, t.v3], aggs: [count, count(1:Int32)] } └─LogicalProject { exprs: [t.v1, t.v2, t.v3, 1:Int32] } └─LogicalScan { table: t, columns: [t.v1, t.v2, t.v3], predicate: (t.v1 = 10:Int32) AND (t.v2 = 20:Int32) AND (t.v3 = 30:Int32) AND (t.v2 > t.v3) } -- name: filter project set transpose +- name: filter project set transpose TODO(https://github.com/risingwavelabs/risingwave/issues/8591) sql: | create table t(v1 int, v2 int, v3 int, arr int[]); with cte as (select v1, v2, v3, unnest(arr) as arr_unnested from t) select * from cte where v1=10 AND v2=20 AND v3=30 AND arr_unnested=30; diff --git a/src/frontend/src/optimizer/plan_node/logical_over_agg.rs b/src/frontend/src/optimizer/plan_node/logical_over_agg.rs index 043e459eb1ea8..fc4dfc99191c6 100644 --- a/src/frontend/src/optimizer/plan_node/logical_over_agg.rs +++ b/src/frontend/src/optimizer/plan_node/logical_over_agg.rs @@ -252,9 +252,14 @@ impl fmt::Display for LogicalOverAgg { } impl ColPrunable for LogicalOverAgg { - fn prune_col(&self, required_cols: &[usize], _ctx: &mut ColumnPruningContext) -> PlanRef { + fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef { let mapping = ColIndexMapping::with_remaining_columns(required_cols, self.schema().len()); - LogicalProject::with_mapping(self.clone().into(), mapping).into() + let new_input = { + let input = self.input(); + let required = (0..input.schema().len()).collect_vec(); + input.prune_col(&required, ctx) + }; + LogicalProject::with_mapping(self.clone_with_input(new_input).into(), mapping).into() } } diff --git a/src/frontend/src/optimizer/plan_node/logical_project_set.rs b/src/frontend/src/optimizer/plan_node/logical_project_set.rs index 44a86c36759f5..53c7cfa27e911 100644 --- a/src/frontend/src/optimizer/plan_node/logical_project_set.rs +++ b/src/frontend/src/optimizer/plan_node/logical_project_set.rs @@ -14,11 +14,12 @@ use std::fmt; +use itertools::Itertools; use risingwave_common::error::Result; use super::{ - generic, BatchProjectSet, ColPrunable, ExprRewritable, LogicalFilter, LogicalProject, PlanBase, - PlanRef, PlanTreeNodeUnary, PredicatePushdown, StreamProjectSet, ToBatch, ToStream, + gen_filter_and_pushdown, generic, BatchProjectSet, ColPrunable, ExprRewritable, LogicalProject, + PlanBase, PlanRef, PlanTreeNodeUnary, PredicatePushdown, StreamProjectSet, ToBatch, ToStream, }; use crate::expr::{Expr, ExprImpl, ExprRewriter, FunctionCall, InputRef, TableFunction}; use crate::optimizer::plan_node::{ @@ -237,10 +238,15 @@ impl fmt::Display for LogicalProjectSet { } impl ColPrunable for LogicalProjectSet { - fn prune_col(&self, required_cols: &[usize], _ctx: &mut ColumnPruningContext) -> PlanRef { - // TODO: column pruning for ProjectSet + fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef { + // TODO: column pruning for ProjectSet https://github.com/risingwavelabs/risingwave/issues/8593 let mapping = ColIndexMapping::with_remaining_columns(required_cols, self.schema().len()); - LogicalProject::with_mapping(self.clone().into(), mapping).into() + let new_input = { + let input = self.input(); + let required = (0..input.schema().len()).collect_vec(); + input.prune_col(&required, ctx) + }; + LogicalProject::with_mapping(self.clone_with_input(new_input).into(), mapping).into() } } @@ -264,10 +270,10 @@ impl PredicatePushdown for LogicalProjectSet { fn predicate_pushdown( &self, predicate: Condition, - _ctx: &mut PredicatePushdownContext, + ctx: &mut PredicatePushdownContext, ) -> PlanRef { - // TODO: predicate pushdown for ProjectSet - LogicalFilter::create(self.clone().into(), predicate) + // TODO: predicate pushdown https://github.com/risingwavelabs/risingwave/issues/8591 + gen_filter_and_pushdown(self, predicate, Condition::true_cond(), ctx) } }