diff --git a/proto/stream_plan.proto b/proto/stream_plan.proto index 1037d66e8ecfb..f6603ca17182d 100644 --- a/proto/stream_plan.proto +++ b/proto/stream_plan.proto @@ -2,6 +2,7 @@ syntax = "proto3"; package stream_plan; +import "catalog.proto"; import "common.proto"; import "data.proto"; import "expr.proto"; @@ -56,19 +57,21 @@ message MaterializeNode { message SimpleAggNode { repeated expr.AggCall agg_calls = 1; repeated uint32 distribution_keys = 2; - repeated uint32 table_ids = 3; + repeated catalog.Table internal_tables = 3; + map column_mapping = 4; // Whether to optimize for append only stream. // It is true when the input is append-only - bool is_append_only = 4; + bool is_append_only = 5; } message HashAggNode { repeated uint32 distribution_keys = 1; repeated expr.AggCall agg_calls = 2; - repeated uint32 table_ids = 3; + repeated catalog.Table internal_tables = 3; + map column_mapping = 4; // Whether to optimize for append only stream. // It is true when the input is append-only - bool is_append_only = 4; + bool is_append_only = 5; } message TopNNode { @@ -90,10 +93,10 @@ message HashJoinNode { // on-the-fly within the plan. // TODO: remove this in the future when we have a separate DeltaHashJoin node. bool is_delta_join = 5; - // Used for internal table states. Id of the left table. - uint32 left_table_id = 6; - // Used for internal table states. Id of the right table. - uint32 right_table_id = 7; + // Used for internal table states. + catalog.Table left_table = 6; + // Used for internal table states. + catalog.Table right_table = 7; repeated uint32 dist_key_l = 8; repeated uint32 dist_key_r = 9; // It is true when the input is append-only diff --git a/src/common/src/catalog/mod.rs b/src/common/src/catalog/mod.rs index b0dc7463f65d0..f00a1350bee24 100644 --- a/src/common/src/catalog/mod.rs +++ b/src/common/src/catalog/mod.rs @@ -45,6 +45,10 @@ impl DatabaseId { pub fn new(database_id: i32) -> Self { DatabaseId { database_id } } + + pub fn placeholder() -> i32 { + i32::MAX - 1 + } } #[derive(Clone, Debug, Default, Hash, PartialOrd, PartialEq, Eq)] @@ -60,6 +64,10 @@ impl SchemaId { schema_id, } } + + pub fn placeholder() -> i32 { + i32::MAX - 1 + } } #[derive(Clone, Copy, Debug, Default, Hash, PartialOrd, PartialEq, Eq)] diff --git a/src/frontend/src/optimizer/plan_node/logical_agg.rs b/src/frontend/src/optimizer/plan_node/logical_agg.rs index ad084c1761c8d..4a6d51cd5096f 100644 --- a/src/frontend/src/optimizer/plan_node/logical_agg.rs +++ b/src/frontend/src/optimizer/plan_node/logical_agg.rs @@ -17,9 +17,10 @@ use std::fmt; use fixedbitset::FixedBitSet; use itertools::Itertools; -use risingwave_common::catalog::{Field, Schema}; +use risingwave_common::catalog::{ColumnDesc, ColumnId, Field, OrderedColumnDesc, Schema, TableId}; use risingwave_common::error::{ErrorCode, Result}; use risingwave_common::types::DataType; +use risingwave_common::util::sort_util::OrderType; use risingwave_expr::expr::AggKind; use risingwave_pb::expr::AggCall as ProstAggCall; @@ -27,6 +28,8 @@ use super::{ BatchHashAgg, BatchSimpleAgg, ColPrunable, PlanBase, PlanRef, PlanTreeNodeUnary, PredicatePushdown, StreamHashAgg, StreamSimpleAgg, ToBatch, ToStream, }; +use crate::catalog::column_catalog::ColumnCatalog; +use crate::catalog::table_catalog::TableCatalog; use crate::expr::{AggCall, Expr, ExprImpl, ExprRewriter, ExprType, FunctionCall, InputRef}; use crate::optimizer::plan_node::{gen_filter_and_pushdown, LogicalProject}; use crate::optimizer::property::RequiredDist; @@ -120,6 +123,78 @@ pub struct LogicalAgg { input: PlanRef, } +impl LogicalAgg { + pub fn infer_internal_table_catalog(&self) -> (Vec, HashMap) { + let mut table_catalogs = vec![]; + let mut column_mapping = HashMap::new(); + let base = self.input.plan_base(); + let schema = &base.schema; + let fields = schema.fields(); + for agg_call in &self.agg_calls { + let mut internal_pk_indices = vec![]; + let mut columns = vec![]; + let mut order_desc = vec![]; + for &idx in &self.group_keys { + let column_id = columns.len() as i32; + internal_pk_indices.push(column_id as usize); // Currently our column index is same as column id + column_mapping.insert(idx, column_id); + let column_desc = ColumnDesc::from_field_with_column_id(&fields[idx], column_id); + columns.push(ColumnCatalog { + column_desc: column_desc.clone(), + is_hidden: false, + }); + order_desc.push(OrderedColumnDesc { + column_desc, + order: OrderType::Ascending, + }) + } + match agg_call.agg_kind { + AggKind::Min | AggKind::Max | AggKind::StringAgg => { + for input in &agg_call.inputs { + let column_id = columns.len() as i32; + column_mapping.insert(input.index, column_id); + columns.push(ColumnCatalog { + column_desc: ColumnDesc::from_field_with_column_id( + &fields[input.index], + column_id, + ), + is_hidden: false, + }); + } + } + AggKind::Sum + | AggKind::Count + | AggKind::RowCount + | AggKind::Avg + | AggKind::SingleValue + | AggKind::ApproxCountDistinct => { + columns.push(ColumnCatalog { + column_desc: ColumnDesc::unnamed( + ColumnId::new(columns.len() as i32), + agg_call.return_type.clone(), + ), + is_hidden: false, + }); + } + } + table_catalogs.push(TableCatalog { + id: TableId::placeholder(), + associated_source_id: None, + name: String::new(), + columns, + order_desc, + pks: internal_pk_indices, + is_index_on: None, + distribution_keys: base.dist.dist_column_indices().to_vec(), + appendonly: false, + owner: risingwave_common::catalog::DEFAULT_SUPPER_USER.to_string(), + vnode_mapping: None, + }); + } + (table_catalogs, column_mapping) + } +} + /// `ExprHandler` extracts agg calls and references to group columns from select list, in /// preparation for generating a plan like `LogicalProject - LogicalAgg - LogicalProject`. struct ExprHandler { diff --git a/src/frontend/src/optimizer/plan_node/stream_hash_agg.rs b/src/frontend/src/optimizer/plan_node/stream_hash_agg.rs index 2c5cde457c193..0680078308593 100644 --- a/src/frontend/src/optimizer/plan_node/stream_hash_agg.rs +++ b/src/frontend/src/optimizer/plan_node/stream_hash_agg.rs @@ -15,6 +15,7 @@ use std::fmt; use itertools::Itertools; +use risingwave_common::catalog::{DatabaseId, SchemaId}; use risingwave_pb::stream_plan::stream_node::NodeBody as ProstStreamNode; use super::logical_agg::PlanAggCall; @@ -91,7 +92,7 @@ impl_plan_tree_node_for_unary! { StreamHashAgg } impl ToStreamProst for StreamHashAgg { fn to_stream_prost_body(&self) -> ProstStreamNode { use risingwave_pb::stream_plan::*; - + let (internal_tables, column_mapping) = self.logical.infer_internal_table_catalog(); ProstStreamNode::HashAgg(HashAggNode { distribution_keys: self .distribution_keys() @@ -103,7 +104,19 @@ impl ToStreamProst for StreamHashAgg { .iter() .map(PlanAggCall::to_protobuf) .collect_vec(), - table_ids: vec![], + internal_tables: internal_tables + .into_iter() + .map(|table_catalog| { + table_catalog.to_prost( + SchemaId::placeholder() as u32, + DatabaseId::placeholder() as u32, + ) + }) + .collect_vec(), + column_mapping: column_mapping + .into_iter() + .map(|(k, v)| (k as u32, v)) + .collect(), is_append_only: self.input().append_only(), }) } diff --git a/src/frontend/src/optimizer/plan_node/stream_hash_join.rs b/src/frontend/src/optimizer/plan_node/stream_hash_join.rs index 92e8ae83a154b..d0a93ae3d5c6d 100644 --- a/src/frontend/src/optimizer/plan_node/stream_hash_join.rs +++ b/src/frontend/src/optimizer/plan_node/stream_hash_join.rs @@ -15,12 +15,16 @@ use std::fmt; use itertools::Itertools; +use risingwave_common::catalog::{ColumnDesc, DatabaseId, OrderedColumnDesc, SchemaId, TableId}; use risingwave_common::session_config::DELTA_JOIN; +use risingwave_common::util::sort_util::OrderType; use risingwave_pb::plan_common::JoinType; use risingwave_pb::stream_plan::stream_node::NodeBody; use risingwave_pb::stream_plan::HashJoinNode; use super::{LogicalJoin, PlanBase, PlanRef, PlanTreeNodeBinary, StreamDeltaJoin, ToStreamProst}; +use crate::catalog::column_catalog::ColumnCatalog; +use crate::catalog::table_catalog::TableCatalog; use crate::expr::Expr; use crate::optimizer::plan_node::EqJoinPredicate; use crate::optimizer::property::Distribution; @@ -214,6 +218,14 @@ impl ToStreamProst for StreamHashJoin { .map(|idx| *idx as u32) .collect_vec(), is_delta_join: self.is_delta, + left_table: Some(infer_internal_table_catalog(self.left()).to_prost( + SchemaId::placeholder() as u32, + DatabaseId::placeholder() as u32, + )), + right_table: Some(infer_internal_table_catalog(self.right()).to_prost( + SchemaId::placeholder() as u32, + DatabaseId::placeholder() as u32, + )), output_indices: self .logical .output_indices() @@ -221,7 +233,40 @@ impl ToStreamProst for StreamHashJoin { .map(|&x| x as u32) .collect(), is_append_only: self.is_append_only, - ..Default::default() }) } } + +fn infer_internal_table_catalog(input: PlanRef) -> TableCatalog { + let base = input.plan_base(); + let schema = &base.schema; + let pk_indices = &base.pk_indices; + let columns = schema + .fields() + .iter() + .map(|field| ColumnCatalog { + column_desc: ColumnDesc::from_field_without_column_id(field), + is_hidden: false, + }) + .collect_vec(); + let mut order_desc = vec![]; + for &idx in pk_indices { + order_desc.push(OrderedColumnDesc { + column_desc: columns[idx].column_desc.clone(), + order: OrderType::Ascending, + }); + } + TableCatalog { + id: TableId::placeholder(), + associated_source_id: None, + name: String::new(), + columns, + order_desc, + pks: pk_indices.clone(), + distribution_keys: base.dist.dist_column_indices().to_vec(), + is_index_on: None, + appendonly: input.append_only(), + owner: risingwave_common::catalog::DEFAULT_SUPPER_USER.to_string(), + vnode_mapping: None, + } +} diff --git a/src/frontend/src/optimizer/plan_node/stream_simple_agg.rs b/src/frontend/src/optimizer/plan_node/stream_simple_agg.rs index 7de45bf0f808d..a09bef5f4c928 100644 --- a/src/frontend/src/optimizer/plan_node/stream_simple_agg.rs +++ b/src/frontend/src/optimizer/plan_node/stream_simple_agg.rs @@ -15,6 +15,7 @@ use std::fmt; use itertools::Itertools; +use risingwave_common::catalog::{DatabaseId, SchemaId}; use risingwave_pb::stream_plan::stream_node::NodeBody as ProstStreamNode; use super::logical_agg::PlanAggCall; @@ -74,7 +75,7 @@ impl_plan_tree_node_for_unary! { StreamSimpleAgg } impl ToStreamProst for StreamSimpleAgg { fn to_stream_prost_body(&self) -> ProstStreamNode { use risingwave_pb::stream_plan::*; - + let (internal_tables, column_mapping) = self.logical.infer_internal_table_catalog(); // TODO: local or global simple agg? ProstStreamNode::GlobalSimpleAgg(SimpleAggNode { agg_calls: self @@ -89,7 +90,19 @@ impl ToStreamProst for StreamSimpleAgg { .iter() .map(|idx| *idx as u32) .collect_vec(), - table_ids: vec![], + internal_tables: internal_tables + .into_iter() + .map(|table_catalog| { + table_catalog.to_prost( + SchemaId::placeholder() as u32, + DatabaseId::placeholder() as u32, + ) + }) + .collect_vec(), + column_mapping: column_mapping + .into_iter() + .map(|(k, v)| (k as u32, v)) + .collect(), is_append_only: self.input().append_only(), }) } diff --git a/src/frontend/src/stream_fragmenter/mod.rs b/src/frontend/src/stream_fragmenter/mod.rs index 4ebad1650a7ec..9f1ecd08cde17 100644 --- a/src/frontend/src/stream_fragmenter/mod.rs +++ b/src/frontend/src/stream_fragmenter/mod.rs @@ -322,20 +322,24 @@ impl StreamFragmenter { NodeBody::HashJoin(hash_join_node) => { // Allocate local table id. It will be rewrite to global table id after get table id // offset from id generator. - hash_join_node.left_table_id = state.gen_table_id(); - hash_join_node.right_table_id = state.gen_table_id(); + if let Some(left_table) = &mut hash_join_node.left_table { + left_table.id = state.gen_table_id(); + } + if let Some(right_table) = &mut hash_join_node.right_table { + right_table.id = state.gen_table_id(); + } } NodeBody::GlobalSimpleAgg(node) | NodeBody::LocalSimpleAgg(node) => { - for _ in &node.agg_calls { - node.table_ids.push(state.gen_table_id()); + for table in &mut node.internal_tables { + table.id = state.gen_table_id(); } } // Rewrite hash agg. One agg call -> one table id. NodeBody::HashAgg(hash_agg_node) => { - for _ in &hash_agg_node.agg_calls { - hash_agg_node.table_ids.push(state.gen_table_id()); + for table in &mut hash_agg_node.internal_tables { + table.id = state.gen_table_id(); } } @@ -354,10 +358,12 @@ impl StreamFragmenter { #[cfg(test)] mod tests { + use risingwave_pb::catalog::{Table, Table as ProstTable}; use risingwave_pb::data::data_type::TypeName; use risingwave_pb::data::DataType; use risingwave_pb::expr::agg_call::{Arg, Type}; use risingwave_pb::expr::{AggCall, InputRefExpr}; + use risingwave_pb::plan_common::{ColumnCatalog, ColumnDesc}; use risingwave_pb::stream_plan::*; use super::*; @@ -380,6 +386,36 @@ mod tests { } } + fn make_column(column_type: TypeName, column_id: i32) -> ColumnCatalog { + ColumnCatalog { + column_desc: Some(ColumnDesc { + column_type: Some(DataType { + type_name: column_type as i32, + ..Default::default() + }), + column_id, + ..Default::default() + }), + is_hidden: false, + } + } + + fn make_internal_table(is_agg_value: bool) -> ProstTable { + let mut columns = vec![make_column(TypeName::Int64, 0)]; + if !is_agg_value { + columns.push(make_column(TypeName::Int32, 1)); + } + ProstTable { + id: TableId::placeholder().table_id, + name: String::new(), + columns, + order_column_ids: vec![0], + orders: vec![2], + pk: vec![2], + ..Default::default() + } + } + #[test] fn test_assign_local_table_id_to_stream_node() { // let fragmenter = StreamFragmenter {}; @@ -391,6 +427,14 @@ mod tests { // test HashJoin Type let mut stream_node = StreamNode { node_body: Some(NodeBody::HashJoin(HashJoinNode { + left_table: Some(Table { + id: 0, + ..Default::default() + }), + right_table: Some(Table { + id: 0, + ..Default::default() + }), ..Default::default() })), ..Default::default() @@ -399,9 +443,15 @@ mod tests { if let NodeBody::HashJoin(hash_join_node) = stream_node.node_body.as_ref().unwrap() { expect_table_id += 1; - assert_eq!(expect_table_id, hash_join_node.left_table_id); + assert_eq!( + expect_table_id, + hash_join_node.left_table.as_ref().unwrap().id + ); expect_table_id += 1; - assert_eq!(expect_table_id, hash_join_node.right_table_id); + assert_eq!( + expect_table_id, + hash_join_node.right_table.as_ref().unwrap().id + ); } } @@ -414,6 +464,11 @@ mod tests { make_sum_aggcall(1), make_sum_aggcall(2), ], + internal_tables: vec![ + make_internal_table(true), + make_internal_table(false), + make_internal_table(false), + ], ..Default::default() })), ..Default::default() @@ -425,12 +480,11 @@ mod tests { { assert_eq!( global_simple_agg_node.agg_calls.len(), - global_simple_agg_node.table_ids.len() + global_simple_agg_node.internal_tables.len() ); - - for table_id in &global_simple_agg_node.table_ids { + for table in &global_simple_agg_node.internal_tables { expect_table_id += 1; - assert_eq!(expect_table_id, *table_id); + assert_eq!(expect_table_id, table.id); } } } @@ -445,6 +499,12 @@ mod tests { make_sum_aggcall(2), make_sum_aggcall(3), ], + internal_tables: vec![ + make_internal_table(true), + make_internal_table(false), + make_internal_table(false), + make_internal_table(false), + ], ..Default::default() })), ..Default::default() @@ -452,11 +512,13 @@ mod tests { StreamFragmenter::assign_local_table_id_to_stream_node(&mut state, &mut stream_node); if let NodeBody::HashAgg(hash_agg_node) = stream_node.node_body.as_ref().unwrap() { - assert_eq!(hash_agg_node.agg_calls.len(), hash_agg_node.table_ids.len()); - - for table_id in &hash_agg_node.table_ids { + assert_eq!( + hash_agg_node.agg_calls.len(), + hash_agg_node.internal_tables.len() + ); + for table in &hash_agg_node.internal_tables { expect_table_id += 1; - assert_eq!(expect_table_id, *table_id); + assert_eq!(expect_table_id, table.id); } } } diff --git a/src/frontend/src/stream_fragmenter/rewrite/delta_join.rs b/src/frontend/src/stream_fragmenter/rewrite/delta_join.rs index 6878e9920fdbb..a83320e4537a7 100644 --- a/src/frontend/src/stream_fragmenter/rewrite/delta_join.rs +++ b/src/frontend/src/stream_fragmenter/rewrite/delta_join.rs @@ -453,13 +453,13 @@ impl StreamFragmenter { state, &exchange_i0a0, hash_join_node.left_key.clone(), - hash_join_node.left_table_id, + hash_join_node.left_table.as_ref().unwrap().id, ); let (arrange_1_info, arrange_1) = self.build_arrange_for_delta_join( state, &exchange_i1a1, hash_join_node.right_key.clone(), - hash_join_node.right_table_id, + hash_join_node.right_table.as_ref().unwrap().id, ); let arrange_0_frag = self.build_and_add_fragment(state, arrange_0)?; @@ -482,8 +482,8 @@ impl StreamFragmenter { join_type: hash_join_node.join_type, left_key: hash_join_node.left_key.clone(), right_key: hash_join_node.right_key.clone(), - left_table_id: hash_join_node.left_table_id, - right_table_id: hash_join_node.right_table_id, + left_table_id: hash_join_node.left_table.as_ref().unwrap().id, + right_table_id: hash_join_node.right_table.as_ref().unwrap().id, condition: hash_join_node.condition.clone(), left_info: Some(arrange_0_info), right_info: Some(arrange_1_info), diff --git a/src/frontend/test_runner/tests/testdata/stream_proto.yaml b/src/frontend/test_runner/tests/testdata/stream_proto.yaml index e2862d8c05c2e..7150d77c1dca3 100644 --- a/src/frontend/test_runner/tests/testdata/stream_proto.yaml +++ b/src/frontend/test_runner/tests/testdata/stream_proto.yaml @@ -461,6 +461,25 @@ returnType: typeName: INT64 isNullable: true + internalTables: + - id: 4294967294 + schemaId: 2147483646 + databaseId: 2147483646 + columns: + - columnDesc: + columnType: + typeName: INT64 + isNullable: true + owner: root + - id: 4294967294 + schemaId: 2147483646 + databaseId: 2147483646 + columns: + - columnDesc: + columnType: + typeName: INT64 + isNullable: true + owner: root fields: - dataType: typeName: INT64 @@ -656,6 +675,55 @@ returnType: typeName: INT64 isNullable: true + internalTables: + - id: 4294967294 + schemaId: 2147483646 + databaseId: 2147483646 + columns: + - columnDesc: + columnType: + typeName: INT32 + isNullable: true + name: v2 + - columnDesc: + columnType: + typeName: INT64 + isNullable: true + columnId: 1 + orderColumnIds: + - 0 + orders: + - ASCENDING + distributionKeys: + - 0 + pk: + - 0 + owner: root + - id: 4294967294 + schemaId: 2147483646 + databaseId: 2147483646 + columns: + - columnDesc: + columnType: + typeName: INT32 + isNullable: true + name: v2 + - columnDesc: + columnType: + typeName: INT64 + isNullable: true + columnId: 1 + orderColumnIds: + - 0 + orders: + - ASCENDING + distributionKeys: + - 0 + pk: + - 0 + owner: root + columnMapping: + 0: 0 pkIndices: - 1 fields: diff --git a/src/meta/src/stream/mod.rs b/src/meta/src/stream/mod.rs index efbf2043c1b78..a7f0d212bd5ff 100644 --- a/src/meta/src/stream/mod.rs +++ b/src/meta/src/stream/mod.rs @@ -51,14 +51,15 @@ pub fn record_table_vnode_mappings( hash_mapping_manager.set_fragment_state_table(fragment_id, node.table_id); } NodeBody::HashAgg(node) => { - let table_ids = node.get_table_ids(); - for table_id in table_ids { - hash_mapping_manager.set_fragment_state_table(fragment_id, *table_id); + for table in &node.internal_tables { + hash_mapping_manager.set_fragment_state_table(fragment_id, table.id); } } NodeBody::HashJoin(node) => { - hash_mapping_manager.set_fragment_state_table(fragment_id, node.left_table_id); - hash_mapping_manager.set_fragment_state_table(fragment_id, node.right_table_id); + hash_mapping_manager + .set_fragment_state_table(fragment_id, node.left_table.as_ref().unwrap().id); + hash_mapping_manager + .set_fragment_state_table(fragment_id, node.right_table.as_ref().unwrap().id); } _ => {} } diff --git a/src/meta/src/stream/stream_graph.rs b/src/meta/src/stream/stream_graph.rs index c2923d7466a59..2fdac72f73a3e 100644 --- a/src/meta/src/stream/stream_graph.rs +++ b/src/meta/src/stream/stream_graph.rs @@ -545,10 +545,16 @@ impl StreamGraphBuilder { NodeBody::HashJoin(node) => { // The operator id must be assigned with table ids. Otherwise it is a logic // error. - let left_table_id = node.left_table_id + table_id_offset; - let right_table_id = left_table_id + 1; - node.left_table_id = left_table_id; - node.right_table_id = right_table_id; + let mut left_table_id: u32 = 0; + let mut right_table_id: u32 = 0; + if let Some(table) = &mut node.left_table { + left_table_id = table.id + table_id_offset; + table.id = left_table_id; + } + if let Some(table) = &mut node.right_table { + right_table_id = left_table_id + 1; + table.id = right_table_id; + } ctx.internal_table_id_set.insert(left_table_id); ctx.internal_table_id_set.insert(right_table_id); } @@ -568,11 +574,11 @@ impl StreamGraphBuilder { } NodeBody::HashAgg(node) => { - assert_eq!(node.table_ids.len(), node.agg_calls.len()); + assert_eq!(node.internal_tables.len(), node.agg_calls.len()); // In-place update the table id. Convert from local to global. - for table_id in &mut node.table_ids { - *table_id += table_id_offset; - ctx.internal_table_id_set.insert(*table_id); + for table in &mut node.internal_tables { + table.id += table_id_offset; + ctx.internal_table_id_set.insert(table.id); } } @@ -582,11 +588,11 @@ impl StreamGraphBuilder { } NodeBody::GlobalSimpleAgg(node) | NodeBody::LocalSimpleAgg(node) => { - assert_eq!(node.table_ids.len(), node.agg_calls.len()); + assert_eq!(node.internal_tables.len(), node.agg_calls.len()); // In-place update the table id. Convert from local to global. - for table_id in &mut node.table_ids { - *table_id += table_id_offset; - ctx.internal_table_id_set.insert(*table_id); + for table in &mut node.internal_tables { + table.id += table_id_offset; + ctx.internal_table_id_set.insert(table.id); } } _ => {} diff --git a/src/meta/src/stream/test_fragmenter.rs b/src/meta/src/stream/test_fragmenter.rs index bf1ec06bf0ffd..0d41be8b5fe4e 100644 --- a/src/meta/src/stream/test_fragmenter.rs +++ b/src/meta/src/stream/test_fragmenter.rs @@ -14,9 +14,11 @@ use std::collections::{HashMap, HashSet}; use std::sync::Arc; +use std::vec; -use risingwave_common::catalog::TableId; +use risingwave_common::catalog::{DatabaseId, SchemaId, TableId}; use risingwave_common::error::Result; +use risingwave_pb::catalog::Table as ProstTable; use risingwave_pb::data::data_type::TypeName; use risingwave_pb::data::DataType; use risingwave_pb::expr::agg_call::{Arg, Type}; @@ -24,7 +26,8 @@ use risingwave_pb::expr::expr_node::RexNode; use risingwave_pb::expr::expr_node::Type::{Add, GreaterThan, InputRef}; use risingwave_pb::expr::{AggCall, ExprNode, FunctionCall, InputRefExpr}; use risingwave_pb::plan_common::{ - ColumnOrder, DatabaseRefId, Field, OrderType, SchemaRefId, TableRefId, + ColumnCatalog, ColumnDesc, ColumnOrder, DatabaseRefId, Field, OrderType, SchemaRefId, + TableRefId, }; use risingwave_pb::stream_plan::source_node::SourceType; use risingwave_pb::stream_plan::stream_node::NodeBody; @@ -99,6 +102,38 @@ fn make_column_order(idx: i32) -> ColumnOrder { } } +fn make_column(column_type: TypeName, column_id: i32) -> ColumnCatalog { + ColumnCatalog { + column_desc: Some(ColumnDesc { + column_type: Some(DataType { + type_name: column_type as i32, + ..Default::default() + }), + column_id, + ..Default::default() + }), + is_hidden: false, + } +} + +fn make_internal_table(is_agg_value: bool) -> ProstTable { + let mut columns = vec![make_column(TypeName::Int64, 0)]; + if !is_agg_value { + columns.push(make_column(TypeName::Int32, 1)); + } + ProstTable { + id: TableId::placeholder().table_id, + schema_id: SchemaId::placeholder() as u32, + database_id: DatabaseId::placeholder() as u32, + name: String::new(), + columns, + order_column_ids: vec![0], + orders: vec![2], + pk: vec![2], + ..Default::default() + } +} + /// [`make_stream_node`] build a plan represent in `StreamNode` for SQL as follow: /// ```sql /// create table t (v1 int, v2 int); @@ -165,7 +200,8 @@ fn make_stream_node() -> StreamNode { node_body: Some(NodeBody::GlobalSimpleAgg(SimpleAggNode { agg_calls: vec![make_sum_aggcall(0), make_sum_aggcall(1)], distribution_keys: Default::default(), - table_ids: vec![], + internal_tables: vec![make_internal_table(true), make_internal_table(false)], + column_mapping: HashMap::new(), is_append_only: false, })), input: vec![filter_node], @@ -197,7 +233,8 @@ fn make_stream_node() -> StreamNode { node_body: Some(NodeBody::GlobalSimpleAgg(SimpleAggNode { agg_calls: vec![make_sum_aggcall(0), make_sum_aggcall(1)], distribution_keys: Default::default(), - table_ids: vec![], + internal_tables: vec![make_internal_table(true), make_internal_table(false)], + column_mapping: HashMap::new(), is_append_only: false, })), fields: vec![], // TODO: fill this later diff --git a/src/stream/src/from_proto/global_simple_agg.rs b/src/stream/src/from_proto/global_simple_agg.rs index 03174b0a9bf12..17d7978649454 100644 --- a/src/stream/src/from_proto/global_simple_agg.rs +++ b/src/stream/src/from_proto/global_simple_agg.rs @@ -38,9 +38,9 @@ impl ExecutorBuilder for SimpleAggExecutorBuilder { // Build vector of keyspace via table ids. // One keyspace for one agg call. let keyspace = node - .get_table_ids() + .internal_tables .iter() - .map(|table_id| Keyspace::table_root(store.clone(), &TableId::new(*table_id))) + .map(|table| Keyspace::table_root(store.clone(), &TableId::new(table.id))) .collect(); let key_indices = node .get_distribution_keys() diff --git a/src/stream/src/from_proto/hash_agg.rs b/src/stream/src/from_proto/hash_agg.rs index 944d70ac1b76f..48f2ae7f7d6b6 100644 --- a/src/stream/src/from_proto/hash_agg.rs +++ b/src/stream/src/from_proto/hash_agg.rs @@ -74,9 +74,9 @@ impl ExecutorBuilder for HashAggExecutorBuilder { // Build vector of keyspace via table ids. // One keyspace for one agg call. let keyspace = node - .get_table_ids() + .internal_tables .iter() - .map(|&table_id| Keyspace::table_root(store.clone(), &TableId::new(table_id))) + .map(|table| Keyspace::table_root(store.clone(), &TableId::new(table.id))) .collect(); let input = params.input.remove(0); let keys = key_indices diff --git a/src/stream/src/from_proto/hash_join.rs b/src/stream/src/from_proto/hash_join.rs index 9cb873080a8c4..e769e99c3e6e0 100644 --- a/src/stream/src/from_proto/hash_join.rs +++ b/src/stream/src/from_proto/hash_join.rs @@ -108,8 +108,8 @@ impl ExecutorBuilder for HashJoinExecutorBuilder { .collect_vec(); let kind = calc_hash_key_kind(&keys); - let left_table_id = TableId::from(node.left_table_id); - let right_table_id = TableId::from(node.right_table_id); + let left_table_id = TableId::from(node.left_table.as_ref().unwrap().id); + let right_table_id = TableId::from(node.right_table.as_ref().unwrap().id); let args = HashJoinExecutorDispatcherArgs { source_l,