diff --git a/Cargo.lock b/Cargo.lock index d280ebde19d90..f49962e0d235c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3898,9 +3898,9 @@ dependencies = [ [[package]] name = "pg_interval" -version = "0.4.1" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "47354dbd658c57a5ee1cc97a79937345170234d4c817768de80ea6d2e9f5b98a" +checksum = "fe46640b465e284b048ef065cbed8ef17a622878d310c724578396b4cfd00df2" dependencies = [ "bytes", "chrono", @@ -5239,6 +5239,9 @@ dependencies = [ "risingwave_pb", "risingwave_rpc_client", "risingwave_storage", + "risingwave_test_runner", + "serial_test", + "sync-point", "workspace-hack", ] @@ -5247,6 +5250,7 @@ name = "risingwave_meta" version = "0.2.0-alpha" dependencies = [ "anyhow", + "arc-swap", "assert_matches", "async-stream", "async-trait", @@ -5411,6 +5415,7 @@ name = "risingwave_rt" version = "0.2.0-alpha" dependencies = [ "async-trait", + "async_stack_trace", "console", "console-subscriber", "futures", @@ -5712,27 +5717,12 @@ dependencies = [ "workspace-hack", ] -[[package]] -name = "risingwave_sync_point_test" -version = "0.1.0" -dependencies = [ - "bytes", - "itertools", - "madsim-tokio", - "risingwave_cmd_all", - "risingwave_common", - "risingwave_object_store", - "risingwave_pb", - "risingwave_rpc_client", - "serial_test", - "sync-point", -] - [[package]] name = "risingwave_test_runner" version = "0.2.0-alpha" dependencies = [ "fail", + "sync-point", "workspace-hack", ] @@ -6449,29 +6439,6 @@ dependencies = [ "thiserror", ] -[[package]] -name = "sync_point_unit_test" -version = "0.2.0-alpha" -dependencies = [ - "async-trait", - "bytes", - "fail", - "futures", - "itertools", - "madsim-tokio", - "parking_lot 0.12.1", - "rand 0.8.5", - "risingwave_common", - "risingwave_common_service", - "risingwave_hummock_sdk", - "risingwave_meta", - "risingwave_pb", - "risingwave_rpc_client", - "risingwave_storage", - "serial_test", - "sync-point", -] - [[package]] name = "sync_wrapper" version = "0.1.1" diff --git a/Cargo.toml b/Cargo.toml index 3f2906f70f15d..48462d19b8a41 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,14 +26,12 @@ members = [ "src/storage/compactor", "src/storage/hummock_sdk", "src/storage/hummock_test", - "src/storage/hummock_test/sync_point_unit_test", "src/stream", "src/test_runner", "src/tests/regress", "src/tests/simulation", "src/tests/simulation_scale", "src/tests/sqlsmith", - "src/tests/sync_point", "src/tracing", "src/utils/async_stack_trace", "src/utils/local_stats_alloc", @@ -78,6 +76,12 @@ inherits = "dev" incremental = false [profile.ci-dev.package."*"] # external dependencies opt-level = 1 +[profile.ci-dev.package."tokio"] +opt-level = 3 +[profile.ci-dev.package."async_stack_trace"] +opt-level = 3 +[profile.ci-dev.package."indextree"] +opt-level = 3 # The profile used for deterministic simulation tests in CI. # The simulator can only run single-threaded, so optimization is required to make the running time diff --git a/README.md b/README.md index 77b14b4689108..4881a8f1587fd 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,8 @@ [![Build status](https://badge.buildkite.com/9394d2bca0f87e2e97aa78b25f765c92d4207c0b65e7f6648f.svg)](https://buildkite.com/singularity-data/main) [![codecov](https://codecov.io/gh/risingwavelabs/risingwave/branch/main/graph/badge.svg?token=EB44K9K38B)](https://codecov.io/gh/risingwavelabs/risingwave) +**RisingWave is now ready for production! Curious about the use cases? More details coming soon!** + RisingWave is a cloud-native streaming database that uses SQL as the interface language. It is designed to reduce the complexity and cost of building real-time applications. RisingWave consumes streaming data, performs continuous queries, and updates results dynamically. As a database system, RisingWave maintains results inside its own storage and allows users to access data efficiently. RisingWave ingests data from sources like Apache Kafka, Apache Pulsar, Amazon Kinesis, Redpanda, and materialized CDC sources. diff --git a/ci/plugins/benchmark/hooks/pre-command b/ci/plugins/benchmark/hooks/pre-command deleted file mode 100755 index e0607e2115d46..0000000000000 --- a/ci/plugins/benchmark/hooks/pre-command +++ /dev/null @@ -1,5 +0,0 @@ -#!/bin/bash - -set -euo pipefail - -export HOST_IP=$(dig +short myip.opendns.com @resolver1.opendns.com) \ No newline at end of file diff --git a/ci/scripts/benchmark.sh b/ci/scripts/benchmark.sh deleted file mode 100755 index 2f4e11f481da7..0000000000000 --- a/ci/scripts/benchmark.sh +++ /dev/null @@ -1,135 +0,0 @@ -#!/bin/bash - -# Exits as soon as any line fails. -set -euo pipefail - -# pollingScript message try_times interval script_string -function pollingScript() { - message=$1 - try_times=$2 - interval=$3 - script_string=$4 - while :; do - echo "polling: $message" - if [ "$try_times" == 0 ]; then - echo "❌ ERROR: polling timeout" - exit 1 - fi - if eval "$script_string"; then - echo "✅ Instance Ready" - break - fi - sleep "$interval" - try_times=$((try_times - 1)) - done -} - -# pollingScript status try_times -function pollingTenantStatus() { - status=$1 - try_times=$2 - interval=10 - pollingScript "tenant status until it is $status" "$try_times" "$interval" \ - "rwc tenant get -name $TENANT_NAME | grep 'Status: $status'" -} - -function polling() { - set +e - try_times=10 - while :; do - if [ $try_times == 0 ]; then - echo "❌ ERROR: Polling Timeout" - exit 1 - fi - psql "$@" -c '\q' - if [ $? == 0 ]; then - echo "✅ Endpoint Available" - break - fi - sleep 5 - try_times=$((try_times - 1)) - done - set -euo pipefail -} - -function cleanup { - echo "--- Delete tenant" - rwc tenant delete -name ${TENANT_NAME} -} - -trap cleanup EXIT - -DB_USER=dbuser -DB_PWD=dbpwd - -if [[ -z "${RISINGWAVE_IMAGE_TAG+x}" ]]; then - IMAGE_TAG="latest" -else - IMAGE_TAG="${RISINGWAVE_IMAGE_TAG}" -fi - -if [ -z "${BENCH_SKU+x}" ] || [ "${BENCH_SKU}" == "MultiNodeBench" ]; then - SKU="multinode" - BENCH_SKU="MultiNodeBench" -elif [ "${BENCH_SKU}" == "SingleNodeBench" ]; then - SKU="singlenode" -else - exit 1 -fi - -date=$(date '+%Y%m%d-%H%M%S') -TENANT_NAME="${SKU}-${date}" - -echo "--- Echo Info" -echo "BENCH-SKU: ${BENCH_SKU}" -echo "Tenant-Name: ${TENANT_NAME}" -echo "Host-Ip: ${HOST_IP}" -echo "IMAGE-TAG: ${IMAGE_TAG}" - -echo "--- Download Necessary Tools" -apt-get -y install golang-go librdkafka-dev python3-pip -curl -L https://rwc-cli-internal-release.s3.ap-southeast-1.amazonaws.com/download.sh | bash && mv rwc /usr/local/bin - -echo "--- RWC Config and Login" -rwc config context -accounturl https://rls-apse1-acc.risingwave-cloud.xyz/api/v1 -rwc config -region ap-southeast-1 -rwc config ls -rwc login -account benchmark -password "$BENCH_TOKEN" - -echo "--- RWC Create a Risingwave Instance" -rwc tenant create -name ${TENANT_NAME} -sku ${BENCH_SKU} -imagetag ${IMAGE_TAG} - -echo "--- Wait Risingwave Instance Ready" -pollingTenantStatus Running 30 - -echo "--- Get Risingwave Instance endpoint" -endpoint=$(rwc tenant endpoint -name ${TENANT_NAME}) - -echo "--- Create DB User" -rwc tenant create-user -n ${TENANT_NAME} -u ${DB_USER} -p ${DB_PWD} - -echo "--- Test endpoint" -endpoint=${endpoint//""/"$DB_USER"} -endpoint=${endpoint//""/"$DB_PWD"} -echo ${endpoint} -polling ${endpoint} - -echo "--- Namespace: ${endpoint#*%3D}" - -echo "--- Generate Tpch-Bench Args" -mkdir ~/risingwave-deploy -echo "--frontend-url ${endpoint}" > ~/risingwave-deploy/tpch-bench-args-frontend -echo "--kafka-addr ${HOST_IP}:29092" > ~/risingwave-deploy/tpch-bench-args-kafka -cat ~/risingwave-deploy/tpch-bench-args-frontend -cat ~/risingwave-deploy/tpch-bench-args-kafka - -echo "--- Clone Tpch-Bench Repo" -git clone https://"$GITHUB_TOKEN"@github.com/risingwavelabs/tpch-bench.git - -echo "--- Run Tpch-Bench" -cd tpch-bench/ -./scripts/build.sh -./scripts/launch_risedev_bench.sh - -echo "--- Waiting For Risingwave to Consume Data" -sleep 300 diff --git a/ci/scripts/pre-unit-test.sh b/ci/scripts/pre-unit-test.sh index 6c235ebad0f7e..0381c082eca84 100755 --- a/ci/scripts/pre-unit-test.sh +++ b/ci/scripts/pre-unit-test.sh @@ -6,7 +6,7 @@ set -euo pipefail source ci/scripts/common.env.sh echo "--- Run clippy check" -cargo clippy --all-targets --features failpoints --locked -- -D warnings +cargo clippy --all-targets --features failpoints,sync_point --locked -- -D warnings echo "--- Build documentation" cargo doc --document-private-items --no-deps diff --git a/ci/scripts/run-unit-test.sh b/ci/scripts/run-unit-test.sh index c4f17a1ee330a..1be4b143d58ee 100755 --- a/ci/scripts/run-unit-test.sh +++ b/ci/scripts/run-unit-test.sh @@ -3,9 +3,9 @@ set -euo pipefail echo "+++ Run unit tests with coverage" # use tee to disable progress bar -NEXTEST_PROFILE=ci cargo llvm-cov nextest --lcov --output-path lcov.info --features failpoints 2> >(tee); +NEXTEST_PROFILE=ci cargo llvm-cov nextest --lcov --output-path lcov.info --features failpoints,sync_point 2> >(tee); if [[ "$RUN_SQLSMITH" -eq "1" ]]; then - NEXTEST_PROFILE=ci cargo nextest run run_sqlsmith_on_frontend --features "failpoints enable_sqlsmith_unit_test" 2> >(tee); + NEXTEST_PROFILE=ci cargo nextest run run_sqlsmith_on_frontend --features "failpoints sync_point enable_sqlsmith_unit_test" 2> >(tee); fi echo "--- Codecov upload coverage reports" diff --git a/ci/workflows/benchmark.yml b/ci/workflows/benchmark.yml deleted file mode 100644 index 37b9ad3f6da1b..0000000000000 --- a/ci/workflows/benchmark.yml +++ /dev/null @@ -1,22 +0,0 @@ -steps: - - label: "$BENCH_SKU benchmark" - command: "ci/scripts/benchmark.sh" - plugins: - - ./ci/plugins/benchmark - - seek-oss/aws-sm#v2.3.1: - env: - GITHUB_TOKEN: github-token - BENCH_TOKEN: bench-token - - docker-compose#v3.9.0: - run: benchmark-env - config: ci/docker-compose.yml - mount-buildkite-agent: true - environment: - - GITHUB_TOKEN - - BENCH_TOKEN - - HOST_IP - - RISINGWAVE_IMAGE_TAG - - BENCH_SKU - agents: - queue: "release-bench" - timeout_in_minutes: 60 \ No newline at end of file diff --git a/e2e_test/batch/types/decimal.slt.part b/e2e_test/batch/types/decimal.slt.part index cba01c90ec656..50f410148d437 100644 --- a/e2e_test/batch/types/decimal.slt.part +++ b/e2e_test/batch/types/decimal.slt.part @@ -31,3 +31,13 @@ query T values(round(42, 1)); ---- 42 + +query R +select ' +inFInity'::decimal; +---- +Infinity + +query R +select ' -inF '::decimal; +---- +-Infinity diff --git a/e2e_test/streaming/basic.slt b/e2e_test/streaming/basic.slt index 1ea6af6e28260..d25bbcfc01f9b 100644 --- a/e2e_test/streaming/basic.slt +++ b/e2e_test/streaming/basic.slt @@ -84,8 +84,8 @@ query IR rowsort select * from mv4 ---- 1 NaN -1 +Inf -1 -Inf +1 Infinity +1 -Infinity statement ok drop materialized view mv1 diff --git a/e2e_test/streaming/demo/ad_ctr.slt b/e2e_test/streaming/demo/ad_ctr.slt index d325044f4eb66..60913528f2554 100644 --- a/e2e_test/streaming/demo/ad_ctr.slt +++ b/e2e_test/streaming/demo/ad_ctr.slt @@ -107,6 +107,9 @@ SELECT ad_id, ROUND(ctr, 2), window_end FROM ad_ctr_5min; 7 0.33 2022-06-10 12:21:00 8 1 2022-06-10 12:21:00 +statement ok +delete from ad_impression; + statement ok DROP MATERIALIZED VIEW ad_ctr; diff --git a/proto/catalog.proto b/proto/catalog.proto index 3cf8880f5236a..8b71113eac04d 100644 --- a/proto/catalog.proto +++ b/proto/catalog.proto @@ -15,31 +15,31 @@ message ColumnIndex { } message StreamSourceInfo { - map properties = 1; - plan_common.RowFormatType row_format = 2; - string row_schema_location = 3; - ColumnIndex row_id_index = 4; - repeated plan_common.ColumnCatalog columns = 5; - repeated int32 pk_column_ids = 6; + plan_common.RowFormatType row_format = 1; + string row_schema_location = 2; } -message TableSourceInfo { - ColumnIndex row_id_index = 1; - repeated plan_common.ColumnCatalog columns = 2; - repeated int32 pk_column_ids = 3; - map properties = 4; -} +message TableSourceInfo {} message Source { uint32 id = 1; uint32 schema_id = 2; uint32 database_id = 3; string name = 4; + // The column index of row ID. If the primary key is specified by the user, this will be `None`. + ColumnIndex row_id_index = 5; + // Columns of the source. + repeated plan_common.ColumnCatalog columns = 6; + // Column id of the primary key specified by the user. If the user does not specify a primary key, the vector will be empty. + repeated int32 pk_column_ids = 7; + // Properties specified by the user in WITH clause. + map properties = 8; + // Information of either stream source (connector) or table source. oneof info { - StreamSourceInfo stream_source = 5; - TableSourceInfo table_source = 6; + StreamSourceInfo stream_source = 9; + TableSourceInfo table_source = 10; } - uint32 owner = 7; + uint32 owner = 11; } message Sink { diff --git a/proto/stream_plan.proto b/proto/stream_plan.proto index 3b2eeaf35c2f0..55893916e1ff4 100644 --- a/proto/stream_plan.proto +++ b/proto/stream_plan.proto @@ -107,14 +107,16 @@ message ActorMapping { repeated uint32 data = 2; } -// todo: StreamSourceNode or TableSourceNode message SourceNode { uint32 source_id = 1; // use source_id to fetch SourceDesc from local source manager - repeated int32 column_ids = 2; - catalog.Table state_table = 3; + catalog.Table state_table = 2; + catalog.ColumnIndex row_id_index = 3; + repeated plan_common.ColumnCatalog columns = 4; + repeated int32 pk_column_ids = 5; + map properties = 6; oneof info { - catalog.StreamSourceInfo stream_source = 4; - catalog.TableSourceInfo table_source = 5; + catalog.StreamSourceInfo stream_source = 7; + catalog.TableSourceInfo table_source = 8; } } @@ -145,6 +147,8 @@ message MaterializeNode { repeated plan_common.ColumnOrder column_orders = 2; // Used for internal table states. catalog.Table table = 3; + // Used to control whether doing sanity check, open it when upstream executor is source executor. + bool ignore_on_conflict = 4; } message AggCallState { @@ -330,6 +334,8 @@ message ArrangeNode { repeated uint32 distribution_key = 2; // Used for internal table states. catalog.Table table = 3; + // Used to control whether doing sanity check, open it when upstream executor is source executor. + bool ignore_on_conflict = 4; } // Special node for shared state. LookupNode will join an arrangement with a stream. @@ -471,6 +477,8 @@ message StreamActor { // Vnodes that the executors in this actor own. // If the fragment is a singleton, this field will not be set and leave a `None`. common.Buffer vnode_bitmap = 8; + // The SQL definition of this materialized view. Used for debugging only. + string mview_definition = 9; } enum FragmentType { diff --git a/risedev.yml b/risedev.yml index 3120698b62f22..ff6b91567b2ad 100644 --- a/risedev.yml +++ b/risedev.yml @@ -275,21 +275,21 @@ risedev: id: compute-node-0 listen-address: "0.0.0.0" address: ${dns-host:rw-compute-0} - enable-async-stack-trace: true + async-stack-trace: verbose enable-tiered-cache: true - use: compute-node id: compute-node-1 listen-address: "0.0.0.0" address: ${dns-host:rw-compute-1} - enable-async-stack-trace: true + async-stack-trace: verbose enable-tiered-cache: true - use: compute-node id: compute-node-2 listen-address: "0.0.0.0" address: ${dns-host:rw-compute-2} - enable-async-stack-trace: true + async-stack-trace: verbose enable-tiered-cache: true - use: frontend @@ -518,8 +518,9 @@ template: # Id of this instance id: compute-node-${port} - # Whether to enable async stack trace for this compute node. - enable-async-stack-trace: false + # Whether to enable async stack trace for this compute node, `off`, `on`, or `verbose`. + # Considering the performance, `verbose` mode only effect under `release` profile with `debug_assertions` off. + async-stack-trace: on # If `enable-tiered-cache` is true, hummock will use data directory as file cache. enable-tiered-cache: false diff --git a/src/batch/benches/filter.rs b/src/batch/benches/filter.rs index a963eeb4eb1c6..c8fa1681e8748 100644 --- a/src/batch/benches/filter.rs +++ b/src/batch/benches/filter.rs @@ -17,6 +17,7 @@ pub mod utils; use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion}; use risingwave_batch::executor::{BoxedExecutor, FilterExecutor}; use risingwave_common::types::{DataType, ScalarImpl}; +use risingwave_common::util::value_encoding::serialize_datum_to_bytes; use risingwave_expr::expr::build_from_prost; use risingwave_pb::data::data_type::TypeName; use risingwave_pb::expr::expr_node::RexNode; @@ -53,7 +54,7 @@ fn create_filter_executor(chunk_size: usize, chunk_num: usize) -> BoxedExecutor ..Default::default() }), rex_node: Some(RexNode::Constant(ConstantValue { - body: ScalarImpl::Int64(2).to_protobuf(), + body: serialize_datum_to_bytes(Some(ScalarImpl::Int64(2)).as_ref()), })), }; @@ -76,7 +77,7 @@ fn create_filter_executor(chunk_size: usize, chunk_num: usize) -> BoxedExecutor ..Default::default() }), rex_node: Some(RexNode::Constant(ConstantValue { - body: ScalarImpl::Int64(0).to_protobuf(), + body: serialize_datum_to_bytes(Some(ScalarImpl::Int64(0)).as_ref()), })), }; diff --git a/src/batch/benches/hash_join.rs b/src/batch/benches/hash_join.rs index 0adea62029cd7..3538a1d858e8c 100644 --- a/src/batch/benches/hash_join.rs +++ b/src/batch/benches/hash_join.rs @@ -21,6 +21,7 @@ use risingwave_batch::executor::{BoxedExecutor, JoinType}; use risingwave_common::catalog::schema_test_utils::field_n; use risingwave_common::hash; use risingwave_common::types::{DataType, ScalarImpl}; +use risingwave_common::util::value_encoding::serialize_datum_to_bytes; use risingwave_expr::expr::build_from_prost; use risingwave_pb::data::data_type::TypeName; use risingwave_pb::expr::expr_node::RexNode; @@ -59,7 +60,7 @@ fn create_hash_join_executor( ..Default::default() }), rex_node: Some(RexNode::Constant(ConstantValue { - body: ScalarImpl::Int64(123).to_protobuf(), + body: serialize_datum_to_bytes(Some(ScalarImpl::Int64(123)).as_ref()), })), }; ExprNode { @@ -89,7 +90,7 @@ fn create_hash_join_executor( ..Default::default() }), rex_node: Some(RexNode::Constant(ConstantValue { - body: ScalarImpl::Int64(456).to_protobuf(), + body: serialize_datum_to_bytes(Some(ScalarImpl::Int64(456)).as_ref()), })), }; ExprNode { @@ -142,7 +143,7 @@ fn create_hash_join_executor( ..Default::default() }), rex_node: Some(RexNode::Constant(ConstantValue { - body: ScalarImpl::Int64(100).to_protobuf(), + body: serialize_datum_to_bytes(Some(ScalarImpl::Int64(100)).as_ref()), })), }; Some(ExprNode { diff --git a/src/batch/benches/nested_loop_join.rs b/src/batch/benches/nested_loop_join.rs index 50324c6548711..517a2236a78a9 100644 --- a/src/batch/benches/nested_loop_join.rs +++ b/src/batch/benches/nested_loop_join.rs @@ -16,6 +16,7 @@ pub mod utils; use criterion::{criterion_group, criterion_main, Criterion}; use risingwave_batch::executor::{BoxedExecutor, JoinType, NestedLoopJoinExecutor}; use risingwave_common::types::{DataType, ScalarImpl}; +use risingwave_common::util::value_encoding::serialize_datum_to_bytes; use risingwave_expr::expr::build_from_prost; use risingwave_pb::data::data_type::TypeName; use risingwave_pb::expr::expr_node::RexNode; @@ -68,7 +69,7 @@ fn create_nested_loop_join_executor( ..Default::default() }), rex_node: Some(RexNode::Constant(ConstantValue { - body: ScalarImpl::Int64(2).to_protobuf(), + body: serialize_datum_to_bytes(Some(ScalarImpl::Int64(2)).as_ref()), })), }; @@ -79,7 +80,7 @@ fn create_nested_loop_join_executor( ..Default::default() }), rex_node: Some(RexNode::Constant(ConstantValue { - body: ScalarImpl::Int64(3).to_protobuf(), + body: serialize_datum_to_bytes(Some(ScalarImpl::Int64(3)).as_ref()), })), }; diff --git a/src/batch/src/executor/delete.rs b/src/batch/src/executor/delete.rs index 67677cb4f5dea..f80e567ca279e 100644 --- a/src/batch/src/executor/delete.rs +++ b/src/batch/src/executor/delete.rs @@ -144,8 +144,8 @@ mod tests { use risingwave_common::array::Array; use risingwave_common::catalog::schema_test_utils; use risingwave_common::test_prelude::DataChunkTestExt; - use risingwave_source::table_test_utils::create_table_info; - use risingwave_source::{SourceDescBuilder, TableSourceManager, TableSourceManagerRef}; + use risingwave_source::table_test_utils::create_table_source_desc_builder; + use risingwave_source::{TableSourceManager, TableSourceManagerRef}; use super::*; use crate::executor::test_utils::MockExecutor; @@ -171,14 +171,16 @@ mod tests { 9 10", )); - let info = create_table_info(&schema, None, vec![1]); - - // Create the table. - let table_id = TableId::new(0); - let source_builder = SourceDescBuilder::new(table_id, &info, &source_manager); - // Create reader - let source_desc = source_builder.build().await?; + let table_id = TableId::new(0); + let source_builder = create_table_source_desc_builder( + &schema, + table_id, + None, + vec![1], + source_manager.clone(), + ); + let source_desc = source_builder.build().await.unwrap(); let source = source_desc.source.as_table().unwrap(); let mut reader = source .stream_reader(vec![0.into(), 1.into()]) diff --git a/src/batch/src/executor/filter.rs b/src/batch/src/executor/filter.rs index e7113ecb011ac..cc5a141600e14 100644 --- a/src/batch/src/executor/filter.rs +++ b/src/batch/src/executor/filter.rs @@ -128,6 +128,7 @@ mod tests { use risingwave_common::catalog::{Field, Schema}; use risingwave_common::test_prelude::DataChunkTestExt; use risingwave_common::types::{DataType, Scalar}; + use risingwave_common::util::value_encoding::serialize_datum_to_bytes; use risingwave_expr::expr::build_from_prost; use risingwave_pb::data::data_type::TypeName; use risingwave_pb::expr::expr_node::Type::InputRef; @@ -241,8 +242,12 @@ mod tests { ..Default::default() }), rex_node: Some(RexNode::Constant(ConstantValue { - body: ScalarImpl::List(ListValue::new(vec![Some(2.to_scalar_value())])) - .to_protobuf(), + body: serialize_datum_to_bytes( + Some(ScalarImpl::List(ListValue::new(vec![Some( + 2.to_scalar_value(), + )]))) + .as_ref(), + ), })), }; let function_call = FunctionCall { diff --git a/src/batch/src/executor/insert.rs b/src/batch/src/executor/insert.rs index fb938941d9887..9f3c7b765b03f 100644 --- a/src/batch/src/executor/insert.rs +++ b/src/batch/src/executor/insert.rs @@ -158,8 +158,8 @@ mod tests { use risingwave_common::catalog::schema_test_utils; use risingwave_common::column_nonnull; use risingwave_common::types::DataType; - use risingwave_source::table_test_utils::create_table_info; - use risingwave_source::{SourceDescBuilder, TableSourceManager, TableSourceManagerRef}; + use risingwave_source::table_test_utils::create_table_source_desc_builder; + use risingwave_source::{TableSourceManager, TableSourceManagerRef}; use risingwave_storage::memory::MemoryStateStore; use risingwave_storage::store::ReadOptions; use risingwave_storage::*; @@ -204,14 +204,17 @@ mod tests { let data_chunk: DataChunk = DataChunk::new(vec![col1, col2, col3], 5); mock_executor.add(data_chunk.clone()); - // To match the row_id column in the schema - let info = create_table_info(&schema, Some(3), vec![3]); - // Create the table. let table_id = TableId::new(0); - let source_builder = SourceDescBuilder::new(table_id, &info, &source_manager); // Create reader + let source_builder = create_table_source_desc_builder( + &schema, + table_id, + Some(3), + vec![3], + source_manager.clone(), + ); let source_desc = source_builder.build().await?; let source = source_desc.source.as_table().unwrap(); let mut reader = source diff --git a/src/batch/src/executor/row_seq_scan.rs b/src/batch/src/executor/row_seq_scan.rs index 7ab7f16116c8f..9d2e83ee920cc 100644 --- a/src/batch/src/executor/row_seq_scan.rs +++ b/src/batch/src/executor/row_seq_scan.rs @@ -22,7 +22,7 @@ use risingwave_common::array::{DataChunk, Row}; use risingwave_common::buffer::Bitmap; use risingwave_common::catalog::{ColumnDesc, ColumnId, Schema, TableId, TableOption}; use risingwave_common::error::{Result, RwError}; -use risingwave_common::types::{DataType, Datum, ScalarImpl}; +use risingwave_common::types::{DataType, Datum}; use risingwave_common::util::select_all; use risingwave_common::util::sort_util::OrderType; use risingwave_common::util::value_encoding::deserialize_datum; @@ -91,10 +91,7 @@ impl ScanRange { let bound_ty = pk_types.next().unwrap(); let build_bound = |bound: &scan_range::Bound| -> Bound { - let scalar = - ScalarImpl::from_proto_bytes(&bound.value, &bound_ty.to_protobuf()).unwrap(); - - let datum = Some(scalar); + let datum = deserialize_datum(bound.value.as_slice(), &bound_ty).unwrap(); if bound.inclusive { Bound::Included(datum) } else { diff --git a/src/batch/src/executor/update.rs b/src/batch/src/executor/update.rs index 26bd43f08c8b5..1071b38e0321a 100644 --- a/src/batch/src/executor/update.rs +++ b/src/batch/src/executor/update.rs @@ -199,8 +199,8 @@ mod tests { use risingwave_common::catalog::schema_test_utils; use risingwave_common::test_prelude::DataChunkTestExt; use risingwave_expr::expr::InputRefExpression; - use risingwave_source::table_test_utils::create_table_info; - use risingwave_source::{SourceDescBuilder, TableSourceManager, TableSourceManagerRef}; + use risingwave_source::table_test_utils::create_table_source_desc_builder; + use risingwave_source::{TableSourceManager, TableSourceManagerRef}; use super::*; use crate::executor::test_utils::MockExecutor; @@ -233,11 +233,16 @@ mod tests { ]; // Create the table. - let info = create_table_info(&schema, None, vec![1]); let table_id = TableId::new(0); - let source_builder = SourceDescBuilder::new(table_id, &info, &source_manager); // Create reader + let source_builder = create_table_source_desc_builder( + &schema, + table_id, + None, + vec![1], + source_manager.clone(), + ); let source_desc = source_builder.build().await?; let source = source_desc.source.as_table().unwrap(); let mut reader = source diff --git a/src/common/src/array/list_array.rs b/src/common/src/array/list_array.rs index 6ecfbfa66e8e4..f3d4086ab94cc 100644 --- a/src/common/src/array/list_array.rs +++ b/src/common/src/array/list_array.rs @@ -20,11 +20,7 @@ use std::hash::{Hash, Hasher}; use bytes::{Buf, BufMut}; use itertools::EitherOrBoth::{Both, Left, Right}; use itertools::Itertools; -use prost::Message; -use risingwave_pb::data::{ - Array as ProstArray, ArrayType as ProstArrayType, DataType as ProstDataType, ListArrayData, -}; -use risingwave_pb::expr::ListValue as ProstListValue; +use risingwave_pb::data::{Array as ProstArray, ArrayType as ProstArrayType, ListArrayData}; use serde::{Deserializer, Serializer}; use super::{ @@ -34,7 +30,7 @@ use super::{ use crate::buffer::{Bitmap, BitmapBuilder}; use crate::types::{ deserialize_datum_from, display_datum_ref, serialize_datum_ref_into, to_datum_ref, DataType, - Datum, DatumRef, Scalar, ScalarImpl, ScalarRefImpl, ToOwnedDatum, + Datum, DatumRef, Scalar, ScalarRefImpl, }; /// This is a naive implementation of list array. @@ -345,27 +341,6 @@ impl ListValue { &self.values } - pub fn to_protobuf_owned(&self) -> Vec { - self.as_scalar_ref().to_protobuf_owned() - } - - pub fn from_protobuf_bytes(data_type: ProstDataType, b: &Vec) -> ArrayResult { - let list_value: ProstListValue = Message::decode(b.as_slice())?; - let d = &data_type.field_type[0]; - let fields: Vec = list_value - .fields - .iter() - .map(|b| { - if b.is_empty() { - Ok(None) - } else { - Ok(Some(ScalarImpl::from_proto_bytes(b, d)?)) - } - }) - .collect::>>()?; - Ok(ListValue::new(fields)) - } - pub fn deserialize( datatype: &DataType, deserializer: &mut memcomparable::Deserializer, @@ -419,21 +394,6 @@ macro_rules! iter_elems_ref { }; } -macro_rules! iter_elems { - ($self:ident, $it:ident, { $($body:tt)* }) => { - match $self { - ListRef::Indexed { arr, idx } => { - let $it = (arr.offsets[idx]..arr.offsets[idx + 1]).map(|o| arr.value.value_at(o).to_owned_datum()); - $($body)* - } - ListRef::ValueRef { val } => { - let $it = val.values.iter(); - $($body)* - } - } - }; -} - impl<'a> ListRef<'a> { pub fn flatten(&self) -> Vec> { iter_elems_ref!(self, it, { @@ -472,19 +432,6 @@ impl<'a> ListRef<'a> { } } - pub fn to_protobuf_owned(self) -> Vec { - let elems = iter_elems!(self, it, { - it.map(|f| match f { - None => { - vec![] - } - Some(s) => s.to_protobuf(), - }) - .collect_vec() - }); - ProstListValue { fields: elems }.encode_to_vec() - } - pub fn serialize( &self, serializer: &mut memcomparable::Serializer, @@ -870,36 +817,6 @@ mod tests { assert_eq!("{1,NULL}".to_string(), format!("{}", r)); } - #[test] - fn test_to_protobuf_owned() { - use crate::array::*; - let arr = ListArray::from_slices( - &[true, true], - vec![ - Some(array! { I32Array, [Some(1), Some(2)] }.into()), - Some(array! { I32Array, [Some(3), Some(4)] }.into()), - ], - DataType::Int32, - ); - let list_ref = arr.value_at(0).unwrap(); - let output = list_ref.to_protobuf_owned(); - let expect = ListValue::new(vec![ - Some(1i32.to_scalar_value()), - Some(2i32.to_scalar_value()), - ]) - .to_protobuf_owned(); - assert_eq!(output, expect); - - let list_ref = arr.value_at(1).unwrap(); - let output = list_ref.to_protobuf_owned(); - let expect = ListValue::new(vec![ - Some(3i32.to_scalar_value()), - Some(4i32.to_scalar_value()), - ]) - .to_protobuf_owned(); - assert_eq!(output, expect); - } - #[test] fn test_serialize_deserialize() { let value = ListValue::new(vec![ diff --git a/src/common/src/array/stream_chunk.rs b/src/common/src/array/stream_chunk.rs index 17d7159f6d0cc..162f7ea31ed8e 100644 --- a/src/common/src/array/stream_chunk.rs +++ b/src/common/src/array/stream_chunk.rs @@ -29,7 +29,7 @@ use crate::types::DataType; /// but always appear in pairs to represent an update operation. /// For example, table source, aggregation and outer join can generate updates by themselves, /// while most of the other operators only pass through updates with best effort. -#[derive(Clone, Copy, Debug, PartialOrd, Ord, PartialEq, Eq)] +#[derive(Clone, Copy, Debug, PartialOrd, Ord, PartialEq, Eq, Hash)] pub enum Op { Insert, Delete, diff --git a/src/common/src/array/struct_array.rs b/src/common/src/array/struct_array.rs index 120bf0ef65c43..68295738179f7 100644 --- a/src/common/src/array/struct_array.rs +++ b/src/common/src/array/struct_array.rs @@ -20,11 +20,7 @@ use std::sync::Arc; use bytes::{Buf, BufMut}; use itertools::Itertools; -use prost::Message; -use risingwave_pb::data::{ - Array as ProstArray, ArrayType as ProstArrayType, DataType as ProstDataType, StructArrayData, -}; -use risingwave_pb::expr::StructValue as ProstStructValue; +use risingwave_pb::data::{Array as ProstArray, ArrayType as ProstArrayType, StructArrayData}; use super::{ Array, ArrayBuilder, ArrayBuilderImpl, ArrayImpl, ArrayIterator, ArrayMeta, ArrayResult, @@ -34,7 +30,7 @@ use crate::array::ArrayRef; use crate::buffer::{Bitmap, BitmapBuilder}; use crate::types::{ deserialize_datum_from, display_datum_ref, serialize_datum_ref_into, to_datum_ref, DataType, - Datum, DatumRef, Scalar, ScalarImpl, ScalarRefImpl, ToOwnedDatum, + Datum, DatumRef, Scalar, ScalarRefImpl, }; #[derive(Debug)] @@ -342,27 +338,6 @@ impl StructValue { &self.fields } - pub fn to_protobuf_owned(&self) -> Vec { - self.as_scalar_ref().to_protobuf_owned() - } - - pub fn from_protobuf_bytes(data_type: ProstDataType, b: &Vec) -> ArrayResult { - let struct_value: ProstStructValue = Message::decode(b.as_slice())?; - let fields: Vec = struct_value - .fields - .iter() - .zip_eq(data_type.field_type.iter()) - .map(|(b, d)| { - if b.is_empty() { - Ok(None) - } else { - Ok(Some(ScalarImpl::from_proto_bytes(b, d)?)) - } - }) - .collect::>>()?; - Ok(StructValue::new(fields)) - } - pub fn deserialize( fields: &[DataType], deserializer: &mut memcomparable::Deserializer, @@ -396,42 +371,11 @@ macro_rules! iter_fields_ref { }; } -macro_rules! iter_fields { - ($self:ident, $it:ident, { $($body:tt)* }) => { - match &$self { - StructRef::Indexed { arr, idx } => { - let $it = arr - .children - .iter() - .map(|a| a.value_at(*idx).to_owned_datum()); - $($body)* - } - StructRef::ValueRef { val } => { - let $it = val.fields.iter(); - $($body)* - } - } - }; -} - impl<'a> StructRef<'a> { pub fn fields_ref(&self) -> Vec> { iter_fields_ref!(self, it, { it.collect() }) } - pub fn to_protobuf_owned(self) -> Vec { - let fields = iter_fields!(self, it, { - it.map(|f| match f { - None => { - vec![] - } - Some(s) => s.to_protobuf(), - }) - .collect_vec() - }); - ProstStructValue { fields }.encode_to_vec() - } - pub fn serialize( &self, serializer: &mut memcomparable::Serializer, @@ -623,27 +567,6 @@ mod tests { ); } - #[test] - fn test_to_protobuf_owned() { - use crate::array::*; - let arr = StructArray::from_slices( - &[true], - vec![ - array! { I32Array, [Some(1)] }.into(), - array! { F32Array, [Some(2.0)] }.into(), - ], - vec![DataType::Int32, DataType::Float32], - ); - let struct_ref = arr.value_at(0).unwrap(); - let output = struct_ref.to_protobuf_owned(); - let expect = StructValue::new(vec![ - Some(1i32.to_scalar_value()), - Some(OrderedF32::from(2.0f32).to_scalar_value()), - ]) - .to_protobuf_owned(); - assert_eq!(output, expect); - } - #[test] fn test_serialize_deserialize() { let value = StructValue::new(vec![ diff --git a/src/common/src/catalog/column.rs b/src/common/src/catalog/column.rs index 9b167fcf129fb..1b026578dcb65 100644 --- a/src/common/src/catalog/column.rs +++ b/src/common/src/catalog/column.rs @@ -54,6 +54,12 @@ impl From for i32 { } } +impl From<&ColumnId> for i32 { + fn from(id: &ColumnId) -> i32 { + id.0 + } +} + impl std::fmt::Display for ColumnId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.0) diff --git a/src/common/src/types/decimal.rs b/src/common/src/types/decimal.rs index 59d5c8eafa4af..c1e437f240ec2 100644 --- a/src/common/src/types/decimal.rs +++ b/src/common/src/types/decimal.rs @@ -25,9 +25,9 @@ pub enum Decimal { Normalized(RustDecimal), #[display("NaN")] NaN, - #[display("+Inf")] + #[display("Infinity")] PositiveInf, - #[display("-Inf")] + #[display("-Infinity")] NegativeInf, } @@ -532,10 +532,10 @@ impl FromStr for Decimal { type Err = Error; fn from_str(s: &str) -> Result { - match s { - "nan" | "NaN" | "NAN" => Ok(Decimal::NaN), - "inf" | "INF" | "+inf" | "+INF" | "+Inf" => Ok(Decimal::PositiveInf), - "-inf" | "-INF" | "-Inf" => Ok(Decimal::NegativeInf), + match s.to_ascii_lowercase().as_str() { + "nan" => Ok(Decimal::NaN), + "inf" | "+inf" | "infinity" | "+infinity" => Ok(Decimal::PositiveInf), + "-inf" | "-infinity" => Ok(Decimal::NegativeInf), s => RustDecimal::from_str(s).map(Decimal::Normalized), } } @@ -611,36 +611,41 @@ mod tests { assert_eq!(Decimal::from_str("nan").unwrap(), Decimal::NaN,); assert_eq!(Decimal::from_str("NaN").unwrap(), Decimal::NaN,); assert_eq!(Decimal::from_str("NAN").unwrap(), Decimal::NaN,); + assert_eq!(Decimal::from_str("nAn").unwrap(), Decimal::NaN,); + assert_eq!(Decimal::from_str("nAN").unwrap(), Decimal::NaN,); + assert_eq!(Decimal::from_str("Nan").unwrap(), Decimal::NaN,); + assert_eq!(Decimal::from_str("NAn").unwrap(), Decimal::NaN,); assert_eq!(Decimal::from_str("inf").unwrap(), Decimal::PositiveInf,); assert_eq!(Decimal::from_str("INF").unwrap(), Decimal::PositiveInf,); + assert_eq!(Decimal::from_str("iNF").unwrap(), Decimal::PositiveInf,); + assert_eq!(Decimal::from_str("inF").unwrap(), Decimal::PositiveInf,); + assert_eq!(Decimal::from_str("InF").unwrap(), Decimal::PositiveInf,); + assert_eq!(Decimal::from_str("INf").unwrap(), Decimal::PositiveInf,); assert_eq!(Decimal::from_str("+inf").unwrap(), Decimal::PositiveInf,); assert_eq!(Decimal::from_str("+INF").unwrap(), Decimal::PositiveInf,); assert_eq!(Decimal::from_str("+Inf").unwrap(), Decimal::PositiveInf,); + assert_eq!(Decimal::from_str("+iNF").unwrap(), Decimal::PositiveInf,); + assert_eq!(Decimal::from_str("+inF").unwrap(), Decimal::PositiveInf,); + assert_eq!(Decimal::from_str("+InF").unwrap(), Decimal::PositiveInf,); + assert_eq!(Decimal::from_str("+INf").unwrap(), Decimal::PositiveInf,); + assert_eq!(Decimal::from_str("inFINity").unwrap(), Decimal::PositiveInf,); + assert_eq!( + Decimal::from_str("+infiNIty").unwrap(), + Decimal::PositiveInf, + ); assert_eq!(Decimal::from_str("-inf").unwrap(), Decimal::NegativeInf,); assert_eq!(Decimal::from_str("-INF").unwrap(), Decimal::NegativeInf,); assert_eq!(Decimal::from_str("-Inf").unwrap(), Decimal::NegativeInf,); - - assert!(Decimal::from_str("nAn").is_err()); - assert!(Decimal::from_str("nAN").is_err()); - assert!(Decimal::from_str("Nan").is_err()); - assert!(Decimal::from_str("NAn").is_err()); - - assert!(Decimal::from_str("iNF").is_err()); - assert!(Decimal::from_str("inF").is_err()); - assert!(Decimal::from_str("InF").is_err()); - assert!(Decimal::from_str("INf").is_err()); - - assert!(Decimal::from_str("+iNF").is_err()); - assert!(Decimal::from_str("+inF").is_err()); - assert!(Decimal::from_str("+InF").is_err()); - assert!(Decimal::from_str("+INf").is_err()); - - assert!(Decimal::from_str("-iNF").is_err()); - assert!(Decimal::from_str("-inF").is_err()); - assert!(Decimal::from_str("-InF").is_err()); - assert!(Decimal::from_str("-INf").is_err()); + assert_eq!(Decimal::from_str("-iNF").unwrap(), Decimal::NegativeInf,); + assert_eq!(Decimal::from_str("-inF").unwrap(), Decimal::NegativeInf,); + assert_eq!(Decimal::from_str("-InF").unwrap(), Decimal::NegativeInf,); + assert_eq!(Decimal::from_str("-INf").unwrap(), Decimal::NegativeInf,); + assert_eq!( + Decimal::from_str("-INfinity").unwrap(), + Decimal::NegativeInf, + ); assert_eq!( Decimal::from_f32(10.0).unwrap() / Decimal::PositiveInf, diff --git a/src/common/src/types/interval.rs b/src/common/src/types/interval.rs index 89a1e242b9638..ab3b0ad17b026 100644 --- a/src/common/src/types/interval.rs +++ b/src/common/src/types/interval.rs @@ -762,7 +762,7 @@ impl IntervalUnit { result = result + (|| match interval_unit { Second => { // TODO: IntervalUnit only support millisecond precision so the part smaller than millisecond will be truncated. - if second < OrderedF64::from(0.001) { + if second > OrderedF64::from(0) && second < OrderedF64::from(0.001) { return None; } let ms = (second * 1000_f64).round() as i64; @@ -804,6 +804,9 @@ mod tests { #[test] fn test_parse() { + let interval = "04:00:00".parse::().unwrap(); + assert_eq!(interval, IntervalUnit::from_millis(4 * 3600 * 1000)); + let interval = "1 year 2 months 3 days 00:00:01" .parse::() .unwrap(); diff --git a/src/common/src/types/mod.rs b/src/common/src/types/mod.rs index 377186db97c6a..3048c86bfe3fd 100644 --- a/src/common/src/types/mod.rs +++ b/src/common/src/types/mod.rs @@ -16,7 +16,6 @@ use std::convert::TryFrom; use std::hash::Hash; use std::sync::Arc; -use anyhow::anyhow; use bytes::{Buf, BufMut, Bytes, BytesMut}; use parse_display::Display; use risingwave_pb::data::DataType as ProstDataType; @@ -902,99 +901,6 @@ impl ScalarImpl { Ok(deserializer.position() - base_position) } - - pub fn to_protobuf(&self) -> Vec { - let body = match self { - ScalarImpl::Int16(v) => v.to_be_bytes().to_vec(), - ScalarImpl::Int32(v) => v.to_be_bytes().to_vec(), - ScalarImpl::Int64(v) => v.to_be_bytes().to_vec(), - ScalarImpl::Float32(v) => v.to_be_bytes().to_vec(), - ScalarImpl::Float64(v) => v.to_be_bytes().to_vec(), - ScalarImpl::Utf8(s) => s.as_bytes().to_vec(), - ScalarImpl::Bool(v) => (*v as i8).to_be_bytes().to_vec(), - ScalarImpl::Decimal(v) => v.to_string().as_bytes().to_vec(), - ScalarImpl::Interval(v) => v.to_protobuf_owned(), - ScalarImpl::NaiveDate(v) => v.to_protobuf_owned(), - ScalarImpl::NaiveDateTime(v) => v.to_protobuf_owned(), - ScalarImpl::NaiveTime(v) => v.to_protobuf_owned(), - ScalarImpl::Struct(v) => v.to_protobuf_owned(), - ScalarImpl::List(v) => v.to_protobuf_owned(), - }; - body - } - - /// This encoding should only be used in proto - /// TODO: replace with value encoding? - pub fn from_proto_bytes(b: &Vec, data_type: &ProstDataType) -> ArrayResult { - let value = match data_type.get_type_name()? { - TypeName::Boolean => ScalarImpl::Bool( - i8::from_be_bytes( - b.as_slice() - .try_into() - .map_err(|e| anyhow!("Failed to deserialize bool, reason: {:?}", e))?, - ) == 1, - ), - TypeName::Int16 => ScalarImpl::Int16(i16::from_be_bytes( - b.as_slice() - .try_into() - .map_err(|e| anyhow!("Failed to deserialize i16, reason: {:?}", e))?, - )), - TypeName::Int32 => ScalarImpl::Int32(i32::from_be_bytes( - b.as_slice() - .try_into() - .map_err(|e| anyhow!("Failed to deserialize i32, reason: {:?}", e))?, - )), - t @ (TypeName::Int64 | TypeName::Timestampz) => { - ScalarImpl::Int64(i64::from_be_bytes(b.as_slice().try_into().map_err( - |e| anyhow!("Failed to deserialize i64 for {:?}, reason: {:?}", t, e), - )?)) - } - TypeName::Float => ScalarImpl::Float32( - f32::from_be_bytes( - b.as_slice() - .try_into() - .map_err(|e| anyhow!("Failed to deserialize f32, reason: {:?}", e))?, - ) - .into(), - ), - TypeName::Double => ScalarImpl::Float64( - f64::from_be_bytes( - b.as_slice() - .try_into() - .map_err(|e| anyhow!("Failed to deserialize f64, reason: {:?}", e))?, - ) - .into(), - ), - TypeName::Varchar => ScalarImpl::Utf8( - std::str::from_utf8(b) - .map_err(|e| anyhow!("Failed to deserialize varchar, reason: {:?}", e))? - .to_string(), - ), - TypeName::Decimal => ScalarImpl::Decimal( - Decimal::from_str(std::str::from_utf8(b).unwrap()) - .map_err(|e| anyhow!("Failed to deserialize decimal, reason: {:?}", e))?, - ), - TypeName::Interval => ScalarImpl::Interval(IntervalUnit::from_protobuf_bytes( - b, - data_type.get_interval_type().unwrap_or(Unspecified), - )?), - TypeName::Timestamp => { - ScalarImpl::NaiveDateTime(NaiveDateTimeWrapper::from_protobuf_bytes(b)?) - } - TypeName::Time => ScalarImpl::NaiveTime(NaiveTimeWrapper::from_protobuf_bytes(b)?), - TypeName::Date => ScalarImpl::NaiveDate(NaiveDateWrapper::from_protobuf_bytes(b)?), - TypeName::Struct => { - ScalarImpl::Struct(StructValue::from_protobuf_bytes(data_type.clone(), b)?) - } - TypeName::List => { - ScalarImpl::List(ListValue::from_protobuf_bytes(data_type.clone(), b)?) - } - TypeName::TypeUnspecified => { - bail!("Unrecognized type name: {:?}", data_type.get_type_name()) - } - }; - Ok(value) - } } pub fn literal_type_match(data_type: &DataType, literal: Option<&ScalarImpl>) -> bool { @@ -1114,31 +1020,6 @@ mod tests { assert_eq!(std::mem::size_of::(), 32); } - #[test] - fn test_protobuf_conversion() { - let v = ScalarImpl::NaiveDateTime(NaiveDateTimeWrapper::default()); - let actual = - ScalarImpl::from_proto_bytes(&v.to_protobuf(), &DataType::Timestamp.to_protobuf()) - .unwrap(); - assert_eq!(v, actual); - - let v = ScalarImpl::NaiveDate(NaiveDateWrapper::default()); - let actual = - ScalarImpl::from_proto_bytes(&v.to_protobuf(), &DataType::Date.to_protobuf()).unwrap(); - assert_eq!(v, actual); - - let v = ScalarImpl::NaiveTime(NaiveTimeWrapper::default()); - let actual = - ScalarImpl::from_proto_bytes(&v.to_protobuf(), &DataType::Time.to_protobuf()).unwrap(); - assert_eq!(v, actual); - - let v = ScalarImpl::Int64(1); - let actual = - ScalarImpl::from_proto_bytes(&v.to_protobuf(), &DataType::Timestampz.to_protobuf()) - .unwrap(); - assert_eq!(v, actual); - } - #[test] fn test_data_type_display() { let d: DataType = DataType::new_struct( diff --git a/src/common/src/util/scan_range.rs b/src/common/src/util/scan_range.rs index 673aee5e01dd2..34194e5f4b3e7 100644 --- a/src/common/src/util/scan_range.rs +++ b/src/common/src/util/scan_range.rs @@ -19,6 +19,7 @@ use paste::paste; use risingwave_pb::batch_plan::scan_range::Bound as BoundProst; use risingwave_pb::batch_plan::ScanRange as ScanRangeProst; +use super::value_encoding::serialize_datum_to_bytes; use crate::array::Row; use crate::types::{Datum, ScalarImpl, VirtualNode}; use crate::util::hash_util::Crc32FastBuilder; @@ -34,11 +35,11 @@ pub struct ScanRange { fn bound_to_proto(bound: &Bound) -> Option { match bound { Bound::Included(literal) => Some(BoundProst { - value: literal.to_protobuf(), + value: serialize_datum_to_bytes(Some(literal)), inclusive: true, }), Bound::Excluded(literal) => Some(BoundProst { - value: literal.to_protobuf(), + value: serialize_datum_to_bytes(Some(literal)), inclusive: false, }), Bound::Unbounded => None, diff --git a/src/common/src/util/value_encoding/mod.rs b/src/common/src/util/value_encoding/mod.rs index ec4a930e6f1a6..7f7e478b09655 100644 --- a/src/common/src/util/value_encoding/mod.rs +++ b/src/common/src/util/value_encoding/mod.rs @@ -31,13 +31,11 @@ use error::ValueEncodingError; pub type Result = std::result::Result; -/// Serialize datum into cell bytes (Not order guarantee, used in value encoding). -pub fn serialize_cell(cell: &Datum) -> Result> { +/// Serialize a datum into bytes and return (Not order guarantee, used in value encoding). +pub fn serialize_datum_to_bytes(cell: Option<&ScalarImpl>) -> Vec { let mut buf: Vec = vec![]; - if let Some(datum) = cell { - serialize_value(datum.as_scalar_ref_impl(), &mut buf) - } - Ok(buf) + serialize_datum_ref(&cell.map(|scala| scala.as_scalar_ref_impl()), &mut buf); + buf } /// Serialize a datum into bytes (Not order guarantee, used in value encoding). diff --git a/src/compute/src/lib.rs b/src/compute/src/lib.rs index ba57146767e21..1130af170039a 100644 --- a/src/compute/src/lib.rs +++ b/src/compute/src/lib.rs @@ -26,9 +26,17 @@ extern crate tracing; pub mod rpc; pub mod server; +use clap::clap_derive::ArgEnum; use clap::Parser; use serde::{Deserialize, Serialize}; +#[derive(Debug, Clone, ArgEnum)] +pub enum AsyncStackTraceOption { + Off, + On, + Verbose, +} + /// Command-line arguments for compute-node. #[derive(Parser, Debug)] pub struct ComputeNodeOpts { @@ -64,8 +72,8 @@ pub struct ComputeNodeOpts { pub enable_jaeger_tracing: bool, /// Enable async stack tracing for risectl. - #[clap(long)] - pub enable_async_stack_trace: bool, + #[clap(long, arg_enum, default_value_t = AsyncStackTraceOption::Off)] + pub async_stack_trace: AsyncStackTraceOption, /// Path to file cache data directory. /// Left empty to disable file cache. @@ -89,7 +97,7 @@ pub fn start(opts: ComputeNodeOpts) -> Pin + Send>> // WARNING: don't change the function signature. Making it `async fn` will cause // slow compile in release mode. Box::pin(async move { - tracing::info!("meta address: {}", opts.meta_address.clone()); + tracing::info!("Compute node options: {:?}", opts); let listen_address = opts.host.parse().unwrap(); tracing::info!("Server Listening at {}", listen_address); diff --git a/src/compute/src/rpc/service/monitor_service.rs b/src/compute/src/rpc/service/monitor_service.rs index f618274eca4a6..ec6c70bec8834 100644 --- a/src/compute/src/rpc/service/monitor_service.rs +++ b/src/compute/src/rpc/service/monitor_service.rs @@ -108,9 +108,8 @@ pub mod grpc_middleware { use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; use std::task::{Context, Poll}; - use std::time::Duration; - use async_stack_trace::{SpanValue, StackTraceManager}; + use async_stack_trace::{SpanValue, StackTraceManager, TraceConfig}; use futures::Future; use hyper::Body; use tokio::sync::Mutex; @@ -124,19 +123,20 @@ pub mod grpc_middleware { #[derive(Clone)] pub struct StackTraceMiddlewareLayer { manager: GrpcStackTraceManagerRef, + config: TraceConfig, } pub type OptionalStackTraceMiddlewareLayer = Either; impl StackTraceMiddlewareLayer { - pub fn new(manager: GrpcStackTraceManagerRef) -> Self { - Self { manager } + pub fn new(manager: GrpcStackTraceManagerRef, config: TraceConfig) -> Self { + Self { manager, config } } pub fn new_optional( - manager: Option, + optional: Option<(GrpcStackTraceManagerRef, TraceConfig)>, ) -> OptionalStackTraceMiddlewareLayer { - if let Some(manager) = manager { - Either::A(Self::new(manager)) + if let Some((manager, config)) = optional { + Either::A(Self::new(manager, config)) } else { Either::B(Identity::new()) } @@ -150,6 +150,7 @@ pub mod grpc_middleware { StackTraceMiddleware { inner: service, manager: self.manager.clone(), + config: self.config.clone(), next_id: Default::default(), } } @@ -159,6 +160,7 @@ pub mod grpc_middleware { pub struct StackTraceMiddleware { inner: S, manager: GrpcStackTraceManagerRef, + config: TraceConfig, next_id: Arc, } @@ -185,19 +187,13 @@ pub mod grpc_middleware { let id = self.next_id.fetch_add(1, Ordering::SeqCst); let manager = self.manager.clone(); + let config = self.config.clone(); async move { let sender = manager.lock().await.register(id); let root_span: SpanValue = format!("{}:{}", req.uri().path(), id).into(); - sender - .trace( - inner.call(req), - root_span, - false, - Duration::from_millis(100), - ) - .await + sender.trace(inner.call(req), root_span, config).await } } } diff --git a/src/compute/src/server.rs b/src/compute/src/server.rs index b17d5b02b9015..022e6528a64a1 100644 --- a/src/compute/src/server.rs +++ b/src/compute/src/server.rs @@ -53,7 +53,7 @@ use crate::rpc::service::monitor_service::{ GrpcStackTraceManagerRef, MonitorServiceImpl, StackTraceMiddlewareLayer, }; use crate::rpc::service::stream_service::StreamServiceImpl; -use crate::{ComputeNodeConfig, ComputeNodeOpts}; +use crate::{AsyncStackTraceOption, ComputeNodeConfig, ComputeNodeOpts}; /// Bootstraps the compute-node. pub async fn compute_node_serve( @@ -177,6 +177,15 @@ pub async fn compute_node_serve( extra_info_sources, )); + let async_stack_trace_config = match opts.async_stack_trace { + AsyncStackTraceOption::Off => None, + c => Some(async_stack_trace::TraceConfig { + report_detached: true, + verbose: matches!(c, AsyncStackTraceOption::Verbose), + interval: Duration::from_secs(1), + }), + }; + // Initialize the managers. let batch_mgr = Arc::new(BatchManager::new(config.batch.worker_threads_num)); let stream_mgr = Arc::new(LocalStreamManager::new( @@ -184,7 +193,7 @@ pub async fn compute_node_serve( state_store.clone(), streaming_metrics.clone(), config.streaming.clone(), - opts.enable_async_stack_trace, + async_stack_trace_config.clone(), opts.enable_managed_cache, )); let source_mgr = Arc::new(TableSourceManager::new( @@ -237,8 +246,7 @@ pub async fn compute_node_serve( .initial_connection_window_size(MAX_CONNECTION_WINDOW_SIZE) .tcp_nodelay(true) .layer(StackTraceMiddlewareLayer::new_optional( - opts.enable_async_stack_trace - .then_some(grpc_stack_trace_mgr), + async_stack_trace_config.map(|c| (grpc_stack_trace_mgr, c)), )) .add_service(TaskServiceServer::new(batch_srv)) .add_service(ExchangeServiceServer::new(exchange_srv)) diff --git a/src/compute/tests/integration_tests.rs b/src/compute/tests/integration_tests.rs index 60ef766d9c595..7dd1b7ee63086 100644 --- a/src/compute/tests/integration_tests.rs +++ b/src/compute/tests/integration_tests.rs @@ -34,7 +34,8 @@ use risingwave_common::test_prelude::DataChunkTestExt; use risingwave_common::types::{DataType, IntoOrdered}; use risingwave_common::util::epoch::EpochPair; use risingwave_common::util::sort_util::{OrderPair, OrderType}; -use risingwave_source::{SourceDescBuilder, TableSourceManager, TableSourceManagerRef}; +use risingwave_source::table_test_utils::create_table_source_desc_builder; +use risingwave_source::{TableSourceManager, TableSourceManagerRef}; use risingwave_storage::memory::MemoryStateStore; use risingwave_storage::table::batch_table::storage_table::StorageTable; use risingwave_storage::table::streaming_table::state_table::StateTable; @@ -88,7 +89,6 @@ impl SingleChunkExecutor { #[tokio::test] async fn test_table_materialize() -> StreamResult<()> { use risingwave_common::types::DataType; - use risingwave_source::table_test_utils::create_table_info; use risingwave_stream::executor::state_table_handler::default_source_internal_table; let memory_state_store = MemoryStateStore::new(); @@ -100,8 +100,13 @@ async fn test_table_materialize() -> StreamResult<()> { Field::unnamed(DataType::Float64), ], }; - let info = create_table_info(&schema, Some(0), vec![0]); - let source_builder = SourceDescBuilder::new(source_table_id, &info, &source_manager); + let source_builder = create_table_source_desc_builder( + &schema, + source_table_id, + Some(0), + vec![0], + source_manager.clone(), + ); // Ensure the source exists let source_desc = source_builder.build().await.unwrap(); diff --git a/src/ctl/src/cmd_impl/hummock/list_kv.rs b/src/ctl/src/cmd_impl/hummock/list_kv.rs index 049bcb991c400..1139dbe902dcc 100644 --- a/src/ctl/src/cmd_impl/hummock/list_kv.rs +++ b/src/ctl/src/cmd_impl/hummock/list_kv.rs @@ -28,7 +28,6 @@ pub async fn list_kv(epoch: u64, table_id: u32) -> anyhow::Result<()> { } let scan_result = { let mut buf = BytesMut::with_capacity(5); - buf.put_u8(b't'); buf.put_u32(table_id); let range = buf.to_vec()..next_key(buf.to_vec().as_slice()); hummock @@ -45,21 +44,8 @@ pub async fn list_kv(epoch: u64, table_id: u32) -> anyhow::Result<()> { .await? }; for (k, v) in scan_result { - let print_string = match k[0] { - b't' => { - let mut buf = &k[1..]; - format!("[t{}]", buf.get_u32()) // table id - } - b's' => { - let mut buf = &k[1..]; - format!("[s{}]", buf.get_u64()) // shared executor root - } - b'e' => { - let mut buf = &k[1..]; - format!("[e{}]", buf.get_u64()) // executor id - } - _ => "no title".to_string(), - }; + let mut buf = &k[..]; + let print_string = format!("[t{}]", buf.get_u32()); println!("{} {:?} => {:?}", print_string, k, v) } hummock_opts.shutdown().await; diff --git a/src/ctl/src/cmd_impl/hummock/sst_dump.rs b/src/ctl/src/cmd_impl/hummock/sst_dump.rs index a197c84ebbcc2..05fd61985285d 100644 --- a/src/ctl/src/cmd_impl/hummock/sst_dump.rs +++ b/src/ctl/src/cmd_impl/hummock/sst_dump.rs @@ -184,10 +184,7 @@ fn print_table_column( table_data: &TableData, is_put: bool, ) -> anyhow::Result<()> { - let table_id = match get_table_id(full_key) { - None => return Ok(()), - Some(table_id) => table_id, - }; + let table_id = get_table_id(full_key); print!("\t\t table: {} - ", table_id); let table_catalog = match table_data.get(&table_id) { diff --git a/src/ctl/src/cmd_impl/trace.rs b/src/ctl/src/cmd_impl/trace.rs index 0e621735f8f9c..7ba721e7a7feb 100644 --- a/src/ctl/src/cmd_impl/trace.rs +++ b/src/ctl/src/cmd_impl/trace.rs @@ -54,9 +54,7 @@ pub async fn trace() -> anyhow::Result<()> { } if all_actor_traces.is_empty() && all_rpc_traces.is_empty() { - println!( - "No traces found. No actors are running, or `--enable-async-stack-trace` not set?" - ); + println!("No traces found. No actors are running, or `--async-stack-trace` not set?"); } else { println!("--- Actor Traces ---"); for (key, trace) in all_actor_traces { diff --git a/src/expr/src/expr/build_expr_from_prost.rs b/src/expr/src/expr/build_expr_from_prost.rs index cb89028121370..499132db8fc32 100644 --- a/src/expr/src/expr/build_expr_from_prost.rs +++ b/src/expr/src/expr/build_expr_from_prost.rs @@ -12,7 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -use risingwave_common::types::{DataType, ScalarImpl}; +use risingwave_common::types::DataType; +use risingwave_common::util::value_encoding::deserialize_datum; use risingwave_pb::expr::expr_node::RexNode; use risingwave_pb::expr::ExprNode; @@ -215,7 +216,7 @@ pub fn build_to_char_expr(prost: &ExprNode) -> Result { let data_expr = expr_build_from_prost(&children[0])?; let tmpl_node = &children[1]; if let RexNode::Constant(tmpl_value) = tmpl_node.get_rex_node().unwrap() - && let Ok(tmpl) = ScalarImpl::from_proto_bytes(tmpl_value.get_body(), tmpl_node.get_return_type().unwrap()) + && let Ok(Some(tmpl)) = deserialize_datum(tmpl_value.get_body().as_slice(), &DataType::from(tmpl_node.get_return_type().unwrap())) { let tmpl = tmpl.as_utf8(); let pattern = compile_pattern_to_chrono(tmpl); @@ -237,6 +238,8 @@ mod tests { use std::vec; use risingwave_common::array::{ArrayImpl, DataChunk, Utf8Array}; + use risingwave_common::types::Scalar; + use risingwave_common::util::value_encoding::serialize_datum_to_bytes; use risingwave_pb::data::data_type::TypeName; use risingwave_pb::data::DataType as ProstDataType; use risingwave_pb::expr::expr_node::{RexNode, Type}; @@ -255,7 +258,9 @@ mod tests { ..Default::default() }), rex_node: Some(RexNode::Constant(ConstantValue { - body: "foo".as_bytes().to_vec(), + body: serialize_datum_to_bytes( + Some("foo".to_owned().to_scalar_value()).as_ref(), + ), })), }, ExprNode { @@ -265,7 +270,9 @@ mod tests { ..Default::default() }), rex_node: Some(RexNode::Constant(ConstantValue { - body: "bar".as_bytes().to_vec(), + body: serialize_datum_to_bytes( + Some("bar".to_owned().to_scalar_value()).as_ref(), + ), })), }, ], @@ -291,7 +298,7 @@ mod tests { ..Default::default() }), rex_node: Some(RexNode::Constant(ConstantValue { - body: vec![0, 0, 0, 1], + body: serialize_datum_to_bytes(Some(1_i32.to_scalar_value()).as_ref()), })), }, ], @@ -321,7 +328,7 @@ mod tests { ..Default::default() }), rex_node: Some(RexNode::Constant(ConstantValue { - body: "DAY".as_bytes().to_vec(), + body: serialize_datum_to_bytes(Some("DAY".to_string().to_scalar_value()).as_ref()), })), }; let right_date = ExprNode { diff --git a/src/expr/src/expr/expr_field.rs b/src/expr/src/expr/expr_field.rs index eb02c5df4d6d3..823c116bcc6a7 100644 --- a/src/expr/src/expr/expr_field.rs +++ b/src/expr/src/expr/expr_field.rs @@ -17,6 +17,7 @@ use std::convert::TryFrom; use anyhow::anyhow; use risingwave_common::array::{ArrayImpl, ArrayRef, DataChunk, Row}; use risingwave_common::types::{DataType, Datum}; +use risingwave_common::util::value_encoding::deserialize_datum; use risingwave_pb::expr::expr_node::{RexNode, Type}; use risingwave_pb::expr::ExprNode; @@ -79,13 +80,12 @@ impl<'a> TryFrom<&'a ExprNode> for FieldExpression { let RexNode::Constant(value) = second.get_rex_node().unwrap() else { bail!("Expected Constant as 1st argument"); }; - let index = i32::from_be_bytes( - value - .body - .clone() - .try_into() - .map_err(|e| anyhow!("Failed to deserialize i32, reason: {:?}", e))?, - ); + let index = deserialize_datum(value.body.as_slice(), &DataType::Int32) + .map_err(|e| anyhow!("Failed to deserialize i32, reason: {:?}", e))? + .unwrap() + .as_int32() + .to_owned(); + Ok(FieldExpression::new(ret_type, input, index as usize)) } } diff --git a/src/expr/src/expr/expr_in.rs b/src/expr/src/expr/expr_in.rs index 3fc3559daee85..d930b9f74cdba 100644 --- a/src/expr/src/expr/expr_in.rs +++ b/src/expr/src/expr/expr_in.rs @@ -123,6 +123,7 @@ mod tests { use risingwave_common::array::{DataChunk, Row}; use risingwave_common::test_prelude::DataChunkTestExt; use risingwave_common::types::{DataType, Scalar, ScalarImpl}; + use risingwave_common::util::value_encoding::serialize_datum_to_bytes; use risingwave_pb::data::data_type::TypeName; use risingwave_pb::data::DataType as ProstDataType; use risingwave_pb::expr::expr_node::{RexNode, Type}; @@ -150,7 +151,9 @@ mod tests { ..Default::default() }), rex_node: Some(RexNode::Constant(ConstantValue { - body: "ABC".as_bytes().to_vec(), + body: serialize_datum_to_bytes( + Some("ABC".to_string().to_scalar_value()).as_ref(), + ), })), }, ExprNode { @@ -160,7 +163,9 @@ mod tests { ..Default::default() }), rex_node: Some(RexNode::Constant(ConstantValue { - body: "def".as_bytes().to_vec(), + body: serialize_datum_to_bytes( + Some("def".to_string().to_scalar_value()).as_ref(), + ), })), }, ]; diff --git a/src/expr/src/expr/expr_literal.rs b/src/expr/src/expr/expr_literal.rs index 1678ef0534063..e0754f65c539f 100644 --- a/src/expr/src/expr/expr_literal.rs +++ b/src/expr/src/expr/expr_literal.rs @@ -18,6 +18,7 @@ use std::sync::Arc; use risingwave_common::array::{Array, ArrayBuilder, ArrayBuilderImpl, ArrayRef, DataChunk, Row}; use risingwave_common::for_all_variants; use risingwave_common::types::{literal_type_match, DataType, Datum, Scalar, ScalarImpl}; +use risingwave_common::util::value_encoding::deserialize_datum; use risingwave_pb::expr::expr_node::{RexNode, Type}; use risingwave_pb::expr::ExprNode; @@ -112,13 +113,14 @@ impl<'a> TryFrom<&'a ExprNode> for LiteralExpression { if let RexNode::Constant(prost_value) = prost.get_rex_node().unwrap() { // TODO: We need to unify these - let value = ScalarImpl::from_proto_bytes( - prost_value.get_body(), - prost.get_return_type().unwrap(), - )?; + let value = deserialize_datum( + prost_value.get_body().as_slice(), + &DataType::from(prost.get_return_type().unwrap()), + ) + .map_err(|e| ExprError::Internal(e.into()))?; Ok(Self { return_type: ret_type, - literal: Some(value), + literal: value, }) } else { bail!("Cannot parse the RexNode"); @@ -131,6 +133,7 @@ mod tests { use risingwave_common::array::{I32Array, StructValue}; use risingwave_common::array_nonnull; use risingwave_common::types::{Decimal, IntervalUnit, IntoOrdered}; + use risingwave_common::util::value_encoding::serialize_datum_to_bytes; use risingwave_pb::data::data_type::{IntervalType, TypeName}; use risingwave_pb::data::DataType as ProstDataType; use risingwave_pb::expr::expr_node::RexNode::Constant; @@ -146,7 +149,7 @@ mod tests { Some(2.into()), None, ]); - let body = value.to_protobuf_owned(); + let body = serialize_datum_to_bytes(Some(value.clone().to_scalar_value()).as_ref()); let expr = ExprNode { expr_type: Type::ConstantValue as i32, return_type: Some(ProstDataType { @@ -177,57 +180,39 @@ mod tests { fn test_expr_literal_from() { let v = true; let t = TypeName::Boolean; - let bytes = (v as i8).to_be_bytes().to_vec(); - // construct LiteralExpression in various types below with value 1i8, and expect Err - for typ in [ - TypeName::Int16, - TypeName::Int32, - TypeName::Int64, - TypeName::Float, - TypeName::Double, - TypeName::Interval, - TypeName::Date, - TypeName::Struct, - ] { - assert!( - LiteralExpression::try_from(&make_expression(Some(bytes.clone()), typ)).is_err() - ); - } + let bytes = serialize_datum_to_bytes(Some(v.to_scalar_value()).as_ref()); + let expr = LiteralExpression::try_from(&make_expression(Some(bytes), t)).unwrap(); assert_eq!(v.to_scalar_value(), expr.literal().unwrap()); let v = 1i16; let t = TypeName::Int16; - let bytes = v.to_be_bytes().to_vec(); - assert!(LiteralExpression::try_from(&make_expression( - Some(bytes.clone()), - TypeName::Boolean, - )) - .is_err()); + let bytes = serialize_datum_to_bytes(Some(v.to_scalar_value()).as_ref()); + let expr = LiteralExpression::try_from(&make_expression(Some(bytes), t)).unwrap(); assert_eq!(v.to_scalar_value(), expr.literal().unwrap()); let v = 1i32; let t = TypeName::Int32; - let bytes = v.to_be_bytes().to_vec(); + let bytes = serialize_datum_to_bytes(Some(v.to_scalar_value()).as_ref()); let expr = LiteralExpression::try_from(&make_expression(Some(bytes), t)).unwrap(); assert_eq!(v.to_scalar_value(), expr.literal().unwrap()); let v = 1i64; let t = TypeName::Int64; - let bytes = v.to_be_bytes().to_vec(); + let bytes = serialize_datum_to_bytes(Some(v.to_scalar_value()).as_ref()); let expr = LiteralExpression::try_from(&make_expression(Some(bytes), t)).unwrap(); assert_eq!(v.to_scalar_value(), expr.literal().unwrap()); let v = 1f32.into_ordered(); let t = TypeName::Float; - let bytes = v.to_be_bytes().to_vec(); + let bytes = serialize_datum_to_bytes(Some(v.to_scalar_value()).as_ref()); let expr = LiteralExpression::try_from(&make_expression(Some(bytes), t)).unwrap(); assert_eq!(v.to_scalar_value(), expr.literal().unwrap()); let v = 1f64.into_ordered(); let t = TypeName::Double; - let bytes = v.to_be_bytes().to_vec(); + let bytes = serialize_datum_to_bytes(Some(v.to_scalar_value()).as_ref()); let expr = LiteralExpression::try_from(&make_expression(Some(bytes), t)).unwrap(); assert_eq!(v.to_scalar_value(), expr.literal().unwrap()); @@ -238,24 +223,20 @@ mod tests { let v = String::from("varchar"); let t = TypeName::Varchar; - let bytes = v.as_bytes().to_vec(); + let bytes = serialize_datum_to_bytes(Some(v.clone().to_scalar_value()).as_ref()); let expr = LiteralExpression::try_from(&make_expression(Some(bytes), t)).unwrap(); assert_eq!(v.to_scalar_value(), expr.literal().unwrap()); let v = Decimal::new(3141, 3); let t = TypeName::Decimal; - let bytes = v.to_string().as_bytes().to_vec(); + let bytes = serialize_datum_to_bytes(Some(v.to_scalar_value()).as_ref()); let expr = LiteralExpression::try_from(&make_expression(Some(bytes), t)).unwrap(); assert_eq!(v.to_scalar_value(), expr.literal().unwrap()); - let v = String::from("NaN"); - let t = TypeName::Decimal; - let bytes = v.as_bytes().to_vec(); - assert!(LiteralExpression::try_from(&make_expression(Some(bytes), t)).is_ok()); - let v = 32i32; let t = TypeName::Interval; - let bytes = v.to_be_bytes().to_vec(); + let bytes = + serialize_datum_to_bytes(Some(IntervalUnit::from_month(v).to_scalar_value()).as_ref()); let expr = LiteralExpression::try_from(&make_expression(Some(bytes), t)).unwrap(); assert_eq!( IntervalUnit::from_month(v).to_scalar_value(), diff --git a/src/expr/src/expr/expr_regexp.rs b/src/expr/src/expr/expr_regexp.rs index ed84d6ecb4f43..48f1cbf976904 100644 --- a/src/expr/src/expr/expr_regexp.rs +++ b/src/expr/src/expr/expr_regexp.rs @@ -21,6 +21,7 @@ use risingwave_common::array::{ Utf8Array, }; use risingwave_common::types::{DataType, Datum, Scalar, ScalarImpl}; +use risingwave_common::util::value_encoding::deserialize_datum; use risingwave_pb::expr::expr_node::{RexNode, Type}; use risingwave_pb::expr::ExprNode; @@ -61,10 +62,13 @@ impl<'a> TryFrom<&'a ExprNode> for RegexpMatchExpression { let RexNode::Constant(pattern_value) = pattern_node.get_rex_node().unwrap() else { return Err(ExprError::UnsupportedFunction("non-constant pattern in regexp_match".to_string())) }; - let pattern_scalar = ScalarImpl::from_proto_bytes( - pattern_value.get_body(), - pattern_node.get_return_type().unwrap(), - )?; + let pattern_scalar = deserialize_datum( + pattern_value.get_body().as_slice(), + &DataType::from(pattern_node.get_return_type().unwrap()), + ) + .map_err(|e| ExprError::Internal(e.into()))? + .unwrap(); + let ScalarImpl::Utf8(pattern) = pattern_scalar else { bail!("Expected pattern to be an String"); }; diff --git a/src/expr/src/expr/test_utils.rs b/src/expr/src/expr/test_utils.rs index 58d072d984444..75ba29986ec98 100644 --- a/src/expr/src/expr/test_utils.rs +++ b/src/expr/src/expr/test_utils.rs @@ -13,6 +13,8 @@ // limitations under the License. use itertools::Itertools; +use risingwave_common::types::ScalarImpl; +use risingwave_common::util::value_encoding::serialize_datum_to_bytes; use risingwave_pb::data::data_type::TypeName; use risingwave_pb::data::{DataType as ProstDataType, DataType}; use risingwave_pb::expr::expr_node::Type::{Field, InputRef}; @@ -55,7 +57,7 @@ pub fn make_i32_literal(data: i32) -> ExprNode { ..Default::default() }), rex_node: Some(RexNode::Constant(ConstantValue { - body: data.to_be_bytes().to_vec(), + body: serialize_datum_to_bytes(Some(ScalarImpl::Int32(data)).as_ref()), })), } } diff --git a/src/expr/src/table_function/regexp_matches.rs b/src/expr/src/table_function/regexp_matches.rs index 4675aa45d4d34..27e90e1e22be7 100644 --- a/src/expr/src/table_function/regexp_matches.rs +++ b/src/expr/src/table_function/regexp_matches.rs @@ -17,6 +17,7 @@ use std::sync::Arc; use regex::Regex; use risingwave_common::array::{Array, ArrayRef, DataChunk, ListValue, Utf8Array}; use risingwave_common::types::{Scalar, ScalarImpl}; +use risingwave_common::util::value_encoding::deserialize_datum; use risingwave_common::{bail, ensure}; use risingwave_pb::expr::expr_node::RexNode; @@ -140,10 +141,13 @@ pub fn new_regexp_matches( let RexNode::Constant(pattern_value) = pattern_node.get_rex_node().unwrap() else { return Err(ExprError::UnsupportedFunction("non-constant pattern in regexp_match".to_string())) }; - let pattern_scalar = ScalarImpl::from_proto_bytes( - pattern_value.get_body(), - pattern_node.get_return_type().unwrap(), - )?; + let pattern_scalar = deserialize_datum( + pattern_value.get_body().as_slice(), + &DataType::from(pattern_node.get_return_type().unwrap()), + ) + .map_err(|e| ExprError::Internal(e.into()))? + .unwrap(); + let ScalarImpl::Utf8(pattern) = pattern_scalar else { bail!("Expected pattern to be an String"); }; diff --git a/src/frontend/src/catalog/source_catalog.rs b/src/frontend/src/catalog/source_catalog.rs index 5f0fb5ff7b83d..fc119576a45b8 100644 --- a/src/frontend/src/catalog/source_catalog.rs +++ b/src/frontend/src/catalog/source_catalog.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashMap; + use risingwave_pb::catalog::source::Info; use risingwave_pb::catalog::{Source as ProstSource, StreamSourceInfo, TableSourceInfo}; @@ -38,6 +40,8 @@ pub struct SourceCatalog { pub append_only: bool, pub owner: u32, pub info: SourceCatalogInfo, + pub row_id_index: Option, + pub properties: HashMap, } impl SourceCatalog { @@ -54,32 +58,28 @@ impl From<&ProstSource> for SourceCatalog { fn from(prost: &ProstSource) -> Self { let id = prost.id; let name = prost.name.clone(); - let (prost_columns, pk_col_ids, with_options, info) = match &prost.info { - Some(Info::StreamSource(source)) => ( - source.columns.clone(), - source - .pk_column_ids - .clone() - .into_iter() - .map(Into::into) - .collect(), - WithOptions::new(source.properties.clone()), - SourceCatalogInfo::StreamSource(source.clone()), - ), - Some(Info::TableSource(source)) => ( - source.columns.clone(), - source - .pk_column_ids - .clone() - .into_iter() - .map(Into::into) - .collect(), - WithOptions::new(source.properties.clone()), - SourceCatalogInfo::TableSource(source.clone()), - ), + let prost_columns = prost.columns.clone(); + let pk_col_ids = prost + .pk_column_ids + .clone() + .into_iter() + .map(Into::into) + .collect(); + let with_options = WithOptions::new(prost.properties.clone()); + let info = match &prost.info { + Some(Info::StreamSource(info_inner)) => { + SourceCatalogInfo::StreamSource(info_inner.clone()) + } + Some(Info::TableSource(info_inner)) => { + SourceCatalogInfo::TableSource(info_inner.clone()) + } None => unreachable!(), }; let columns = prost_columns.into_iter().map(ColumnCatalog::from).collect(); + let row_id_index = prost + .row_id_index + .clone() + .map(|row_id_index| row_id_index.index as _); let append_only = with_options.append_only(); let owner = prost.owner; @@ -92,6 +92,8 @@ impl From<&ProstSource> for SourceCatalog { append_only, owner, info, + row_id_index, + properties: with_options.into_inner(), } } } diff --git a/src/frontend/src/expr/literal.rs b/src/frontend/src/expr/literal.rs index 60e4048f2e5ff..11f52288ad31a 100644 --- a/src/frontend/src/expr/literal.rs +++ b/src/frontend/src/expr/literal.rs @@ -13,6 +13,7 @@ // limitations under the License. use risingwave_common::types::{literal_type_match, DataType, Datum, ScalarImpl}; +use risingwave_common::util::value_encoding::serialize_datum_to_bytes; use risingwave_pb::expr::expr_node::RexNode; use super::Expr; @@ -68,18 +69,18 @@ impl Expr for Literal { ExprNode { expr_type: self.get_expr_type() as i32, return_type: Some(self.return_type().to_protobuf()), - rex_node: literal_to_protobuf(self.get_data()), + rex_node: literal_to_value_encoding(self.get_data()), } } } /// Convert a literal value (datum) into protobuf. -fn literal_to_protobuf(d: &Datum) -> Option { - let Some(d) = d.as_ref() else { +fn literal_to_value_encoding(d: &Datum) -> Option { + if d.is_none() { return None; - }; + } use risingwave_pb::expr::*; - let body = d.to_protobuf(); + let body = serialize_datum_to_bytes(d.as_ref()); Some(RexNode::Constant(ConstantValue { body })) } @@ -87,46 +88,51 @@ fn literal_to_protobuf(d: &Datum) -> Option { mod tests { use risingwave_common::array::{ListValue, StructValue}; use risingwave_common::types::{DataType, ScalarImpl}; + use risingwave_common::util::value_encoding::deserialize_datum; use risingwave_pb::expr::expr_node::RexNode; - use crate::expr::literal::literal_to_protobuf; + use crate::expr::literal::literal_to_value_encoding; #[test] - fn test_struct_to_protobuf() { + fn test_struct_to_value_encoding() { let value = StructValue::new(vec![ - Some(ScalarImpl::Utf8("12222".to_string())), + Some(ScalarImpl::Utf8("".to_string())), Some(2.into()), Some(3.into()), ]); let data = Some(ScalarImpl::Struct(value.clone())); - let node = literal_to_protobuf(&data); + let node = literal_to_value_encoding(&data); if let RexNode::Constant(prost) = node.as_ref().unwrap() { - let data2 = ScalarImpl::from_proto_bytes( - prost.get_body(), + let data2 = deserialize_datum( + prost.get_body().as_slice(), &DataType::new_struct( vec![DataType::Varchar, DataType::Int32, DataType::Int32], vec![], - ) - .to_protobuf(), + ), ) + .unwrap() .unwrap(); assert_eq!(ScalarImpl::Struct(value), data2); } } #[test] - fn test_list_to_protobuf() { - let value = ListValue::new(vec![Some(1.into()), Some(2.into()), Some(3.into())]); + fn test_list_to_value_encoding() { + let value = ListValue::new(vec![ + Some(ScalarImpl::Utf8("1".to_owned())), + Some(ScalarImpl::Utf8("2".to_owned())), + Some(ScalarImpl::Utf8("".to_owned())), + ]); let data = Some(ScalarImpl::List(value.clone())); - let node = literal_to_protobuf(&data); + let node = literal_to_value_encoding(&data); if let RexNode::Constant(prost) = node.as_ref().unwrap() { - let data2 = ScalarImpl::from_proto_bytes( - prost.get_body(), + let data2 = deserialize_datum( + prost.get_body().as_slice(), &DataType::List { - datatype: Box::new(DataType::Int32), - } - .to_protobuf(), + datatype: Box::new(DataType::Varchar), + }, ) + .unwrap() .unwrap(); assert_eq!(ScalarImpl::List(value), data2); } diff --git a/src/frontend/src/handler/create_mv.rs b/src/frontend/src/handler/create_mv.rs index ca35f5f8507ca..cbd4b0e29cdad 100644 --- a/src/frontend/src/handler/create_mv.rs +++ b/src/frontend/src/handler/create_mv.rs @@ -64,7 +64,7 @@ pub fn gen_create_mv_plan( (db_id, schema.id()) }; - let definition = format!("{}", query); + let definition = query.to_string(); let bound = { let mut binder = Binder::new(session); diff --git a/src/frontend/src/handler/create_source.rs b/src/frontend/src/handler/create_source.rs index d57e096ec0e18..538dbea23d648 100644 --- a/src/frontend/src/handler/create_source.rs +++ b/src/frontend/src/handler/create_source.rs @@ -44,6 +44,10 @@ use crate::stream_fragmenter::build_graph; pub(crate) fn make_prost_source( session: &SessionImpl, name: ObjectName, + row_id_index: Option, + columns: Vec, + pk_column_ids: Vec, + properties: HashMap, source_info: Info, ) -> Result { let db_name = session.database(); @@ -79,6 +83,10 @@ pub(crate) fn make_prost_source( schema_id, database_id, name, + row_id_index, + columns, + pk_column_ids, + properties, info: Some(source_info), owner: session.user_id(), }) @@ -101,8 +109,16 @@ async fn extract_avro_table_schema( } /// Map a protobuf schema to a relational schema. -fn extract_protobuf_table_schema(schema: &ProtobufSchema) -> Result> { - let parser = ProtobufParser::new(&schema.row_schema_location.0, &schema.message_name.0)?; +async fn extract_protobuf_table_schema( + schema: &ProtobufSchema, + with_properties: HashMap, +) -> Result> { + let parser = ProtobufParser::new( + &schema.row_schema_location.0, + &schema.message_name.0, + with_properties, + ) + .await?; let column_descs = parser.map_to_columns()?; Ok(column_descs @@ -123,45 +139,54 @@ pub async fn handle_create_source( let (mut columns, pk_column_ids, row_id_index) = bind_sql_table_constraints(column_descs, pk_column_id_from_columns, stmt.constraints)?; - let with_properties = context.with_options.inner().clone(); + let mut with_properties = context.with_options.inner().clone(); - let source = match &stmt.source_schema { + let (columns, source_info) = match &stmt.source_schema { SourceSchema::Protobuf(protobuf_schema) => { + // the key is identified with SourceParserImpl::create + const PROTOBUF_MESSAGE_KEY: &str = "proto.message"; + assert_eq!(columns.len(), 1); assert_eq!(pk_column_ids, vec![0.into()]); assert_eq!(row_id_index, Some(0)); - columns.extend(extract_protobuf_table_schema(protobuf_schema)?); - StreamSourceInfo { - properties: with_properties.clone(), - row_format: RowFormatType::Protobuf as i32, - row_schema_location: protobuf_schema.row_schema_location.0.clone(), - row_id_index: row_id_index.map(|index| ProstColumnIndex { index: index as _ }), + + // unlike other formats, there are multiple messages in one file. Will insert a key to + // identify the desired message. + with_properties + .entry(PROTOBUF_MESSAGE_KEY.into()) + .or_insert_with(|| protobuf_schema.message_name.0.clone()); + columns.extend( + extract_protobuf_table_schema(protobuf_schema, with_properties.clone()).await?, + ); + + ( columns, - pk_column_ids: pk_column_ids.into_iter().map(Into::into).collect(), - } + StreamSourceInfo { + row_format: RowFormatType::Protobuf as i32, + row_schema_location: protobuf_schema.row_schema_location.0.clone(), + }, + ) } SourceSchema::Avro(avro_schema) => { assert_eq!(columns.len(), 1); assert_eq!(pk_column_ids, vec![0.into()]); assert_eq!(row_id_index, Some(0)); columns.extend(extract_avro_table_schema(avro_schema, with_properties.clone()).await?); - StreamSourceInfo { - properties: with_properties.clone(), - row_format: RowFormatType::Avro as i32, - row_schema_location: avro_schema.row_schema_location.0.clone(), - row_id_index: row_id_index.map(|index| ProstColumnIndex { index: index as _ }), + ( columns, - pk_column_ids: pk_column_ids.into_iter().map(Into::into).collect(), - } + StreamSourceInfo { + row_format: RowFormatType::Avro as i32, + row_schema_location: avro_schema.row_schema_location.0.clone(), + }, + ) } - SourceSchema::Json => StreamSourceInfo { - properties: with_properties.clone(), - row_format: RowFormatType::Json as i32, - row_schema_location: "".to_string(), - row_id_index: row_id_index.map(|index| ProstColumnIndex { index: index as _ }), + SourceSchema::Json => ( columns, - pk_column_ids: pk_column_ids.into_iter().map(Into::into).collect(), - }, + StreamSourceInfo { + row_format: RowFormatType::Json as i32, + row_schema_location: "".to_string(), + }, + ), SourceSchema::DebeziumJson => { // return err if user has not specified a pk if row_id_index.is_some() { @@ -170,17 +195,19 @@ pub async fn handle_create_source( .to_string(), ))); } - StreamSourceInfo { - properties: with_properties.clone(), - row_format: RowFormatType::DebeziumJson as i32, - row_schema_location: "".to_string(), - row_id_index: row_id_index.map(|index| ProstColumnIndex { index: index as _ }), + ( columns, - pk_column_ids: pk_column_ids.into_iter().map(Into::into).collect(), - } + StreamSourceInfo { + row_format: RowFormatType::DebeziumJson as i32, + row_schema_location: "".to_string(), + }, + ) } }; + let row_id_index = row_id_index.map(|index| ProstColumnIndex { index: index as _ }); + let pk_column_ids = pk_column_ids.into_iter().map(Into::into).collect(); + let session = context.session_ctx.clone(); { let db_name = session.database(); @@ -200,7 +227,15 @@ pub async fn handle_create_source( }; catalog_reader.check_relation_name_duplicated(db_name, &schema_name, &source_name)?; } - let source = make_prost_source(&session, stmt.source_name, Info::StreamSource(source))?; + let source = make_prost_source( + &session, + stmt.source_name, + row_id_index, + columns, + pk_column_ids, + with_properties, + Info::StreamSource(source_info), + )?; let catalog_writer = session.env().catalog_writer(); if is_materialized { let (graph, table) = { diff --git a/src/frontend/src/handler/create_table.rs b/src/frontend/src/handler/create_table.rs index 0dbc8b7a88744..081caf3c40421 100644 --- a/src/frontend/src/handler/create_table.rs +++ b/src/frontend/src/handler/create_table.rs @@ -22,8 +22,7 @@ use risingwave_common::catalog::ColumnDesc; use risingwave_common::error::{ErrorCode, Result}; use risingwave_pb::catalog::source::Info; use risingwave_pb::catalog::{ - ColumnIndex as ProstColumnIndex, Source as ProstSource, StreamSourceInfo, Table as ProstTable, - TableSourceInfo, + ColumnIndex as ProstColumnIndex, Source as ProstSource, Table as ProstTable, TableSourceInfo, }; use risingwave_pb::plan_common::ColumnCatalog as ProstColumnCatalog; use risingwave_sqlparser::ast::{ @@ -207,15 +206,17 @@ pub(crate) fn gen_create_table_plan( let (column_descs, pk_column_id_from_columns) = bind_sql_columns(columns)?; let (columns, pk_column_ids, row_id_index) = bind_sql_table_constraints(column_descs, pk_column_id_from_columns, constraints)?; + let row_id_index = row_id_index.map(|index| ProstColumnIndex { index: index as _ }); + let pk_column_ids = pk_column_ids.into_iter().map(Into::into).collect(); + let properties = context.inner().with_options.inner().clone(); let source = make_prost_source( session, table_name, - Info::TableSource(TableSourceInfo { - row_id_index: row_id_index.map(|index| ProstColumnIndex { index: index as _ }), - columns, - pk_column_ids: pk_column_ids.into_iter().map(Into::into).collect(), - properties: context.inner().with_options.inner().clone(), - }), + row_id_index, + columns, + pk_column_ids, + properties, + Info::TableSource(TableSourceInfo {}), )?; let (plan, table) = gen_materialized_source_plan(context, source.clone(), session.user_id())?; Ok((plan, source, table)) @@ -232,11 +233,7 @@ pub(crate) fn gen_materialized_source_plan( // Manually assemble the materialization plan for the table. let source_node: PlanRef = StreamSource::new(LogicalSource::new(Rc::new((&source).into()), context)).into(); - let row_id_index = { - let (Info::StreamSource(StreamSourceInfo { row_id_index, .. }) - | Info::TableSource(TableSourceInfo { row_id_index, .. })) = source.info.unwrap(); - row_id_index.as_ref().map(|index| index.index as _) - }; + let row_id_index = source.row_id_index.as_ref().map(|index| index.index as _); let mut required_cols = FixedBitSet::with_capacity(source_node.schema().len()); required_cols.toggle_range(..); let mut out_names = source_node.schema().names(); diff --git a/src/frontend/src/handler/query.rs b/src/frontend/src/handler/query.rs index 23f0984a44f9b..d6d43c9485bf2 100644 --- a/src/frontend/src/handler/query.rs +++ b/src/frontend/src/handler/query.rs @@ -32,7 +32,7 @@ use crate::planner::Planner; use crate::scheduler::plan_fragmenter::Query; use crate::scheduler::{ BatchPlanFragmenter, DistributedQueryStream, ExecutionContext, ExecutionContextRef, - LocalQueryExecution, LocalQueryStream, + HummockSnapshotGuard, LocalQueryExecution, LocalQueryStream, }; use crate::session::{OptimizerContext, OptimizerContextRef, SessionImpl}; use crate::PlanRef; @@ -123,19 +123,27 @@ pub async fn handle_query( .map(|f| f.data_type()) .collect_vec(); - let mut row_stream = match query_mode { - QueryMode::Local => PgResponseStream::LocalQuery(DataChunkToRowSetAdapter::new( - local_execute(session.clone(), query).await?, - column_types, - format, - )), - // Local mode do not support cancel tasks. - QueryMode::Distributed => { - PgResponseStream::DistributedQuery(DataChunkToRowSetAdapter::new( - distribute_execute(session.clone(), query).await?, + let mut row_stream = { + // Acquire hummock snapshot for execution. + // TODO: if there's no table scan, we don't need to acquire snapshot. + let hummock_snapshot_manager = session.env().hummock_snapshot_manager(); + let query_id = query.query_id().clone(); + let pinned_snapshot = hummock_snapshot_manager.acquire(&query_id).await?; + + match query_mode { + QueryMode::Local => PgResponseStream::LocalQuery(DataChunkToRowSetAdapter::new( + local_execute(session.clone(), query, pinned_snapshot).await?, column_types, format, - )) + )), + // Local mode do not support cancel tasks. + QueryMode::Distributed => { + PgResponseStream::DistributedQuery(DataChunkToRowSetAdapter::new( + distribute_execute(session.clone(), query, pinned_snapshot).await?, + column_types, + format, + )) + } } }; @@ -203,29 +211,30 @@ fn to_statement_type(stmt: &Statement) -> Result { pub async fn distribute_execute( session: Arc, query: Query, + pinned_snapshot: HummockSnapshotGuard, ) -> Result { let execution_context: ExecutionContextRef = ExecutionContext::new(session.clone()).into(); - let query_manager = execution_context.session().env().query_manager().clone(); + let query_manager = session.env().query_manager().clone(); query_manager - .schedule(execution_context, query) + .schedule(execution_context, query, pinned_snapshot) .await .map_err(|err| err.into()) } -async fn local_execute(session: Arc, query: Query) -> Result { +#[expect(clippy::unused_async)] +async fn local_execute( + session: Arc, + query: Query, + pinned_snapshot: HummockSnapshotGuard, +) -> Result { let front_env = session.env(); - // Acquire hummock snapshot for local execution. - let hummock_snapshot_manager = front_env.hummock_snapshot_manager(); - let query_id = query.query_id().clone(); - let pinned_snapshot = hummock_snapshot_manager.acquire(&query_id).await?; - // TODO: Passing sql here let execution = LocalQueryExecution::new( query, front_env.clone(), "", - pinned_snapshot.snapshot.committed_epoch, + pinned_snapshot, session.auth_context(), ); diff --git a/src/frontend/src/optimizer/plan_node/mod.rs b/src/frontend/src/optimizer/plan_node/mod.rs index f831fa964e42c..bf7a23d0a8366 100644 --- a/src/frontend/src/optimizer/plan_node/mod.rs +++ b/src/frontend/src/optimizer/plan_node/mod.rs @@ -236,6 +236,7 @@ mod predicate_pushdown; pub use predicate_pushdown::*; pub mod generic; +pub mod stream; pub use generic::{PlanAggCall, PlanAggCallDisplay}; diff --git a/src/frontend/src/optimizer/plan_node/stream.rs b/src/frontend/src/optimizer/plan_node/stream.rs new file mode 100644 index 0000000000000..865d118da1d13 --- /dev/null +++ b/src/frontend/src/optimizer/plan_node/stream.rs @@ -0,0 +1,199 @@ +// Copyright 2022 Singularity Data +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use risingwave_common::catalog::Schema; + +use super::{generic, EqJoinPredicate, PlanNodeId}; +use crate::optimizer::property::{Distribution, FunctionalDependencySet}; +use crate::session::OptimizerContextRef; +use crate::utils::Condition; +use crate::{TableCatalog, WithOptions}; + +macro_rules! impl_node { +($base:ident, $($t:ident),*) => { + #[derive(Debug, Clone)] + pub enum Node { + $($t(Box<$t>),)* + } + $( + impl From<$t> for Node { + fn from(o: $t) -> Node { + Node::$t(Box::new(o)) + } + } + )* + pub type PlanOwned = ($base, Node); + pub type PlanRef = std::rc::Rc; +}; +} + +/// Implements [`generic::Join`] with delta join. It requires its two +/// inputs to be indexes. +#[derive(Debug, Clone)] +pub struct DeltaJoin { + pub logical: generic::Join, + + /// The join condition must be equivalent to `logical.on`, but separated into equal and + /// non-equal parts to facilitate execution later + pub eq_join_predicate: EqJoinPredicate, +} + +#[derive(Clone, Debug)] +pub struct DynamicFilter { + /// The predicate (formed with exactly one of < , <=, >, >=) + pub predicate: Condition, + // dist_key_l: Distribution, + pub left_index: usize, + pub left: PlanRef, + pub right: PlanRef, +} + +#[derive(Debug, Clone)] +pub struct Exchange(pub PlanRef); + +#[derive(Debug, Clone)] +pub struct Expand(pub generic::Expand); + +#[derive(Debug, Clone)] +pub struct Filter(pub generic::Filter); + +#[derive(Debug, Clone)] +pub struct GlobalSimpleAgg(pub generic::Agg); + +#[derive(Debug, Clone)] +pub struct GroupTopN { + pub logical: generic::TopN, + /// an optional column index which is the vnode of each row computed by the input's consistent + /// hash distribution + pub vnode_col_idx: Option, +} + +#[derive(Debug, Clone)] +pub struct HashAgg { + /// an optional column index which is the vnode of each row computed by the input's consistent + /// hash distribution + pub vnode_col_idx: Option, + pub logical: generic::Agg, +} + +/// Implements [`generic::Join`] with hash table. It builds a hash table +/// from inner (right-side) relation and probes with data from outer (left-side) relation to +/// get output rows. +#[derive(Debug, Clone)] +pub struct HashJoin { + pub logical: generic::Join, + + /// The join condition must be equivalent to `logical.on`, but separated into equal and + /// non-equal parts to facilitate execution later + pub eq_join_predicate: EqJoinPredicate, + + /// Whether can optimize for append-only stream. + /// It is true if input of both side is append-only + pub is_append_only: bool, +} + +#[derive(Debug, Clone)] +pub struct HopWindow(pub generic::HopWindow); + +/// [`IndexScan`] is a virtual plan node to represent a stream table scan. It will be converted +/// to chain + merge node (for upstream materialize) + batch table scan when converting to `MView` +/// creation request. Compared with [`TableScan`], it will reorder columns, and the chain node +/// doesn't allow rearrange. +#[derive(Debug, Clone)] +pub struct IndexScan { + pub logical: generic::Scan, + pub batch_plan_id: PlanNodeId, +} + +/// Local simple agg. +/// +/// Should only be used for stateless agg, including `sum`, `count` and *append-only* `min`/`max`. +/// +/// The output of `LocalSimpleAgg` doesn't have pk columns, so the result can only +/// be used by `GlobalSimpleAgg` with `ManagedValueState`s. +#[derive(Debug, Clone)] +pub struct LocalSimpleAgg(pub generic::Agg); + +#[derive(Debug, Clone)] +pub struct Materialize { + /// Child of Materialize plan + pub input: PlanRef, + pub table: TableCatalog, +} + +#[derive(Debug, Clone)] +pub struct ProjectSet(pub generic::ProjectSet); + +/// `Project` implements [`super::LogicalProject`] to evaluate specified expressions on input +/// rows. +#[derive(Debug, Clone)] +pub struct Project(pub generic::Project); + +/// [`Sink`] represents a table/connector sink at the very end of the graph. +#[derive(Debug, Clone)] +pub struct Sink { + pub input: PlanRef, + pub properties: WithOptions, +} + +/// [`Source`] represents a table/connector source at the very beginning of the graph. +#[derive(Debug, Clone)] +pub struct Source(pub generic::Source); + +/// `TableScan` is a virtual plan node to represent a stream table scan. It will be converted +/// to chain + merge node (for upstream materialize) + batch table scan when converting to `MView` +/// creation request. +#[derive(Debug, Clone)] +pub struct TableScan { + pub logical: generic::Scan, + pub batch_plan_id: PlanNodeId, +} + +/// `TopN` implements [`super::LogicalTopN`] to find the top N elements with a heap +#[derive(Debug, Clone)] +pub struct TopN(pub generic::TopN); + +#[derive(Clone, Debug)] +pub struct PlanBase { + pub id: PlanNodeId, + pub ctx: OptimizerContextRef, + pub schema: Schema, + pub logical_pk: Vec, + pub dist: Distribution, + pub append_only: bool, + pub functional_dependency: FunctionalDependencySet, +} + +impl_node!( + PlanBase, + Exchange, + DynamicFilter, + DeltaJoin, + Expand, + Filter, + GlobalSimpleAgg, + GroupTopN, + HashAgg, + HashJoin, + HopWindow, + IndexScan, + LocalSimpleAgg, + Materialize, + ProjectSet, + Project, + Sink, + Source, + TableScan, + TopN +); diff --git a/src/frontend/src/optimizer/plan_node/stream_materialize.rs b/src/frontend/src/optimizer/plan_node/stream_materialize.rs index 396dd7035ea78..e76a589934389 100644 --- a/src/frontend/src/optimizer/plan_node/stream_materialize.rs +++ b/src/frontend/src/optimizer/plan_node/stream_materialize.rs @@ -265,6 +265,7 @@ impl StreamNode for StreamMaterialize { .map(FieldOrder::to_protobuf) .collect(), table: Some(self.table().to_internal_table_prost()), + ignore_on_conflict: true, }) } } diff --git a/src/frontend/src/optimizer/plan_node/stream_source.rs b/src/frontend/src/optimizer/plan_node/stream_source.rs index d7e42d14da557..43afda4f28d59 100644 --- a/src/frontend/src/optimizer/plan_node/stream_source.rs +++ b/src/frontend/src/optimizer/plan_node/stream_source.rs @@ -14,6 +14,8 @@ use std::fmt; +use itertools::Itertools; +use risingwave_pb::catalog::ColumnIndex; use risingwave_pb::stream_plan::source_node::Info; use risingwave_pb::stream_plan::stream_node::NodeBody as ProstStreamNode; use risingwave_pb::stream_plan::SourceNode; @@ -72,11 +74,6 @@ impl StreamNode for StreamSource { let source_catalog = self.logical.source_catalog(); ProstStreamNode::Source(SourceNode { source_id: source_catalog.id, - column_ids: source_catalog - .columns - .iter() - .map(|c| c.column_id().into()) - .collect(), state_table: Some( self.logical .infer_internal_table_catalog() @@ -87,6 +84,20 @@ impl StreamNode for StreamSource { SourceCatalogInfo::StreamSource(info) => Info::StreamSource(info.to_owned()), SourceCatalogInfo::TableSource(info) => Info::TableSource(info.to_owned()), }), + row_id_index: source_catalog + .row_id_index + .map(|index| ColumnIndex { index: index as _ }), + columns: source_catalog + .columns + .iter() + .map(|c| c.to_protobuf()) + .collect_vec(), + pk_column_ids: source_catalog + .pk_col_ids + .iter() + .map(Into::into) + .collect_vec(), + properties: source_catalog.properties.clone(), }) } } diff --git a/src/frontend/src/scheduler/distributed/query.rs b/src/frontend/src/scheduler/distributed/query.rs index d94eee6640bff..81a48b872219e 100644 --- a/src/frontend/src/scheduler/distributed/query.rs +++ b/src/frontend/src/scheduler/distributed/query.rs @@ -36,7 +36,7 @@ use crate::scheduler::distributed::StageExecution; use crate::scheduler::plan_fragmenter::{Query, StageId, ROOT_TASK_ID, ROOT_TASK_OUTPUT_ID}; use crate::scheduler::worker_node_manager::WorkerNodeManagerRef; use crate::scheduler::{ - ExecutionContextRef, HummockSnapshotManagerRef, PinnedHummockSnapshot, SchedulerError, + ExecutionContextRef, HummockSnapshotGuard, PinnedHummockSnapshot, SchedulerError, SchedulerResult, }; @@ -115,7 +115,7 @@ impl QueryExecution { &self, context: ExecutionContextRef, worker_node_manager: WorkerNodeManagerRef, - hummock_snapshot_manager: HummockSnapshotManagerRef, + pinned_snapshot: HummockSnapshotGuard, compute_client_pool: ComputeClientPoolRef, catalog_reader: CatalogReader, query_execution_info: QueryExecutionInfoRef, @@ -123,11 +123,6 @@ impl QueryExecution { let mut state = self.state.write().await; let cur_state = mem::replace(&mut *state, QueryState::Failed); - // Acquired a pinned `HummockSnapshot`. - let pinned_snapshot = hummock_snapshot_manager - .acquire(&self.query.query_id) - .await?; - // Because the snapshot may be released before all stages are scheduled, we only pass a // reference of `pinned_snapshot`. Its ownership will be moved into `QueryRunner` so that it // can control when to release the snapshot. @@ -213,7 +208,8 @@ impl QueryExecution { .collect::>>(); let stage_exec = Arc::new(StageExecution::new( - pinned_snapshot.snapshot.committed_epoch, + // TODO: Add support to use current epoch when needed + pinned_snapshot.get_committed_epoch(), self.query.stage_graph.stages[&stage_id].clone(), worker_node_manager.clone(), self.shutdown_tx.clone(), @@ -258,10 +254,7 @@ impl QueryRunner { // thus they all successfully pinned a HummockVersion. // So we can now unpin their epoch. tracing::trace!("Query {:?} has scheduled all of its stages that have table scan (iterator creation).", self.query.query_id); - if let Some(pinned_snapshot) = pinned_snapshot_to_drop { - drop(pinned_snapshot); - pinned_snapshot_to_drop = None; - } + pinned_snapshot_to_drop.take(); } // For root stage, we execute in frontend local. We will pass the root fragment @@ -421,15 +414,17 @@ pub(crate) mod tests { CatalogReader::new(Arc::new(parking_lot::RwLock::new(Catalog::default()))); let query = create_query().await; let query_id = query.query_id().clone(); + let pinned_snapshot = hummock_snapshot_manager.acquire(&query_id).await.unwrap(); let query_execution = Arc::new(QueryExecution::new(query, (0, 0))); let query_execution_info = Arc::new(RwLock::new(QueryExecutionInfo::new_from_map( HashMap::from([(query_id, query_execution.clone())]), ))); + assert!(query_execution .start( ExecutionContext::new(SessionImpl::mock().into()).into(), worker_node_manager, - hummock_snapshot_manager, + pinned_snapshot, compute_client_pool, catalog_reader, query_execution_info, diff --git a/src/frontend/src/scheduler/distributed/query_manager.rs b/src/frontend/src/scheduler/distributed/query_manager.rs index 8264787582d29..c64280707c63f 100644 --- a/src/frontend/src/scheduler/distributed/query_manager.rs +++ b/src/frontend/src/scheduler/distributed/query_manager.rs @@ -33,7 +33,9 @@ use super::QueryExecution; use crate::catalog::catalog_service::CatalogReader; use crate::scheduler::plan_fragmenter::{Query, QueryId}; use crate::scheduler::worker_node_manager::WorkerNodeManagerRef; -use crate::scheduler::{ExecutionContextRef, HummockSnapshotManagerRef, SchedulerResult}; +use crate::scheduler::{ + ExecutionContextRef, HummockSnapshotGuard, HummockSnapshotManagerRef, SchedulerResult, +}; pub struct DistributedQueryStream { chunk_rx: tokio::sync::mpsc::Receiver>, @@ -158,6 +160,7 @@ impl QueryManager { &self, context: ExecutionContextRef, query: Query, + pinned_snapshot: HummockSnapshotGuard, ) -> SchedulerResult { let query_id = query.query_id.clone(); let query_execution = Arc::new(QueryExecution::new(query, context.session().id())); @@ -174,7 +177,7 @@ impl QueryManager { .start( context.clone(), self.worker_node_manager.clone(), - self.hummock_snapshot_manager.clone(), + pinned_snapshot, self.compute_client_pool.clone(), self.catalog_reader.clone(), self.query_execution_info.clone(), diff --git a/src/frontend/src/scheduler/hummock_snapshot_manager.rs b/src/frontend/src/scheduler/hummock_snapshot_manager.rs index 5760bd02b500d..9a24ebf1b931e 100644 --- a/src/frontend/src/scheduler/hummock_snapshot_manager.rs +++ b/src/frontend/src/scheduler/hummock_snapshot_manager.rs @@ -18,6 +18,7 @@ use std::sync::Arc; use std::time::{Duration, Instant}; use anyhow::anyhow; +use arc_swap::ArcSwap; use risingwave_common::util::epoch::INVALID_EPOCH; use risingwave_pb::hummock::HummockSnapshot; use tokio::sync::mpsc::UnboundedSender; @@ -33,20 +34,23 @@ const UNPIN_INTERVAL_SECS: u64 = 10; pub type HummockSnapshotManagerRef = Arc; pub type PinnedHummockSnapshot = HummockSnapshotGuard; +type SnapshotRef = Arc>; + /// Cache of hummock snapshot in meta. pub struct HummockSnapshotManager { /// Send epoch-related operations to `HummockSnapshotManagerCore` for async batch handling. sender: UnboundedSender, + /// The latest snapshot synced from the meta service. + /// /// The `max_committed_epoch` and `max_current_epoch` are pushed from meta node to reduce rpc /// number. - max_committed_epoch: Arc, - + /// /// We have two epoch(committed and current), We only use `committed_epoch` to pin or unpin, /// because `committed_epoch` always less or equal `current_epoch`, and the data with /// `current_epoch` is always in the shared buffer, so it will never be gc before the data /// of `committed_epoch`. - max_current_epoch: Arc, + latest_snapshot: SnapshotRef, } #[derive(Debug)] @@ -63,11 +67,21 @@ enum EpochOperation { } pub struct HummockSnapshotGuard { - pub snapshot: HummockSnapshot, + snapshot: HummockSnapshot, query_id: QueryId, unpin_snapshot_sender: UnboundedSender, } +impl HummockSnapshotGuard { + pub fn get_committed_epoch(&self) -> u64 { + self.snapshot.committed_epoch + } + + pub fn get_current_epoch(&self) -> u64 { + self.snapshot.current_epoch + } +} + impl Drop for HummockSnapshotGuard { fn drop(&mut self) { self.unpin_snapshot_sender @@ -82,16 +96,15 @@ impl Drop for HummockSnapshotGuard { impl HummockSnapshotManager { pub fn new(meta_client: Arc) -> Self { let (sender, mut receiver) = tokio::sync::mpsc::unbounded_channel(); - let max_committed_epoch = Arc::new(AtomicU64::new(INVALID_EPOCH)); - let max_committed_epoch_cloned = max_committed_epoch.clone(); - let max_current_epoch = Arc::new(AtomicU64::new(INVALID_EPOCH)); - let max_current_epoch_cloned = max_current_epoch.clone(); + + let latest_snapshot = Arc::new(ArcSwap::from_pointee(HummockSnapshot { + committed_epoch: INVALID_EPOCH, + current_epoch: INVALID_EPOCH, + })); + let latest_snapshot_cloned = latest_snapshot.clone(); + tokio::spawn(async move { - let mut manager = HummockSnapshotManagerCore::new( - meta_client, - max_committed_epoch_cloned, - max_current_epoch_cloned, - ); + let mut manager = HummockSnapshotManagerCore::new(meta_client, latest_snapshot_cloned); let mut unpin_batches = vec![]; let mut pin_batches = vec![]; let mut unpin_interval = @@ -145,8 +158,7 @@ impl HummockSnapshotManager { }); Self { sender, - max_committed_epoch, - max_current_epoch, + latest_snapshot, } } @@ -173,11 +185,13 @@ impl HummockSnapshotManager { }) } - pub fn update_epoch(&self, epoch: HummockSnapshot) { - self.max_committed_epoch - .fetch_max(epoch.committed_epoch, Ordering::Relaxed); - self.max_current_epoch - .fetch_max(epoch.current_epoch, Ordering::Relaxed); + pub fn update_epoch(&self, snapshot: HummockSnapshot) { + // Note: currently the snapshot is not only updated from the observer, so we need to take + // the `max` here instead of directly replace the snapshot. + self.latest_snapshot.rcu(|prev| HummockSnapshot { + committed_epoch: std::cmp::max(prev.committed_epoch, snapshot.committed_epoch), + current_epoch: std::cmp::max(prev.current_epoch, snapshot.current_epoch), + }); } } @@ -187,23 +201,17 @@ struct HummockSnapshotManagerCore { epoch_to_query_ids: BTreeMap>, meta_client: Arc, last_unpin_snapshot: Arc, - max_committed_epoch: Arc, - max_current_epoch: Arc, + latest_snapshot: SnapshotRef, } impl HummockSnapshotManagerCore { - fn new( - meta_client: Arc, - max_committed_epoch: Arc, - max_current_epoch: Arc, - ) -> Self { + fn new(meta_client: Arc, latest_snapshot: SnapshotRef) -> Self { Self { // Initialize by setting `is_outdated` to `true`. meta_client, epoch_to_query_ids: BTreeMap::default(), last_unpin_snapshot: Arc::new(AtomicU64::new(INVALID_EPOCH)), - max_committed_epoch, - max_current_epoch, + latest_snapshot, } } @@ -241,13 +249,7 @@ impl HummockSnapshotManagerCore { &mut self, batches: &mut Vec<(QueryId, Callback>)>, ) -> HummockSnapshot { - let committed_epoch = self.max_committed_epoch.load(Ordering::Relaxed); - let current_epoch = self.max_current_epoch.load(Ordering::Relaxed); - let snapshot = HummockSnapshot { - committed_epoch, - current_epoch, - }; - + let snapshot = HummockSnapshot::clone(&self.latest_snapshot.load()); self.notify_epoch_assigned_for_queries(&snapshot, batches); snapshot } diff --git a/src/frontend/src/scheduler/local.rs b/src/frontend/src/scheduler/local.rs index ac26d221b1983..fa0619664c268 100644 --- a/src/frontend/src/scheduler/local.rs +++ b/src/frontend/src/scheduler/local.rs @@ -38,6 +38,7 @@ use tracing::debug; use uuid::Uuid; use super::plan_fragmenter::{PartitionInfo, QueryStageRef}; +use super::HummockSnapshotGuard; use crate::optimizer::plan_node::PlanNodeType; use crate::scheduler::plan_fragmenter::{ExecutionPlanNode, Query, StageId}; use crate::scheduler::task_context::FrontendBatchTaskContext; @@ -69,8 +70,8 @@ pub struct LocalQueryExecution { sql: String, query: Query, front_env: FrontendEnv, - epoch: u64, - + // The snapshot will be released when LocalQueryExecution is dropped. + snapshot: HummockSnapshotGuard, auth_context: Arc, } @@ -79,14 +80,14 @@ impl LocalQueryExecution { query: Query, front_env: FrontendEnv, sql: S, - epoch: u64, + snapshot: HummockSnapshotGuard, auth_context: Arc, ) -> Self { Self { sql: sql.into(), query, front_env, - epoch, + snapshot, auth_context, } } @@ -109,7 +110,13 @@ impl LocalQueryExecution { let plan_fragment = self.create_plan_fragment()?; let plan_node = plan_fragment.root.unwrap(); - let executor = ExecutorBuilder::new(&plan_node, &task_id, context, self.epoch); + let executor = ExecutorBuilder::new( + &plan_node, + &task_id, + context, + // TODO: Add support to use current epoch when needed + self.snapshot.get_committed_epoch(), + ); let executor = executor.build().await?; #[for_await] @@ -238,7 +245,8 @@ impl LocalQueryExecution { }; let local_execute_plan = LocalExecutePlan { plan: Some(second_stage_plan_fragment), - epoch: self.epoch, + // TODO: Add support to use current epoch when needed + epoch: self.snapshot.get_committed_epoch(), }; let exchange_source = ExchangeSource { task_output_id: Some(TaskOutputId { @@ -266,8 +274,9 @@ impl LocalQueryExecution { }; let local_execute_plan = LocalExecutePlan { - plan: Some(second_stage_plan_fragment), - epoch: self.epoch, + plan: Some(second_stage_plan_fragment), + // TODO: Add support to use current epoch when needed + epoch: self.snapshot.get_committed_epoch(), }; let workers = if second_stage.parallelism == 1 { diff --git a/src/meta/Cargo.toml b/src/meta/Cargo.toml index e9bc876a7ca1e..e97ca9822dd74 100644 --- a/src/meta/Cargo.toml +++ b/src/meta/Cargo.toml @@ -5,6 +5,7 @@ edition = "2021" [dependencies] anyhow = "1" +arc-swap = "1" assert_matches = "1" async-stream = "0.3" async-trait = "0.1" diff --git a/src/meta/src/hummock/compaction_scheduler.rs b/src/meta/src/hummock/compaction_scheduler.rs index b7db2964dc56b..c3e903b32ff1d 100644 --- a/src/meta/src/hummock/compaction_scheduler.rs +++ b/src/meta/src/hummock/compaction_scheduler.rs @@ -343,7 +343,7 @@ mod tests { // No task let compactor = hummock_manager.get_idle_compactor().await.unwrap(); assert_eq!( - ScheduleStatus::PickFailure, + ScheduleStatus::NoTask, compaction_scheduler .pick_and_assign( StaticCompactionGroupId::StateDefault.into(), diff --git a/src/meta/src/hummock/manager/mod.rs b/src/meta/src/hummock/manager/mod.rs index 017f0b5e829a8..d620f82568941 100644 --- a/src/meta/src/hummock/manager/mod.rs +++ b/src/meta/src/hummock/manager/mod.rs @@ -16,10 +16,10 @@ use std::borrow::{Borrow, BorrowMut}; use std::collections::{BTreeMap, HashMap, HashSet}; use std::ops::Bound::{Excluded, Included}; use std::ops::DerefMut; -use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::{Arc, LazyLock}; use std::time::Instant; +use arc_swap::ArcSwap; use fail::fail_point; use function_name::named; use itertools::Itertools; @@ -76,6 +76,8 @@ use versioning::*; mod compaction; use compaction::*; +type Snapshot = ArcSwap; + // Update to states are performed as follow: // - Initialize ValTransaction for the meta state to update // - Make changes on the ValTransaction. @@ -89,8 +91,7 @@ pub struct HummockManager { // be requested before versioning lock. compaction: MonitoredRwLock, versioning: MonitoredRwLock, - max_committed_epoch: AtomicU64, - max_current_epoch: AtomicU64, + latest_snapshot: Snapshot, metrics: Arc, @@ -223,8 +224,10 @@ where compaction_request_channel: parking_lot::RwLock::new(None), compaction_resume_notifier: parking_lot::RwLock::new(None), compactor_manager, - max_committed_epoch: AtomicU64::new(0), - max_current_epoch: AtomicU64::new(0), + latest_snapshot: ArcSwap::from_pointee(HummockSnapshot { + committed_epoch: INVALID_EPOCH, + current_epoch: INVALID_EPOCH, + }), }; instance.load_meta_store_state().await?; @@ -338,10 +341,13 @@ where redo_state.apply_version_delta(version_delta); } } - self.max_committed_epoch - .store(redo_state.max_committed_epoch, Ordering::Relaxed); - self.max_current_epoch - .fetch_max(redo_state.max_committed_epoch, Ordering::Relaxed); + self.latest_snapshot.store( + HummockSnapshot { + committed_epoch: redo_state.max_committed_epoch, + current_epoch: redo_state.max_committed_epoch, + } + .into(), + ); versioning_guard.current_version = redo_state; versioning_guard.branched_ssts = versioning_guard.current_version.build_branched_sst_info(); @@ -479,8 +485,7 @@ where context_id: HummockContextId, epoch: HummockEpoch, ) -> Result { - let max_committed_epoch = self.max_committed_epoch.load(Ordering::Relaxed); - let max_current_epoch = self.max_current_epoch.load(Ordering::Relaxed); + let snapshot = self.latest_snapshot.load(); let mut guard = write_lock!(self, versioning).await; let mut pinned_snapshots = BTreeMapTransaction::new(&mut guard.pinned_snapshots); let mut context_pinned_snapshot = pinned_snapshots.new_entry_txn_or_default( @@ -490,26 +495,18 @@ where minimal_pinned_snapshot: INVALID_EPOCH, }, ); - let epoch_to_pin = if epoch <= max_committed_epoch { - epoch - } else { - max_committed_epoch - }; + let epoch_to_pin = std::cmp::min(epoch, snapshot.committed_epoch); if context_pinned_snapshot.minimal_pinned_snapshot == INVALID_EPOCH { context_pinned_snapshot.minimal_pinned_snapshot = epoch_to_pin; commit_multi_var!(self, Some(context_id), context_pinned_snapshot)?; } - Ok(HummockSnapshot { - committed_epoch: max_committed_epoch, - current_epoch: max_current_epoch, - }) + Ok(HummockSnapshot::clone(&snapshot)) } /// Make sure `max_committed_epoch` is pinned and return it. #[named] pub async fn pin_snapshot(&self, context_id: HummockContextId) -> Result { - let max_committed_epoch = self.max_committed_epoch.load(Ordering::Relaxed); - let max_current_epoch = self.max_current_epoch.load(Ordering::Relaxed); + let snapshot = self.latest_snapshot.load(); let mut guard = write_lock!(self, versioning).await; let _timer = start_measure_real_process_timer!(self); let mut pinned_snapshots = BTreeMapTransaction::new(&mut guard.pinned_snapshots); @@ -521,23 +518,16 @@ where }, ); if context_pinned_snapshot.minimal_pinned_snapshot == INVALID_EPOCH { - context_pinned_snapshot.minimal_pinned_snapshot = max_committed_epoch; + context_pinned_snapshot.minimal_pinned_snapshot = snapshot.committed_epoch; commit_multi_var!(self, Some(context_id), context_pinned_snapshot)?; trigger_pin_unpin_snapshot_state(&self.metrics, &guard.pinned_snapshots); } - Ok(HummockSnapshot { - committed_epoch: max_committed_epoch, - current_epoch: max_current_epoch, - }) + Ok(HummockSnapshot::clone(&snapshot)) } pub fn get_last_epoch(&self) -> Result { - let max_committed_epoch = self.max_committed_epoch.load(Ordering::Relaxed); - let max_current_epoch = self.max_current_epoch.load(Ordering::Relaxed); - Ok(HummockSnapshot { - committed_epoch: max_committed_epoch, - current_epoch: max_current_epoch, - }) + let snapshot = self.latest_snapshot.load(); + Ok(HummockSnapshot::clone(&snapshot)) } #[named] @@ -644,12 +634,14 @@ where ); commit_multi_var!(self, None, new_compact_status)?; } - let mut compact_status = VarTransaction::new( - compaction - .compaction_statuses - .get_mut(&compaction_group_id) - .ok_or(Error::InvalidCompactionGroup(compaction_group_id))?, - ); + let mut compact_status = match compaction.compaction_statuses.get_mut(&compaction_group_id) + { + Some(c) => VarTransaction::new(c), + None => { + // sync_group has not been called for this group, which means no data even written. + return Ok(None); + } + }; let (current_version, watermark) = { let versioning_guard = read_lock!(self, versioning).await; let max_committed_epoch = versioning_guard.current_version.max_committed_epoch; @@ -661,7 +653,8 @@ where (versioning_guard.current_version.clone(), watermark) }; if current_version.levels.get(&compaction_group_id).is_none() { - return Err(Error::InvalidCompactionGroup(compaction_group_id)); + // sync_group has not been called for this group, which means no data even written. + return Ok(None); } let can_trivial_move = manual_compaction_option.is_none(); let compact_task = compact_status.get_compact_task( @@ -1390,33 +1383,13 @@ where .await; } - // Warn of table_ids that is not found in expected compaction group. - // It indicates: - // 1. Either these table_ids are never registered to any compaction group. This is FATAL - // since compaction filter will remove these valid states incorrectly. - // 2. Or the owners of these table_ids have been dropped, but their stale states are still - // committed. This is OK since compaction filter will remove these stale states - // later. let mut branch_sstables = vec![]; sstables.retain_mut(|(compaction_group_id, sst)| { let is_sst_belong_to_group_declared = match compaction_groups.get(compaction_group_id) { - Some(compaction_group) => { - let mut is_valid = true; - for table_id in sst - .table_ids - .iter() - .filter(|t| !compaction_group.member_table_ids().contains(t)) - { - is_valid = false; - tracing::warn!( - "table {} in SST {} doesn't belong to expected compaction group {}", - table_id, - sst.get_id(), - compaction_group_id - ); - } - is_valid - } + Some(compaction_group) => sst + .table_ids + .iter() + .all(|t| compaction_group.member_table_ids().contains(t)), None => false, }; if !is_sst_belong_to_group_declared { @@ -1521,8 +1494,14 @@ where commit_multi_var!(self, None, new_version_delta)?; branched_ssts.commit_memory(); versioning.current_version = new_hummock_version; - self.max_committed_epoch.store(epoch, Ordering::Release); - self.max_current_epoch.fetch_max(epoch, Ordering::Release); + + let snapshot = HummockSnapshot { + committed_epoch: epoch, + current_epoch: epoch, + }; + let prev_snapshot = self.latest_snapshot.swap(snapshot.clone().into()); + assert!(prev_snapshot.committed_epoch < epoch); + assert!(prev_snapshot.current_epoch < epoch); trigger_version_stat(&self.metrics, &versioning.current_version); for compaction_group_id in &modified_compaction_groups { @@ -1540,10 +1519,7 @@ where .notification_manager() .notify_frontend_asynchronously( Operation::Update, // Frontends don't care about operation. - Info::HummockSnapshot(HummockSnapshot { - committed_epoch: epoch, - current_epoch: self.max_current_epoch.load(Ordering::Relaxed), - }), + Info::HummockSnapshot(snapshot), ); self.env .notification_manager() @@ -1577,17 +1553,19 @@ where /// We don't commit an epoch without checkpoint. We will only update the `max_current_epoch`. pub fn update_current_epoch(&self, max_current_epoch: HummockEpoch) -> Result<()> { // We only update `max_current_epoch`! - let original_epoch = self - .max_current_epoch - .fetch_max(max_current_epoch, Ordering::Release); - assert!(original_epoch < max_current_epoch); + let prev_snapshot = self.latest_snapshot.rcu(|snapshot| HummockSnapshot { + committed_epoch: snapshot.committed_epoch, + current_epoch: max_current_epoch, + }); + assert!(prev_snapshot.current_epoch < max_current_epoch); + tracing::trace!("new current epoch {}", max_current_epoch); self.env .notification_manager() .notify_frontend_asynchronously( Operation::Update, // Frontends don't care about operation. Info::HummockSnapshot(HummockSnapshot { - committed_epoch: self.max_committed_epoch.load(Ordering::Relaxed), + committed_epoch: prev_snapshot.committed_epoch, current_epoch: max_current_epoch, }), ); diff --git a/src/meta/src/hummock/manager/tests.rs b/src/meta/src/hummock/manager/tests.rs index 04eb3a93338b9..1ec401d3444e2 100644 --- a/src/meta/src/hummock/manager/tests.rs +++ b/src/meta/src/hummock/manager/tests.rs @@ -22,9 +22,7 @@ use risingwave_hummock_sdk::compact::compact_task_to_string; use risingwave_hummock_sdk::compaction_group::hummock_version_ext::HummockVersionExt; use risingwave_hummock_sdk::compaction_group::StaticCompactionGroupId; // use risingwave_hummock_sdk::key_range::KeyRange; -use risingwave_hummock_sdk::{ - CompactionGroupId, HummockContextId, HummockEpoch, HummockVersionId, FIRST_VERSION_ID, -}; +use risingwave_hummock_sdk::{HummockContextId, HummockEpoch, HummockVersionId, FIRST_VERSION_ID}; use risingwave_pb::common::{HostAddress, WorkerType}; use risingwave_pb::hummock::compact_task::TaskStatus; use risingwave_pb::hummock::pin_version_response::Payload; @@ -107,23 +105,15 @@ async fn test_unpin_snapshot_before() { #[tokio::test] async fn test_hummock_compaction_task() { - use crate::hummock::error::Error; let (_, hummock_manager, _, worker_node) = setup_compute_env(80).await; let sst_num = 2; // No compaction task available. - let task = hummock_manager + assert!(hummock_manager .get_compact_task(StaticCompactionGroupId::StateDefault.into()) .await - .unwrap_err(); - if let Error::InvalidCompactionGroup(group_id) = task { - assert_eq!( - group_id, - StaticCompactionGroupId::StateDefault as CompactionGroupId - ); - } else { - panic!(); - } + .unwrap() + .is_none()); // Add some sstables and commit. let epoch: u64 = 1; @@ -838,7 +828,7 @@ async fn test_trigger_manual_compaction() { .await; assert_eq!( - "Failed to get compaction task: InvalidCompactionGroup(\n 2,\n) compaction_group 2", + "trigger_manual_compaction No compaction_task is available. compaction_group 2", result.err().unwrap().to_string() ); } @@ -968,18 +958,11 @@ async fn test_hummock_compaction_task_heartbeat() { HummockManager::start_compaction_heartbeat(hummock_manager.clone()).await; // No compaction task available. - let task = hummock_manager + assert!(hummock_manager .get_compact_task(StaticCompactionGroupId::StateDefault.into()) .await - .unwrap_err(); - if let Error::InvalidCompactionGroup(group_id) = task { - assert_eq!( - group_id, - StaticCompactionGroupId::StateDefault as CompactionGroupId - ); - } else { - panic!(); - } + .unwrap() + .is_none()); // Add some sstables and commit. let epoch: u64 = 1; @@ -1093,18 +1076,11 @@ async fn test_hummock_compaction_task_heartbeat_removal_on_node_removal() { HummockManager::start_compaction_heartbeat(hummock_manager.clone()).await; // No compaction task available. - let task = hummock_manager + assert!(hummock_manager .get_compact_task(StaticCompactionGroupId::StateDefault.into()) .await - .unwrap_err(); - if let Error::InvalidCompactionGroup(group_id) = task { - assert_eq!( - group_id, - StaticCompactionGroupId::StateDefault as CompactionGroupId - ); - } else { - panic!(); - } + .unwrap() + .is_none()); // Add some sstables and commit. let epoch: u64 = 1; diff --git a/src/meta/src/manager/streaming_job.rs b/src/meta/src/manager/streaming_job.rs index c95780cef1743..8e602d3beb3aa 100644 --- a/src/meta/src/manager/streaming_job.rs +++ b/src/meta/src/manager/streaming_job.rs @@ -85,6 +85,14 @@ impl StreamingJob { } } + pub fn mview_definition(&self) -> String { + match self { + Self::MaterializedView(table) => table.definition.clone(), + Self::MaterializedSource(_, table) => table.definition.clone(), + _ => "".to_owned(), + } + } + pub fn properties(&self) -> HashMap { match self { Self::MaterializedView(table) => table.properties.clone(), diff --git a/src/meta/src/rpc/service/ddl_service.rs b/src/meta/src/rpc/service/ddl_service.rs index 28f26aed960d5..504d7820596ea 100644 --- a/src/meta/src/rpc/service/ddl_service.rs +++ b/src/meta/src/rpc/service/ddl_service.rs @@ -465,6 +465,7 @@ where schema_id: stream_job.schema_id(), database_id: stream_job.database_id(), mview_name: stream_job.name(), + mview_definition: stream_job.mview_definition(), table_properties: stream_job.properties(), table_sink_map: self .fragment_manager diff --git a/src/meta/src/stream/scheduler.rs b/src/meta/src/stream/scheduler.rs index 23bfe85bf7133..205f2dade0958 100644 --- a/src/meta/src/stream/scheduler.rs +++ b/src/meta/src/stream/scheduler.rs @@ -324,6 +324,7 @@ mod test { upstream_actor_id: vec![], same_worker_node_as_upstream: false, vnode_bitmap: None, + mview_definition: "".to_owned(), }], ..Default::default() }; @@ -350,6 +351,7 @@ mod test { upstream_actor_id: vec![], same_worker_node_as_upstream: false, vnode_bitmap: None, + mview_definition: "".to_owned(), }) .collect_vec(); actor_id += node_count * parallel_degree as u32; diff --git a/src/meta/src/stream/source_manager.rs b/src/meta/src/stream/source_manager.rs index e9e8dadc2a88a..9b7638e5fba21 100644 --- a/src/meta/src/stream/source_manager.rs +++ b/src/meta/src/stream/source_manager.rs @@ -22,11 +22,9 @@ use std::time::Duration; use anyhow::anyhow; use itertools::Itertools; use risingwave_common::catalog::TableId; -use risingwave_common::try_match_expand; use risingwave_connector::source::{ ConnectorProperties, SplitEnumeratorImpl, SplitId, SplitImpl, SplitMetaData, }; -use risingwave_pb::catalog::source::Info; use risingwave_pb::catalog::source::Info::StreamSource; use risingwave_pb::catalog::Source; use risingwave_pb::source::{ConnectorSplit, ConnectorSplits}; @@ -66,12 +64,7 @@ struct ConnectorSourceWorker { impl ConnectorSourceWorker { pub async fn create(source: &Source, period: Duration) -> MetaResult { - let info = source - .info - .clone() - .ok_or_else(|| anyhow!("source info is empty"))?; - let stream_source_info = try_match_expand!(info, Info::StreamSource)?; - let properties = ConnectorProperties::extract(stream_source_info.properties)?; + let properties = ConnectorProperties::extract(source.properties.clone())?; let enumerator = SplitEnumeratorImpl::create(properties).await?; let splits = Arc::new(Mutex::new(SharedSplitMap { splits: None })); Ok(Self { diff --git a/src/meta/src/stream/stream_graph.rs b/src/meta/src/stream/stream_graph.rs index 4621fdd885322..ae0887a85d2ca 100644 --- a/src/meta/src/stream/stream_graph.rs +++ b/src/meta/src/stream/stream_graph.rs @@ -323,6 +323,8 @@ impl StreamActorBuilder { same_worker_node_as_upstream: self.chain_same_worker_node || self.upstreams.values().any(|u| u.same_worker_node), vnode_bitmap: None, + // To be filled by `StreamGraphBuilder::build` + mview_definition: "".to_owned(), } } } @@ -495,6 +497,8 @@ impl StreamGraphBuilder { )?; actor.nodes = Some(stream_node); + actor.mview_definition = ctx.mview_definition.clone(); + graph .entry(builder.get_fragment_id()) .or_default() diff --git a/src/meta/src/stream/stream_manager.rs b/src/meta/src/stream/stream_manager.rs index 6cce0326c41c5..c373aded1d2ac 100644 --- a/src/meta/src/stream/stream_manager.rs +++ b/src/meta/src/stream/stream_manager.rs @@ -73,6 +73,9 @@ pub struct CreateMaterializedViewContext { pub database_id: DatabaseId, /// Name of mview, for internal table name generation. pub mview_name: String, + /// The SQL definition of this materialized view. Used for debugging only. + pub mview_definition: String, + pub table_properties: HashMap, } diff --git a/src/meta/src/stream/test_fragmenter.rs b/src/meta/src/stream/test_fragmenter.rs index 67296792b4c46..d89fc35490d7e 100644 --- a/src/meta/src/stream/test_fragmenter.rs +++ b/src/meta/src/stream/test_fragmenter.rs @@ -15,6 +15,7 @@ use std::collections::{HashMap, HashSet}; use std::vec; +use itertools::Itertools; use risingwave_common::catalog::{DatabaseId, SchemaId, TableId}; use risingwave_pb::catalog::Table as ProstTable; use risingwave_pb::data::data_type::TypeName; @@ -168,12 +169,23 @@ fn make_empty_table(id: u32) -> ProstTable { fn make_stream_fragments() -> Vec { let mut fragments = vec![]; // table source node + let column_ids = vec![1, 2, 0]; + let columns = column_ids + .iter() + .map(|column_id| ColumnCatalog { + column_desc: Some(ColumnDesc { + column_id: *column_id, + ..Default::default() + }), + ..Default::default() + }) + .collect_vec(); let source_node = StreamNode { node_body: Some(NodeBody::Source(SourceNode { source_id: 1, - column_ids: vec![1, 2, 0], state_table: Some(make_source_internal_table(1)), - info: None, + columns, + ..Default::default() })), stream_key: vec![2], ..Default::default() @@ -324,6 +336,7 @@ fn make_stream_fragments() -> Vec { table_id: 1, table: Some(make_internal_table(4, true)), column_orders: vec![make_column_order(1), make_column_order(2)], + ignore_on_conflict: true, })), fields: vec![], // TODO: fill this later operator_id: 7, diff --git a/src/object_store/src/object/mod.rs b/src/object_store/src/object/mod.rs index e281744d96831..efc85e197d389 100644 --- a/src/object_store/src/object/mod.rs +++ b/src/object_store/src/object/mod.rs @@ -524,7 +524,7 @@ impl MonitoredObjectStore { let ret = self .inner .upload(path, obj) - .stack_trace("object_store_upload") + .verbose_stack_trace("object_store_upload") .await; try_update_failure_metric(&self.object_store_metrics, &ret, operation_type); @@ -561,7 +561,7 @@ impl MonitoredObjectStore { let res = self .inner .read(path, block_loc) - .stack_trace("object_store_read") + .verbose_stack_trace("object_store_read") .await .map_err(|err| { ObjectError::internal(format!( @@ -598,7 +598,7 @@ impl MonitoredObjectStore { let res = self .inner .readv(path, block_locs) - .stack_trace("object_store_readv") + .verbose_stack_trace("object_store_readv") .await; try_update_failure_metric(&self.object_store_metrics, &res, operation_type); @@ -636,7 +636,7 @@ impl MonitoredObjectStore { let ret = self .inner .metadata(path) - .stack_trace("object_store_metadata") + .verbose_stack_trace("object_store_metadata") .await; try_update_failure_metric(&self.object_store_metrics, &ret, operation_type); @@ -654,7 +654,7 @@ impl MonitoredObjectStore { let ret = self .inner .delete(path) - .stack_trace("object_store_delete") + .verbose_stack_trace("object_store_delete") .await; try_update_failure_metric(&self.object_store_metrics, &ret, operation_type); @@ -672,7 +672,7 @@ impl MonitoredObjectStore { let ret = self .inner .delete_objects(paths) - .stack_trace("object_store_delete_objects") + .verbose_stack_trace("object_store_delete_objects") .await; try_update_failure_metric(&self.object_store_metrics, &ret, operation_type); @@ -690,7 +690,7 @@ impl MonitoredObjectStore { let ret = self .inner .list(prefix) - .stack_trace("object_store_list") + .verbose_stack_trace("object_store_list") .await; try_update_failure_metric(&self.object_store_metrics, &ret, operation_type); diff --git a/src/risedevtool/src/service_config.rs b/src/risedevtool/src/service_config.rs index 4acdb2a7dddf5..2564a4a814ef7 100644 --- a/src/risedevtool/src/service_config.rs +++ b/src/risedevtool/src/service_config.rs @@ -27,7 +27,7 @@ pub struct ComputeNodeConfig { pub port: u16, pub listen_address: String, pub exporter_port: u16, - pub enable_async_stack_trace: bool, + pub async_stack_trace: String, pub enable_managed_cache: bool, pub enable_tiered_cache: bool, diff --git a/src/risedevtool/src/task/compute_node_service.rs b/src/risedevtool/src/task/compute_node_service.rs index a2464c6535071..36f6e010f70b2 100644 --- a/src/risedevtool/src/task/compute_node_service.rs +++ b/src/risedevtool/src/task/compute_node_service.rs @@ -57,11 +57,9 @@ impl ComputeNodeService { .arg("--client-address") .arg(format!("{}:{}", config.address, config.port)) .arg("--metrics-level") - .arg("1"); - - if config.enable_async_stack_trace { - cmd.arg("--enable-async-stack-trace"); - } + .arg("1") + .arg("--async-stack-trace") + .arg(&config.async_stack_trace); if config.enable_managed_cache { cmd.arg("--enable-managed-cache"); diff --git a/src/source/src/lib.rs b/src/source/src/lib.rs index 17f4c27d864d0..e3cc17f2deb08 100644 --- a/src/source/src/lib.rs +++ b/src/source/src/lib.rs @@ -39,13 +39,13 @@ use crate::connector_source::ConnectorSource; pub mod parser; mod manager; +pub use manager::test_utils as table_test_utils; mod common; pub mod connector_source; pub mod monitor; pub mod row_id; mod table; -pub use table::test_utils as table_test_utils; #[derive(Clone, Debug, PartialEq, Eq)] pub enum SourceFormat { diff --git a/src/source/src/manager.rs b/src/source/src/manager.rs index 6153bb821ecb8..36d717d6a8d5a 100644 --- a/src/source/src/manager.rs +++ b/src/source/src/manager.rs @@ -21,10 +21,11 @@ use parking_lot::Mutex; use risingwave_common::catalog::{ColumnDesc, ColumnId, TableId}; use risingwave_common::error::ErrorCode::{ConnectorError, InternalError, ProtocolError}; use risingwave_common::error::{Result, RwError}; +use risingwave_common::try_match_expand; use risingwave_common::types::DataType; use risingwave_connector::source::ConnectorProperties; -use risingwave_pb::catalog::{StreamSourceInfo, TableSourceInfo}; -use risingwave_pb::plan_common::RowFormatType; +use risingwave_pb::catalog::ColumnIndex as ProstColumnIndex; +use risingwave_pb::plan_common::{ColumnCatalog as ProstColumnCatalog, RowFormatType}; use risingwave_pb::stream_plan::source_node::Info as ProstSourceInfo; use crate::monitor::SourceMetrics; @@ -126,25 +127,24 @@ impl TableSourceManager { pub fn insert_source( &self, - source_id: &TableId, - info: &TableSourceInfo, + source_id: TableId, + row_id_index: Option, + columns: Vec, + pk_column_ids: Vec, ) -> Result { let mut sources = self.sources.lock(); sources.drain_filter(|_, weak_ref| weak_ref.strong_count() == 0); if let Some(strong_ref) = sources - .get(source_id) + .get(&source_id) .and_then(|weak_ref| weak_ref.upgrade()) { Ok(strong_ref) } else { - let columns = info - .columns + let columns = columns .iter() - .cloned() - .map(|c| ColumnDesc::from(c.column_desc.unwrap())) + .map(|c| ColumnDesc::from(c.column_desc.as_ref().unwrap())) .collect_vec(); - let row_id_index = info.row_id_index.as_ref().map(|index| index.index as _); - let pk_column_ids = info.pk_column_ids.clone(); + let row_id_index = row_id_index.map(|index| index.index as _); // Table sources do not need columns and format let strong_ref = Arc::new(SourceDesc { @@ -155,7 +155,7 @@ impl TableSourceManager { pk_column_ids, metrics: self.metrics.clone(), }); - sources.insert(*source_id, Arc::downgrade(&strong_ref)); + sources.insert(source_id, Arc::downgrade(&strong_ref)); Ok(strong_ref) } } @@ -196,40 +196,54 @@ impl TableSourceManager { #[derive(Clone)] pub struct SourceDescBuilder { - id: TableId, + source_id: TableId, + row_id_index: Option, + columns: Vec, + pk_column_ids: Vec, + properties: HashMap, info: ProstSourceInfo, - mgr: TableSourceManagerRef, + source_manager: TableSourceManagerRef, } impl SourceDescBuilder { - pub fn new(id: TableId, info: &ProstSourceInfo, mgr: &TableSourceManagerRef) -> Self { + pub fn new( + source_id: TableId, + row_id_index: Option, + columns: Vec, + pk_column_ids: Vec, + properties: HashMap, + info: ProstSourceInfo, + source_manager: TableSourceManagerRef, + ) -> Self { Self { - id, - info: info.clone(), - mgr: mgr.clone(), + source_id, + row_id_index, + columns, + pk_column_ids, + properties, + info, + source_manager, } } pub async fn build(&self) -> Result { - let Self { id, info, mgr } = self; - match &info { - ProstSourceInfo::TableSource(info) => Self::build_table_source(mgr, id, info), - ProstSourceInfo::StreamSource(info) => Self::build_stream_source(mgr, info).await, + match &self.info { + ProstSourceInfo::TableSource(_) => self.build_table_source(), + ProstSourceInfo::StreamSource(_) => self.build_stream_source().await, } } - fn build_table_source( - mgr: &TableSourceManagerRef, - table_id: &TableId, - info: &TableSourceInfo, - ) -> Result { - mgr.insert_source(table_id, info) + fn build_table_source(&self) -> Result { + self.source_manager.insert_source( + self.source_id, + self.row_id_index.clone(), + self.columns.clone(), + self.pk_column_ids.clone(), + ) } - async fn build_stream_source( - mgr: &TableSourceManagerRef, - info: &StreamSourceInfo, - ) -> Result { + async fn build_stream_source(&self) -> Result { + let info = try_match_expand!(&self.info, ProstSourceInfo::StreamSource).unwrap(); let format = match info.get_row_format()? { RowFormatType::Json => SourceFormat::Json, RowFormatType::Protobuf => SourceFormat::Protobuf, @@ -244,7 +258,7 @@ impl SourceDescBuilder { ))); } let source_parser_rs = - SourceParserImpl::create(&format, &info.properties, info.row_schema_location.as_str()) + SourceParserImpl::create(&format, &self.properties, info.row_schema_location.as_str()) .await; let parser = if let Ok(source_parser) = source_parser_rs { source_parser @@ -252,29 +266,28 @@ impl SourceDescBuilder { return Err(source_parser_rs.err().unwrap()); }; - let mut columns: Vec<_> = info + let mut columns: Vec<_> = self .columns .iter() .map(|c| SourceColumnDesc::from(&ColumnDesc::from(c.column_desc.as_ref().unwrap()))) .collect(); - let row_id_index = info.row_id_index.as_ref().map(|row_id_index| { + let row_id_index = self.row_id_index.as_ref().map(|row_id_index| { columns[row_id_index.index as usize].skip_parse = true; row_id_index.index as usize }); - let pk_column_ids = info.pk_column_ids.clone(); assert!( - !pk_column_ids.is_empty(), + !self.pk_column_ids.is_empty(), "source should have at least one pk column" ); - let config = ConnectorProperties::extract(info.properties.clone()) + let config = ConnectorProperties::extract(self.properties.clone()) .map_err(|e| RwError::from(ConnectorError(e.into())))?; let source = SourceImpl::Connector(ConnectorSource { config, columns: columns.clone(), parser, - connector_message_buffer_size: mgr.msg_buf_size(), + connector_message_buffer_size: self.source_manager.msg_buf_size(), }); Ok(Arc::new(SourceDesc { @@ -282,12 +295,59 @@ impl SourceDescBuilder { format, columns, row_id_index, - pk_column_ids, - metrics: mgr.metrics(), + pk_column_ids: self.pk_column_ids.clone(), + metrics: self.source_manager.metrics(), })) } } +pub mod test_utils { + use risingwave_common::catalog::{ColumnDesc, ColumnId, Schema, TableId}; + use risingwave_pb::catalog::{ColumnIndex, TableSourceInfo}; + use risingwave_pb::plan_common::ColumnCatalog; + use risingwave_pb::stream_plan::source_node::Info as ProstSourceInfo; + + use crate::{SourceDescBuilder, TableSourceManagerRef}; + + pub fn create_table_source_desc_builder( + schema: &Schema, + source_id: TableId, + row_id_index: Option, + pk_column_ids: Vec, + source_manager: TableSourceManagerRef, + ) -> SourceDescBuilder { + let row_id_index = row_id_index.map(|index| ColumnIndex { index }); + let columns = schema + .fields + .iter() + .enumerate() + .map(|(i, f)| ColumnCatalog { + column_desc: Some( + ColumnDesc { + data_type: f.data_type.clone(), + column_id: ColumnId::from(i as i32), // use column index as column id + name: f.name.clone(), + field_descs: vec![], + type_name: "".to_string(), + } + .to_protobuf(), + ), + is_hidden: false, + }) + .collect(); + let info = ProstSourceInfo::TableSource(TableSourceInfo {}); + SourceDescBuilder { + source_id, + row_id_index, + columns, + pk_column_ids, + properties: Default::default(), + info, + source_manager, + } + } +} + #[cfg(test)] mod tests { use std::sync::Arc; @@ -318,19 +378,24 @@ mod tests { is_hidden: false, }) .collect(); + let row_id_index = Some(ColumnIndex { index: 0 }); + let pk_column_ids = vec![0]; let info = StreamSourceInfo { - properties, row_format: 0, row_schema_location: "".to_string(), - row_id_index: Some(ColumnIndex { index: 0 }), - pk_column_ids: vec![0], - columns, }; let source_id = TableId::default(); let mem_source_manager: TableSourceManagerRef = Arc::new(TableSourceManager::default()); - let source_builder = - SourceDescBuilder::new(source_id, &Info::StreamSource(info), &mem_source_manager); + let source_builder = SourceDescBuilder::new( + source_id, + row_id_index, + columns, + pk_column_ids, + properties, + Info::StreamSource(info), + mem_source_manager, + ); let source = source_builder.build().await; assert!(source.is_ok()); @@ -349,35 +414,39 @@ mod tests { ], }; - let info = TableSourceInfo { - row_id_index: None, - columns: schema - .fields - .iter() - .enumerate() - .map(|(i, f)| ColumnCatalog { - column_desc: Some( - ColumnDesc { - data_type: f.data_type.clone(), - column_id: ColumnId::from(i as i32), // use column index as column id - name: f.name.clone(), - field_descs: vec![], - type_name: "".to_string(), - } - .to_protobuf(), - ), - is_hidden: false, - }) - .collect(), - pk_column_ids: vec![1], - properties: Default::default(), - }; + let columns = schema + .fields + .iter() + .enumerate() + .map(|(i, f)| ColumnCatalog { + column_desc: Some( + ColumnDesc { + data_type: f.data_type.clone(), + column_id: ColumnId::from(i as i32), // use column index as column id + name: f.name.clone(), + field_descs: vec![], + type_name: "".to_string(), + } + .to_protobuf(), + ), + is_hidden: false, + }) + .collect(); + let pk_column_ids = vec![1]; + let info = TableSourceInfo {}; let _keyspace = Keyspace::table_root(MemoryStateStore::new(), &table_id); let mem_source_manager: TableSourceManagerRef = Arc::new(TableSourceManager::default()); - let mut source_builder = - SourceDescBuilder::new(table_id, &Info::TableSource(info), &mem_source_manager); + let mut source_builder = SourceDescBuilder::new( + table_id, + None, + columns, + pk_column_ids, + Default::default(), + Info::TableSource(info), + mem_source_manager.clone(), + ); let res = source_builder.build().await; assert!(res.is_ok()); @@ -392,7 +461,7 @@ mod tests { let result = mem_source_manager.get_source(&table_id); assert!(result.is_err()); - source_builder.id = TableId::new(1u32); + source_builder.source_id = TableId::new(1u32); let _new_source = source_builder.build().await; assert_eq!(mem_source_manager.sources.lock().len(), 1); diff --git a/src/source/src/parser/mod.rs b/src/source/src/parser/mod.rs index 85c5bd09724b5..0031a9fa750ee 100644 --- a/src/source/src/parser/mod.rs +++ b/src/source/src/parser/mod.rs @@ -300,7 +300,9 @@ impl SourceParserImpl { PROTOBUF_MESSAGE_KEY ))) })?; - SourceParserImpl::Protobuf(ProtobufParser::new(schema_location, message_name)?) + SourceParserImpl::Protobuf( + ProtobufParser::new(schema_location, message_name, properties.clone()).await?, + ) } SourceFormat::DebeziumJson => SourceParserImpl::DebeziumJson(DebeziumJsonParser), SourceFormat::Avro => { diff --git a/src/source/src/parser/pb_parser.rs b/src/source/src/parser/pb_parser.rs index bb48dfd467653..c02252bd3a0fa 100644 --- a/src/source/src/parser/pb_parser.rs +++ b/src/source/src/parser/pb_parser.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashMap; use std::path::Path; use itertools::Itertools; @@ -20,21 +21,30 @@ use prost_reflect::{ ReflectMessage, Value, }; use risingwave_common::array::{ListValue, StructValue}; -use risingwave_common::error::ErrorCode::{InternalError, NotImplemented, ProtocolError}; +use risingwave_common::error::ErrorCode::{ + InternalError, InvalidConfigValue, NotImplemented, ProtocolError, +}; use risingwave_common::error::{Result, RwError}; use risingwave_common::types::{DataType, Datum, Decimal, OrderedF32, OrderedF64, ScalarImpl}; +use risingwave_connector::aws_utils::{default_conn_config, s3_client, AwsConfigV2}; use risingwave_pb::plan_common::ColumnDesc; use url::Url; use crate::{SourceParser, WriteGuard}; +const PB_SCHEMA_LOCATION_S3_REGION: &str = "region"; + #[derive(Debug, Clone)] pub struct ProtobufParser { pub message_descriptor: MessageDescriptor, } impl ProtobufParser { - pub fn new(location: &str, message_name: &str) -> Result { + pub async fn new( + location: &str, + message_name: &str, + props: HashMap, + ) -> Result { let url = Url::parse(location) .map_err(|e| InternalError(format!("failed to parse url ({}): {}", location, e)))?; @@ -51,12 +61,7 @@ impl ProtobufParser { } Self::local_read_to_bytes(&path) } - "s3" => { - // TODO(tabVersion): Support load from s3. - return Err(RwError::from(ProtocolError( - "s3 schema location is not supported".to_string(), - ))); - } + "s3" => load_bytes_from_s3(&url, props).await, scheme => Err(RwError::from(ProtocolError(format!( "path scheme {} is not supported", scheme @@ -134,6 +139,50 @@ impl ProtobufParser { } } +async fn load_bytes_from_s3( + location: &Url, + properties: HashMap, +) -> Result> { + let bucket = location.domain().ok_or_else(|| { + RwError::from(InternalError(format!( + "Illegal Protobuf schema path {}", + location + ))) + })?; + if properties.get(PB_SCHEMA_LOCATION_S3_REGION).is_none() { + return Err(RwError::from(InvalidConfigValue { + config_entry: PB_SCHEMA_LOCATION_S3_REGION.to_string(), + config_value: "NONE".to_string(), + })); + } + let key = location.path().replace('/', ""); + let config = AwsConfigV2::from(properties.clone()); + let sdk_config = config.load_config(None).await; + let s3_client = s3_client(&sdk_config, Some(default_conn_config())); + let schema_content = s3_client + .get_object() + .bucket(bucket.to_string()) + .key(&key) + .send() + .await; + match schema_content { + Ok(response) => { + let body = response.body.collect().await; + if let Ok(body_bytes) = body { + let schema_bytes = body_bytes.into_bytes().to_vec(); + Ok(schema_bytes) + } else { + let read_schema_err = body.err().unwrap().to_string(); + Err(RwError::from(InternalError(format!( + "Read Protobuf schema file from s3 {}", + read_schema_err + )))) + } + } + Err(err) => Err(RwError::from(InternalError(err.to_string()))), + } +} + fn from_protobuf_value(field_desc: &FieldDescriptor, value: &Value) -> Result { let v = match value { Value::Bool(v) => ScalarImpl::Bool(*v), @@ -289,12 +338,12 @@ mod test { // Date: "2021-01-01" static PRE_GEN_PROTO_DATA: &[u8] = b"\x08\x7b\x12\x0c\x74\x65\x73\x74\x20\x61\x64\x64\x72\x65\x73\x73\x1a\x09\x74\x65\x73\x74\x20\x63\x69\x74\x79\x20\xc8\x03\x2d\x19\x04\x9e\x3f\x32\x0a\x32\x30\x32\x31\x2d\x30\x31\x2d\x30\x31"; - #[test] - fn test_simple_schema() -> Result<()> { + #[tokio::test] + async fn test_simple_schema() -> Result<()> { let location = schema_dir() + "/simple-schema"; let message_name = "test.TestRecord"; println!("location: {}", location); - let parser = ProtobufParser::new(&location, message_name)?; + let parser = ProtobufParser::new(&location, message_name, HashMap::new()).await?; let value = DynamicMessage::decode(parser.message_descriptor, PRE_GEN_PROTO_DATA).unwrap(); assert_eq!( @@ -325,12 +374,12 @@ mod test { Ok(()) } - #[test] - fn test_complex_schema() -> Result<()> { + #[tokio::test] + async fn test_complex_schema() -> Result<()> { let location = schema_dir() + "/complex-schema"; let message_name = "test.User"; - let parser = ProtobufParser::new(&location, message_name)?; + let parser = ProtobufParser::new(&location, message_name, HashMap::new()).await?; let columns = parser.map_to_columns().unwrap(); assert_eq!(columns[0].name, "id".to_string()); diff --git a/src/source/src/table.rs b/src/source/src/table.rs index 5348251366577..1e29b0518cca2 100644 --- a/src/source/src/table.rs +++ b/src/source/src/table.rs @@ -165,43 +165,6 @@ impl TableSource { } } -pub mod test_utils { - use risingwave_common::catalog::{ColumnDesc, ColumnId, Schema}; - use risingwave_pb::catalog::{ColumnIndex, TableSourceInfo}; - use risingwave_pb::plan_common::ColumnCatalog; - use risingwave_pb::stream_plan::source_node::Info as ProstSourceInfo; - - pub fn create_table_info( - schema: &Schema, - row_id_index: Option, - pk_column_ids: Vec, - ) -> ProstSourceInfo { - ProstSourceInfo::TableSource(TableSourceInfo { - row_id_index: row_id_index.map(|index| ColumnIndex { index }), - columns: schema - .fields - .iter() - .enumerate() - .map(|(i, f)| ColumnCatalog { - column_desc: Some( - ColumnDesc { - data_type: f.data_type.clone(), - column_id: ColumnId::from(i as i32), // use column index as column id - name: f.name.clone(), - field_descs: vec![], - type_name: "".to_string(), - } - .to_protobuf(), - ), - is_hidden: false, - }) - .collect(), - pk_column_ids, - properties: Default::default(), - }) - } -} - #[cfg(test)] mod tests { use std::sync::Arc; diff --git a/src/storage/benches/bench_block_iter.rs b/src/storage/benches/bench_block_iter.rs index 0c273bddaadb9..497849958be1a 100644 --- a/src/storage/benches/bench_block_iter.rs +++ b/src/storage/benches/bench_block_iter.rs @@ -116,7 +116,6 @@ fn build_block_data(t: u32, i: u64) -> Bytes { fn key(t: u32, i: u64) -> Bytes { let mut buf = BytesMut::new(); - buf.put_u8(b't'); buf.put_u32(t); buf.put_u64(i); buf.freeze() diff --git a/src/storage/benches/bench_compactor.rs b/src/storage/benches/bench_compactor.rs index 0221e058646de..85e81fef582b2 100644 --- a/src/storage/benches/bench_compactor.rs +++ b/src/storage/benches/bench_compactor.rs @@ -15,6 +15,7 @@ use std::ops::Range; use std::sync::Arc; +use bytes::BufMut; use criterion::async_executor::FuturesExecutor; use criterion::{criterion_group, criterion_main, Criterion}; use risingwave_hummock_sdk::key::key_with_epoch; @@ -63,7 +64,9 @@ pub fn default_writer_opts() -> SstableWriterOptions { } pub fn test_key_of(idx: usize, epoch: u64) -> Vec { - let user_key = format!("key_test_{:08}", idx * 2).as_bytes().to_vec(); + let mut user_key = Vec::new(); + user_key.put_u32(0); + user_key.put_slice(format!("key_test_{:08}", idx * 2).as_bytes()); key_with_epoch(user_key, epoch) } @@ -104,12 +107,9 @@ async fn build_table( .unwrap(); } let output = builder.finish().await.unwrap(); - let runtime = tokio::runtime::Builder::new_current_thread() - .build() - .unwrap(); let handle = output.writer_output; let sst = output.sst_info; - runtime.block_on(handle).unwrap().unwrap(); + handle.await.unwrap().unwrap(); sst } @@ -129,8 +129,11 @@ async fn scan_all_table(info: &SstableInfo, sstable_store: SstableStoreRef) { fn bench_table_build(c: &mut Criterion) { c.bench_function("bench_table_build", |b| { let sstable_store = mock_sstable_store(); - b.iter(|| { - let _ = build_table(sstable_store.clone(), 0, 0..(MAX_KEY_COUNT as u64), 1); + let runtime = tokio::runtime::Builder::new_current_thread() + .build() + .unwrap(); + b.to_async(&runtime).iter(|| async { + build_table(sstable_store.clone(), 0, 0..(MAX_KEY_COUNT as u64), 1).await; }); }); } diff --git a/src/storage/benches/bench_compression.rs b/src/storage/benches/bench_compression.rs index 4c26f1d56afb5..6a16937f239b4 100644 --- a/src/storage/benches/bench_compression.rs +++ b/src/storage/benches/bench_compression.rs @@ -30,7 +30,6 @@ fn gen_dataset(vsize: usize) -> Vec> { let mut v = vec![0; vsize]; rng.fill(&mut v[..]); let mut buf = vec![]; - buf.put_u8(b't'); // prefix buf.put_u32(t); // table id buf.put_u64(i); // key buf.put_u32(0); // cell idx diff --git a/src/storage/benches/bench_multi_builder.rs b/src/storage/benches/bench_multi_builder.rs index bf1ab861927cd..a8945194dae24 100644 --- a/src/storage/benches/bench_multi_builder.rs +++ b/src/storage/benches/bench_multi_builder.rs @@ -19,6 +19,7 @@ use std::sync::atomic::Ordering::SeqCst; use std::sync::Arc; use std::time::Duration; +use bytes::BufMut; use criterion::{criterion_group, criterion_main, Criterion}; use futures::future::try_join_all; use itertools::Itertools; @@ -32,7 +33,7 @@ use risingwave_storage::hummock::{ }; use risingwave_storage::monitor::ObjectStoreMetrics; -const RANGE: Range = 0..2500000; +const RANGE: Range = 0..1500000; const VALUE: &[u8] = &[0; 400]; const SAMPLE_COUNT: usize = 10; const ESTIMATED_MEASUREMENT_TIME: Duration = Duration::from_secs(60); @@ -89,12 +90,19 @@ fn get_builder_options(capacity_mb: usize) -> SstableBuilderOptions { } } +fn test_user_key_of(idx: u64) -> Vec { + let mut user_key = Vec::new(); + user_key.put_u32(0); + user_key.put_u64(idx); + user_key +} + async fn build_tables( mut builder: CapacitySplitTableBuilder>, ) { for i in RANGE { builder - .add_user_key(i.to_be_bytes().to_vec(), HummockValue::put(VALUE), 1) + .add_user_key(test_user_key_of(i), HummockValue::put(VALUE), 1) .await .unwrap(); } diff --git a/src/storage/hummock_sdk/src/filter_key_extractor.rs b/src/storage/hummock_sdk/src/filter_key_extractor.rs index ddce61c3e80f5..e3eb2caef8a3a 100644 --- a/src/storage/hummock_sdk/src/filter_key_extractor.rs +++ b/src/storage/hummock_sdk/src/filter_key_extractor.rs @@ -228,7 +228,7 @@ impl FilterKeyExtractor for MultiFilterKeyExtractor { return full_key; } - let table_id = get_table_id(full_key).unwrap(); + let table_id = get_table_id(full_key); self.id_to_filter_key_extractor .get(&table_id) .unwrap() @@ -496,7 +496,6 @@ mod tests { let table_prefix = { let mut buf = BytesMut::with_capacity(TABLE_PREFIX_LEN); - buf.put_u8(b't'); buf.put_u32(1); buf.to_vec() }; @@ -535,7 +534,6 @@ mod tests { let table_prefix = { let mut buf = BytesMut::with_capacity(TABLE_PREFIX_LEN); - buf.put_u8(b't'); buf.put_u32(1); buf.to_vec() }; @@ -579,7 +577,6 @@ mod tests { let table_prefix = { let mut buf = BytesMut::with_capacity(TABLE_PREFIX_LEN); - buf.put_u8(b't'); buf.put_u32(2); buf.to_vec() }; @@ -615,7 +612,6 @@ mod tests { let table_prefix = { let mut buf = BytesMut::with_capacity(TABLE_PREFIX_LEN); - buf.put_u8(b't'); buf.put_u32(3); buf.to_vec() }; diff --git a/src/storage/hummock_sdk/src/key.rs b/src/storage/hummock_sdk/src/key.rs index 541011b3ae8cd..986c63215323c 100644 --- a/src/storage/hummock_sdk/src/key.rs +++ b/src/storage/hummock_sdk/src/key.rs @@ -21,8 +21,8 @@ use bytes::{Buf, BufMut, BytesMut}; use super::version_cmp::VersionedComparator; use crate::HummockEpoch; -const EPOCH_LEN: usize = std::mem::size_of::(); -pub const TABLE_PREFIX_LEN: usize = 5; +pub const EPOCH_LEN: usize = std::mem::size_of::(); +pub const TABLE_PREFIX_LEN: usize = std::mem::size_of::(); /// Converts user key to full key by appending `u64::MAX - epoch` to the user key. /// @@ -76,24 +76,13 @@ pub fn user_key(full_key: &[u8]) -> &[u8] { /// Extract table id in key prefix #[inline(always)] -pub fn get_table_id(full_key: &[u8]) -> Option { - if full_key[0] == b't' { - let mut buf = &full_key[1..]; - Some(buf.get_u32()) - } else { - None - } +pub fn get_table_id(full_key: &[u8]) -> u32 { + let mut buf = full_key; + buf.get_u32() } -pub fn extract_table_id_and_epoch(full_key: &[u8]) -> (Option, HummockEpoch) { - match get_table_id(full_key) { - Some(table_id) => { - let epoch = get_epoch(full_key); - (Some(table_id), epoch) - } - - None => (None, 0), - } +pub fn extract_table_id_and_epoch(full_key: &[u8]) -> (u32, HummockEpoch) { + (get_table_id(full_key), get_epoch(full_key)) } // Copyright 2016 TiKV Project Authors. Licensed under Apache-2.0. @@ -229,7 +218,6 @@ pub fn prefixed_range>( pub fn table_prefix(table_id: u32) -> Vec { let mut buf = BytesMut::with_capacity(TABLE_PREFIX_LEN); - buf.put_u8(b't'); buf.put_u32(table_id); buf.to_vec() } diff --git a/src/storage/hummock_test/Cargo.toml b/src/storage/hummock_test/Cargo.toml index 1dbf63daf5a84..39f338584e1ab 100644 --- a/src/storage/hummock_test/Cargo.toml +++ b/src/storage/hummock_test/Cargo.toml @@ -28,6 +28,10 @@ workspace-hack = { version = "0.2.0-alpha", path = "../../workspace-hack" } [dev-dependencies] futures = { version = "0.3", default-features = false, features = ["alloc", "executor"] } +risingwave_test_runner = { path = "../../test_runner" } +serial_test = "0.9" +sync-point = { path = "../../utils/sync-point" } [features] failpoints = ["risingwave_storage/failpoints"] +sync_point = ["sync-point/sync_point"] diff --git a/src/storage/hummock_test/src/compactor_tests.rs b/src/storage/hummock_test/src/compactor_tests.rs index dd60b2bfcca97..7801a7f11cd2d 100644 --- a/src/storage/hummock_test/src/compactor_tests.rs +++ b/src/storage/hummock_test/src/compactor_tests.rs @@ -53,7 +53,7 @@ mod tests { use risingwave_storage::store::{ReadOptions, WriteOptions}; use risingwave_storage::{Keyspace, StateStore}; - use crate::test_utils::get_test_notification_client; + use crate::test_utils::{get_test_notification_client, prefixed_key}; async fn get_hummock_storage( hummock_meta_client: Arc, @@ -167,11 +167,11 @@ mod tests { let compact_ctx = get_compactor_context(&storage, &hummock_meta_client); // 1. add sstables - let mut key = b"t".to_vec(); + let mut key = Vec::new(); key.extend_from_slice(&1u32.to_be_bytes()); key.extend_from_slice(&0u64.to_be_bytes()); let key = Bytes::from(key); - let table_id = get_table_id(&key).unwrap(); + let table_id = get_table_id(&key); assert_eq!(table_id, 1); hummock_manager_ref @@ -304,7 +304,7 @@ mod tests { let compact_ctx = get_compactor_context(&storage, &hummock_meta_client); // 1. add sstables with 1MB value - let key = Bytes::from(&b"same_key"[..]); + let key = prefixed_key(Bytes::from(&b"same_key"[..])); let mut val = b"0"[..].repeat(1 << 20); val.extend_from_slice(&128u64.to_be_bytes()); prepare_test_put_data( @@ -322,7 +322,7 @@ mod tests { .await .unwrap() .unwrap(); - let compaction_filter_flag = CompactionFilterFlag::STATE_CLEAN | CompactionFilterFlag::TTL; + let compaction_filter_flag = CompactionFilterFlag::NONE; compact_task.compaction_filter_mask = compaction_filter_flag.bits(); compact_task.current_epoch_time = 0; @@ -689,7 +689,7 @@ mod tests { .unwrap(); let mut scan_count = 0; for (k, _) in scan_result { - let table_id = get_table_id(&k).unwrap(); + let table_id = get_table_id(&k); assert_eq!(table_id, existing_table_ids); scan_count += 1; } @@ -858,7 +858,7 @@ mod tests { .unwrap(); let mut scan_count = 0; for (k, _) in scan_result { - let table_id = get_table_id(&k).unwrap(); + let table_id = get_table_id(&k); assert_eq!(table_id, existing_table_id); scan_count += 1; } @@ -1029,7 +1029,7 @@ mod tests { let mut scan_count = 0; for (k, _) in scan_result { - let table_id = get_table_id(&k).unwrap(); + let table_id = get_table_id(&k); assert_eq!(table_id, existing_table_id); scan_count += 1; } diff --git a/src/storage/hummock_test/src/hummock_storage_tests.rs b/src/storage/hummock_test/src/hummock_storage_tests.rs index be68bc3954299..96c944c1b49c3 100644 --- a/src/storage/hummock_test/src/hummock_storage_tests.rs +++ b/src/storage/hummock_test/src/hummock_storage_tests.rs @@ -13,17 +13,19 @@ // limitations under the License. use std::ops::Bound::{Included, Unbounded}; +use std::sync::atomic::AtomicU64; use std::sync::Arc; +use std::time::Duration; use bytes::Bytes; use parking_lot::RwLock; -use risingwave_hummock_sdk::HummockReadEpoch; +use risingwave_hummock_sdk::HummockEpoch; use risingwave_meta::hummock::test_utils::setup_compute_env; use risingwave_meta::hummock::MockHummockMetaClient; use risingwave_rpc_client::HummockMetaClient; +use risingwave_storage::hummock::conflict_detector::ConflictDetector; use risingwave_storage::hummock::event_handler::HummockEventHandler; use risingwave_storage::hummock::iterator::test_utils::mock_sstable_store; -use risingwave_storage::hummock::local_version::local_version_manager::LocalVersionManager; use risingwave_storage::hummock::store::state_store::HummockStorage; use risingwave_storage::hummock::store::version::HummockReadVersion; use risingwave_storage::hummock::store::{ReadOptions, StateStore}; @@ -32,13 +34,35 @@ use risingwave_storage::storage_value::StorageValue; use risingwave_storage::store::WriteOptions; use risingwave_storage::StateStoreIter; -use crate::test_utils::prepare_local_version_manager_new; +use crate::test_utils::{prefixed_key, prepare_local_version_manager_new}; -async fn try_wait_epoch_for_test(wait_epoch: u64, uploader: Arc) { - uploader - .try_wait_epoch(HummockReadEpoch::Committed(wait_epoch)) - .await - .unwrap() +async fn try_wait_epoch_for_test( + wait_epoch: u64, + version_update_notifier_tx: Arc>, +) { + let mut receiver = version_update_notifier_tx.subscribe(); + let max_committed_epoch = *receiver.borrow(); + if max_committed_epoch >= wait_epoch { + return; + } + + match tokio::time::timeout(Duration::from_secs(1), receiver.changed()).await { + Err(elapsed) => { + panic!( + "wait_epoch {:?} timeout when waiting for version update elapsed {:?}s", + wait_epoch, elapsed + ); + } + Ok(Err(_)) => { + panic!("tx dropped"); + } + Ok(Ok(_)) => { + let max_committed_epoch = *receiver.borrow(); + if max_committed_epoch < wait_epoch { + panic!("max_committed_epoch {:?} update fail", max_committed_epoch); + } + } + } } #[tokio::test] @@ -65,9 +89,27 @@ async fn test_storage_basic() { uploader.get_pinned_version(), ))); + let (version_update_notifier_tx, seal_epoch) = { + let basic_max_committed_epoch = uploader.get_pinned_version().max_committed_epoch(); + let (version_update_notifier_tx, _rx) = + tokio::sync::watch::channel(basic_max_committed_epoch); + + ( + Arc::new(version_update_notifier_tx), + Arc::new(AtomicU64::new(0)), + ) + }; + tokio::spawn( - HummockEventHandler::new(uploader.clone(), event_rx, read_version.clone()) - .start_hummock_event_handler_worker(), + HummockEventHandler::new( + uploader.clone(), + event_rx, + read_version.clone(), + version_update_notifier_tx, + seal_epoch, + ConflictDetector::new_from_config(hummock_options.clone()), + ) + .start_hummock_event_handler_worker(), ); let hummock_storage = HummockStorage::for_test( @@ -415,9 +457,27 @@ async fn test_state_store_sync() { uploader.get_pinned_version(), ))); + let (version_update_notifier_tx, seal_epoch) = { + let basic_max_committed_epoch = uploader.get_pinned_version().max_committed_epoch(); + let (version_update_notifier_tx, _rx) = + tokio::sync::watch::channel(basic_max_committed_epoch); + + ( + Arc::new(version_update_notifier_tx), + Arc::new(AtomicU64::new(0)), + ) + }; + tokio::spawn( - HummockEventHandler::new(uploader.clone(), event_rx, read_version.clone()) - .start_hummock_event_handler_worker(), + HummockEventHandler::new( + uploader.clone(), + event_rx, + read_version.clone(), + version_update_notifier_tx.clone(), + seal_epoch, + ConflictDetector::new_from_config(hummock_options.clone()), + ) + .start_hummock_event_handler_worker(), ); let hummock_storage = HummockStorage::for_test( @@ -431,10 +491,16 @@ async fn test_state_store_sync() { let epoch1: _ = uploader.get_pinned_version().max_committed_epoch() + 1; - // ingest 16B batch + // ingest 26B batch let mut batch1 = vec![ - (Bytes::from("aaaa"), StorageValue::new_put("1111")), - (Bytes::from("bbbb"), StorageValue::new_put("2222")), + ( + prefixed_key(Bytes::from("aaaa")), + StorageValue::new_put("1111"), + ), + ( + prefixed_key(Bytes::from("bbbb")), + StorageValue::new_put("2222"), + ), ]; // Make sure the batch is sorted. @@ -450,11 +516,20 @@ async fn test_state_store_sync() { .await .unwrap(); - // ingest 24B batch + // ingest 39B batch let mut batch2 = vec![ - (Bytes::from("cccc"), StorageValue::new_put("3333")), - (Bytes::from("dddd"), StorageValue::new_put("4444")), - (Bytes::from("eeee"), StorageValue::new_put("5555")), + ( + prefixed_key(Bytes::from("cccc")), + StorageValue::new_put("3333"), + ), + ( + prefixed_key(Bytes::from("dddd")), + StorageValue::new_put("4444"), + ), + ( + prefixed_key(Bytes::from("eeee")), + StorageValue::new_put("5555"), + ), ]; batch2.sort_by(|(k1, _), (k2, _)| k1.cmp(k2)); hummock_storage @@ -470,8 +545,11 @@ async fn test_state_store_sync() { let epoch2 = epoch1 + 1; - // ingest more 8B then will trigger a sync behind the scene - let mut batch3 = vec![(Bytes::from("eeee"), StorageValue::new_put("6666"))]; + // ingest more 13B then will trigger a sync behind the scene + let mut batch3 = vec![( + prefixed_key(Bytes::from("eeee")), + StorageValue::new_put("6666"), + )]; batch3.sort_by(|(k1, _), (k2, _)| k1.cmp(k2)); hummock_storage .ingest_batch( @@ -493,7 +571,7 @@ async fn test_state_store_sync() { .commit_epoch(epoch1, ssts) .await .unwrap(); - try_wait_epoch_for_test(epoch1, uploader.clone()).await; + try_wait_epoch_for_test(epoch1, version_update_notifier_tx.clone()).await; { // after sync 1 epoch let read_version = hummock_storage.read_version(); @@ -513,7 +591,7 @@ async fn test_state_store_sync() { for (k, v) in kv_map { let value = hummock_storage .get( - k.as_bytes(), + &prefixed_key(k.as_bytes()), epoch1, ReadOptions { table_id: Default::default(), @@ -539,7 +617,7 @@ async fn test_state_store_sync() { .commit_epoch(epoch2, ssts) .await .unwrap(); - try_wait_epoch_for_test(epoch2, uploader.clone()).await; + try_wait_epoch_for_test(epoch2, version_update_notifier_tx.clone()).await; { // after sync all epoch let read_version = hummock_storage.read_version(); @@ -559,7 +637,7 @@ async fn test_state_store_sync() { for (k, v) in kv_map { let value = hummock_storage .get( - k.as_bytes(), + &prefixed_key(k.as_bytes()), epoch2, ReadOptions { table_id: Default::default(), @@ -579,7 +657,7 @@ async fn test_state_store_sync() { { let mut iter = hummock_storage .iter( - (Unbounded, Included(b"eeee".to_vec())), + (Unbounded, Included(prefixed_key(b"eeee").to_vec())), epoch1, ReadOptions { table_id: Default::default(), @@ -601,7 +679,7 @@ async fn test_state_store_sync() { for (k, v) in kv_map { let result = iter.next().await.unwrap(); - assert_eq!(result, Some((Bytes::from(k), Bytes::from(v)))); + assert_eq!(result, Some((prefixed_key(Bytes::from(k)), Bytes::from(v)))); } assert!(iter.next().await.unwrap().is_none()); @@ -610,7 +688,7 @@ async fn test_state_store_sync() { { let mut iter = hummock_storage .iter( - (Unbounded, Included(b"eeee".to_vec())), + (Unbounded, Included(prefixed_key(b"eeee").to_vec())), epoch2, ReadOptions { table_id: Default::default(), @@ -632,7 +710,7 @@ async fn test_state_store_sync() { for (k, v) in kv_map { let result = iter.next().await.unwrap(); - assert_eq!(result, Some((Bytes::from(k), Bytes::from(v)))); + assert_eq!(result, Some((prefixed_key(Bytes::from(k)), Bytes::from(v)))); } } } @@ -661,9 +739,27 @@ async fn test_delete_get() { uploader.get_pinned_version(), ))); + let (version_update_notifier_tx, seal_epoch) = { + let basic_max_committed_epoch = uploader.get_pinned_version().max_committed_epoch(); + let (version_update_notifier_tx, _rx) = + tokio::sync::watch::channel(basic_max_committed_epoch); + + ( + Arc::new(version_update_notifier_tx), + Arc::new(AtomicU64::new(0)), + ) + }; + tokio::spawn( - HummockEventHandler::new(uploader.clone(), event_rx, read_version.clone()) - .start_hummock_event_handler_worker(), + HummockEventHandler::new( + uploader.clone(), + event_rx, + read_version.clone(), + version_update_notifier_tx.clone(), + seal_epoch, + ConflictDetector::new_from_config(hummock_options.clone()), + ) + .start_hummock_event_handler_worker(), ); let hummock_storage = HummockStorage::for_test( @@ -678,8 +774,14 @@ async fn test_delete_get() { let initial_epoch = uploader.get_pinned_version().max_committed_epoch(); let epoch1 = initial_epoch + 1; let batch1 = vec![ - (Bytes::from("aa"), StorageValue::new_put("111")), - (Bytes::from("bb"), StorageValue::new_put("222")), + ( + prefixed_key(Bytes::from("aa")), + StorageValue::new_put("111"), + ), + ( + prefixed_key(Bytes::from("bb")), + StorageValue::new_put("222"), + ), ]; hummock_storage .ingest_batch( @@ -701,7 +803,7 @@ async fn test_delete_get() { .await .unwrap(); let epoch2 = initial_epoch + 2; - let batch2 = vec![(Bytes::from("bb"), StorageValue::new_delete())]; + let batch2 = vec![(prefixed_key(Bytes::from("bb")), StorageValue::new_delete())]; hummock_storage .ingest_batch( batch2, @@ -722,10 +824,10 @@ async fn test_delete_get() { .await .unwrap(); - try_wait_epoch_for_test(epoch2, uploader.clone()).await; + try_wait_epoch_for_test(epoch2, version_update_notifier_tx).await; assert!(hummock_storage .get( - "bb".as_bytes(), + &prefixed_key("bb".as_bytes()), epoch2, ReadOptions { prefix_hint: None, @@ -763,9 +865,27 @@ async fn test_multiple_epoch_sync() { uploader.get_pinned_version(), ))); + let (version_update_notifier_tx, seal_epoch) = { + let basic_max_committed_epoch = uploader.get_pinned_version().max_committed_epoch(); + let (version_update_notifier_tx, _rx) = + tokio::sync::watch::channel(basic_max_committed_epoch); + + ( + Arc::new(version_update_notifier_tx), + Arc::new(AtomicU64::new(0)), + ) + }; + tokio::spawn( - HummockEventHandler::new(uploader.clone(), event_rx, read_version.clone()) - .start_hummock_event_handler_worker(), + HummockEventHandler::new( + uploader.clone(), + event_rx, + read_version.clone(), + version_update_notifier_tx.clone(), + seal_epoch, + ConflictDetector::new_from_config(hummock_options.clone()), + ) + .start_hummock_event_handler_worker(), ); let hummock_storage = HummockStorage::for_test( @@ -780,8 +900,14 @@ async fn test_multiple_epoch_sync() { let initial_epoch = uploader.get_pinned_version().max_committed_epoch(); let epoch1 = initial_epoch + 1; let batch1 = vec![ - (Bytes::from("aa"), StorageValue::new_put("111")), - (Bytes::from("bb"), StorageValue::new_put("222")), + ( + prefixed_key(Bytes::from("aa")), + StorageValue::new_put("111"), + ), + ( + prefixed_key(Bytes::from("bb")), + StorageValue::new_put("222"), + ), ]; hummock_storage .ingest_batch( @@ -795,7 +921,7 @@ async fn test_multiple_epoch_sync() { .unwrap(); let epoch2 = initial_epoch + 2; - let batch2 = vec![(Bytes::from("bb"), StorageValue::new_delete())]; + let batch2 = vec![(prefixed_key(Bytes::from("bb")), StorageValue::new_delete())]; hummock_storage .ingest_batch( batch2, @@ -809,8 +935,14 @@ async fn test_multiple_epoch_sync() { let epoch3 = initial_epoch + 3; let batch3 = vec![ - (Bytes::from("aa"), StorageValue::new_put("444")), - (Bytes::from("bb"), StorageValue::new_put("555")), + ( + prefixed_key(Bytes::from("aa")), + StorageValue::new_put("444"), + ), + ( + prefixed_key(Bytes::from("bb")), + StorageValue::new_put("555"), + ), ]; hummock_storage .ingest_batch( @@ -828,7 +960,7 @@ async fn test_multiple_epoch_sync() { assert_eq!( hummock_storage_clone .get( - "bb".as_bytes(), + &prefixed_key("bb".as_bytes()), epoch1, ReadOptions { table_id: Default::default(), @@ -844,7 +976,7 @@ async fn test_multiple_epoch_sync() { ); assert!(hummock_storage_clone .get( - "bb".as_bytes(), + &prefixed_key("bb".as_bytes()), epoch2, ReadOptions { table_id: Default::default(), @@ -859,7 +991,7 @@ async fn test_multiple_epoch_sync() { assert_eq!( hummock_storage_clone .get( - "bb".as_bytes(), + &prefixed_key("bb".as_bytes()), epoch3, ReadOptions { table_id: Default::default(), @@ -888,7 +1020,7 @@ async fn test_multiple_epoch_sync() { .await .unwrap(); - try_wait_epoch_for_test(epoch3, uploader.clone()).await; + try_wait_epoch_for_test(epoch3, version_update_notifier_tx).await; test_get().await; } @@ -916,9 +1048,27 @@ async fn test_iter_with_min_epoch() { uploader.get_pinned_version(), ))); + let (version_update_notifier_tx, seal_epoch) = { + let basic_max_committed_epoch = uploader.get_pinned_version().max_committed_epoch(); + let (version_update_notifier_tx, _rx) = + tokio::sync::watch::channel(basic_max_committed_epoch); + + ( + Arc::new(version_update_notifier_tx), + Arc::new(AtomicU64::new(0)), + ) + }; + tokio::spawn( - HummockEventHandler::new(uploader.clone(), event_rx, read_version.clone()) - .start_hummock_event_handler_worker(), + HummockEventHandler::new( + uploader.clone(), + event_rx, + read_version.clone(), + version_update_notifier_tx.clone(), + seal_epoch, + ConflictDetector::new_from_config(hummock_options.clone()), + ) + .start_hummock_event_handler_worker(), ); let hummock_storage = HummockStorage::for_test( @@ -941,7 +1091,7 @@ async fn test_iter_with_min_epoch() { .into_iter() .map(|index| { ( - Bytes::from(gen_key(index)), + prefixed_key(Bytes::from(gen_key(index))), StorageValue::new_put(gen_val(index)), ) }) @@ -964,7 +1114,7 @@ async fn test_iter_with_min_epoch() { .into_iter() .map(|index| { ( - Bytes::from(gen_key(index)), + prefixed_key(Bytes::from(gen_key(index))), StorageValue::new_put(gen_val(index)), ) }) @@ -1055,7 +1205,7 @@ async fn test_iter_with_min_epoch() { .await .unwrap(); - try_wait_epoch_for_test(epoch2, uploader.clone()).await; + try_wait_epoch_for_test(epoch2, version_update_notifier_tx).await; { let iter = hummock_storage diff --git a/src/storage/hummock_test/src/lib.rs b/src/storage/hummock_test/src/lib.rs index 001f8ae26c794..d0d8f5f632c78 100644 --- a/src/storage/hummock_test/src/lib.rs +++ b/src/storage/hummock_test/src/lib.rs @@ -12,6 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +#![feature(custom_test_frameworks)] +#![test_runner(risingwave_test_runner::test_runner::run_failpont_tests)] + #[cfg(test)] mod compactor_tests; #[cfg(all(test, feature = "failpoints"))] @@ -32,3 +35,5 @@ mod hummock_read_version_tests; #[cfg(test)] mod hummock_storage_tests; +#[cfg(all(test, feature = "sync_point"))] +mod sync_point_tests; diff --git a/src/storage/hummock_test/src/local_version_manager_tests.rs b/src/storage/hummock_test/src/local_version_manager_tests.rs index a0ce7c24d9999..9a77d1484da04 100644 --- a/src/storage/hummock_test/src/local_version_manager_tests.rs +++ b/src/storage/hummock_test/src/local_version_manager_tests.rs @@ -345,7 +345,6 @@ async fn test_update_uncommitted_ssts() { }; assert!(local_version_manager .try_update_pinned_version(Payload::PinnedVersion(version.clone())) - .0 .is_some()); let local_version = local_version_manager.get_local_version(); // Check shared buffer diff --git a/src/storage/hummock_test/src/snapshot_tests.rs b/src/storage/hummock_test/src/snapshot_tests.rs index 706b9a4703946..4a067baa946e0 100644 --- a/src/storage/hummock_test/src/snapshot_tests.rs +++ b/src/storage/hummock_test/src/snapshot_tests.rs @@ -26,12 +26,12 @@ use risingwave_storage::storage_value::StorageValue; use risingwave_storage::store::{ReadOptions, StateStoreIter, WriteOptions}; use risingwave_storage::StateStore; -use crate::test_utils::get_test_notification_client; +use crate::test_utils::{get_test_notification_client, prefixed_key}; macro_rules! assert_count_range_scan { ($storage:expr, $range:expr, $expect_count:expr, $epoch:expr) => {{ let mut it = $storage - .iter::<_, Vec>( + .iter::<_, Bytes>( None, $range, ReadOptions { @@ -56,7 +56,7 @@ macro_rules! assert_count_range_scan { macro_rules! assert_count_backward_range_scan { ($storage:expr, $range:expr, $expect_count:expr, $epoch:expr) => {{ let mut it = $storage - .backward_iter::<_, Vec>( + .backward_iter::<_, Bytes>( $range, ReadOptions { epoch: $epoch, @@ -100,8 +100,14 @@ async fn test_snapshot_inner(enable_sync: bool, enable_commit: bool) { hummock_storage .ingest_batch( vec![ - (Bytes::from("1"), StorageValue::new_put("test")), - (Bytes::from("2"), StorageValue::new_put("test")), + ( + prefixed_key(Bytes::from("1")), + StorageValue::new_put("test"), + ), + ( + prefixed_key(Bytes::from("2")), + StorageValue::new_put("test"), + ), ], WriteOptions { epoch: epoch1, @@ -133,9 +139,15 @@ async fn test_snapshot_inner(enable_sync: bool, enable_commit: bool) { hummock_storage .ingest_batch( vec![ - (Bytes::from("1"), StorageValue::new_delete()), - (Bytes::from("3"), StorageValue::new_put("test")), - (Bytes::from("4"), StorageValue::new_put("test")), + (prefixed_key(Bytes::from("1")), StorageValue::new_delete()), + ( + prefixed_key(Bytes::from("3")), + StorageValue::new_put("test"), + ), + ( + prefixed_key(Bytes::from("4")), + StorageValue::new_put("test"), + ), ], WriteOptions { epoch: epoch2, @@ -168,9 +180,9 @@ async fn test_snapshot_inner(enable_sync: bool, enable_commit: bool) { hummock_storage .ingest_batch( vec![ - (Bytes::from("2"), StorageValue::new_delete()), - (Bytes::from("3"), StorageValue::new_delete()), - (Bytes::from("4"), StorageValue::new_delete()), + (prefixed_key(Bytes::from("2")), StorageValue::new_delete()), + (prefixed_key(Bytes::from("3")), StorageValue::new_delete()), + (prefixed_key(Bytes::from("4")), StorageValue::new_delete()), ], WriteOptions { epoch: epoch3, @@ -224,10 +236,22 @@ async fn test_snapshot_range_scan_inner(enable_sync: bool, enable_commit: bool) hummock_storage .ingest_batch( vec![ - (Bytes::from("1"), StorageValue::new_put("test")), - (Bytes::from("2"), StorageValue::new_put("test")), - (Bytes::from("3"), StorageValue::new_put("test")), - (Bytes::from("4"), StorageValue::new_put("test")), + ( + prefixed_key(Bytes::from("1")), + StorageValue::new_put("test"), + ), + ( + prefixed_key(Bytes::from("2")), + StorageValue::new_put("test"), + ), + ( + prefixed_key(Bytes::from("3")), + StorageValue::new_put("test"), + ), + ( + prefixed_key(Bytes::from("4")), + StorageValue::new_put("test"), + ), ], WriteOptions { epoch, @@ -255,7 +279,7 @@ async fn test_snapshot_range_scan_inner(enable_sync: bool, enable_commit: bool) } macro_rules! key { ($idx:expr) => { - Bytes::from(stringify!($idx)).to_vec() + prefixed_key(Bytes::from(stringify!($idx))) }; } @@ -356,7 +380,7 @@ async fn test_snapshot_backward_range_scan_inner(enable_sync: bool, enable_commi } macro_rules! key { ($idx:expr) => { - Bytes::from(stringify!($idx)).to_vec() + Bytes::from(stringify!($idx)) }; } diff --git a/src/storage/hummock_test/src/state_store_tests.rs b/src/storage/hummock_test/src/state_store_tests.rs index 13eea5cb8d6eb..c28b4710b5543 100644 --- a/src/storage/hummock_test/src/state_store_tests.rs +++ b/src/storage/hummock_test/src/state_store_tests.rs @@ -15,6 +15,7 @@ use std::sync::Arc; use bytes::Bytes; +use risingwave_hummock_sdk::key::{EPOCH_LEN, TABLE_PREFIX_LEN}; use risingwave_hummock_sdk::{HummockEpoch, HummockReadEpoch}; use risingwave_meta::hummock::test_utils::setup_compute_env; use risingwave_meta::hummock::MockHummockMetaClient; @@ -26,7 +27,7 @@ use risingwave_storage::storage_value::StorageValue; use risingwave_storage::store::{ReadOptions, StateStore, WriteOptions}; use risingwave_storage::StateStoreIter; -use crate::test_utils::get_test_notification_client; +use crate::test_utils::{get_test_notification_client, prefixed_key}; #[tokio::test] async fn test_basic() { @@ -47,12 +48,15 @@ async fn test_basic() { .await .unwrap(); - let anchor = Bytes::from("aa"); + let anchor = prefixed_key(Bytes::from("aa")); // First batch inserts the anchor and others. let mut batch1 = vec![ (anchor.clone(), StorageValue::new_put("111")), - (Bytes::from("bb"), StorageValue::new_put("222")), + ( + prefixed_key(Bytes::from("bb")), + StorageValue::new_put("222"), + ), ]; // Make sure the batch is sorted. @@ -60,7 +64,10 @@ async fn test_basic() { // Second batch modifies the anchor. let mut batch2 = vec![ - (Bytes::from("cc"), StorageValue::new_put("333")), + ( + prefixed_key(Bytes::from("cc")), + StorageValue::new_put("333"), + ), (anchor.clone(), StorageValue::new_put("111111")), ]; @@ -69,8 +76,14 @@ async fn test_basic() { // Third batch deletes the anchor let mut batch3 = vec![ - (Bytes::from("dd"), StorageValue::new_put("444")), - (Bytes::from("ee"), StorageValue::new_put("555")), + ( + prefixed_key(Bytes::from("dd")), + StorageValue::new_put("444"), + ), + ( + prefixed_key(Bytes::from("ee")), + StorageValue::new_put("555"), + ), (anchor.clone(), StorageValue::new_delete()), ]; @@ -109,7 +122,7 @@ async fn test_basic() { assert_eq!(value, Bytes::from("111")); let value = hummock_storage .get( - &Bytes::from("bb"), + &prefixed_key(Bytes::from("bb")), true, ReadOptions { epoch: epoch1, @@ -125,7 +138,7 @@ async fn test_basic() { // Test looking for a nonexistent key. `next()` would return the next key. let value = hummock_storage .get( - &Bytes::from("ab"), + &prefixed_key(Bytes::from("ab")), true, ReadOptions { epoch: epoch1, @@ -197,7 +210,7 @@ async fn test_basic() { // Get non-existent maximum key. let value = hummock_storage .get( - &Bytes::from("ff"), + &prefixed_key(Bytes::from("ff")), true, ReadOptions { epoch: epoch3, @@ -213,7 +226,7 @@ async fn test_basic() { let mut iter = hummock_storage .iter( None, - ..=b"ee".to_vec(), + ..=prefixed_key(b"ee"), ReadOptions { epoch: epoch1, table_id: Default::default(), @@ -260,7 +273,7 @@ async fn test_basic() { let mut iter = hummock_storage .iter( None, - ..=b"ee".to_vec(), + ..=prefixed_key(b"ee"), ReadOptions { epoch: epoch2, table_id: Default::default(), @@ -276,7 +289,7 @@ async fn test_basic() { let mut iter = hummock_storage .iter( None, - ..=b"ee".to_vec(), + ..=prefixed_key(b"ee"), ReadOptions { epoch: epoch3, table_id: Default::default(), @@ -299,7 +312,7 @@ async fn test_basic() { .unwrap(); let value = hummock_storage .get( - &Bytes::from("bb"), + &prefixed_key(Bytes::from("bb")), true, ReadOptions { epoch: epoch2, @@ -313,7 +326,7 @@ async fn test_basic() { assert_eq!(value, Bytes::from("222")); let value = hummock_storage .get( - &Bytes::from("dd"), + &prefixed_key(Bytes::from("dd")), true, ReadOptions { epoch: epoch2, @@ -352,10 +365,16 @@ async fn test_state_store_sync() { let mut epoch: HummockEpoch = hummock_storage.get_pinned_version().max_committed_epoch() + 1; - // ingest 16B batch + // ingest 26B batch let mut batch1 = vec![ - (Bytes::from("aaaa"), StorageValue::new_put("1111")), - (Bytes::from("bbbb"), StorageValue::new_put("2222")), + ( + prefixed_key(Bytes::from("aaaa")), + StorageValue::new_put("1111"), + ), + ( + prefixed_key(Bytes::from("bbbb")), + StorageValue::new_put("2222"), + ), ]; // Make sure the batch is sorted. @@ -374,15 +393,24 @@ async fn test_state_store_sync() { // check sync state store metrics // Note: epoch(8B) will be appended to each kv pair assert_eq!( - (16 + (8) * 2) as usize, + (TABLE_PREFIX_LEN * 2 + 16 + (EPOCH_LEN) * 2) as usize, hummock_storage.get_shared_buffer_size() ); - // ingest 24B batch + // ingest 39B batch let mut batch2 = vec![ - (Bytes::from("cccc"), StorageValue::new_put("3333")), - (Bytes::from("dddd"), StorageValue::new_put("4444")), - (Bytes::from("eeee"), StorageValue::new_put("5555")), + ( + prefixed_key(Bytes::from("cccc")), + StorageValue::new_put("3333"), + ), + ( + prefixed_key(Bytes::from("dddd")), + StorageValue::new_put("4444"), + ), + ( + prefixed_key(Bytes::from("eeee")), + StorageValue::new_put("5555"), + ), ]; batch2.sort_by(|(k1, _), (k2, _)| k1.cmp(k2)); hummock_storage @@ -408,7 +436,10 @@ async fn test_state_store_sync() { epoch += 1; // ingest more 8B then will trigger a sync behind the scene - let mut batch3 = vec![(Bytes::from("eeee"), StorageValue::new_put("5555"))]; + let mut batch3 = vec![( + prefixed_key(Bytes::from("eeee")), + StorageValue::new_put("5555"), + )]; batch3.sort_by(|(k1, _), (k2, _)| k1.cmp(k2)); hummock_storage .ingest_batch( @@ -667,7 +698,7 @@ async fn test_write_anytime() { "111".as_bytes(), hummock_storage .get( - "aa".as_bytes(), + &prefixed_key("aa".as_bytes()), true, ReadOptions { epoch, @@ -683,7 +714,7 @@ async fn test_write_anytime() { "222".as_bytes(), hummock_storage .get( - "bb".as_bytes(), + &prefixed_key("bb".as_bytes()), true, ReadOptions { epoch, @@ -699,7 +730,7 @@ async fn test_write_anytime() { "333".as_bytes(), hummock_storage .get( - "cc".as_bytes(), + &prefixed_key("cc".as_bytes()), true, ReadOptions { epoch, @@ -715,7 +746,7 @@ async fn test_write_anytime() { let mut iter = hummock_storage .iter( None, - "aa".as_bytes()..="cc".as_bytes(), + prefixed_key("aa".as_bytes())..=prefixed_key("cc".as_bytes()), ReadOptions { epoch, table_id: Default::default(), @@ -725,15 +756,15 @@ async fn test_write_anytime() { .await .unwrap(); assert_eq!( - (Bytes::from("aa"), Bytes::from("111")), + (prefixed_key(Bytes::from("aa")), Bytes::from("111")), iter.next().await.unwrap().unwrap() ); assert_eq!( - (Bytes::from("bb"), Bytes::from("222")), + (prefixed_key(Bytes::from("bb")), Bytes::from("222")), iter.next().await.unwrap().unwrap() ); assert_eq!( - (Bytes::from("cc"), Bytes::from("333")), + (prefixed_key(Bytes::from("cc")), Bytes::from("333")), iter.next().await.unwrap().unwrap() ); assert!(iter.next().await.unwrap().is_none()); @@ -741,9 +772,18 @@ async fn test_write_anytime() { }; let batch1 = vec![ - (Bytes::from("aa"), StorageValue::new_put("111")), - (Bytes::from("bb"), StorageValue::new_put("222")), - (Bytes::from("cc"), StorageValue::new_put("333")), + ( + prefixed_key(Bytes::from("aa")), + StorageValue::new_put("111"), + ), + ( + prefixed_key(Bytes::from("bb")), + StorageValue::new_put("222"), + ), + ( + prefixed_key(Bytes::from("cc")), + StorageValue::new_put("333"), + ), ]; hummock_storage @@ -766,7 +806,7 @@ async fn test_write_anytime() { "111_new".as_bytes(), hummock_storage .get( - "aa".as_bytes(), + &prefixed_key("aa".as_bytes()), true, ReadOptions { epoch, @@ -780,7 +820,7 @@ async fn test_write_anytime() { ); assert!(hummock_storage .get( - "bb".as_bytes(), + &prefixed_key("bb".as_bytes()), true, ReadOptions { epoch, @@ -795,7 +835,7 @@ async fn test_write_anytime() { "333".as_bytes(), hummock_storage .get( - "cc".as_bytes(), + &prefixed_key("cc".as_bytes()), true, ReadOptions { epoch, @@ -810,7 +850,7 @@ async fn test_write_anytime() { let mut iter = hummock_storage .iter( None, - "aa".as_bytes()..="cc".as_bytes(), + prefixed_key("aa".as_bytes())..=prefixed_key("cc".as_bytes()), ReadOptions { epoch, table_id: Default::default(), @@ -820,11 +860,11 @@ async fn test_write_anytime() { .await .unwrap(); assert_eq!( - (Bytes::from("aa"), Bytes::from("111_new")), + (prefixed_key(Bytes::from("aa")), Bytes::from("111_new")), iter.next().await.unwrap().unwrap() ); assert_eq!( - (Bytes::from("cc"), Bytes::from("333")), + (prefixed_key(Bytes::from("cc")), Bytes::from("333")), iter.next().await.unwrap().unwrap() ); assert!(iter.next().await.unwrap().is_none()); @@ -833,8 +873,11 @@ async fn test_write_anytime() { // Update aa, delete bb, cc unchanged let batch2 = vec![ - (Bytes::from("aa"), StorageValue::new_put("111_new")), - (Bytes::from("bb"), StorageValue::new_delete()), + ( + prefixed_key(Bytes::from("aa")), + StorageValue::new_put("111_new"), + ), + (prefixed_key(Bytes::from("bb")), StorageValue::new_delete()), ]; hummock_storage @@ -911,8 +954,14 @@ async fn test_delete_get() { let initial_epoch = hummock_storage.get_pinned_version().max_committed_epoch(); let epoch1 = initial_epoch + 1; let batch1 = vec![ - (Bytes::from("aa"), StorageValue::new_put("111")), - (Bytes::from("bb"), StorageValue::new_put("222")), + ( + prefixed_key(Bytes::from("aa")), + StorageValue::new_put("111"), + ), + ( + prefixed_key(Bytes::from("bb")), + StorageValue::new_put("222"), + ), ]; hummock_storage .ingest_batch( @@ -931,7 +980,7 @@ async fn test_delete_get() { .uncommitted_ssts; meta_client.commit_epoch(epoch1, ssts).await.unwrap(); let epoch2 = initial_epoch + 2; - let batch2 = vec![(Bytes::from("bb"), StorageValue::new_delete())]; + let batch2 = vec![(prefixed_key(Bytes::from("bb")), StorageValue::new_delete())]; hummock_storage .ingest_batch( batch2, @@ -954,7 +1003,7 @@ async fn test_delete_get() { .unwrap(); assert!(hummock_storage .get( - "bb".as_bytes(), + &prefixed_key("bb".as_bytes()), true, ReadOptions { epoch: epoch2, @@ -990,8 +1039,14 @@ async fn test_multiple_epoch_sync() { let initial_epoch = hummock_storage.get_pinned_version().max_committed_epoch(); let epoch1 = initial_epoch + 1; let batch1 = vec![ - (Bytes::from("aa"), StorageValue::new_put("111")), - (Bytes::from("bb"), StorageValue::new_put("222")), + ( + prefixed_key(Bytes::from("aa")), + StorageValue::new_put("111"), + ), + ( + prefixed_key(Bytes::from("bb")), + StorageValue::new_put("222"), + ), ]; hummock_storage .ingest_batch( @@ -1005,7 +1060,7 @@ async fn test_multiple_epoch_sync() { .unwrap(); let epoch2 = initial_epoch + 2; - let batch2 = vec![(Bytes::from("bb"), StorageValue::new_delete())]; + let batch2 = vec![(prefixed_key(Bytes::from("bb")), StorageValue::new_delete())]; hummock_storage .ingest_batch( batch2, @@ -1019,8 +1074,14 @@ async fn test_multiple_epoch_sync() { let epoch3 = initial_epoch + 3; let batch3 = vec![ - (Bytes::from("aa"), StorageValue::new_put("444")), - (Bytes::from("bb"), StorageValue::new_put("555")), + ( + prefixed_key(Bytes::from("aa")), + StorageValue::new_put("444"), + ), + ( + prefixed_key(Bytes::from("bb")), + StorageValue::new_put("555"), + ), ]; hummock_storage .ingest_batch( @@ -1038,7 +1099,7 @@ async fn test_multiple_epoch_sync() { assert_eq!( hummock_storage_clone .get( - "bb".as_bytes(), + &prefixed_key("bb".as_bytes()), true, ReadOptions { epoch: epoch1, @@ -1053,7 +1114,7 @@ async fn test_multiple_epoch_sync() { ); assert!(hummock_storage_clone .get( - "bb".as_bytes(), + &prefixed_key("bb".as_bytes()), true, ReadOptions { epoch: epoch2, @@ -1067,7 +1128,7 @@ async fn test_multiple_epoch_sync() { assert_eq!( hummock_storage_clone .get( - "bb".as_bytes(), + &prefixed_key("bb".as_bytes()), true, ReadOptions { epoch: epoch3, diff --git a/src/storage/hummock_test/sync_point_unit_test/src/tests.rs b/src/storage/hummock_test/src/sync_point_tests.rs similarity index 96% rename from src/storage/hummock_test/sync_point_unit_test/src/tests.rs rename to src/storage/hummock_test/src/sync_point_tests.rs index 28a4127d89ed1..6a139d359f761 100644 --- a/src/storage/hummock_test/sync_point_unit_test/src/tests.rs +++ b/src/storage/hummock_test/src/sync_point_tests.rs @@ -27,8 +27,7 @@ use serial_test::serial; #[tokio::test] #[serial] -async fn test_sstable_id_manager() { - sync_point::reset(); +async fn test_syncpoints_sstable_id_manager() { let (_env, hummock_manager_ref, _cluster_manager_ref, worker_node) = setup_compute_env(8080).await; let hummock_meta_client: Arc = Arc::new(MockHummockMetaClient::new( @@ -78,8 +77,7 @@ async fn test_sstable_id_manager() { #[cfg(feature = "failpoints")] #[tokio::test] #[serial] -async fn test_failpoints_fetch_ids() { - sync_point::reset(); +async fn test_syncpoints_test_failpoints_fetch_ids() { let (_env, hummock_manager_ref, _cluster_manager_ref, worker_node) = setup_compute_env(8080).await; let hummock_meta_client: Arc = Arc::new(MockHummockMetaClient::new( @@ -132,9 +130,7 @@ async fn test_failpoints_fetch_ids() { #[tokio::test] #[serial] -async fn test_local_notification_receiver() { - sync_point::reset(); - +async fn test_syncpoints_test_local_notification_receiver() { let (env, hummock_manager, _cluster_manager, worker_node) = setup_compute_env(80).await; let context_id = worker_node.id; let (join_handle, shutdown_sender) = start_local_notification_receiver( diff --git a/src/storage/hummock_test/src/test_utils.rs b/src/storage/hummock_test/src/test_utils.rs index c980386d7c13f..1b8426383938e 100644 --- a/src/storage/hummock_test/src/test_utils.rs +++ b/src/storage/hummock_test/src/test_utils.rs @@ -12,8 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::sync::atomic::AtomicU64; use std::sync::Arc; +use bytes::{BufMut, Bytes}; use parking_lot::RwLock; use risingwave_common::config::StorageConfig; use risingwave_common::error::Result; @@ -27,6 +29,7 @@ use risingwave_pb::common::WorkerNode; use risingwave_pb::hummock::pin_version_response; use risingwave_pb::meta::subscribe_response::{Info, Operation}; use risingwave_pb::meta::{MetaSnapshot, SubscribeResponse, SubscribeType}; +use risingwave_storage::hummock::conflict_detector::ConflictDetector; use risingwave_storage::hummock::event_handler::{HummockEvent, HummockEventHandler}; use risingwave_storage::hummock::iterator::test_utils::mock_sstable_store; use risingwave_storage::hummock::local_version::local_version_manager::{ @@ -146,6 +149,19 @@ pub async fn prepare_local_version_manager( event_tx, ); + let (version_update_notifier_tx, seal_epoch) = { + let basic_max_committed_epoch = local_version_manager + .get_pinned_version() + .max_committed_epoch(); + let (version_update_notifier_tx, _rx) = + tokio::sync::watch::channel(basic_max_committed_epoch); + + ( + Arc::new(version_update_notifier_tx), + Arc::new(AtomicU64::new(basic_max_committed_epoch)), + ) + }; + tokio::spawn( HummockEventHandler::new( local_version_manager.clone(), @@ -153,6 +169,9 @@ pub async fn prepare_local_version_manager( Arc::new(RwLock::new(HummockReadVersion::new( local_version_manager.get_pinned_version(), ))), + version_update_notifier_tx, + seal_epoch, + ConflictDetector::new_from_config(opt.clone()), ) .start_hummock_event_handler_worker(), ); @@ -206,3 +225,15 @@ pub async fn prepare_local_version_manager_new( (local_version_manager, event_tx, event_rx) } + +/// Prefix the `key` with a dummy table id. +/// We use `0` because: +/// - This value is used in the code to identify unit tests and prevent some parameters that are not +/// easily constructible in tests from breaking the test. +/// - When calling state store interfaces, we normally pass `TableId::default()`, which is `0`. +pub fn prefixed_key>(key: T) -> Bytes { + let mut buf = Vec::new(); + buf.put_u32(0); + buf.put_slice(key.as_ref()); + buf.into() +} diff --git a/src/storage/hummock_test/sync_point_unit_test/Cargo.toml b/src/storage/hummock_test/sync_point_unit_test/Cargo.toml deleted file mode 100644 index 4473d38ec62de..0000000000000 --- a/src/storage/hummock_test/sync_point_unit_test/Cargo.toml +++ /dev/null @@ -1,30 +0,0 @@ -[package] -name = "sync_point_unit_test" -version = "0.2.0-alpha" -edition = "2021" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -async-trait = "0.1" -bytes = { version = "1" } -fail = "0.5" -futures = { version = "0.3", default-features = false, features = ["alloc"] } -itertools = "0.10" -parking_lot = "0.12" -rand = "0.8" -risingwave_common = { path = "../../../common" } -risingwave_common_service = { path = "../../../common/common_service" } -risingwave_hummock_sdk = { path = "../../hummock_sdk" } -risingwave_meta = { path = "../../../meta", features = ["test"] } -risingwave_pb = { path = "../../../prost" } -risingwave_rpc_client = { path = "../../../rpc_client" } -risingwave_storage = { path = "../..", features = ["test"] } -tokio = { version = "0.2", package = "madsim-tokio" } - -[dev-dependencies] -serial_test = "0.9" -sync-point = { path = "../../../utils/sync-point" } - -[features] -sync_point = ["sync-point/sync_point"] -failpoints = ["risingwave_storage/failpoints"] diff --git a/src/storage/hummock_test/sync_point_unit_test/src/lib.rs b/src/storage/hummock_test/sync_point_unit_test/src/lib.rs deleted file mode 100644 index f88e74ee525a7..0000000000000 --- a/src/storage/hummock_test/sync_point_unit_test/src/lib.rs +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright 2022 Singularity Data -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#[cfg(all(test, feature = "sync_point"))] -mod tests; diff --git a/src/storage/src/hummock/block_cache.rs b/src/storage/src/hummock/block_cache.rs index 3763c5c7ca88d..1df1f8545fe45 100644 --- a/src/storage/src/hummock/block_cache.rs +++ b/src/storage/src/hummock/block_cache.rs @@ -168,7 +168,7 @@ impl BlockCache { Ok((block, len)) } }) - .stack_trace("block_cache_lookup") + .verbose_stack_trace("block_cache_lookup") .await { // Return when meet IO error, or retry again. Because this error may be caused by diff --git a/src/storage/src/hummock/compactor/compaction_filter.rs b/src/storage/src/hummock/compactor/compaction_filter.rs index b51f5405f7645..3707f6f892ecc 100644 --- a/src/storage/src/hummock/compactor/compaction_filter.rs +++ b/src/storage/src/hummock/compactor/compaction_filter.rs @@ -48,20 +48,15 @@ impl StateCleanUpCompactionFilter { impl CompactionFilter for StateCleanUpCompactionFilter { fn should_delete(&mut self, key: &[u8]) -> bool { - let table_id_option = get_table_id(key); - match table_id_option { - None => false, - Some(table_id) => { - if let Some((last_table_id, removed)) = self.last_table.as_ref() { - if *last_table_id == table_id { - return *removed; - } - } - let removed = !self.existing_table_ids.contains(&table_id); - self.last_table = Some((table_id, removed)); - removed + let table_id = get_table_id(key); + if let Some((last_table_id, removed)) = self.last_table.as_ref() { + if *last_table_id == table_id { + return *removed; } } + let removed = !self.existing_table_ids.contains(&table_id); + self.last_table = Some((table_id, removed)); + removed } } @@ -76,25 +71,20 @@ impl CompactionFilter for TtlCompactionFilter { fn should_delete(&mut self, key: &[u8]) -> bool { pub use risingwave_common::util::epoch::Epoch; let (table_id, epoch) = extract_table_id_and_epoch(key); - match table_id { - Some(table_id) => { - if let Some((last_table_id, ttl_mill)) = self.last_table_and_ttl.as_ref() { - if *last_table_id == table_id { - let min_epoch = Epoch(self.expire_epoch).subtract_ms(*ttl_mill); - return Epoch(epoch) <= min_epoch; - } - } - match self.table_id_to_ttl.get(&table_id) { - Some(ttl_second_u32) => { - assert!(*ttl_second_u32 != TABLE_OPTION_DUMMY_RETENTION_SECOND); - // default to zero. - let ttl_mill = (*ttl_second_u32 * 1000) as u64; - let min_epoch = Epoch(self.expire_epoch).subtract_ms(ttl_mill); - self.last_table_and_ttl = Some((table_id, ttl_mill)); - Epoch(epoch) <= min_epoch - } - None => false, - } + if let Some((last_table_id, ttl_mill)) = self.last_table_and_ttl.as_ref() { + if *last_table_id == table_id { + let min_epoch = Epoch(self.expire_epoch).subtract_ms(*ttl_mill); + return Epoch(epoch) <= min_epoch; + } + } + match self.table_id_to_ttl.get(&table_id) { + Some(ttl_second_u32) => { + assert!(*ttl_second_u32 != TABLE_OPTION_DUMMY_RETENTION_SECOND); + // default to zero. + let ttl_mill = (*ttl_second_u32 * 1000) as u64; + let min_epoch = Epoch(self.expire_epoch).subtract_ms(ttl_mill); + self.last_table_and_ttl = Some((table_id, ttl_mill)); + Epoch(epoch) <= min_epoch } None => false, } diff --git a/src/storage/src/hummock/event_handler/hummock_event_handler.rs b/src/storage/src/hummock/event_handler/hummock_event_handler.rs index ba0f1bf9309f0..990f0416aef03 100644 --- a/src/storage/src/hummock/event_handler/hummock_event_handler.rs +++ b/src/storage/src/hummock/event_handler/hummock_event_handler.rs @@ -14,24 +14,28 @@ use std::collections::HashMap; use std::iter::once; -use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; use std::sync::Arc; use futures::future::{select, try_join_all, Either}; use futures::FutureExt; use itertools::Itertools; use parking_lot::RwLock; +use risingwave_hummock_sdk::compaction_group::hummock_version_ext::HummockVersionExt; use risingwave_hummock_sdk::HummockEpoch; use risingwave_pb::hummock::pin_version_response::Payload; use tokio::sync::{mpsc, oneshot}; use tracing::{error, info}; +use crate::hummock::conflict_detector::ConflictDetector; use crate::hummock::event_handler::HummockEvent; use crate::hummock::local_version::local_version_manager::LocalVersionManager; +use crate::hummock::local_version::pinned_version::PinnedVersion; use crate::hummock::local_version::upload_handle_manager::UploadHandleManager; use crate::hummock::local_version::SyncUncommittedDataStage; use crate::hummock::store::memtable::ImmutableMemtable; use crate::hummock::store::version::{HummockReadVersion, VersionUpdate}; +use crate::hummock::utils::validate_table_key_range; use crate::hummock::{HummockError, HummockResult, MemoryLimiter, SstableIdManagerRef, TrackerId}; use crate::store::SyncResult; @@ -82,7 +86,6 @@ impl BufferTracker { } pub struct HummockEventHandler { - local_version_manager: Arc, buffer_tracker: BufferTracker, sstable_id_manager: SstableIdManagerRef, shared_buffer_event_receiver: mpsc::UnboundedReceiver, @@ -91,6 +94,13 @@ pub struct HummockEventHandler { // TODO: replace it with hashmap read_version: Arc>, + + version_update_notifier_tx: Arc>, + seal_epoch: Arc, + pinned_version: PinnedVersion, + write_conflict_detector: Option>, + + local_version_manager: Arc, } impl HummockEventHandler { @@ -98,15 +108,22 @@ impl HummockEventHandler { local_version_manager: Arc, shared_buffer_event_receiver: mpsc::UnboundedReceiver, read_version: Arc>, + version_update_notifier_tx: Arc>, + seal_epoch: Arc, + write_conflict_detector: Option>, ) -> Self { Self { buffer_tracker: local_version_manager.buffer_tracker().clone(), sstable_id_manager: local_version_manager.sstable_id_manager(), - local_version_manager, shared_buffer_event_receiver, upload_handle_manager: UploadHandleManager::new(), pending_sync_requests: Default::default(), read_version, + version_update_notifier_tx, + seal_epoch, + pinned_version: local_version_manager.get_pinned_version(), + write_conflict_detector, + local_version_manager, } } @@ -294,24 +311,55 @@ impl HummockEventHandler { notifier.send(()).unwrap(); } - fn handle_version_update(&self, version_payload: Payload) { - if let (Some(new_version), mce_change) = self - .local_version_manager - .try_update_pinned_version(version_payload) - { - let new_version_id = new_version.id(); - // update the read_version of hummock instance - self.read_version - .write() - .update(VersionUpdate::CommittedSnapshot(new_version)); - - if mce_change { - // only notify local_version_manager when MCE change - // TODO: use MCE to replace new_version_id - self.local_version_manager - .notify_version_id_to_worker_context(new_version_id); + fn handle_version_update(&mut self, version_payload: Payload) { + let prev_max_committed_epoch = self.pinned_version.max_committed_epoch(); + // TODO: after local version manager is removed, we can match version_payload directly + // instead of taking a reference + let newly_pinned_version = match &version_payload { + Payload::VersionDeltas(version_deltas) => { + let mut version_to_apply = self.pinned_version.version(); + for version_delta in &version_deltas.version_deltas { + assert_eq!(version_to_apply.id, version_delta.prev_id); + version_to_apply.apply_version_delta(version_delta); + } + version_to_apply + } + Payload::PinnedVersion(version) => version.clone(), + }; + + validate_table_key_range(&newly_pinned_version); + + self.pinned_version = self.pinned_version.new_pin_version(newly_pinned_version); + + self.read_version + .write() + .update(VersionUpdate::CommittedSnapshot( + self.pinned_version.clone(), + )); + + let max_committed_epoch = self.pinned_version.max_committed_epoch(); + + // only notify local_version_manager when MCE change + self.version_update_notifier_tx.send_if_modified(|state| { + assert_eq!(prev_max_committed_epoch, *state); + if max_committed_epoch > *state { + *state = max_committed_epoch; + true + } else { + false } + }); + + if let Some(conflict_detector) = self.write_conflict_detector.as_ref() { + conflict_detector.set_watermark(self.pinned_version.max_committed_epoch()); } + self.sstable_id_manager + .remove_watermark_sst_id(TrackerId::Epoch(self.pinned_version.max_committed_epoch())); + + // this is only for clear the committed data in local version + // TODO: remove it + self.local_version_manager + .try_update_pinned_version(version_payload); } fn handle_imm_to_uploader(&self, imm: ImmutableMemtable) { @@ -369,11 +417,14 @@ impl HummockEventHandler { HummockEvent::SealEpoch { epoch, is_checkpoint, - } => self - .local_version_manager - .local_version - .write() - .seal_epoch(epoch, is_checkpoint), + } => { + self.local_version_manager + .local_version + .write() + .seal_epoch(epoch, is_checkpoint); + + self.seal_epoch.store(epoch, Ordering::SeqCst); + } }, Either::Right(None) => { break; diff --git a/src/storage/src/hummock/iterator/backward_user.rs b/src/storage/src/hummock/iterator/backward_user.rs index f558e0331a5fc..aeeb8cbfb4ef1 100644 --- a/src/storage/src/hummock/iterator/backward_user.rs +++ b/src/storage/src/hummock/iterator/backward_user.rs @@ -333,7 +333,7 @@ mod tests { }; use crate::hummock::iterator::HummockIteratorUnion; use crate::hummock::sstable::Sstable; - use crate::hummock::test_utils::{create_small_table_cache, gen_test_sstable}; + use crate::hummock::test_utils::{create_small_table_cache, gen_test_sstable, prefixed_key}; use crate::hummock::value::HummockValue; use crate::hummock::{BackwardSstableIterator, SstableStoreRef}; @@ -883,9 +883,7 @@ mod tests { fn key_from_num(num: usize) -> Vec { let width = 20; - format!("{:0width$}", num, width = width) - .as_bytes() - .to_vec() + prefixed_key(format!("{:0width$}", num, width = width).as_bytes()).to_vec() } async fn chaos_test_case( diff --git a/src/storage/src/hummock/iterator/test_utils.rs b/src/storage/src/hummock/iterator/test_utils.rs index 1655aa2da731f..224a10be7b8b7 100644 --- a/src/storage/src/hummock/iterator/test_utils.rs +++ b/src/storage/src/hummock/iterator/test_utils.rs @@ -15,6 +15,7 @@ use std::iter::Iterator; use std::sync::Arc; +use bytes::BufMut; use risingwave_hummock_sdk::key::key_with_epoch; use risingwave_hummock_sdk::{HummockEpoch, HummockSstableId}; use risingwave_object_store::object::{ @@ -67,14 +68,21 @@ pub fn mock_sstable_store_with_object_store(store: ObjectStoreRef) -> SstableSto )) } -/// Generates keys like `key_test_00002` with epoch 233. +fn test_user_key_of(idx: usize) -> Vec { + let mut user_key = Vec::new(); + user_key.put_u32(0); + user_key.put_slice(format!("key_test_{:05}", idx).as_bytes()); + user_key +} + +/// Generates keys like `{table_id=0}key_test_00002` with epoch 233. pub fn iterator_test_key_of(idx: usize) -> Vec { - key_with_epoch(format!("key_test_{:05}", idx).as_bytes().to_vec(), 233) + key_with_epoch(test_user_key_of(idx), 233) } -/// Generates keys like `key_test_00002` with epoch `epoch` . +/// Generates keys like `{table_id=0}key_test_00002` with epoch `epoch` . pub fn iterator_test_key_of_epoch(idx: usize, epoch: HummockEpoch) -> Vec { - key_with_epoch(format!("key_test_{:05}", idx).as_bytes().to_vec(), epoch) + key_with_epoch(test_user_key_of(idx), epoch) } /// The value of an index, like `value_test_00002` without value meta diff --git a/src/storage/src/hummock/local_version/local_version_manager.rs b/src/storage/src/hummock/local_version/local_version_manager.rs index bbd85e57d836c..acac6bbc4c8fc 100644 --- a/src/storage/src/hummock/local_version/local_version_manager.rs +++ b/src/storage/src/hummock/local_version/local_version_manager.rs @@ -15,7 +15,6 @@ use std::collections::HashMap; use std::ops::RangeBounds; use std::sync::Arc; -use std::time::Duration; use bytes::Bytes; use parking_lot::{RwLock, RwLockWriteGuard}; @@ -24,7 +23,7 @@ use risingwave_common::config::StorageConfig; use risingwave_hummock_sdk::compaction_group::hummock_version_ext::HummockVersionExt; #[cfg(any(test, feature = "test"))] use risingwave_hummock_sdk::filter_key_extractor::FilterKeyExtractorManager; -use risingwave_hummock_sdk::{CompactionGroupId, HummockReadEpoch}; +use risingwave_hummock_sdk::CompactionGroupId; use risingwave_pb::hummock::pin_version_response; use risingwave_pb::hummock::pin_version_response::Payload; #[cfg(any(test, feature = "test"))] @@ -34,7 +33,6 @@ use tokio::sync::oneshot; use tokio::task::JoinHandle; use tracing::{error, info}; -use crate::hummock::conflict_detector::ConflictDetector; use crate::hummock::event_handler::{BufferTracker, HummockEvent}; use crate::hummock::local_version::pinned_version::PinnedVersion; use crate::hummock::local_version::{LocalVersion, ReadVersion}; @@ -46,25 +44,12 @@ use crate::hummock::shared_buffer::OrderIndex; #[cfg(any(test, feature = "test"))] use crate::hummock::sstable_store::SstableStoreRef; use crate::hummock::utils::validate_table_key_range; -use crate::hummock::{ - HummockEpoch, HummockError, HummockResult, HummockVersionId, MemoryLimiter, - SstableIdManagerRef, TrackerId, INVALID_VERSION_ID, -}; +use crate::hummock::{HummockEpoch, HummockResult, MemoryLimiter, SstableIdManagerRef, TrackerId}; #[cfg(any(test, feature = "test"))] use crate::monitor::StateStoreMetrics; use crate::storage_value::StorageValue; use crate::store::SyncResult; -struct WorkerContext { - version_update_notifier_tx: tokio::sync::watch::Sender, -} - -impl WorkerContext { - pub fn notify_version_id(&self, version_id: HummockVersionId) { - self.version_update_notifier_tx.send(version_id).ok(); - } -} - pub type LocalVersionManagerRef = Arc; /// The `LocalVersionManager` maintains a local copy of storage service's hummock version data. @@ -73,9 +58,7 @@ pub type LocalVersionManagerRef = Arc; /// versions in storage service. pub struct LocalVersionManager { pub(crate) local_version: RwLock, - worker_context: WorkerContext, buffer_tracker: BufferTracker, - write_conflict_detector: Option>, shared_buffer_uploader: Arc, sstable_id_manager: SstableIdManagerRef, } @@ -85,14 +68,12 @@ impl LocalVersionManager { pub fn new( options: Arc, pinned_version: PinnedVersion, - write_conflict_detector: Option>, sstable_id_manager: SstableIdManagerRef, shared_buffer_uploader: Arc, event_sender: UnboundedSender, memory_limiter: Arc, ) -> Arc { - let (version_update_notifier_tx, _) = tokio::sync::watch::channel(INVALID_VERSION_ID); - + assert!(pinned_version.is_valid()); let capacity = (options.shared_buffer_capacity_mb as usize) * (1 << 20); let buffer_tracker = BufferTracker::new( @@ -106,11 +87,7 @@ impl LocalVersionManager { Arc::new(LocalVersionManager { local_version: RwLock::new(LocalVersion::new(pinned_version)), - worker_context: WorkerContext { - version_update_notifier_tx, - }, buffer_tracker, - write_conflict_detector, shared_buffer_uploader, sstable_id_manager, }) @@ -131,7 +108,6 @@ impl LocalVersionManager { Self::new( options.clone(), pinned_version, - ConflictDetector::new_from_config(options.clone()), sstable_id_manager.clone(), Arc::new(SharedBufferUploader::new( options, @@ -156,8 +132,7 @@ impl LocalVersionManager { pub fn try_update_pinned_version( &self, pin_resp_payload: pin_version_response::Payload, - ) -> (Option, bool) { - let mut mce_change = false; + ) -> Option { let old_version = self.local_version.read(); let new_version_id = match &pin_resp_payload { Payload::VersionDeltas(version_deltas) => match version_deltas.version_deltas.last() { @@ -168,7 +143,7 @@ impl LocalVersionManager { }; if old_version.pinned_version().id() >= new_version_id { - return (None, mce_change); + return None; } let (newly_pinned_version, version_deltas) = match pin_resp_payload { @@ -183,95 +158,22 @@ impl LocalVersionManager { Payload::PinnedVersion(version) => (version, None), }; - for levels in newly_pinned_version.levels.values() { - if validate_table_key_range(&levels.levels).is_err() { - error!("invalid table key range: {:?}", levels.levels); - return (None, mce_change); - } - } + validate_table_key_range(&newly_pinned_version); drop(old_version); let mut new_version = self.local_version.write(); // check again to prevent other thread changes new_version. if new_version.pinned_version().id() >= newly_pinned_version.get_id() { - return (None, mce_change); + return None; } - { - // mce_change be used to check if need to notify - let max_committed_epoch_before_update = - new_version.pinned_version().max_committed_epoch(); - let max_committed_epoch_after_update = newly_pinned_version.max_committed_epoch; - mce_change = max_committed_epoch_before_update != max_committed_epoch_after_update; - } - - if let Some(conflict_detector) = self.write_conflict_detector.as_ref() { - conflict_detector.set_watermark(newly_pinned_version.max_committed_epoch); - } self.sstable_id_manager .remove_watermark_sst_id(TrackerId::Epoch(newly_pinned_version.max_committed_epoch)); new_version.set_pinned_version(newly_pinned_version, version_deltas); let result = new_version.pinned_version().clone(); RwLockWriteGuard::unlock_fair(new_version); - (Some(result), mce_change) - } - - /// Waits until the local hummock version contains the epoch. If `wait_epoch` is `Current`, - /// we will only check whether it is le `sealed_epoch` and won't wait. - pub async fn try_wait_epoch(&self, wait_epoch: HummockReadEpoch) -> HummockResult<()> { - let wait_epoch = match wait_epoch { - HummockReadEpoch::Committed(epoch) => epoch, - HummockReadEpoch::Current(epoch) => { - let sealed_epoch = self.local_version.read().get_sealed_epoch(); - assert!( - epoch <= sealed_epoch - && epoch != HummockEpoch::MAX - , - "current epoch can't read, because the epoch in storage is not updated, epoch{}, sealed epoch{}" - ,epoch - ,sealed_epoch - ); - return Ok(()); - } - HummockReadEpoch::NoWait(_) => return Ok(()), - }; - if wait_epoch == HummockEpoch::MAX { - panic!("epoch should not be u64::MAX"); - } - let mut receiver = self.worker_context.version_update_notifier_tx.subscribe(); - loop { - let (pinned_version_id, pinned_version_epoch) = { - let current_version = self.local_version.read(); - if current_version.pinned_version().max_committed_epoch() >= wait_epoch { - return Ok(()); - } - ( - current_version.pinned_version().id(), - current_version.pinned_version().max_committed_epoch(), - ) - }; - match tokio::time::timeout(Duration::from_secs(30), receiver.changed()).await { - Err(err) => { - // The reason that we need to retry here is batch scan in chain/rearrange_chain - // is waiting for an uncommitted epoch carried by the CreateMV barrier, which - // can take unbounded time to become committed and propagate - // to the CN. We should consider removing the retry as well as wait_epoch for - // chain/rearrange_chain if we enforce chain/rearrange_chain to be - // scheduled on the same CN with the same distribution as - // the upstream MV. See #3845 for more details. - tracing::warn!( - "wait_epoch {:?} timeout when waiting for version update. pinned_version_id {}, pinned_version_epoch {} err {:?} .", - wait_epoch, pinned_version_id, pinned_version_epoch, err - ); - continue; - } - Ok(Err(_)) => { - return Err(HummockError::wait_epoch("tx dropped")); - } - Ok(Ok(_)) => {} - } - } + Some(result) } pub fn write_shared_buffer_batch(&self, batch: SharedBufferBatch) { @@ -485,10 +387,6 @@ impl LocalVersionManager { pub fn sstable_id_manager(&self) -> SstableIdManagerRef { self.sstable_id_manager.clone() } - - pub fn notify_version_id_to_worker_context(&self, version_id: HummockVersionId) { - self.worker_context.notify_version_id(version_id); - } } // concurrent worker thread of `LocalVersionManager` diff --git a/src/storage/src/hummock/mod.rs b/src/storage/src/hummock/mod.rs index 87d11ebbebbdb..82ff2f348abfb 100644 --- a/src/storage/src/hummock/mod.rs +++ b/src/storage/src/hummock/mod.rs @@ -14,6 +14,7 @@ //! Hummock is the state store of the streaming system. +use std::sync::atomic::AtomicU64; use std::sync::Arc; use bytes::Bytes; @@ -130,6 +131,10 @@ pub struct HummockStorage { _shutdown_guard: Arc, storage_core: HummockStorageV2, + + version_update_notifier_tx: Arc>, + + seal_epoch: Arc, } impl HummockStorage { @@ -144,7 +149,6 @@ impl HummockStorage { ) -> HummockResult { // For conflict key detection. Enabled by setting `write_conflict_detection_enabled` to // true in `StorageConfig` - let write_conflict_detector = ConflictDetector::new_from_config(options.clone()); let sstable_id_manager = Arc::new(SstableIdManager::new( hummock_meta_client.clone(), options.sstable_id_remote_fetch_number, @@ -190,7 +194,6 @@ impl HummockStorage { let local_version_manager = LocalVersionManager::new( options.clone(), pinned_version, - write_conflict_detector, sstable_id_manager.clone(), shared_buffer_uploader, event_tx.clone(), @@ -201,10 +204,25 @@ impl HummockStorage { local_version_manager.get_pinned_version(), ))); + let (version_update_notifier_tx, seal_epoch) = { + let basic_max_committed_epoch = local_version_manager + .get_pinned_version() + .max_committed_epoch(); + let (version_update_notifier_tx, _rx) = + tokio::sync::watch::channel(basic_max_committed_epoch); + + ( + Arc::new(version_update_notifier_tx), + Arc::new(AtomicU64::new(basic_max_committed_epoch)), + ) + }; let hummock_event_handler = HummockEventHandler::new( local_version_manager.clone(), event_rx, read_version.clone(), + version_update_notifier_tx.clone(), + seal_epoch.clone(), + ConflictDetector::new_from_config(options.clone()), ); // Buffer size manager. @@ -233,6 +251,8 @@ impl HummockStorage { shutdown_sender: event_tx, }), storage_core, + version_update_notifier_tx, + seal_epoch, }; Ok(instance) } diff --git a/src/storage/src/hummock/sstable/backward_sstable_iterator.rs b/src/storage/src/hummock/sstable/backward_sstable_iterator.rs index 0ac36bff2bf18..e04528546dddc 100644 --- a/src/storage/src/hummock/sstable/backward_sstable_iterator.rs +++ b/src/storage/src/hummock/sstable/backward_sstable_iterator.rs @@ -182,7 +182,7 @@ mod tests { use crate::hummock::iterator::test_utils::mock_sstable_store; use crate::hummock::test_utils::{ create_small_table_cache, default_builder_opt_for_test, gen_default_test_sstable, - test_key_of, test_value_of, TEST_KEYS_COUNT, + prefixed_key, test_key_of, test_value_of, TEST_KEYS_COUNT, }; #[tokio::test] @@ -249,13 +249,19 @@ mod tests { } assert!(!sstable_iter.is_valid()); - let largest_key = key_with_epoch(format!("key_zzzz_{:05}", 0).as_bytes().to_vec(), 233); + let largest_key = key_with_epoch( + prefixed_key(format!("key_zzzz_{:05}", 0).as_bytes()).to_vec(), + 233, + ); sstable_iter.seek(largest_key.as_slice()).await.unwrap(); let key = sstable_iter.key(); assert_eq!(key, test_key_of(TEST_KEYS_COUNT - 1)); // Seek to > last key - let smallest_key = key_with_epoch(format!("key_aaaa_{:05}", 0).as_bytes().to_vec(), 233); + let smallest_key = key_with_epoch( + prefixed_key(format!("key_aaaa_{:05}", 0).as_bytes()).to_vec(), + 233, + ); sstable_iter.seek(smallest_key.as_slice()).await.unwrap(); assert!(!sstable_iter.is_valid()); @@ -268,7 +274,7 @@ mod tests { sstable_iter .seek( key_with_epoch( - format!("key_test_{:05}", idx * 2 - 1).as_bytes().to_vec(), + prefixed_key(format!("key_test_{:05}", idx * 2 - 1).as_bytes()).to_vec(), 0, ) .as_slice(), diff --git a/src/storage/src/hummock/sstable/builder.rs b/src/storage/src/hummock/sstable/builder.rs index 6fe3f802d541a..a6f3a3b1287d4 100644 --- a/src/storage/src/hummock/sstable/builder.rs +++ b/src/storage/src/hummock/sstable/builder.rs @@ -172,11 +172,10 @@ impl SstableBuilder { value.encode(&mut self.raw_value); if is_new_user_key { let mut extract_key = user_key(full_key); - if let Some(table_id) = get_table_id(full_key) { - if self.last_table_id != table_id { - self.table_ids.insert(table_id); - self.last_table_id = table_id; - } + let table_id = get_table_id(full_key); + if self.last_table_id != table_id { + self.table_ids.insert(table_id); + self.last_table_id = table_id; } extract_key = self.filter_key_extractor.extract(extract_key); diff --git a/src/storage/src/hummock/sstable/forward_sstable_iterator.rs b/src/storage/src/hummock/sstable/forward_sstable_iterator.rs index eb37a543f6350..3c7ff60418de4 100644 --- a/src/storage/src/hummock/sstable/forward_sstable_iterator.rs +++ b/src/storage/src/hummock/sstable/forward_sstable_iterator.rs @@ -196,7 +196,7 @@ mod tests { use crate::hummock::iterator::test_utils::mock_sstable_store; use crate::hummock::test_utils::{ create_small_table_cache, default_builder_opt_for_test, gen_default_test_sstable, - gen_test_sstable, test_key_of, test_value_of, TEST_KEYS_COUNT, + gen_test_sstable, prefixed_key, test_key_of, test_value_of, TEST_KEYS_COUNT, }; async fn inner_test_forward_iterator(sstable_store: SstableStoreRef, handle: TableHolder) { @@ -277,13 +277,19 @@ mod tests { assert!(!sstable_iter.is_valid()); // Seek to < first key - let smallest_key = key_with_epoch(format!("key_aaaa_{:05}", 0).as_bytes().to_vec(), 233); + let smallest_key = key_with_epoch( + prefixed_key(format!("key_aaaa_{:05}", 0).as_bytes()).to_vec(), + 233, + ); sstable_iter.seek(smallest_key.as_slice()).await.unwrap(); let key = sstable_iter.key(); assert_eq!(key, test_key_of(0)); // Seek to > last key - let largest_key = key_with_epoch(format!("key_zzzz_{:05}", 0).as_bytes().to_vec(), 233); + let largest_key = key_with_epoch( + prefixed_key(format!("key_zzzz_{:05}", 0).as_bytes()).to_vec(), + 233, + ); sstable_iter.seek(largest_key.as_slice()).await.unwrap(); assert!(!sstable_iter.is_valid()); @@ -296,7 +302,7 @@ mod tests { sstable_iter .seek( key_with_epoch( - format!("key_test_{:05}", idx * 2 - 1).as_bytes().to_vec(), + prefixed_key(format!("key_test_{:05}", idx * 2 - 1).as_bytes()).to_vec(), 0, ) .as_slice(), diff --git a/src/storage/src/hummock/sstable/multi_builder.rs b/src/storage/src/hummock/sstable/multi_builder.rs index 74af7edd3f6e2..0a1ba27623af3 100644 --- a/src/storage/src/hummock/sstable/multi_builder.rs +++ b/src/storage/src/hummock/sstable/multi_builder.rs @@ -257,7 +257,7 @@ mod tests { use super::*; use crate::hummock::iterator::test_utils::mock_sstable_store; use crate::hummock::sstable::utils::CompressionAlgorithm; - use crate::hummock::test_utils::default_builder_opt_for_test; + use crate::hummock::test_utils::{default_builder_opt_for_test, test_key_of}; use crate::hummock::{SstableBuilderOptions, DEFAULT_RESTART_INTERVAL}; #[tokio::test] @@ -294,7 +294,7 @@ mod tests { for i in 0..table_capacity { builder .add_user_key( - b"key".to_vec(), + test_key_of(i), HummockValue::put(b"value"), (table_capacity - i) as u64, ) @@ -320,7 +320,7 @@ mod tests { () => { epoch -= 1; builder - .add_user_key(b"k".to_vec(), HummockValue::put(b"v"), epoch) + .add_user_key(test_key_of(1), HummockValue::put(b"v"), epoch) .await .unwrap(); }; diff --git a/src/storage/src/hummock/sstable_store.rs b/src/storage/src/hummock/sstable_store.rs index 6ed9cd315468c..68edf25b1889a 100644 --- a/src/storage/src/hummock/sstable_store.rs +++ b/src/storage/src/hummock/sstable_store.rs @@ -355,7 +355,7 @@ impl SstableStore { Ok((Box::new(sst), charge)) } }) - .stack_trace("meta_cache_lookup") + .verbose_stack_trace("meta_cache_lookup") .await .map_err(|e| { HummockError::other(format!( diff --git a/src/storage/src/hummock/state_store.rs b/src/storage/src/hummock/state_store.rs index 38f4d6fba54db..060952ad93af4 100644 --- a/src/storage/src/hummock/state_store.rs +++ b/src/storage/src/hummock/state_store.rs @@ -16,7 +16,9 @@ use std::cmp::Ordering; use std::future::Future; use std::ops::Bound::{Excluded, Included}; use std::ops::RangeBounds; +use std::sync::atomic::Ordering as MemOrdering; use std::sync::Arc; +use std::time::Duration; use bytes::Bytes; use itertools::Itertools; @@ -38,7 +40,7 @@ use super::{ BackwardSstableIterator, HummockStorage, HummockStorageIterator, SstableIterator, SstableIteratorType, }; -use crate::error::StorageResult; +use crate::error::{StorageError, StorageResult}; use crate::hummock::iterator::{ Backward, BackwardUserIteratorType, DirectedUserIteratorBuilder, DirectionEnum, Forward, ForwardUserIteratorType, HummockIteratorDirection, @@ -48,7 +50,7 @@ use crate::hummock::shared_buffer::build_ordered_merge_iter; use crate::hummock::sstable::SstableIteratorReadOptions; use crate::hummock::store::{ReadOptions as ReadOptionsV2, StateStore as StateStoreV2}; use crate::hummock::utils::prune_ssts; -use crate::hummock::HummockResult; +use crate::hummock::{HummockEpoch, HummockError, HummockResult}; use crate::monitor::{StateStoreMetrics, StoreLocalStatistic}; use crate::storage_value::StorageValue; use crate::store::*; @@ -599,8 +601,71 @@ impl StateStore for HummockStorage { } } - fn try_wait_epoch(&self, epoch: HummockReadEpoch) -> Self::WaitEpochFuture<'_> { - async move { Ok(self.local_version_manager.try_wait_epoch(epoch).await?) } + /// Waits until the local hummock version contains the epoch. If `wait_epoch` is `Current`, + /// we will only check whether it is le `sealed_epoch` and won't wait. + fn try_wait_epoch(&self, wait_epoch: HummockReadEpoch) -> Self::WaitEpochFuture<'_> { + async move { + // Ok(self.local_version_manager.try_wait_epoch(epoch).await?) + let wait_epoch = match wait_epoch { + HummockReadEpoch::Committed(epoch) => epoch, + HummockReadEpoch::Current(epoch) => { + // let sealed_epoch = self.local_version.read().get_sealed_epoch(); + let sealed_epoch = (*self.seal_epoch).load(MemOrdering::SeqCst); + assert!( + epoch <= sealed_epoch + && epoch != HummockEpoch::MAX + , + "current epoch can't read, because the epoch in storage is not updated, epoch{}, sealed epoch{}" + ,epoch + ,sealed_epoch + ); + return Ok(()); + } + HummockReadEpoch::NoWait(_) => return Ok(()), + }; + if wait_epoch == HummockEpoch::MAX { + panic!("epoch should not be u64::MAX"); + } + + let mut receiver = self.version_update_notifier_tx.subscribe(); + // avoid unnecessary check in the loop if the value does not change + let max_committed_epoch = *receiver.borrow_and_update(); + if max_committed_epoch >= wait_epoch { + return Ok(()); + } + loop { + match tokio::time::timeout(Duration::from_secs(30), receiver.changed()).await { + Err(elapsed) => { + // The reason that we need to retry here is batch scan in + // chain/rearrange_chain is waiting for an + // uncommitted epoch carried by the CreateMV barrier, which + // can take unbounded time to become committed and propagate + // to the CN. We should consider removing the retry as well as wait_epoch + // for chain/rearrange_chain if we enforce + // chain/rearrange_chain to be scheduled on the same + // CN with the same distribution as the upstream MV. + // See #3845 for more details. + tracing::warn!( + "wait_epoch {:?} timeout when waiting for version update elapsed {:?}s", + wait_epoch, + elapsed + ); + continue; + } + Ok(Err(_)) => { + return StorageResult::Err(StorageError::Hummock( + HummockError::wait_epoch("tx dropped"), + )); + } + Ok(Ok(_)) => { + let max_committed_epoch = *receiver.borrow(); + if max_committed_epoch >= wait_epoch { + return Ok(()); + } + } + } + } + } } fn sync(&self, epoch: u64) -> Self::SyncFuture<'_> { diff --git a/src/storage/src/hummock/test_utils.rs b/src/storage/src/hummock/test_utils.rs index 58415af9f5f66..275f59871a9e8 100644 --- a/src/storage/src/hummock/test_utils.rs +++ b/src/storage/src/hummock/test_utils.rs @@ -14,7 +14,7 @@ use std::sync::Arc; -use bytes::Bytes; +use bytes::{BufMut, Bytes}; use itertools::Itertools; use risingwave_common::config::StorageConfig; use risingwave_hummock_sdk::key::key_with_epoch; @@ -219,9 +219,21 @@ pub async fn gen_test_sstable( gen_test_sstable_inner(opts, sst_id, kv_iter, sstable_store, CachePolicy::NotFill).await } -/// The key (with epoch 0) of an index in the test table +/// Prefix the `key` with a dummy table id. +/// We use `0` because: +/// - This value is used in the code to identify unit tests and prevent some parameters that are not +/// easily constructible in tests from breaking the test. +/// - When calling state store interfaces, we normally pass `TableId::default()`, which is `0`. +pub fn prefixed_key>(key: T) -> Bytes { + let mut buf = Vec::new(); + buf.put_u32(0); + buf.put_slice(key.as_ref()); + buf.into() +} + +/// The key (with epoch 0 and table id 0) of an index in the test table pub fn test_key_of(idx: usize) -> Vec { - let user_key = format!("key_test_{:05}", idx * 2).as_bytes().to_vec(); + let user_key = prefixed_key(&format!("key_test_{:05}", idx * 2).as_bytes()).to_vec(); key_with_epoch(user_key, 233) } diff --git a/src/storage/src/hummock/utils.rs b/src/storage/src/hummock/utils.rs index 20f81d6f54813..575e28fc9deb1 100644 --- a/src/storage/src/hummock/utils.rs +++ b/src/storage/src/hummock/utils.rs @@ -20,7 +20,7 @@ use std::sync::Arc; use risingwave_common::catalog::TableId; use risingwave_hummock_sdk::key::user_key; -use risingwave_pb::hummock::{Level, SstableInfo}; +use risingwave_pb::hummock::{HummockVersion, SstableInfo}; use tokio::sync::Notify; use super::{HummockError, HummockResult}; @@ -62,18 +62,24 @@ pub fn validate_epoch(safe_epoch: u64, epoch: u64) -> HummockResult<()> { Ok(()) } -pub fn validate_table_key_range(levels: &[Level]) -> HummockResult<()> { - for l in levels { +pub fn validate_table_key_range(version: &HummockVersion) { + for l in version.levels.values().flat_map(|levels| { + levels + .l0 + .as_ref() + .unwrap() + .sub_levels + .iter() + .chain(levels.levels.iter()) + }) { for t in &l.table_infos { - if t.key_range.is_none() { - return Err(HummockError::meta_error(format!( - "key_range in table [{}] is none", - t.id - ))); - } + assert!( + t.key_range.is_some(), + "key_range in table [{}] is none", + t.id + ); } } - Ok(()) } pub fn filter_single_sst(info: &SstableInfo, table_id: TableId, key_range: &R) -> bool diff --git a/src/storage/src/monitor/monitored_store.rs b/src/storage/src/monitor/monitored_store.rs index fdd083fe182c2..55e995cfdf07f 100644 --- a/src/storage/src/monitor/monitored_store.rs +++ b/src/storage/src/monitor/monitored_store.rs @@ -59,7 +59,7 @@ where // wait for iterator creation (e.g. seek) let iter = iter - .stack_trace("store_create_iter") + .verbose_stack_trace("store_create_iter") .await .inspect_err(|e| error!("Failed in iter: {:?}", e))?; @@ -106,7 +106,7 @@ where let value = self .inner .get(key, check_bloom_filter, read_options) - .stack_trace("store_get") + .verbose_stack_trace("store_get") .await .inspect_err(|e| error!("Failed in get: {:?}", e))?; timer.observe_duration(); @@ -136,7 +136,7 @@ where let result = self .inner .scan(prefix_hint, key_range, limit, read_options) - .stack_trace("store_scan") + .verbose_stack_trace("store_scan") .await .inspect_err(|e| error!("Failed in scan: {:?}", e))?; timer.observe_duration(); @@ -164,7 +164,7 @@ where let result = self .inner .scan(None, key_range, limit, read_options) - .stack_trace("store_backward_scan") + .verbose_stack_trace("store_backward_scan") .await .inspect_err(|e| error!("Failed in backward_scan: {:?}", e))?; timer.observe_duration(); @@ -194,7 +194,7 @@ where let batch_size = self .inner .ingest_batch(kv_pairs, write_options) - .stack_trace("store_ingest_batch") + .verbose_stack_trace("store_ingest_batch") .await .inspect_err(|e| error!("Failed in ingest_batch: {:?}", e))?; timer.observe_duration(); @@ -239,7 +239,7 @@ where async move { self.inner .try_wait_epoch(epoch) - .stack_trace("store_wait_epoch") + .verbose_stack_trace("store_wait_epoch") .await .inspect_err(|e| error!("Failed in wait_epoch: {:?}", e)) } @@ -253,7 +253,7 @@ where let sync_result = self .inner .sync(epoch) - .stack_trace("store_await_sync") + .verbose_stack_trace("store_await_sync") .await .inspect_err(|e| error!("Failed in sync: {:?}", e))?; timer.observe_duration(); @@ -278,7 +278,7 @@ where async move { self.inner .clear_shared_buffer() - .stack_trace("store_clear_shared_buffer") + .verbose_stack_trace("store_clear_shared_buffer") .await .inspect_err(|e| error!("Failed in clear_shared_buffer: {:?}", e)) } diff --git a/src/storage/src/storage_failpoints/test_sstable.rs b/src/storage/src/storage_failpoints/test_sstable.rs index 6dc8d733f907a..fec7c6236709a 100644 --- a/src/storage/src/storage_failpoints/test_sstable.rs +++ b/src/storage/src/storage_failpoints/test_sstable.rs @@ -22,7 +22,7 @@ use crate::hummock::iterator::HummockIterator; use crate::hummock::sstable::SstableIteratorReadOptions; use crate::hummock::test_utils::{ default_builder_opt_for_test, default_writer_opt_for_test, gen_test_sstable, - gen_test_sstable_data, put_sst, test_key_of, test_value_of, TEST_KEYS_COUNT, + gen_test_sstable_data, prefixed_key, put_sst, test_key_of, test_value_of, TEST_KEYS_COUNT, }; use crate::hummock::value::HummockValue; use crate::hummock::{SstableIterator, SstableIteratorType}; @@ -63,7 +63,7 @@ async fn test_failpoints_table_read() { fail::cfg(mem_read_err_fp, "return").unwrap(); let seek_key = key_with_epoch( - format!("key_test_{:05}", 600 * 2 - 1).as_bytes().to_vec(), + prefixed_key(&format!("key_test_{:05}", 600 * 2 - 1).as_bytes()).to_vec(), 0, ); let result = sstable_iter.seek(&seek_key).await; diff --git a/src/storage/src/table/batch_table/storage_table.rs b/src/storage/src/table/batch_table/storage_table.rs index 4a9f3ea408623..8ccd7e9247b94 100644 --- a/src/storage/src/table/batch_table/storage_table.rs +++ b/src/storage/src/table/batch_table/storage_table.rs @@ -545,7 +545,7 @@ impl StorageTableIterInner { while let Some((raw_key, value)) = self .iter .next() - .stack_trace("storage_table_iter_next") + .verbose_stack_trace("storage_table_iter_next") .await? { let (_, key) = parse_raw_key_to_vnode_and_key(&raw_key); diff --git a/src/storage/src/table/streaming_table/state_table.rs b/src/storage/src/table/streaming_table/state_table.rs index 89c788fd6b857..cb15f8eb6b67b 100644 --- a/src/storage/src/table/streaming_table/state_table.rs +++ b/src/storage/src/table/streaming_table/state_table.rs @@ -1048,7 +1048,7 @@ impl StorageIterInner { while let Some((key, value)) = self .iter .next() - .stack_trace("storage_table_iter_next") + .verbose_stack_trace("storage_table_iter_next") .await? { let row = self.deserializer.deserialize(value.as_ref())?; diff --git a/src/stream/src/common/builder.rs b/src/stream/src/common/builder.rs index 17580df9bbbd6..0d4c3ffd4a125 100644 --- a/src/stream/src/common/builder.rs +++ b/src/stream/src/common/builder.rs @@ -16,6 +16,8 @@ use itertools::Itertools; use risingwave_common::array::{ArrayBuilderImpl, ArrayResult, Op, Row, RowRef, StreamChunk}; use risingwave_common::types::DataType; +type IndexMappings = Vec<(usize, usize)>; + /// Build a array and it's corresponding operations. pub struct StreamChunkBuilder { /// operations in the data chunk to build @@ -27,19 +29,11 @@ pub struct StreamChunkBuilder { /// Data types of columns data_types: Vec, - /// The start position of the columns of the side - /// stream coming from. If the coming side is the - /// left, the `update_start_pos` should be 0. - /// If the coming side is the right, the `update_start_pos` - /// is the number of columns of the left side. - update_start_pos: usize, + /// The column index mapping from update side to output. + update_to_output: IndexMappings, - /// The start position of the columns of the opposite side - /// stream coming from. If the coming side is the - /// left, the `matched_start_pos` should be the number of columns of the left side. - /// If the coming side is the right, the `matched_start_pos` - /// should be 0. - matched_start_pos: usize, + /// The column index mapping from matched side to output. + matched_to_output: IndexMappings, /// Maximum capacity of column builder capacity: usize, @@ -58,8 +52,8 @@ impl StreamChunkBuilder { pub fn new( capacity: usize, data_types: &[DataType], - update_start_pos: usize, - matched_start_pos: usize, + update_to_output: IndexMappings, + matched_to_output: IndexMappings, ) -> ArrayResult { // Leave room for paired `UpdateDelete` and `UpdateInsert`. When there are `capacity - 1` // ops in current builder and the last op is `UpdateDelete`, we delay the chunk generation @@ -77,13 +71,34 @@ impl StreamChunkBuilder { ops, column_builders, data_types: data_types.to_owned(), - update_start_pos, - matched_start_pos, + update_to_output, + matched_to_output, capacity: reduced_capacity, size: 0, }) } + /// Get the mapping from left/right input indices to the output indices. + pub fn get_i2o_mapping( + output_indices: impl Iterator, + left_len: usize, + right_len: usize, + ) -> (IndexMappings, IndexMappings) { + let mut left_to_output = vec![]; + let mut right_to_output = vec![]; + + for (output_idx, idx) in output_indices.enumerate() { + if idx < left_len { + left_to_output.push((idx, output_idx)) + } else if idx >= left_len && idx < left_len + right_len { + right_to_output.push((idx - left_len, output_idx)); + } else { + unreachable!("output_indices out of bound") + } + } + (left_to_output, right_to_output) + } + /// Increase chunk size /// /// A [`StreamChunk`] will be returned when `size == capacity` @@ -109,11 +124,11 @@ impl StreamChunkBuilder { row_matched: &Row, ) -> ArrayResult> { self.ops.push(op); - for (i, d) in row_update.values().enumerate() { - self.column_builders[i + self.update_start_pos].append_datum_ref(d); + for &(update_idx, output_idx) in &self.update_to_output { + self.column_builders[output_idx].append_datum_ref(row_update.value_at(update_idx)); } - for (i, d) in row_matched.values().enumerate() { - self.column_builders[i + self.matched_start_pos].append_datum(d); + for &(matched_idx, output_idx) in &self.matched_to_output { + self.column_builders[output_idx].append_datum(&row_matched[matched_idx]); } self.inc_size() @@ -128,11 +143,11 @@ impl StreamChunkBuilder { row_update: &RowRef<'_>, ) -> ArrayResult> { self.ops.push(op); - for (i, d) in row_update.values().enumerate() { - self.column_builders[i + self.update_start_pos].append_datum_ref(d); + for &(update_idx, output_idx) in &self.update_to_output { + self.column_builders[output_idx].append_datum_ref(row_update.value_at(update_idx)); } - for i in 0..self.column_builders.len() - row_update.size() { - self.column_builders[i + self.matched_start_pos].append_datum(&None); + for &(_matched_idx, output_idx) in &self.matched_to_output { + self.column_builders[output_idx].append_datum(&None); } self.inc_size() @@ -147,11 +162,11 @@ impl StreamChunkBuilder { row_matched: &Row, ) -> ArrayResult> { self.ops.push(op); - for i in 0..self.column_builders.len() - row_matched.size() { - self.column_builders[i + self.update_start_pos].append_datum_ref(None); + for &(_update_idx, output_idx) in &self.update_to_output { + self.column_builders[output_idx].append_datum_ref(None); } - for i in 0..row_matched.size() { - self.column_builders[i + self.matched_start_pos].append_datum(&row_matched[i]); + for &(matched_idx, output_idx) in &self.matched_to_output { + self.column_builders[output_idx].append_datum(&row_matched[matched_idx]); } self.inc_size() diff --git a/src/stream/src/executor/actor.rs b/src/stream/src/executor/actor.rs index 73ed27dd9d722..02410d0383213 100644 --- a/src/stream/src/executor/actor.rs +++ b/src/stream/src/executor/actor.rs @@ -31,7 +31,6 @@ use crate::error::StreamResult; use crate::task::{ActorId, SharedContext}; /// Shared by all operators of an actor. -#[derive(Default)] pub struct ActorContext { pub id: ActorId, @@ -45,7 +44,7 @@ impl ActorContext { pub fn create(id: ActorId) -> ActorContextRef { Arc::new(Self { id, - ..Default::default() + errors: Default::default(), }) } @@ -66,10 +65,9 @@ pub struct Actor { /// The subtasks to execute concurrently. subtasks: Vec, - id: ActorId, context: Arc, _metrics: Arc, - _actor_context: ActorContextRef, + actor_context: ActorContextRef, } impl Actor @@ -79,7 +77,6 @@ where pub fn new( consumer: C, subtasks: Vec, - id: ActorId, context: Arc, metrics: Arc, actor_context: ActorContextRef, @@ -87,10 +84,9 @@ where Self { consumer, subtasks, - id, context, _metrics: metrics, - _actor_context: actor_context, + actor_context, } } @@ -105,11 +101,13 @@ where } async fn run_consumer(self) -> StreamResult<()> { - let span_name = format!("actor_poll_{:03}", self.id); + let id = self.actor_context.id; + + let span_name = format!("actor_poll_{:03}", id); let mut span = { let mut span = Span::enter_with_local_parent("actor_poll"); span.add_property(|| ("otel.name", span_name.to_string())); - span.add_property(|| ("next", self.id.to_string())); + span.add_property(|| ("next", id.to_string())); span.add_property(|| ("next", "Outbound".to_string())); span.add_property(|| ("epoch", (-1).to_string())); span @@ -133,14 +131,12 @@ where last_epoch = Some(barrier.epoch); // Collect barriers to local barrier manager - self.context - .lock_barrier_manager() - .collect(self.id, &barrier)?; + self.context.lock_barrier_manager().collect(id, &barrier)?; // Then stop this actor if asked - let to_stop = barrier.is_stop_or_update_drop_actor(self.id); + let to_stop = barrier.is_stop_or_update_drop_actor(id); if to_stop { - tracing::trace!(actor_id = self.id, "actor exit"); + tracing::trace!(actor_id = id, "actor exit"); return Ok(()); } @@ -148,14 +144,14 @@ where span = { let mut span = Span::enter_with_local_parent("actor_poll"); span.add_property(|| ("otel.name", span_name.to_string())); - span.add_property(|| ("next", self.id.to_string())); + span.add_property(|| ("next", id.to_string())); span.add_property(|| ("next", "Outbound".to_string())); span.add_property(|| ("epoch", barrier.epoch.curr.to_string())); span }; } - tracing::error!(actor_id = self.id, "actor exit without stop barrier"); + tracing::error!(actor_id = id, "actor exit without stop barrier"); Ok(()) } diff --git a/src/stream/src/executor/aggregation/agg_group.rs b/src/stream/src/executor/aggregation/agg_group.rs index 544fd02fd711f..07ba88a894133 100644 --- a/src/stream/src/executor/aggregation/agg_group.rs +++ b/src/stream/src/executor/aggregation/agg_group.rs @@ -133,21 +133,19 @@ impl AggGroup { } /// Apply input chunk to all managed agg states. - pub async fn apply_chunk( + /// `visibilities` contains the row visibility of the input chunk for each agg call. + pub fn apply_chunk( &mut self, storages: &mut [AggStateStorage], ops: &[Op], columns: &[Column], - visibilities: &[Option], + visibilities: Vec>, ) -> StreamExecutorResult<()> { - // TODO(yuchao): may directly pass `&[Column]` to managed states. - let column_refs = columns.iter().map(|col| col.array_ref()).collect_vec(); + let columns = columns.iter().map(|col| col.array_ref()).collect_vec(); for ((state, storage), visibility) in self.states.iter_mut().zip_eq(storages).zip_eq(visibilities) { - state - .apply_chunk(ops, visibility.as_ref(), &column_refs, storage) - .await?; + state.apply_chunk(ops, visibility.as_ref(), &columns, storage)?; } Ok(()) } diff --git a/src/stream/src/executor/aggregation/agg_impl/approx_count_distinct.rs b/src/stream/src/executor/aggregation/agg_impl/approx_count_distinct.rs index ad6075e48ea42..1f557db92d067 100644 --- a/src/stream/src/executor/aggregation/agg_impl/approx_count_distinct.rs +++ b/src/stream/src/executor/aggregation/agg_impl/approx_count_distinct.rs @@ -12,30 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -//! This module implements `StreamingApproxCountDistinct`. +//! This module implements `UpdatableStreamingApproxCountDistinct`. -use std::collections::hash_map::DefaultHasher; -use std::hash::{Hash, Hasher}; - -use itertools::Itertools; -use risingwave_common::array::stream_chunk::Ops; -use risingwave_common::array::*; use risingwave_common::bail; -use risingwave_common::buffer::Bitmap; -use risingwave_common::types::{Datum, DatumRef, Scalar, ScalarImpl}; -use super::StreamingAggImpl; +use super::approx_distinct_utils::{RegisterBucket, StreamingApproxCountDistinct}; use crate::executor::error::StreamExecutorResult; -const INDEX_BITS: u8 = 16; // number of bits used for finding the index of each 64-bit hash -const NUM_OF_REGISTERS: usize = 1 << INDEX_BITS; // number of registers available -const COUNT_BITS: u8 = 64 - INDEX_BITS; // number of non-index bits in each 64-bit hash - -// Approximation for bias correction for 16384 registers. See "HyperLogLog: the analysis of a -// near-optimal cardinality estimation algorithm" by Philippe Flajolet et al. -const BIAS_CORRECTION: f64 = 0.7213 / (1. + (1.079 / NUM_OF_REGISTERS as f64)); - -pub(crate) const DENSE_BITS_DEFAULT: usize = 16; // number of bits in the dense repr of the `RegisterBucket` +pub(crate) const DENSE_BITS_DEFAULT: usize = 16; // number of bits in the dense repr of the `UpdatableRegisterBucket` #[derive(Clone, Debug)] struct SparseCount { @@ -105,19 +89,12 @@ impl SparseCount { } #[derive(Clone, Debug)] -struct RegisterBucket { +pub(super) struct UpdatableRegisterBucket { dense_counts: [u64; DENSE_BITS], sparse_counts: SparseCount, } -impl RegisterBucket { - pub fn new() -> Self { - Self { - dense_counts: [0u64; DENSE_BITS], - sparse_counts: SparseCount::new(), - } - } - +impl UpdatableRegisterBucket { fn get_bucket(&self, index: usize) -> StreamExecutorResult { if index > 64 || index == 0 { bail!("HyperLogLog: Invalid bucket index"); @@ -129,9 +106,16 @@ impl RegisterBucket { Ok(self.dense_counts[index - 1]) } } +} + +impl RegisterBucket for UpdatableRegisterBucket { + fn new() -> Self { + Self { + dense_counts: [0u64; DENSE_BITS], + sparse_counts: SparseCount::new(), + } + } - /// Increments or decrements the bucket at `index` depending on the state of `is_insert`. - /// Returns an Error if `index` is invalid or if inserting will cause an overflow in the bucket. fn update_bucket(&mut self, index: usize, is_insert: bool) -> StreamExecutorResult<()> { if index > 64 || index == 0 { bail!("HyperLogLog: Invalid bucket index"); @@ -165,180 +149,63 @@ impl RegisterBucket { Ok(()) } - /// Gets the number of the maximum bucket which has a count greater than zero. - fn get_max(&self) -> StreamExecutorResult { + fn get_max(&self) -> u8 { if !self.sparse_counts.is_empty() { - return Ok(self.sparse_counts.last_key()); + return self.sparse_counts.last_key(); } for i in (0..DENSE_BITS).rev() { if self.dense_counts[i] > 0 { - return Ok(i as u8 + 1); + return i as u8 + 1; } } - Ok(0) + 0 } } -/// `StreamingApproxCountDistinct` approximates the count of non-null rows using a modified version -/// of the `HyperLogLog` algorithm. Each `RegisterBucket` stores a count of how many hash values -/// have x trailing zeroes for all x from 1-64. This allows the algorithm to support insertion and -/// deletion, but uses up more memory and limits the number of rows that can be counted. -/// -/// `StreamingApproxCountDistinct` can count up to a total of 2^64 unduplicated rows. -/// -/// The estimation error for `HyperLogLog` is 1.04/sqrt(num of registers). With 2^16 registers this -/// is ~1/256, or about 0.4%. The memory usage for the default choice of parameters is about -/// (1024 + 24) bits * 2^16 buckets, which is about 8.58 MB. #[derive(Clone, Debug, Default)] -pub struct StreamingApproxCountDistinct { +pub struct UpdatableStreamingApproxCountDistinct { // TODO(yuchao): The state may need to be stored in state table to allow correct recovery. - registers: Vec>, + registers: Vec>, initial_count: i64, } -impl StreamingApproxCountDistinct { - pub fn new() -> Self { - StreamingApproxCountDistinct::with_datum(None) - } - - pub fn with_datum(datum: Datum) -> Self { - let count = if let Some(c) = datum { - match c { - ScalarImpl::Int64(num) => num, - other => panic!( - "type mismatch in streaming aggregator StreamingApproxCountDistinct init: expected i64, get {}", - other.get_ident() - ), - } - } else { - 0 - }; +impl StreamingApproxCountDistinct + for UpdatableStreamingApproxCountDistinct +{ + type Bucket = UpdatableRegisterBucket; + fn with_i64(registers_num: u32, initial_count: i64) -> Self { Self { - registers: vec![RegisterBucket::new(); NUM_OF_REGISTERS], - initial_count: count, - } - } - - /// Adds the count of the datum's hash into the register, if it is greater than the existing - /// count at the register. - fn update_registers( - &mut self, - datum_ref: DatumRef<'_>, - is_insert: bool, - ) -> StreamExecutorResult<()> { - if datum_ref.is_none() { - return Ok(()); + registers: vec![UpdatableRegisterBucket::new(); registers_num as usize], + initial_count, } - - let scalar_impl = datum_ref.unwrap().into_scalar_impl(); - let hash = self.get_hash(scalar_impl); - - let index = (hash as usize) & (NUM_OF_REGISTERS - 1); // Index is based on last few bits - let count = self.count_hash(hash) as usize; - - self.registers[index].update_bucket(count, is_insert)?; - - Ok(()) } - /// Calculate the hash of the `scalar_impl`. - fn get_hash(&self, scalar_impl: ScalarImpl) -> u64 { - let mut hasher = DefaultHasher::new(); - scalar_impl.hash(&mut hasher); - hasher.finish() - } - - /// Counts the number of trailing zeroes plus 1 in the non-index bits of the hash. - fn count_hash(&self, mut hash: u64) -> u8 { - hash >>= INDEX_BITS; // Ignore bits used as index for the hash - hash |= 1 << COUNT_BITS; // To allow hash to terminate if it is all 0s - - (hash.trailing_zeros() + 1) as u8 - } -} - -impl StreamingAggImpl for StreamingApproxCountDistinct { - fn apply_batch( - &mut self, - ops: Ops<'_>, - visibility: Option<&Bitmap>, - data: &[&ArrayImpl], - ) -> StreamExecutorResult<()> { - match visibility { - None => { - for (op, datum) in ops.iter().zip_eq(data[0].iter()) { - match op { - Op::Insert | Op::UpdateInsert => self.update_registers(datum, true)?, - Op::Delete | Op::UpdateDelete => self.update_registers(datum, false)?, - } - } - } - Some(visibility) => { - for ((visible, op), datum) in - visibility.iter().zip_eq(ops.iter()).zip_eq(data[0].iter()) - { - if visible { - match op { - Op::Insert | Op::UpdateInsert => self.update_registers(datum, true)?, - Op::Delete | Op::UpdateDelete => self.update_registers(datum, false)?, - } - } - } - } - } - Ok(()) + fn get_initial_count(&self) -> i64 { + self.initial_count } - fn get_output(&self) -> StreamExecutorResult { - let m = NUM_OF_REGISTERS as f64; - let mut mean = 0.0; - - // Get harmonic mean of all the counts in results - for register_bucket in &self.registers { - let count = register_bucket.get_max()?; - mean += 1.0 / ((1 << count) as f64); - } - - let raw_estimate = BIAS_CORRECTION * m * m / mean; - - // If raw_estimate is not much bigger than m and some registers have value 0, set answer to - // m * log(m/V) where V is the number of registers with value 0 - let answer = if raw_estimate <= 2.5 * m { - let mut zero_registers: f64 = 0.0; - for i in &self.registers { - if i.get_max()? == 0 { - zero_registers += 1.0; - } - } - - if zero_registers == 0.0 { - raw_estimate - } else { - m * (m.log2() - (zero_registers.log2())) - } - } else { - raw_estimate - }; - - Ok(Some((answer as i64 + self.initial_count).to_scalar_value())) + fn reset_buckets(&mut self, registers_num: u32) { + self.registers = vec![UpdatableRegisterBucket::new(); registers_num as usize]; } - fn new_builder(&self) -> ArrayBuilderImpl { - ArrayBuilderImpl::Int64(I64ArrayBuilder::new(0)) + fn registers(&self) -> &[UpdatableRegisterBucket] { + &self.registers } - fn reset(&mut self) { - self.registers = vec![RegisterBucket::new(); NUM_OF_REGISTERS]; + fn registers_mut(&mut self) -> &mut [UpdatableRegisterBucket] { + &mut self.registers } } #[cfg(test)] mod tests { use assert_matches::assert_matches; + use risingwave_common::array::*; use risingwave_common::array_nonnull; use super::*; + use crate::executor::aggregation::agg_impl::StreamingAggImpl; #[test] fn test_streaming_approx_count_distinct_insert_and_delete() { @@ -349,7 +216,7 @@ mod tests { } fn test_streaming_approx_count_distinct_insert_and_delete_inner() { - let mut agg = StreamingApproxCountDistinct::::new(); + let mut agg = UpdatableStreamingApproxCountDistinct::::with_no_initial(); assert_eq!(agg.get_output().unwrap().unwrap().as_int64(), &0); agg.apply_batch( @@ -392,7 +259,7 @@ mod tests { /// error. #[test] fn test_error_ratio() { - let mut agg = StreamingApproxCountDistinct::<16>::new(); + let mut agg = UpdatableStreamingApproxCountDistinct::<16>::with_no_initial(); assert_eq!(agg.get_output().unwrap().unwrap().as_int64(), &0); let actual_ndv = 1000000; for i in 0..1000000 { @@ -410,7 +277,7 @@ mod tests { } fn test_register_bucket_get_and_update_inner() { - let mut rb = RegisterBucket::::new(); + let mut rb = UpdatableRegisterBucket::::new(); for i in 0..20 { rb.update_bucket(i % 2 + 1, true).unwrap(); @@ -428,7 +295,7 @@ mod tests { #[test] fn test_register_bucket_invalid_register() { - let mut rb = RegisterBucket::<0>::new(); + let mut rb = UpdatableRegisterBucket::<0>::new(); assert_matches!(rb.get_bucket(0), Err(_)); assert_matches!(rb.get_bucket(65), Err(_)); diff --git a/src/stream/src/executor/aggregation/agg_impl/approx_distinct_append.rs b/src/stream/src/executor/aggregation/agg_impl/approx_distinct_append.rs new file mode 100644 index 0000000000000..9c247288201a7 --- /dev/null +++ b/src/stream/src/executor/aggregation/agg_impl/approx_distinct_append.rs @@ -0,0 +1,83 @@ +// Copyright 2022 Singularity Data +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use risingwave_common::bail; + +use super::approx_distinct_utils::{RegisterBucket, StreamingApproxCountDistinct}; +use crate::executor::StreamExecutorResult; + +#[derive(Clone, Debug)] +pub(super) struct AppendOnlyRegisterBucket { + max: u8, +} + +impl RegisterBucket for AppendOnlyRegisterBucket { + fn new() -> Self { + Self { max: 0 } + } + + fn update_bucket(&mut self, index: usize, is_insert: bool) -> StreamExecutorResult<()> { + if index > 64 || index == 0 { + bail!("HyperLogLog: Invalid bucket index"); + } + + if !is_insert { + bail!("HyperLogLog: Deletion in append-only bucket"); + } + + if index as u8 > self.max { + self.max = index as u8; + } + + Ok(()) + } + + fn get_max(&self) -> u8 { + self.max + } +} + +#[derive(Clone, Debug, Default)] +pub struct AppendOnlyStreamingApproxCountDistinct { + // TODO(yuchao): The state may need to be stored in state table to allow correct recovery. + registers: Vec, + initial_count: i64, +} + +impl StreamingApproxCountDistinct for AppendOnlyStreamingApproxCountDistinct { + type Bucket = AppendOnlyRegisterBucket; + + fn with_i64(registers_num: u32, initial_count: i64) -> Self { + Self { + registers: vec![AppendOnlyRegisterBucket::new(); registers_num as usize], + initial_count, + } + } + + fn get_initial_count(&self) -> i64 { + self.initial_count + } + + fn reset_buckets(&mut self, registers_num: u32) { + self.registers = vec![AppendOnlyRegisterBucket::new(); registers_num as usize]; + } + + fn registers(&self) -> &[AppendOnlyRegisterBucket] { + &self.registers + } + + fn registers_mut(&mut self) -> &mut [AppendOnlyRegisterBucket] { + &mut self.registers + } +} diff --git a/src/stream/src/executor/aggregation/agg_impl/approx_distinct_utils.rs b/src/stream/src/executor/aggregation/agg_impl/approx_distinct_utils.rs new file mode 100644 index 0000000000000..a70a0990dba9d --- /dev/null +++ b/src/stream/src/executor/aggregation/agg_impl/approx_distinct_utils.rs @@ -0,0 +1,221 @@ +// Copyright 2022 Singularity Data +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::hash_map::DefaultHasher; +use std::hash::{Hash, Hasher}; + +use dyn_clone::DynClone; +use itertools::Itertools; +use risingwave_common::array::stream_chunk::Ops; +use risingwave_common::array::*; +use risingwave_common::buffer::Bitmap; +use risingwave_common::types::{Datum, DatumRef, Scalar, ScalarImpl}; + +use crate::executor::aggregation::agg_impl::StreamingAggImpl; +use crate::executor::StreamExecutorResult; + +const INDEX_BITS: u8 = 16; // number of bits used for finding the index of each 64-bit hash +const NUM_OF_REGISTERS: u32 = 1 << INDEX_BITS; // number of registers available +const COUNT_BITS: u8 = 64 - INDEX_BITS; // number of non-index bits in each 64-bit hash + +// Approximation for bias correction for 16384 registers. See "HyperLogLog: the analysis of a +// near-optimal cardinality estimation algorithm" by Philippe Flajolet et al. +const BIAS_CORRECTION: f64 = 0.7213 / (1. + (1.079 / NUM_OF_REGISTERS as f64)); + +pub(super) trait RegisterBucket { + fn new() -> Self; + + /// Increments or decrements the bucket at `index` depending on the state of `is_insert`. + /// Returns an Error if `index` is invalid or if inserting will cause an overflow in the bucket. + fn update_bucket(&mut self, index: usize, is_insert: bool) -> StreamExecutorResult<()>; + + /// Gets the number of the maximum bucket which has a count greater than zero. + fn get_max(&self) -> u8; +} + +/// `StreamingApproxCountDistinct` approximates the count of non-null rows using a modified version +/// of the `HyperLogLog` algorithm. Each `RegisterBucket` stores a count of how many hash values +/// have x trailing zeroes for all x from 1-64. This allows the algorithm to support insertion and +/// deletion, but uses up more memory and limits the number of rows that can be counted. +/// +/// `StreamingApproxCountDistinct` can count up to a total of 2^64 unduplicated rows. +/// +/// The estimation error for `HyperLogLog` is 1.04/sqrt(num of registers). With 2^16 registers this +/// is ~1/256, or about 0.4%. The memory usage for the default choice of parameters is about +/// (1024 + 24) bits * 2^16 buckets, which is about 8.58 MB. +pub(super) trait StreamingApproxCountDistinct: Sized { + type Bucket: RegisterBucket; + + fn with_no_initial() -> Self { + Self::with_datum(None) + } + + fn with_datum(datum: Datum) -> Self { + let count = if let Some(c) = datum { + match c { + ScalarImpl::Int64(num) => num, + other => panic!( + "type mismatch in streaming aggregator StreamingApproxCountDistinct init: expected i64, get {}", + other.get_ident() + ), + } + } else { + 0 + }; + + Self::with_i64(NUM_OF_REGISTERS, count) + } + + fn with_i64(registers_num: u32, initial_count: i64) -> Self; + fn get_initial_count(&self) -> i64; + fn reset_buckets(&mut self, registers_num: u32); + fn registers(&self) -> &[Self::Bucket]; + fn registers_mut(&mut self) -> &mut [Self::Bucket]; + + /// Adds the count of the datum's hash into the register, if it is greater than the existing + /// count at the register. + fn update_registers( + &mut self, + datum_ref: DatumRef<'_>, + is_insert: bool, + ) -> StreamExecutorResult<()> { + if datum_ref.is_none() { + return Ok(()); + } + + let scalar_impl = datum_ref.unwrap().into_scalar_impl(); + let hash = self.get_hash(scalar_impl); + + let index = (hash as u32) & (NUM_OF_REGISTERS - 1); // Index is based on last few bits + let count = self.count_hash(hash) as usize; + + self.registers_mut()[index as usize].update_bucket(count, is_insert)?; + + Ok(()) + } + + /// Calculate the hash of the `scalar_impl`. + fn get_hash(&self, scalar_impl: ScalarImpl) -> u64 { + let mut hasher = DefaultHasher::new(); + scalar_impl.hash(&mut hasher); + hasher.finish() + } + + /// Counts the number of trailing zeroes plus 1 in the non-index bits of the hash. + fn count_hash(&self, mut hash: u64) -> u8 { + hash >>= INDEX_BITS; // Ignore bits used as index for the hash + hash |= 1 << COUNT_BITS; // To allow hash to terminate if it is all 0s + + (hash.trailing_zeros() + 1) as u8 + } + + fn apply_batch_inner( + &mut self, + ops: Ops<'_>, + visibility: Option<&Bitmap>, + data: &[&ArrayImpl], + ) -> StreamExecutorResult<()> { + match visibility { + None => { + for (op, datum) in ops.iter().zip_eq(data[0].iter()) { + match op { + Op::Insert | Op::UpdateInsert => self.update_registers(datum, true)?, + Op::Delete | Op::UpdateDelete => self.update_registers(datum, false)?, + } + } + } + Some(visibility) => { + for ((visible, op), datum) in + visibility.iter().zip_eq(ops.iter()).zip_eq(data[0].iter()) + { + if visible { + match op { + Op::Insert | Op::UpdateInsert => self.update_registers(datum, true)?, + Op::Delete | Op::UpdateDelete => self.update_registers(datum, false)?, + } + } + } + } + } + Ok(()) + } + + fn get_output_inner(&self) -> StreamExecutorResult { + let m = NUM_OF_REGISTERS as f64; + let mut mean = 0.0; + + // Get harmonic mean of all the counts in results + for register_bucket in self.registers() { + let count = register_bucket.get_max(); + mean += 1.0 / ((1 << count) as f64); + } + + let raw_estimate = BIAS_CORRECTION * m * m / mean; + + // If raw_estimate is not much bigger than m and some registers have value 0, set answer to + // m * log(m/V) where V is the number of registers with value 0 + let answer = if raw_estimate <= 2.5 * m { + let mut zero_registers: f64 = 0.0; + for i in self.registers() { + if i.get_max() == 0 { + zero_registers += 1.0; + } + } + + if zero_registers == 0.0 { + raw_estimate + } else { + m * (m.log2() - (zero_registers.log2())) + } + } else { + raw_estimate + }; + + Ok(Some( + (answer as i64 + self.get_initial_count()).to_scalar_value(), + )) + } +} + +impl StreamingAggImpl for T +where + B: RegisterBucket, + T: std::fmt::Debug + + DynClone + + Send + + Sync + + 'static + + StreamingApproxCountDistinct, +{ + fn apply_batch( + &mut self, + ops: Ops<'_>, + visibility: Option<&Bitmap>, + data: &[&ArrayImpl], + ) -> StreamExecutorResult<()> { + self.apply_batch_inner(ops, visibility, data) + } + + fn get_output(&self) -> StreamExecutorResult { + self.get_output_inner() + } + + fn new_builder(&self) -> ArrayBuilderImpl { + ArrayBuilderImpl::Int64(I64ArrayBuilder::new(0)) + } + + fn reset(&mut self) { + self.reset_buckets(NUM_OF_REGISTERS); + } +} diff --git a/src/stream/src/executor/aggregation/agg_impl/mod.rs b/src/stream/src/executor/aggregation/agg_impl/mod.rs index 6fa025be61569..a517ee907538b 100644 --- a/src/stream/src/executor/aggregation/agg_impl/mod.rs +++ b/src/stream/src/executor/aggregation/agg_impl/mod.rs @@ -18,6 +18,7 @@ use std::any::Any; pub use approx_count_distinct::*; +use approx_distinct_utils::StreamingApproxCountDistinct; use dyn_clone::DynClone; pub use foldable::*; use risingwave_common::array::stream_chunk::Ops; @@ -35,6 +36,8 @@ pub use row_count::*; use crate::executor::{StreamExecutorError, StreamExecutorResult}; mod approx_count_distinct; +mod approx_distinct_append; +mod approx_distinct_utils; mod foldable; mod row_count; @@ -126,10 +129,10 @@ pub fn create_streaming_agg_impl( } )* (AggKind::ApproxCountDistinct, _, DataType::Int64, Some(datum)) => { - Box::new(StreamingApproxCountDistinct::<{approx_count_distinct::DENSE_BITS_DEFAULT}>::with_datum(datum)) + Box::new(UpdatableStreamingApproxCountDistinct::<{approx_count_distinct::DENSE_BITS_DEFAULT}>::with_datum(datum)) } (AggKind::ApproxCountDistinct, _, DataType::Int64, None) => { - Box::new(StreamingApproxCountDistinct::<{approx_count_distinct::DENSE_BITS_DEFAULT}>::new()) + Box::new(UpdatableStreamingApproxCountDistinct::<{approx_count_distinct::DENSE_BITS_DEFAULT}>::with_no_initial()) } (other_agg, other_input, other_return, _) => panic!( "streaming agg state not implemented: {:?} {:?} {:?}", diff --git a/src/stream/src/executor/aggregation/agg_state.rs b/src/stream/src/executor/aggregation/agg_state.rs index 87c209fd8fed5..0e7b2bc687407 100644 --- a/src/stream/src/executor/aggregation/agg_state.rs +++ b/src/stream/src/executor/aggregation/agg_state.rs @@ -84,7 +84,7 @@ impl AggState { agg_call, group_key, pk_indices, - mapping.clone(), + mapping, row_count, extreme_cache_size, input_schema, @@ -94,7 +94,7 @@ impl AggState { } /// Apply input chunk to the state. - pub async fn apply_chunk( + pub fn apply_chunk( &mut self, ops: Ops<'_>, visibility: Option<&Bitmap>, @@ -108,11 +108,8 @@ impl AggState { state.apply_chunk(ops, visibility, columns) } Self::MaterializedInput(state) => { - let state_table = - must_match!(storage, AggStateStorage::MaterializedInput { table, .. } => table); - state - .apply_chunk(ops, visibility, columns, state_table) - .await + debug_assert!(matches!(storage, AggStateStorage::MaterializedInput { .. })); + state.apply_chunk(ops, visibility, columns) } } } diff --git a/src/stream/src/executor/aggregation/minput.rs b/src/stream/src/executor/aggregation/minput.rs index 95fd5c54bc2f3..74a3597bc00cf 100644 --- a/src/stream/src/executor/aggregation/minput.rs +++ b/src/stream/src/executor/aggregation/minput.rs @@ -12,20 +12,29 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::marker::PhantomData; + +use futures::{pin_mut, StreamExt}; +use futures_async_stream::for_await; +use itertools::Itertools; use risingwave_common::array::stream_chunk::Ops; -use risingwave_common::array::{ArrayImpl, Row}; +use risingwave_common::array::{ArrayImpl, Op, Row}; use risingwave_common::buffer::Bitmap; use risingwave_common::catalog::Schema; -use risingwave_common::types::Datum; +use risingwave_common::types::{Datum, DatumRef, ScalarImpl}; +use risingwave_common::util::ordered::OrderedRowSerde; +use risingwave_common::util::sort_util::OrderType; use risingwave_expr::expr::AggKind; use risingwave_storage::table::streaming_table::state_table::StateTable; use risingwave_storage::StateStore; +use smallvec::SmallVec; -use super::table_state::{ - GenericExtremeState, ManagedArrayAggState, ManagedStringAggState, ManagedTableState, -}; +use super::state_cache::array_agg::ArrayAgg; +use super::state_cache::extreme::ExtremeAgg; +use super::state_cache::string_agg::StringAgg; +use super::state_cache::{CacheKey, GenericStateCache, StateCache}; use super::AggCall; -use crate::common::StateTableColumnMapping; +use crate::common::{iter_state_table, StateTableColumnMapping}; use crate::executor::{PkIndices, StreamExecutorResult}; /// Aggregation state as a materialization of input chunks. @@ -34,7 +43,29 @@ use crate::executor::{PkIndices, StreamExecutorResult}; /// stored in the state table when applying chunks, and the aggregation result is calculated /// when need to get output. pub struct MaterializedInputState { - inner: Box>, + /// Group key to aggregate with group. + /// None for simple agg, Some for group key of hash agg. + group_key: Option, + + /// Argument column indices in input chunks. + arg_col_indices: Vec, + + /// Argument column indices in state table. + state_table_arg_col_indices: Vec, + + /// The columns to order by in input chunks. + order_col_indices: Vec, + + /// The columns to order by in state table. + state_table_order_col_indices: Vec, + + /// Cache of state table. + cache: Box, + + /// Serializer for cache key. + cache_key_serializer: OrderedRowSerde, + + _phantom_data: PhantomData, } impl MaterializedInputState { @@ -43,61 +74,1182 @@ impl MaterializedInputState { agg_call: &AggCall, group_key: Option<&Row>, pk_indices: &PkIndices, - col_mapping: StateTableColumnMapping, + col_mapping: &StateTableColumnMapping, row_count: usize, extreme_cache_size: usize, input_schema: &Schema, ) -> Self { - Self { - inner: match agg_call.kind { - AggKind::Max | AggKind::Min | AggKind::FirstValue => { - Box::new(GenericExtremeState::new( - agg_call, - group_key, - pk_indices, - col_mapping, - row_count, - extreme_cache_size, - input_schema, - )) - } - AggKind::StringAgg => Box::new(ManagedStringAggState::new( - agg_call, - group_key, - pk_indices, - col_mapping, - row_count, - )), - AggKind::ArrayAgg => Box::new(ManagedArrayAggState::new( - agg_call, - group_key, - pk_indices, - col_mapping, - row_count, - )), + let arg_col_indices = agg_call.args.val_indices().to_vec(); + let (mut order_col_indices, mut order_types) = + if matches!(agg_call.kind, AggKind::Min | AggKind::Max) { + // `min`/`max` need not to order by any other columns, but have to + // order by the agg value implicitly. + let order_type = if agg_call.kind == AggKind::Min { + OrderType::Ascending + } else { + OrderType::Descending + }; + (vec![arg_col_indices[0]], vec![order_type]) + } else { + agg_call + .order_pairs + .iter() + .map(|p| (p.column_idx, p.order_type)) + .unzip() + }; + + let pk_len = pk_indices.len(); + order_col_indices.extend(pk_indices.iter()); + order_types.extend(itertools::repeat_n(OrderType::Ascending, pk_len)); + + // map argument columns to state table column indices + let state_table_arg_col_indices = arg_col_indices + .iter() + .map(|i| { + col_mapping + .upstream_to_state_table(*i) + .expect("the argument columns must appear in the state table") + }) + .collect_vec(); + + // map order by columns to state table column indices + let state_table_order_col_indices = order_col_indices + .iter() + .map(|i| { + col_mapping + .upstream_to_state_table(*i) + .expect("the order columns must appear in the state table") + }) + .collect_vec(); + + let cache_key_data_types = order_col_indices + .iter() + .map(|i| input_schema[*i].data_type()) + .collect_vec(); + let cache_key_serializer = OrderedRowSerde::new(cache_key_data_types, order_types); + + let cache_capacity = if matches!(agg_call.kind, AggKind::Min | AggKind::Max) { + extreme_cache_size + } else { + usize::MAX + }; + + let cache: Box = + match agg_call.kind { + AggKind::Min | AggKind::Max | AggKind::FirstValue => Box::new( + GenericStateCache::new(ExtremeAgg, cache_capacity, row_count), + ), + AggKind::StringAgg => { + Box::new(GenericStateCache::new(StringAgg, cache_capacity, row_count)) + } + AggKind::ArrayAgg => { + Box::new(GenericStateCache::new(ArrayAgg, cache_capacity, row_count)) + } _ => panic!( "Agg kind `{}` is not expected to have materialized input state", agg_call.kind ), - }, + }; + + Self { + group_key: group_key.cloned(), + arg_col_indices, + state_table_arg_col_indices, + order_col_indices, + state_table_order_col_indices, + cache, + cache_key_serializer, + _phantom_data: PhantomData, } } - /// Apply a chunk of data to the state. - pub async fn apply_chunk( + /// Apply a chunk of data to the state cache. + pub fn apply_chunk( &mut self, ops: Ops<'_>, visibility: Option<&Bitmap>, columns: &[&ArrayImpl], - state_table: &mut StateTable, ) -> StreamExecutorResult<()> { - self.inner - .apply_chunk(ops, visibility, columns, state_table) - .await + self.cache.apply_batch(StateCacheInputBatch::new( + ops, + visibility, + columns, + &self.cache_key_serializer, + &self.arg_col_indices, + &self.order_col_indices, + )); + Ok(()) } /// Get the output of the state. pub async fn get_output(&mut self, state_table: &StateTable) -> StreamExecutorResult { - self.inner.get_output(state_table).await + if !self.cache.is_synced() { + let all_data_iter = iter_state_table(state_table, self.group_key.as_ref()).await?; + pin_mut!(all_data_iter); + + let mut cache_filler = self.cache.begin_syncing(); + #[for_await] + for state_row in all_data_iter.take(cache_filler.capacity()) { + let state_row = state_row?; + let cache_key = { + let mut cache_key = Vec::new(); + self.cache_key_serializer.serialize_datums( + self.state_table_order_col_indices + .iter() + .map(|col_idx| &(state_row.0)[*col_idx]), + &mut cache_key, + ); + cache_key + }; + let cache_value = self + .state_table_arg_col_indices + .iter() + .map(|i| state_row[*i].as_ref().map(ScalarImpl::as_scalar_ref_impl)) + .collect(); + cache_filler.insert(cache_key, cache_value); + } + } + assert!(self.cache.is_synced()); + Ok(self.cache.get_output()) + } +} + +// TODO(yuchao): May extract common logic here to `struct [Data/Stream]ChunkRef` if there's other +// usage in the future. https://github.com/risingwavelabs/risingwave/pull/5908#discussion_r1002896176 +pub struct StateCacheInputBatch<'a> { + idx: usize, + ops: Ops<'a>, + visibility: Option<&'a Bitmap>, + columns: &'a [&'a ArrayImpl], + cache_key_serializer: &'a OrderedRowSerde, + arg_col_indices: &'a [usize], + order_col_indices: &'a [usize], +} + +impl<'a> StateCacheInputBatch<'a> { + fn new( + ops: Ops<'a>, + visibility: Option<&'a Bitmap>, + columns: &'a [&'a ArrayImpl], + cache_key_serializer: &'a OrderedRowSerde, + arg_col_indices: &'a [usize], + order_col_indices: &'a [usize], + ) -> Self { + let first_idx = visibility.map_or(0, |v| v.next_set_bit(0).unwrap_or(ops.len())); + Self { + idx: first_idx, + ops, + visibility, + columns, + cache_key_serializer, + arg_col_indices, + order_col_indices, + } + } +} + +impl<'a> Iterator for StateCacheInputBatch<'a> { + type Item = (Op, CacheKey, SmallVec<[DatumRef<'a>; 2]>); + + fn next(&mut self) -> Option { + if self.idx >= self.ops.len() { + None + } else { + let op = self.ops[self.idx]; + let key = { + let mut key = Vec::new(); + self.cache_key_serializer.serialize_datum_refs( + self.order_col_indices + .iter() + .map(|col_idx| self.columns[*col_idx].value_at(self.idx)), + &mut key, + ); + key + }; + let value = self + .arg_col_indices + .iter() + .map(|col_idx| self.columns[*col_idx].value_at(self.idx)) + .collect(); + self.idx = self.visibility.map_or(self.idx + 1, |v| { + v.next_set_bit(self.idx + 1).unwrap_or(self.ops.len()) + }); + Some((op, key, value)) + } + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashSet; + + use itertools::Itertools; + use rand::seq::IteratorRandom; + use rand::Rng; + use risingwave_common::array::{Row, StreamChunk}; + use risingwave_common::catalog::{ColumnDesc, ColumnId, Field, Schema, TableId}; + use risingwave_common::test_prelude::StreamChunkTestExt; + use risingwave_common::types::{DataType, ScalarImpl}; + use risingwave_common::util::epoch::EpochPair; + use risingwave_common::util::sort_util::{OrderPair, OrderType}; + use risingwave_expr::expr::AggKind; + use risingwave_storage::memory::MemoryStateStore; + use risingwave_storage::table::streaming_table::state_table::StateTable; + use risingwave_storage::StateStore; + + use super::MaterializedInputState; + use crate::common::StateTableColumnMapping; + use crate::executor::aggregation::{AggArgs, AggCall}; + use crate::executor::StreamExecutorResult; + + fn create_chunk( + pretty: &str, + table: &mut StateTable, + col_mapping: &StateTableColumnMapping, + ) -> StreamChunk { + let chunk = StreamChunk::from_pretty(pretty); + table.write_chunk(StreamChunk::new( + chunk.ops().to_vec(), + col_mapping + .upstream_columns() + .iter() + .map(|col_idx| chunk.columns()[*col_idx].clone()) + .collect(), + chunk.visibility().cloned(), + )); + chunk + } + + fn create_mem_state_table( + input_schema: &Schema, + upstream_columns: Vec, + order_types: Vec, + ) -> (StateTable, StateTableColumnMapping) { + // see `LogicalAgg::infer_stream_agg_state` for the construction of state table + let table_id = TableId::new(rand::thread_rng().gen()); + let columns = upstream_columns + .iter() + .map(|col_idx| input_schema[*col_idx].data_type()) + .enumerate() + .map(|(i, data_type)| ColumnDesc::unnamed(ColumnId::new(i as i32), data_type)) + .collect_vec(); + let mapping = StateTableColumnMapping::new(upstream_columns); + let pk_len = order_types.len(); + let table = StateTable::new_without_distribution( + MemoryStateStore::new(), + table_id, + columns, + order_types, + (0..pk_len).collect(), + ); + (table, mapping) + } + + fn create_extreme_agg_call(kind: AggKind, arg_type: DataType, arg_idx: usize) -> AggCall { + AggCall { + kind, + args: AggArgs::Unary(arg_type.clone(), arg_idx), + return_type: arg_type, + order_pairs: vec![], + append_only: false, + filter: None, + } + } + + #[tokio::test] + async fn test_extreme_agg_state_basic_min() -> StreamExecutorResult<()> { + // Assumption of input schema: + // (a: varchar, b: int32, c: int32, _row_id: int64) + + let input_pk_indices = vec![3]; // _row_id + let field1 = Field::unnamed(DataType::Varchar); + let field2 = Field::unnamed(DataType::Int32); + let field3 = Field::unnamed(DataType::Int32); + let field4 = Field::unnamed(DataType::Int64); + let input_schema = Schema::new(vec![field1, field2, field3, field4]); + + let agg_call = create_extreme_agg_call(AggKind::Min, DataType::Int32, 2); // min(c) + + let (mut table, mapping) = create_mem_state_table( + &input_schema, + vec![2, 3], + vec![ + OrderType::Ascending, // for AggKind::Min + OrderType::Ascending, + ], + ); + + let mut state = MaterializedInputState::new( + &agg_call, + None, + &input_pk_indices, + &mapping, + 0, + usize::MAX, + &input_schema, + ); + + let epoch = EpochPair::new_test_epoch(1); + table.init_epoch(epoch); + epoch.inc(); + + let mut row_count = 0; + + { + let chunk = create_chunk( + " T i i I + + a 1 8 123 + + b 5 2 128 + - b 5 2 128 + + c 1 3 130", + &mut table, + &mapping, + ); + row_count += 2; + + let (ops, columns, visibility) = chunk.into_inner(); + let columns: Vec<_> = columns.iter().map(|col| col.array_ref()).collect(); + state.apply_chunk(&ops, visibility.as_ref(), &columns)?; + + table.commit_for_test(epoch).await.unwrap(); + epoch.inc(); + + let res = state.get_output(&table).await?; + match res { + Some(ScalarImpl::Int32(s)) => { + assert_eq!(s, 3); + } + _ => panic!("unexpected output"), + } + } + + { + let chunk = create_chunk( + " T i i I + + d 0 8 134 + + e 2 2 137", + &mut table, + &mapping, + ); + row_count += 2; + + let (ops, columns, visibility) = chunk.into_inner(); + let columns: Vec<_> = columns.iter().map(|col| col.array_ref()).collect(); + state.apply_chunk(&ops, visibility.as_ref(), &columns)?; + + table.commit_for_test(epoch).await.unwrap(); + + let res = state.get_output(&table).await?; + match res { + Some(ScalarImpl::Int32(s)) => { + assert_eq!(s, 2); + } + _ => panic!("unexpected output"), + } + } + + { + // test recovery (cold start) + let mut state = MaterializedInputState::new( + &agg_call, + None, + &input_pk_indices, + &mapping, + row_count, + usize::MAX, + &input_schema, + ); + let res = state.get_output(&table).await?; + match res { + Some(ScalarImpl::Int32(s)) => { + assert_eq!(s, 2); + } + _ => panic!("unexpected output"), + } + } + + Ok(()) + } + + #[tokio::test] + async fn test_extreme_agg_state_basic_max() -> StreamExecutorResult<()> { + // Assumption of input schema: + // (a: varchar, b: int32, c: int32, _row_id: int64) + + let input_pk_indices = vec![3]; // _row_id + let field1 = Field::unnamed(DataType::Varchar); + let field2 = Field::unnamed(DataType::Int32); + let field3 = Field::unnamed(DataType::Int32); + let field4 = Field::unnamed(DataType::Int64); + let input_schema = Schema::new(vec![field1, field2, field3, field4]); + let agg_call = create_extreme_agg_call(AggKind::Max, DataType::Int32, 2); // max(c) + + let (mut table, mapping) = create_mem_state_table( + &input_schema, + vec![2, 3], + vec![ + OrderType::Descending, // for AggKind::Max + OrderType::Ascending, + ], + ); + + let mut state = MaterializedInputState::new( + &agg_call, + None, + &input_pk_indices, + &mapping, + 0, + usize::MAX, + &input_schema, + ); + + let epoch = EpochPair::new_test_epoch(1); + table.init_epoch(epoch); + epoch.inc(); + + let mut row_count = 0; + + { + let chunk = create_chunk( + " T i i I + + a 1 8 123 + + b 5 2 128 + - b 5 2 128 + + c 1 3 130", + &mut table, + &mapping, + ); + row_count += 2; + + let (ops, columns, visibility) = chunk.into_inner(); + let columns: Vec<_> = columns.iter().map(|col| col.array_ref()).collect(); + state.apply_chunk(&ops, visibility.as_ref(), &columns)?; + + table.commit_for_test(epoch).await.unwrap(); + epoch.inc(); + + let res = state.get_output(&table).await?; + match res { + Some(ScalarImpl::Int32(s)) => { + assert_eq!(s, 8); + } + _ => panic!("unexpected output"), + } + } + + { + let chunk = create_chunk( + " T i i I + + d 0 9 134 + + e 2 2 137", + &mut table, + &mapping, + ); + row_count += 2; + + let (ops, columns, visibility) = chunk.into_inner(); + let columns: Vec<_> = columns.iter().map(|col| col.array_ref()).collect(); + state.apply_chunk(&ops, visibility.as_ref(), &columns)?; + + table.commit_for_test(epoch).await.unwrap(); + + let res = state.get_output(&table).await?; + match res { + Some(ScalarImpl::Int32(s)) => { + assert_eq!(s, 9); + } + _ => panic!("unexpected output"), + } + } + + { + // test recovery (cold start) + let mut state = MaterializedInputState::new( + &agg_call, + None, + &input_pk_indices, + &mapping, + row_count, + usize::MAX, + &input_schema, + ); + let res = state.get_output(&table).await?; + match res { + Some(ScalarImpl::Int32(s)) => { + assert_eq!(s, 9); + } + _ => panic!("unexpected output"), + } + } + + Ok(()) + } + + #[tokio::test] + async fn test_extreme_agg_state_with_hidden_input() -> StreamExecutorResult<()> { + // Assumption of input schema: + // (a: varchar, b: int32, c: int32, _row_id: int64) + + let input_pk_indices = vec![3]; // _row_id + let field1 = Field::unnamed(DataType::Varchar); + let field2 = Field::unnamed(DataType::Int32); + let field3 = Field::unnamed(DataType::Int32); + let field4 = Field::unnamed(DataType::Int64); + let input_schema = Schema::new(vec![field1, field2, field3, field4]); + let agg_call_1 = create_extreme_agg_call(AggKind::Min, DataType::Varchar, 0); // min(a) + let agg_call_2 = create_extreme_agg_call(AggKind::Max, DataType::Varchar, 1); // max(b) + + let (mut table_1, mapping_1) = create_mem_state_table( + &input_schema, + vec![0, 3], + vec![ + OrderType::Ascending, // for AggKind::Min + OrderType::Ascending, + ], + ); + let (mut table_2, mapping_2) = create_mem_state_table( + &input_schema, + vec![1, 3], + vec![ + OrderType::Descending, // for AggKind::Max + OrderType::Ascending, + ], + ); + + let epoch = EpochPair::new_test_epoch(1); + table_1.init_epoch(epoch); + table_2.init_epoch(epoch); + epoch.inc(); + + let mut state_1 = MaterializedInputState::new( + &agg_call_1, + None, + &input_pk_indices, + &mapping_1, + 0, + usize::MAX, + &input_schema, + ); + let mut state_2 = MaterializedInputState::new( + &agg_call_2, + None, + &input_pk_indices, + &mapping_2, + 0, + usize::MAX, + &input_schema, + ); + + { + let chunk_1 = create_chunk( + " T i i I + + a 1 8 123 + + b 5 2 128 + - b 5 2 128 + + c 1 3 130 + + . 9 4 131 D + + . 6 5 132 D + + c . 3 133", + &mut table_1, + &mapping_1, + ); + let chunk_2 = create_chunk( + " T i i I + + a 1 8 123 + + b 5 2 128 + - b 5 2 128 + + c 1 3 130 + + . 9 4 131 + + . 6 5 132 + + c . 3 133 D", + &mut table_2, + &mapping_2, + ); + + [chunk_1, chunk_2] + .into_iter() + .zip_eq([&mut state_1, &mut state_2]) + .try_for_each(|(chunk, state)| { + let (ops, columns, visibility) = chunk.into_inner(); + let columns: Vec<_> = columns.iter().map(|col| col.array_ref()).collect(); + state.apply_chunk(&ops, visibility.as_ref(), &columns) + })?; + + table_1.commit_for_test(epoch).await.unwrap(); + table_2.commit_for_test(epoch).await.unwrap(); + + match state_1.get_output(&table_1).await? { + Some(ScalarImpl::Utf8(s)) => { + assert_eq!(&s, "a"); + } + _ => panic!("unexpected output"), + } + match state_2.get_output(&table_2).await? { + Some(ScalarImpl::Int32(s)) => { + assert_eq!(s, 9); + } + _ => panic!("unexpected output"), + } + } + + Ok(()) + } + + #[tokio::test] + async fn test_extreme_agg_state_grouped() -> StreamExecutorResult<()> { + // Assumption of input schema: + // (a: varchar, b: int32, c: int32, _row_id: int64) + + let input_pk_indices = vec![3]; + let field1 = Field::unnamed(DataType::Varchar); + let field2 = Field::unnamed(DataType::Int32); + let field3 = Field::unnamed(DataType::Int32); + let field4 = Field::unnamed(DataType::Int64); + let input_schema = Schema::new(vec![field1, field2, field3, field4]); + let agg_call = create_extreme_agg_call(AggKind::Max, DataType::Int32, 1); // max(b) + + let (mut table, mapping) = create_mem_state_table( + &input_schema, + vec![2, 1, 3], + vec![ + OrderType::Ascending, // c ASC + OrderType::Descending, // b DESC for AggKind::Max + OrderType::Ascending, // _row_id ASC + ], + ); + let group_key = Row::new(vec![Some(8.into())]); + + let mut state = MaterializedInputState::new( + &agg_call, + Some(&group_key), + &input_pk_indices, + &mapping, + 0, + usize::MAX, + &input_schema, + ); + + let epoch = EpochPair::new_test_epoch(1); + table.init_epoch(epoch); + epoch.inc(); + + let mut row_count = 0; + + { + let chunk = create_chunk( + " T i i I + + a 1 8 123 + + b 5 8 128 + + c 7 3 130 D // hide this row", + &mut table, + &mapping, + ); + row_count += 2; + + let (ops, columns, visibility) = chunk.into_inner(); + let columns: Vec<_> = columns.iter().map(|col| col.array_ref()).collect(); + state.apply_chunk(&ops, visibility.as_ref(), &columns)?; + + table.commit_for_test(epoch).await.unwrap(); + epoch.inc(); + + let res = state.get_output(&table).await?; + match res { + Some(ScalarImpl::Int32(s)) => { + assert_eq!(s, 5); + } + _ => panic!("unexpected output"), + } + } + + { + let chunk = create_chunk( + " T i i I + + d 9 2 134 D // hide this row + + e 8 8 137", + &mut table, + &mapping, + ); + row_count += 1; + + let (ops, columns, visibility) = chunk.into_inner(); + let columns: Vec<_> = columns.iter().map(|col| col.array_ref()).collect(); + state.apply_chunk(&ops, visibility.as_ref(), &columns)?; + + table.commit_for_test(epoch).await.unwrap(); + + let res = state.get_output(&table).await?; + match res { + Some(ScalarImpl::Int32(s)) => { + assert_eq!(s, 8); + } + _ => panic!("unexpected output"), + } + } + + { + // test recovery (cold start) + let mut state = MaterializedInputState::new( + &agg_call, + Some(&group_key), + &input_pk_indices, + &mapping, + row_count, + usize::MAX, + &input_schema, + ); + let res = state.get_output(&table).await?; + match res { + Some(ScalarImpl::Int32(s)) => { + assert_eq!(s, 8); + } + _ => panic!("unexpected output"), + } + } + + Ok(()) + } + + #[tokio::test] + async fn test_extreme_agg_state_with_random_values() -> StreamExecutorResult<()> { + // Assumption of input schema: + // (a: int32, _row_id: int64) + + let input_pk_indices = vec![1]; // _row_id + let field1 = Field::unnamed(DataType::Int32); + let field2 = Field::unnamed(DataType::Int64); + let input_schema = Schema::new(vec![field1, field2]); + let agg_call = create_extreme_agg_call(AggKind::Min, DataType::Int32, 0); // min(a) + + let (mut table, mapping) = create_mem_state_table( + &input_schema, + vec![0, 1], + vec![ + OrderType::Ascending, // for AggKind::Min + OrderType::Ascending, + ], + ); + + let epoch = EpochPair::new_test_epoch(1); + table.init_epoch(epoch); + epoch.inc(); + + let mut state = MaterializedInputState::new( + &agg_call, + None, + &input_pk_indices, + &mapping, + 0, + 1024, + &input_schema, + ); + + let mut rng = rand::thread_rng(); + let insert_values: Vec = (0..10000).map(|_| rng.gen()).collect_vec(); + let delete_values: HashSet<_> = insert_values + .iter() + .choose_multiple(&mut rng, 1000) + .into_iter() + .collect(); + let mut min_value = i32::MAX; + + { + let mut pretty_lines = vec!["i I".to_string()]; + for (row_id, value) in insert_values + .iter() + .enumerate() + .take(insert_values.len() / 2) + { + pretty_lines.push(format!("+ {} {}", value, row_id)); + if delete_values.contains(&value) { + pretty_lines.push(format!("- {} {}", value, row_id)); + continue; + } + if *value < min_value { + min_value = *value; + } + } + + let chunk = create_chunk(&pretty_lines.join("\n"), &mut table, &mapping); + let (ops, columns, visibility) = chunk.into_inner(); + let columns: Vec<_> = columns.iter().map(|col| col.array_ref()).collect(); + state.apply_chunk(&ops, visibility.as_ref(), &columns)?; + + table.commit_for_test(epoch).await.unwrap(); + epoch.inc(); + + let res = state.get_output(&table).await?; + match res { + Some(ScalarImpl::Int32(s)) => { + assert_eq!(s, min_value); + } + _ => panic!("unexpected output"), + } + } + + { + let mut pretty_lines = vec!["i I".to_string()]; + for (row_id, value) in insert_values + .iter() + .enumerate() + .skip(insert_values.len() / 2) + { + pretty_lines.push(format!("+ {} {}", value, row_id)); + if delete_values.contains(&value) { + pretty_lines.push(format!("- {} {}", value, row_id)); + continue; + } + if *value < min_value { + min_value = *value; + } + } + + let chunk = create_chunk(&pretty_lines.join("\n"), &mut table, &mapping); + let (ops, columns, visibility) = chunk.into_inner(); + let columns: Vec<_> = columns.iter().map(|col| col.array_ref()).collect(); + state.apply_chunk(&ops, visibility.as_ref(), &columns)?; + + table.commit_for_test(epoch).await.unwrap(); + + let res = state.get_output(&table).await?; + match res { + Some(ScalarImpl::Int32(s)) => { + assert_eq!(s, min_value); + } + _ => panic!("unexpected output"), + } + } + + Ok(()) + } + + #[tokio::test] + async fn test_extreme_agg_state_cache_maintenance() -> StreamExecutorResult<()> { + // Assumption of input schema: + // (a: int32, _row_id: int64) + + let input_pk_indices = vec![1]; // _row_id + let field1 = Field::unnamed(DataType::Int32); + let field2 = Field::unnamed(DataType::Int64); + let input_schema = Schema::new(vec![field1, field2]); + let agg_call = create_extreme_agg_call(AggKind::Min, DataType::Int32, 0); // min(a) + + let (mut table, mapping) = create_mem_state_table( + &input_schema, + vec![0, 1], + vec![ + OrderType::Ascending, // for AggKind::Min + OrderType::Ascending, + ], + ); + + let mut state = MaterializedInputState::new( + &agg_call, + None, + &input_pk_indices, + &mapping, + 0, + 3, // cache capacity = 3 for easy testing + &input_schema, + ); + + let epoch = EpochPair::new_test_epoch(1); + table.init_epoch(epoch); + epoch.inc(); + + { + let chunk = create_chunk( + " i I + + 4 123 + + 8 128 + + 12 129", + &mut table, + &mapping, + ); + let (ops, columns, visibility) = chunk.into_inner(); + let columns: Vec<_> = columns.iter().map(|col| col.array_ref()).collect(); + state.apply_chunk(&ops, visibility.as_ref(), &columns)?; + + table.commit_for_test(epoch).await.unwrap(); + epoch.inc(); + + let res = state.get_output(&table).await?; + match res { + Some(ScalarImpl::Int32(s)) => { + assert_eq!(s, 4); + } + _ => panic!("unexpected output"), + } + } + + { + let chunk = create_chunk( + " i I + + 9 130 // this will evict 12 + - 9 130 + + 13 128 + - 4 123 + - 8 128", + &mut table, + &mapping, + ); + let (ops, columns, visibility) = chunk.into_inner(); + let columns: Vec<_> = columns.iter().map(|col| col.array_ref()).collect(); + state.apply_chunk(&ops, visibility.as_ref(), &columns)?; + + table.commit_for_test(epoch).await.unwrap(); + epoch.inc(); + + let res = state.get_output(&table).await?; + match res { + Some(ScalarImpl::Int32(s)) => { + assert_eq!(s, 12); + } + _ => panic!("unexpected output"), + } + } + + { + let chunk = create_chunk( + " i I + + 1 131 + + 2 132 + + 3 133 // evict all from cache + - 1 131 + - 2 132 + - 3 133 + + 14 134", + &mut table, + &mapping, + ); + let (ops, columns, visibility) = chunk.into_inner(); + let columns: Vec<_> = columns.iter().map(|col| col.array_ref()).collect(); + state.apply_chunk(&ops, visibility.as_ref(), &columns)?; + + table.commit_for_test(epoch).await.unwrap(); + + let res = state.get_output(&table).await?; + match res { + Some(ScalarImpl::Int32(s)) => { + assert_eq!(s, 12); + } + _ => panic!("unexpected output"), + } + } + + Ok(()) + } + + #[tokio::test] + async fn test_string_agg_state() -> StreamExecutorResult<()> { + // Assumption of input schema: + // (a: varchar, _delim: varchar, b: int32, c: int32, _row_id: int64) + // where `a` is the column to aggregate + + let input_pk_indices = vec![4]; + let field1 = Field::unnamed(DataType::Varchar); + let field2 = Field::unnamed(DataType::Varchar); + let field3 = Field::unnamed(DataType::Int32); + let field4 = Field::unnamed(DataType::Int32); + let field5 = Field::unnamed(DataType::Int64); + let input_schema = Schema::new(vec![field1, field2, field3, field4, field5]); + + let agg_call = AggCall { + kind: AggKind::StringAgg, + args: AggArgs::Binary([DataType::Varchar, DataType::Varchar], [0, 1]), + return_type: DataType::Varchar, + order_pairs: vec![ + OrderPair::new(2, OrderType::Ascending), // b ASC + OrderPair::new(0, OrderType::Descending), // a DESC + ], + append_only: false, + filter: None, + }; + + let (mut table, mapping) = create_mem_state_table( + &input_schema, + vec![2, 0, 4, 1], + vec![ + OrderType::Ascending, // _row_id ASC + OrderType::Descending, // a DESC + OrderType::Ascending, // b ASC + ], + ); + + let mut state = MaterializedInputState::new( + &agg_call, + None, + &input_pk_indices, + &mapping, + 0, + usize::MAX, + &input_schema, + ); + + let epoch = EpochPair::new_test_epoch(1); + table.init_epoch(epoch); + epoch.inc(); + + { + let chunk = create_chunk( + " T T i i I + + a , 1 8 123 + + b / 5 2 128 + - b / 5 2 128 + + c _ 1 3 130", + &mut table, + &mapping, + ); + let (ops, columns, visibility) = chunk.into_inner(); + let columns: Vec<_> = columns.iter().map(|col| col.array_ref()).collect(); + state.apply_chunk(&ops, visibility.as_ref(), &columns)?; + + table.commit_for_test(epoch).await.unwrap(); + epoch.inc(); + + let res = state.get_output(&table).await?; + match res { + Some(ScalarImpl::Utf8(s)) => { + assert_eq!(s, "c,a".to_string()); + } + _ => panic!("unexpected output"), + } + } + + { + let chunk = create_chunk( + " T T i i I + + d - 0 8 134 + + e + 2 2 137", + &mut table, + &mapping, + ); + let (ops, columns, visibility) = chunk.into_inner(); + let columns: Vec<_> = columns.iter().map(|col| col.array_ref()).collect(); + state.apply_chunk(&ops, visibility.as_ref(), &columns)?; + + table.commit_for_test(epoch).await.unwrap(); + + let res = state.get_output(&table).await?; + match res { + Some(ScalarImpl::Utf8(s)) => { + assert_eq!(s, "d_c,a+e".to_string()); + } + _ => panic!("unexpected output"), + } + } + + Ok(()) + } + + #[tokio::test] + async fn test_array_agg_state() -> StreamExecutorResult<()> { + // Assumption of input schema: + // (a: varchar, b: int32, c: int32, _row_id: int64) + // where `a` is the column to aggregate + + let input_pk_indices = vec![3]; + let field1 = Field::unnamed(DataType::Varchar); + let field2 = Field::unnamed(DataType::Int32); + let field3 = Field::unnamed(DataType::Int32); + let field4 = Field::unnamed(DataType::Int64); + let input_schema = Schema::new(vec![field1, field2, field3, field4]); + + let agg_call = AggCall { + kind: AggKind::ArrayAgg, + args: AggArgs::Unary(DataType::Int32, 1), // array_agg(b) + return_type: DataType::Int32, + order_pairs: vec![ + OrderPair::new(2, OrderType::Ascending), // c ASC + OrderPair::new(0, OrderType::Descending), // a DESC + ], + append_only: false, + filter: None, + }; + + let (mut table, mapping) = create_mem_state_table( + &input_schema, + vec![2, 0, 3, 1], + vec![ + OrderType::Ascending, // c ASC + OrderType::Descending, // a DESC + OrderType::Ascending, // _row_id ASC + ], + ); + + let mut state = MaterializedInputState::new( + &agg_call, + None, + &input_pk_indices, + &mapping, + 0, + usize::MAX, + &input_schema, + ); + + let epoch = EpochPair::new_test_epoch(1); + table.init_epoch(epoch); + epoch.inc(); + + { + let chunk = create_chunk( + " T i i I + + a 1 8 123 + + b 5 2 128 + - b 5 2 128 + + c 2 3 130", + &mut table, + &mapping, + ); + let (ops, columns, visibility) = chunk.into_inner(); + let columns: Vec<_> = columns.iter().map(|col| col.array_ref()).collect(); + state.apply_chunk(&ops, visibility.as_ref(), &columns)?; + + table.commit_for_test(epoch).await.unwrap(); + epoch.inc(); + + let res = state.get_output(&table).await?; + match res { + Some(ScalarImpl::List(res)) => { + let res = res + .values() + .iter() + .map(|v| v.as_ref().map(ScalarImpl::as_int32).cloned()) + .collect_vec(); + assert_eq!(res, vec![Some(2), Some(1)]); + } + _ => panic!("unexpected output"), + } + } + + { + let chunk = create_chunk( + " T i i I + + d 0 8 134 + + e 2 2 137", + &mut table, + &mapping, + ); + let (ops, columns, visibility) = chunk.into_inner(); + let columns: Vec<_> = columns.iter().map(|col| col.array_ref()).collect(); + state.apply_chunk(&ops, visibility.as_ref(), &columns)?; + + table.commit_for_test(epoch).await.unwrap(); + + let res = state.get_output(&table).await?; + match res { + Some(ScalarImpl::List(res)) => { + let res = res + .values() + .iter() + .map(|v| v.as_ref().map(ScalarImpl::as_int32).cloned()) + .collect_vec(); + assert_eq!(res, vec![Some(2), Some(2), Some(0), Some(1)]); + } + _ => panic!("unexpected output"), + } + } + + Ok(()) } } diff --git a/src/stream/src/executor/aggregation/mod.rs b/src/stream/src/executor/aggregation/mod.rs index 5a316528f1409..ead881be84a77 100644 --- a/src/stream/src/executor/aggregation/mod.rs +++ b/src/stream/src/executor/aggregation/mod.rs @@ -15,18 +15,19 @@ pub use agg_call::*; pub use agg_group::*; pub use agg_state::*; -use anyhow::anyhow; use risingwave_common::array::column::Column; use risingwave_common::array::ArrayImpl::Bool; -use risingwave_common::array::{DataChunk, Vis}; +use risingwave_common::array::DataChunk; +use risingwave_common::bail; use risingwave_common::buffer::Bitmap; use risingwave_common::catalog::{Field, Schema}; +use risingwave_expr::expr::AggKind; use risingwave_storage::table::streaming_table::state_table::StateTable; use risingwave_storage::StateStore; use super::ActorContextRef; use crate::common::InfallibleExpression; -use crate::executor::error::{StreamExecutorError, StreamExecutorResult}; +use crate::executor::error::StreamExecutorResult; use crate::executor::Executor; mod agg_call; @@ -34,7 +35,7 @@ mod agg_group; pub mod agg_impl; mod agg_state; mod minput; -mod table_state; +mod state_cache; mod value; /// Generate [`crate::executor::HashAggExecutor`]'s schema from `input`, `agg_calls` and @@ -66,30 +67,41 @@ pub fn agg_call_filter_res( ctx: &ActorContextRef, identity: &str, agg_call: &AggCall, - columns: &Vec, - visibility: Option<&Bitmap>, + columns: &[Column], + base_visibility: Option<&Bitmap>, capacity: usize, ) -> StreamExecutorResult> { - if let Some(ref filter) = agg_call.filter { - let vis = Vis::from( - visibility - .cloned() - .unwrap_or_else(|| Bitmap::all_high_bits(capacity)), - ); - let data_chunk = DataChunk::new(columns.to_owned(), vis); + let agg_col_vis = if matches!( + agg_call.kind, + AggKind::Min | AggKind::Max | AggKind::StringAgg + ) { + // should skip NULL value for these kinds of agg function + let agg_col_idx = agg_call.args.val_indices()[0]; // the first arg is the agg column for all these kinds + let agg_col_bitmap = columns[agg_col_idx].array_ref().null_bitmap(); + Some(agg_col_bitmap) + } else { + None + }; + + let filter_vis = if let Some(ref filter) = agg_call.filter { + let data_chunk = DataChunk::new(columns.to_vec(), capacity); if let Bool(filter_res) = filter .eval_infallible(&data_chunk, |err| ctx.on_compute_error(err, identity)) .as_ref() { - Ok(Some(filter_res.to_bitmap())) + Some(filter_res.to_bitmap()) } else { - Err(StreamExecutorError::from(anyhow!( - "Filter can only receive bool array" - ))) + bail!("Filter can only receive bool array"); } } else { - Ok(visibility.cloned()) - } + None + }; + + Ok([base_visibility, agg_col_vis, filter_vis.as_ref()] + .into_iter() + .flatten() + .cloned() + .reduce(|x, y| &x & &y)) } pub fn iter_table_storage( diff --git a/src/stream/src/executor/aggregation/state_cache/array_agg.rs b/src/stream/src/executor/aggregation/state_cache/array_agg.rs new file mode 100644 index 0000000000000..be4a3703b7fc0 --- /dev/null +++ b/src/stream/src/executor/aggregation/state_cache/array_agg.rs @@ -0,0 +1,102 @@ +// Copyright 2022 Singularity Data +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use risingwave_common::array::ListValue; +use risingwave_common::types::{Datum, DatumRef, ScalarRefImpl}; +use smallvec::SmallVec; + +use super::StateCacheAggregator; + +pub struct ArrayAgg; + +impl StateCacheAggregator for ArrayAgg { + type Value = Datum; + + fn convert_cache_value(&self, value: SmallVec<[DatumRef<'_>; 2]>) -> Self::Value { + value[0].map(ScalarRefImpl::into_scalar_impl) + } + + fn aggregate<'a>(&'a self, values: impl Iterator) -> Datum { + let res_values = values.cloned().collect(); + Some(ListValue::new(res_values).into()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::executor::aggregation::state_cache::cache::OrderedCache; + + #[test] + fn test_array_agg_aggregate() { + let agg = ArrayAgg; + + let mut cache = OrderedCache::new(10); + // FIXME(yuchao): the behavior is not compatible with PG, #5962 + assert_eq!( + agg.aggregate(cache.iter_values()), + Some(ListValue::new(vec![]).into()) + ); + + cache.insert(vec![1, 2, 3], Some("hello".to_string().into())); + cache.insert(vec![1, 2, 4], Some("world".to_string().into())); + assert_eq!( + agg.aggregate(cache.iter_values()), + Some( + ListValue::new(vec![ + Some("hello".to_string().into()), + Some("world".to_string().into()), + ]) + .into() + ) + ); + + cache.insert(vec![0, 1, 2], Some("emmm".to_string().into())); + assert_eq!( + agg.aggregate(cache.iter_values()), + Some( + ListValue::new(vec![ + Some("emmm".to_string().into()), + Some("hello".to_string().into()), + Some("world".to_string().into()), + ]) + .into() + ) + ); + + cache.insert(vec![6, 6, 6], None); + assert_eq!( + agg.aggregate(cache.iter_values()), + Some( + ListValue::new(vec![ + Some("emmm".to_string().into()), + Some("hello".to_string().into()), + Some("world".to_string().into()), + None, + ]) + .into() + ) + ); + } + + #[test] + fn test_array_agg_convert() { + let agg = ArrayAgg; + let args = SmallVec::from_vec(vec![Some("hello".into())]); + assert_eq!( + agg.convert_cache_value(args), + Some("hello".to_string().into()) + ); + } +} diff --git a/src/stream/src/executor/aggregation/table_state/mod.rs b/src/stream/src/executor/aggregation/state_cache/cache.rs similarity index 54% rename from src/stream/src/executor/aggregation/table_state/mod.rs rename to src/stream/src/executor/aggregation/state_cache/cache.rs index 5337e2846e9f6..2363bb5947435 100644 --- a/src/stream/src/executor/aggregation/table_state/mod.rs +++ b/src/stream/src/executor/aggregation/state_cache/cache.rs @@ -14,52 +14,15 @@ use std::collections::BTreeMap; -pub use array_agg::ManagedArrayAggState; -use async_trait::async_trait; -pub use extreme::GenericExtremeState; -use risingwave_common::array::stream_chunk::Ops; -use risingwave_common::array::ArrayImpl; -use risingwave_common::buffer::Bitmap; -use risingwave_common::types::Datum; -use risingwave_storage::table::streaming_table::state_table::StateTable; -use risingwave_storage::StateStore; -pub use string_agg::ManagedStringAggState; - -use crate::executor::StreamExecutorResult; - -mod array_agg; -mod extreme; -mod string_agg; - -/// A trait over all table-structured states. -/// -/// It is true that this interface also fits to value managed state, but we won't implement -/// `ManagedTableState` for them. We want to reduce the overhead of `BoxedFuture`. For -/// `ManagedValueState`, we can directly forward its async functions to `ManagedStateImpl`, instead -/// of adding a layer of indirection caused by async traits. -#[async_trait] -pub trait ManagedTableState: Send + Sync + 'static { - async fn apply_chunk( - &mut self, - ops: Ops<'_>, - visibility: Option<&Bitmap>, - columns: &[&ArrayImpl], - state_table: &mut StateTable, - ) -> StreamExecutorResult<()>; - - /// Get the output of the state. Must flush before getting output. - async fn get_output(&mut self, state_table: &StateTable) -> StreamExecutorResult; -} - -/// Common cache structure for managed table states (non-append-only `min`/`max`, `string_agg`). -struct Cache { +/// Common cache structure for [`super::StateCache`] (non-append-only `min`/`max`, `string_agg`). +pub struct OrderedCache { /// The capacity of the cache. capacity: usize, /// Ordered cache entries. entries: BTreeMap, } -impl Cache { +impl OrderedCache { /// Create a new cache with specified capacity and order requirements. /// To create a cache with unlimited capacity, use `usize::MAX` for `capacity`. pub fn new(capacity: usize) -> Self { @@ -90,8 +53,6 @@ impl Cache { } /// Insert an entry into the cache. - /// Key: `OrderedRow` composed of order by fields. - /// Value: The value fields that are to be aggregated. pub fn insert(&mut self, key: K, value: V) { self.entries.insert(key, value); // evict if capacity is reached @@ -110,13 +71,62 @@ impl Cache { self.entries.last_key_value().map(|(k, _)| k) } - /// Get the first (smallest) value in the cache. - pub fn first_value(&self) -> Option<&V> { - self.entries.first_key_value().map(|(_, v)| v) - } - /// Iterate over the values in the cache. pub fn iter_values(&self) -> impl Iterator { self.entries.values() } } + +#[cfg(test)] +mod tests { + use itertools::Itertools; + + use super::*; + + #[test] + fn test_ordered_cache() { + let mut cache = OrderedCache::new(3); + assert_eq!(cache.capacity(), 3); + assert_eq!(cache.len(), 0); + assert!(cache.is_empty()); + assert!(cache.last_key().is_none()); + assert!(cache.iter_values().collect_vec().is_empty()); + + cache.insert(5, "hello".to_string()); + assert_eq!(cache.len(), 1); + assert!(!cache.is_empty()); + assert_eq!(cache.iter_values().collect_vec(), vec!["hello"]); + + cache.insert(3, "world".to_string()); + cache.insert(1, "risingwave!".to_string()); + assert_eq!(cache.len(), 3); + assert_eq!(cache.last_key(), Some(&5)); + assert_eq!( + cache.iter_values().collect_vec(), + vec!["risingwave!", "world", "hello"] + ); + + cache.insert(0, "foo".to_string()); + assert_eq!(cache.capacity(), 3); + assert_eq!(cache.len(), 3); + assert_eq!(cache.last_key(), Some(&3)); + assert_eq!( + cache.iter_values().collect_vec(), + vec!["foo", "risingwave!", "world"] + ); + + cache.remove(0); + assert_eq!(cache.len(), 2); + assert_eq!(cache.last_key(), Some(&3)); + cache.remove(3); + assert_eq!(cache.len(), 1); + assert_eq!(cache.last_key(), Some(&1)); + cache.remove(100); // can remove non-existing key + assert_eq!(cache.len(), 1); + + cache.clear(); + assert_eq!(cache.len(), 0); + assert_eq!(cache.capacity(), 3); + assert_eq!(cache.last_key(), None); + } +} diff --git a/src/stream/src/executor/aggregation/state_cache/extreme.rs b/src/stream/src/executor/aggregation/state_cache/extreme.rs new file mode 100644 index 0000000000000..39039e1fa7f70 --- /dev/null +++ b/src/stream/src/executor/aggregation/state_cache/extreme.rs @@ -0,0 +1,78 @@ +// Copyright 2022 Singularity Data +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use risingwave_common::types::{Datum, DatumRef, ScalarRefImpl}; +use smallvec::SmallVec; + +use super::StateCacheAggregator; + +/// Common aggregator for `min`/`max`. The behavior is simply to choose the +/// first value as aggregation result, so the value order in the given cache +/// is important and should be maintained outside. +pub struct ExtremeAgg; + +impl StateCacheAggregator for ExtremeAgg { + // TODO(yuchao): We can generate an `ExtremeAgg` for each data type to save memory. + type Value = Datum; + + fn convert_cache_value(&self, value: SmallVec<[DatumRef<'_>; 2]>) -> Self::Value { + value[0].map(ScalarRefImpl::into_scalar_impl) + } + + fn aggregate<'a>(&'a self, mut values: impl Iterator) -> Datum { + values.next().cloned().flatten() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::executor::aggregation::state_cache::cache::OrderedCache; + + #[test] + fn test_extreme_agg_aggregate() { + let agg = ExtremeAgg; + + let mut cache = OrderedCache::new(10); + assert_eq!(agg.aggregate(cache.iter_values()), None); + + cache.insert(vec![1, 2, 3], Some("hello".to_string().into())); + cache.insert(vec![1, 3, 4], Some("world".to_string().into())); + assert_eq!( + agg.aggregate(cache.iter_values()), + Some("hello".to_string().into()) + ); + + cache.insert(vec![0, 1, 2], Some("emmm".to_string().into())); + assert_eq!( + agg.aggregate(cache.iter_values()), + Some("emmm".to_string().into()) + ); + } + + #[test] + fn test_extreme_agg_convert() { + let agg = ExtremeAgg; + let args = SmallVec::from_vec(vec![Some("boom".into())]); + assert_eq!( + agg.convert_cache_value(args), + Some("boom".to_string().into()) + ); + let args = SmallVec::from_vec(vec![Some("hello".into()), Some("world".into())]); + assert_eq!( + agg.convert_cache_value(args), + Some("hello".to_string().into()) + ); + } +} diff --git a/src/stream/src/executor/aggregation/state_cache/mod.rs b/src/stream/src/executor/aggregation/state_cache/mod.rs new file mode 100644 index 0000000000000..e37ea28474704 --- /dev/null +++ b/src/stream/src/executor/aggregation/state_cache/mod.rs @@ -0,0 +1,197 @@ +// Copyright 2022 Singularity Data +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use itertools::Itertools; +use risingwave_common::array::Op; +use risingwave_common::types::{Datum, DatumRef}; +use smallvec::SmallVec; + +use self::cache::OrderedCache; +use super::minput::StateCacheInputBatch; + +pub mod array_agg; +mod cache; +pub mod extreme; +pub mod string_agg; + +/// Cache key type. +pub type CacheKey = Vec; + +/// Trait that defines the interface of state table cache. +pub trait StateCache: Send + Sync + 'static { + /// Check if the cache is synced with state table. + fn is_synced(&self) -> bool; + + /// Apply a batch of updates to the cache. + fn apply_batch(&mut self, batch: StateCacheInputBatch<'_>); + + /// Begin syncing the cache with state table. + fn begin_syncing(&mut self) -> StateCacheFiller<'_>; + + /// Get the aggregation output. + fn get_output(&self) -> Datum; +} + +/// Cache maintenance interface. +/// Note that this trait must be private, so that only [`StateCacheFiller`] can use it. +trait StateCacheMaintain: Send + Sync + 'static { + /// Insert an entry to the cache without checking row count, capacity, key order, etc. + /// Just insert into the inner cache structure, e.g. `BTreeMap`. + fn insert_unchecked(&mut self, key: CacheKey, value: SmallVec<[DatumRef<'_>; 2]>); + + /// Mark the cache as synced. + fn set_synced(&mut self); +} + +/// A temporary handle for filling the state cache. +/// The state cache will be marked as synced automatically when this handle is dropped. +pub struct StateCacheFiller<'a> { + capacity: usize, + cache: &'a mut dyn StateCacheMaintain, +} + +impl<'a> StateCacheFiller<'a> { + /// Get the capacity of the cache to be filled. + pub fn capacity(&self) -> usize { + self.capacity + } + + /// Insert an entry to the cache. + pub fn insert(&mut self, key: CacheKey, value: SmallVec<[DatumRef<'_>; 2]>) { + self.cache.insert_unchecked(key, value) + } +} + +impl<'a> Drop for StateCacheFiller<'a> { + fn drop(&mut self) { + self.cache.set_synced(); + } +} + +/// Trait that defines aggregators that aggregate entries in an [`OrderedCache`]. +pub trait StateCacheAggregator { + /// The cache value type. + type Value: Send + Sync; + + /// Convert cache value into compact representation. + fn convert_cache_value(&self, value: SmallVec<[DatumRef<'_>; 2]>) -> Self::Value; + + /// Aggregate all entries in the ordered cache. + fn aggregate<'a>(&'a self, values: impl Iterator) -> Datum; +} + +/// A [`StateCache`] implementation that uses [`OrderedCache`] as the cache. +pub struct GenericStateCache +where + Agg: StateCacheAggregator + Send + Sync + 'static, +{ + /// Aggregator implementation. + aggregator: Agg, + + /// The inner ordered cache. + cache: OrderedCache, + + /// Number of all items in the state store. + total_count: usize, + + /// Sync status of the state cache. + synced: bool, +} + +impl GenericStateCache +where + Agg: StateCacheAggregator + Send + Sync + 'static, +{ + pub fn new(aggregator: Agg, capacity: usize, total_count: usize) -> Self { + Self { + aggregator, + cache: OrderedCache::new(capacity), + total_count, + synced: total_count == 0, + } + } +} + +impl StateCache for GenericStateCache +where + Agg: StateCacheAggregator + Send + Sync + 'static, +{ + fn is_synced(&self) -> bool { + self.synced + } + + fn apply_batch(&mut self, mut batch: StateCacheInputBatch<'_>) { + if self.synced { + // only insert/delete entries if the cache is synced + for (op, key, value) in batch.by_ref() { + match op { + Op::Insert | Op::UpdateInsert => { + self.total_count += 1; + if self.cache.len() == self.total_count - 1 + || &key < self.cache.last_key().unwrap() + { + self.cache + .insert(key, self.aggregator.convert_cache_value(value)); + } + } + Op::Delete | Op::UpdateDelete => { + self.total_count -= 1; + self.cache.remove(key); + if self.total_count > 0 /* still has rows after deletion */ && self.cache.is_empty() + { + // the cache is empty, but the state table is not, so it's not synced + // any more + self.synced = false; + break; + } + } + } + } + } + + // count remaining ops + let op_counts = batch.counts_by(|(op, _, _)| op); + self.total_count += op_counts.get(&Op::Insert).unwrap_or(&0) + + op_counts.get(&Op::UpdateInsert).unwrap_or(&0); + self.total_count -= op_counts.get(&Op::Delete).unwrap_or(&0) + + op_counts.get(&Op::UpdateDelete).unwrap_or(&0); + } + + fn begin_syncing(&mut self) -> StateCacheFiller<'_> { + self.cache.clear(); // ensure the cache is clear before syncing + StateCacheFiller { + capacity: self.cache.capacity(), + cache: self, + } + } + + fn get_output(&self) -> Datum { + debug_assert!(self.synced); + self.aggregator.aggregate(self.cache.iter_values()) + } +} + +impl StateCacheMaintain for GenericStateCache +where + Agg: StateCacheAggregator + Send + Sync + 'static, +{ + fn insert_unchecked(&mut self, key: CacheKey, value: SmallVec<[DatumRef<'_>; 2]>) { + let value = self.aggregator.convert_cache_value(value); + self.cache.insert(key, value); + } + + fn set_synced(&mut self) { + self.synced = true; + } +} diff --git a/src/stream/src/executor/aggregation/state_cache/string_agg.rs b/src/stream/src/executor/aggregation/state_cache/string_agg.rs new file mode 100644 index 0000000000000..89100cfcc2f7c --- /dev/null +++ b/src/stream/src/executor/aggregation/state_cache/string_agg.rs @@ -0,0 +1,108 @@ +// Copyright 2022 Singularity Data +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use risingwave_common::types::{Datum, DatumRef, ScalarRefImpl}; +use smallvec::SmallVec; + +use super::StateCacheAggregator; + +pub struct StringAggData { + delim: String, + value: String, +} + +pub struct StringAgg; + +impl StateCacheAggregator for StringAgg { + type Value = StringAggData; + + fn convert_cache_value(&self, value: SmallVec<[DatumRef<'_>; 2]>) -> Self::Value { + StringAggData { + delim: value[1] + .map(ScalarRefImpl::into_utf8) + .unwrap_or_default() + .to_string(), + value: value[0] + .map(ScalarRefImpl::into_utf8) + .unwrap_or_default() + .to_string(), + } + } + + fn aggregate<'a>(&'a self, mut values: impl Iterator) -> Datum { + let mut result = match values.next() { + Some(data) => data.value.clone(), + None => return None, // return NULL if no rows to aggregate + }; + for StringAggData { value, delim } in values { + result.push_str(delim); + result.push_str(value); + } + Some(result.into()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::executor::aggregation::state_cache::cache::OrderedCache; + + #[test] + fn test_string_agg_aggregate() { + let agg = StringAgg; + + let mut cache = OrderedCache::new(10); + assert_eq!(agg.aggregate(cache.iter_values()), None); + + cache.insert( + vec![1, 2, 3], + StringAggData { + delim: "_".to_string(), + value: "hello".to_string(), + }, + ); + cache.insert( + vec![1, 3, 4], + StringAggData { + delim: ",".to_string(), + value: "world".to_string(), + }, + ); + assert_eq!( + agg.aggregate(cache.iter_values()), + Some("hello,world".to_string().into()) + ); + + cache.insert( + vec![0, 1, 2], + StringAggData { + delim: "/".to_string(), + value: "emmm".to_string(), + }, + ); + assert_eq!( + agg.aggregate(cache.iter_values()), + Some("emmm_hello,world".to_string().into()) + ); + } + + #[test] + fn test_string_agg_convert() { + let agg = StringAgg; + let args = SmallVec::from_vec(vec![Some("hello".into()), Some("world".into())]); + let value = agg.convert_cache_value(args); + assert_eq!(value.value, "hello".to_string()); + assert_eq!(value.delim, "world".to_string()); + } +} diff --git a/src/stream/src/executor/aggregation/table_state/array_agg.rs b/src/stream/src/executor/aggregation/table_state/array_agg.rs deleted file mode 100644 index 525ab363f71b0..0000000000000 --- a/src/stream/src/executor/aggregation/table_state/array_agg.rs +++ /dev/null @@ -1,525 +0,0 @@ -// Copyright 2022 Singularity Data -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::marker::PhantomData; - -use async_trait::async_trait; -use futures::pin_mut; -use futures_async_stream::for_await; -use risingwave_common::array::stream_chunk::Ops; -use risingwave_common::array::Op::{Delete, Insert, UpdateDelete, UpdateInsert}; -use risingwave_common::array::{ArrayImpl, ListValue, Row}; -use risingwave_common::buffer::Bitmap; -use risingwave_common::types::Datum; -use risingwave_common::util::ordered::OrderedRow; -use risingwave_common::util::sort_util::OrderType; -use risingwave_storage::table::streaming_table::state_table::StateTable; -use risingwave_storage::StateStore; - -use super::{Cache, ManagedTableState}; -use crate::common::{iter_state_table, StateTableColumnMapping}; -use crate::executor::aggregation::AggCall; -use crate::executor::error::StreamExecutorResult; -use crate::executor::PkIndices; - -pub struct ManagedArrayAggState { - _phantom_data: PhantomData, - - /// Group key to aggregate with group. - /// None for simple agg, Some for group key of hash agg. - group_key: Option, - - // TODO(yuchao): remove this after we move state table insertion out. - /// Contains the column mapping between upstream schema and state table. - state_table_col_mapping: StateTableColumnMapping, - - /// The column to aggregate in state table. - state_table_agg_col_idx: usize, - - /// The columns to order by in state table. - state_table_order_col_indices: Vec, - - /// The order types of `state_table_order_col_indices`. - state_table_order_types: Vec, - - /// In-memory all-or-nothing cache. - cache: Cache, - - /// Whether the cache is fully synced to state table. - cache_synced: bool, -} - -impl ManagedArrayAggState { - pub fn new( - agg_call: &AggCall, - group_key: Option<&Row>, - pk_indices: &PkIndices, - col_mapping: StateTableColumnMapping, - row_count: usize, - ) -> Self { - // map agg column to state table column index - let state_table_agg_col_idx = col_mapping - .upstream_to_state_table(agg_call.args.val_indices()[0]) - .expect("the column to be aggregate must appear in the state table"); - // map order by columns to state table column indices - let (state_table_order_col_indices, state_table_order_types) = agg_call - .order_pairs - .iter() - .map(|o| { - ( - col_mapping - .upstream_to_state_table(o.column_idx) - .expect("the column to be order by must appear in the state table"), - o.order_type, - ) - }) - .chain(pk_indices.iter().map(|idx| { - ( - col_mapping - .upstream_to_state_table(*idx) - .expect("the pk columns must appear in the state table"), - OrderType::Ascending, - ) - })) - .unzip(); - Self { - _phantom_data: PhantomData, - group_key: group_key.cloned(), - state_table_col_mapping: col_mapping, - state_table_agg_col_idx, - state_table_order_col_indices, - state_table_order_types, - cache: Cache::new(usize::MAX), - cache_synced: row_count == 0, // if there is no row, the cache is synced initially - } - } - - fn state_row_to_cache_entry(&self, state_row: &Row) -> (OrderedRow, Datum) { - let cache_key = OrderedRow::new( - state_row.by_indices(&self.state_table_order_col_indices), - &self.state_table_order_types, - ); - let cache_data = state_row[self.state_table_agg_col_idx].clone(); - (cache_key, cache_data) - } - - fn apply_chunk_inner( - &mut self, - ops: Ops<'_>, - visibility: Option<&Bitmap>, - columns: &[&ArrayImpl], - state_table: &mut StateTable, - ) -> StreamExecutorResult<()> { - // should not skip NULL value like `string_agg` and `min`/`max` - for (i, op) in ops - .iter() - .enumerate() - .filter(|(i, _)| visibility.map(|x| x.is_set(*i)).unwrap_or(true)) - { - let state_row = Row::new( - self.state_table_col_mapping - .upstream_columns() - .iter() - .map(|col_idx| columns[*col_idx].datum_at(i)) - .collect(), - ); - let (cache_key, cache_data) = self.state_row_to_cache_entry(&state_row); - - match op { - Insert | UpdateInsert => { - if self.cache_synced { - self.cache.insert(cache_key, cache_data); - } - state_table.insert(state_row); - } - Delete | UpdateDelete => { - if self.cache_synced { - self.cache.remove(cache_key); - } - state_table.delete(state_row); - } - } - } - - Ok(()) - } - - async fn get_output_inner( - &mut self, - state_table: &StateTable, - ) -> StreamExecutorResult { - if !self.cache_synced { - let all_data_iter = iter_state_table(state_table, self.group_key.as_ref()).await?; - pin_mut!(all_data_iter); - - self.cache.clear(); - #[for_await] - for state_row in all_data_iter { - let state_row = state_row?; - let (cache_key, cache_data) = self.state_row_to_cache_entry(&state_row); - self.cache.insert(cache_key, cache_data.clone()); - } - self.cache_synced = true; - } - - let mut values = Vec::with_capacity(self.cache.len()); - for cache_data in self.cache.iter_values() { - values.push(cache_data.clone()); - } - Ok(Some(ListValue::new(values).into())) - } -} - -#[async_trait] -impl ManagedTableState for ManagedArrayAggState { - async fn apply_chunk( - &mut self, - ops: Ops<'_>, - visibility: Option<&Bitmap>, - columns: &[&ArrayImpl], // contains all upstream columns - state_table: &mut StateTable, - ) -> StreamExecutorResult<()> { - self.apply_chunk_inner(ops, visibility, columns, state_table) - } - - async fn get_output(&mut self, state_table: &StateTable) -> StreamExecutorResult { - self.get_output_inner(state_table).await - } -} - -#[cfg(test)] -mod tests { - use itertools::Itertools; - use risingwave_common::array::StreamChunk; - use risingwave_common::catalog::{ColumnDesc, ColumnId, TableId}; - use risingwave_common::test_prelude::*; - use risingwave_common::types::{DataType, ScalarImpl}; - use risingwave_common::util::epoch::EpochPair; - use risingwave_common::util::sort_util::OrderPair; - use risingwave_expr::expr::AggKind; - use risingwave_storage::memory::MemoryStateStore; - - use super::*; - use crate::executor::aggregation::AggArgs; - - #[tokio::test] - async fn test_array_agg_state_simple_agg_without_order() -> StreamExecutorResult<()> { - // Assumption of input schema: - // (a: varchar, b: int32, c: int32, _row_id: int64) - - let input_pk_indices = vec![3]; - let agg_call = AggCall { - kind: AggKind::ArrayAgg, - args: AggArgs::Unary(DataType::Varchar, 0), // array_agg(a) - return_type: DataType::List { - datatype: Box::new(DataType::Varchar), - }, - order_pairs: vec![], - append_only: false, - filter: None, - }; - - // see `LogicalAgg::infer_stream_agg_state` for the construction of state table - let table_id = TableId::new(6666); - let columns = vec![ - ColumnDesc::unnamed(ColumnId::new(0), DataType::Int64), // _row_id - ColumnDesc::unnamed(ColumnId::new(1), DataType::Varchar), // a - ]; - let state_table_col_mapping = StateTableColumnMapping::new(vec![3, 0]); - let mut state_table = StateTable::new_without_distribution( - MemoryStateStore::new(), - table_id, - columns, - vec![OrderType::Ascending], - vec![0], // [_row_id] - ); - - let mut managed_state = ManagedArrayAggState::new( - &agg_call, - None, - &input_pk_indices, - state_table_col_mapping, - 0, - ); - - let epoch = EpochPair::new_test_epoch(1); - state_table.init_epoch(epoch); - epoch.inc(); - - let chunk = StreamChunk::from_pretty( - " T i i I - + a 1 8 123 - + b 5 2 128 - - b 5 2 128 - + . 7 6 129 - + c 1 3 130", - ); - let (ops, columns, visibility) = chunk.into_inner(); - let column_refs: Vec<_> = columns.iter().map(|col| col.array_ref()).collect(); - managed_state - .apply_chunk(&ops, visibility.as_ref(), &column_refs, &mut state_table) - .await?; - - epoch.inc(); - state_table.commit_for_test(epoch).await.unwrap(); - - let res = managed_state.get_output(&state_table).await?; - match res { - Some(ScalarImpl::List(res)) => { - let res = res - .values() - .iter() - .map(|v| v.as_ref().map(ScalarImpl::as_utf8).cloned()) - .collect_vec(); - assert_eq!(res.len(), 3); - assert!(res.contains(&Some("a".to_string()))); - assert!(res.contains(&Some("c".to_string()))); - assert!(res.contains(&None)); - } - _ => panic!("unexpected output"), - } - - Ok(()) - } - - #[tokio::test] - async fn test_array_agg_state_simple_agg_with_order() -> StreamExecutorResult<()> { - // Assumption of input schema: - // (a: varchar, b: int32, c: int32, _row_id: int64) - // where `a` is the column to aggregate - - let input_pk_indices = vec![3]; - let agg_call = AggCall { - kind: AggKind::ArrayAgg, - args: AggArgs::Unary(DataType::Int32, 1), // array_agg(b) - return_type: DataType::Int32, - order_pairs: vec![ - OrderPair::new(2, OrderType::Ascending), // c ASC - OrderPair::new(0, OrderType::Descending), // a DESC - ], - append_only: false, - filter: None, - }; - - let table_id = TableId::new(6666); - let columns = vec![ - ColumnDesc::unnamed(ColumnId::new(0), DataType::Int32), // c - ColumnDesc::unnamed(ColumnId::new(1), DataType::Varchar), // a - ColumnDesc::unnamed(ColumnId::new(2), DataType::Int64), // _row_id - ColumnDesc::unnamed(ColumnId::new(3), DataType::Int32), // b - ]; - let state_table_col_mapping = StateTableColumnMapping::new(vec![2, 0, 3, 1]); - let mut state_table = StateTable::new_without_distribution( - MemoryStateStore::new(), - table_id, - columns, - vec![ - OrderType::Ascending, // c ASC - OrderType::Descending, // a DESC - OrderType::Ascending, // _row_id ASC - ], - vec![0, 1, 2], // [c, a, _row_id] - ); - - let mut managed_state = ManagedArrayAggState::new( - &agg_call, - None, - &input_pk_indices, - state_table_col_mapping, - 0, - ); - - let epoch = EpochPair::new_test_epoch(1); - state_table.init_epoch(epoch); - epoch.inc(); - - { - let chunk = StreamChunk::from_pretty( - " T i i I - + a 1 8 123 - + b 5 2 128 - - b 5 2 128 - + c 2 3 130", - ); - let (ops, columns, visibility) = chunk.into_inner(); - let column_refs: Vec<_> = columns.iter().map(|col| col.array_ref()).collect(); - managed_state - .apply_chunk(&ops, visibility.as_ref(), &column_refs, &mut state_table) - .await?; - - state_table.commit_for_test(epoch).await.unwrap(); - epoch.inc(); - - let res = managed_state.get_output(&state_table).await?; - match res { - Some(ScalarImpl::List(res)) => { - let res = res - .values() - .iter() - .map(|v| v.as_ref().map(ScalarImpl::as_int32).cloned()) - .collect_vec(); - assert_eq!(res, vec![Some(2), Some(1)]); - } - _ => panic!("unexpected output"), - } - } - - { - let chunk = StreamChunk::from_pretty( - " T i i I - + d 0 8 134 - + e 2 2 137", - ); - let (ops, columns, visibility) = chunk.into_inner(); - let column_refs: Vec<_> = columns.iter().map(|col| col.array_ref()).collect(); - managed_state - .apply_chunk(&ops, visibility.as_ref(), &column_refs, &mut state_table) - .await?; - - state_table.commit_for_test(epoch).await.unwrap(); - - let res = managed_state.get_output(&state_table).await?; - match res { - Some(ScalarImpl::List(res)) => { - let res = res - .values() - .iter() - .map(|v| v.as_ref().map(ScalarImpl::as_int32).cloned()) - .collect_vec(); - assert_eq!(res, vec![Some(2), Some(2), Some(0), Some(1)]); - } - _ => panic!("unexpected output"), - } - } - - Ok(()) - } - - #[tokio::test] - async fn test_array_agg_state_grouped_agg_with_order() -> StreamExecutorResult<()> { - // Assumption of input schema: - // (a: varchar, b: int32, c: int32, _row_id: int64) - - let input_pk_indices = vec![3]; - let agg_call = AggCall { - kind: AggKind::ArrayAgg, - args: AggArgs::Unary(DataType::Varchar, 0), - return_type: DataType::Varchar, - order_pairs: vec![ - OrderPair::new(1, OrderType::Ascending), // b ASC - ], - append_only: false, - filter: None, - }; - - let table_id = TableId::new(6666); - let columns = vec![ - ColumnDesc::unnamed(ColumnId::new(0), DataType::Int32), // group by c - ColumnDesc::unnamed(ColumnId::new(1), DataType::Int32), // order by b - ColumnDesc::unnamed(ColumnId::new(2), DataType::Int64), // _row_id - ColumnDesc::unnamed(ColumnId::new(3), DataType::Varchar), // a - ]; - let state_table_col_mapping = StateTableColumnMapping::new(vec![2, 1, 3, 0]); - let mut state_table = StateTable::new_without_distribution( - MemoryStateStore::new(), - table_id, - columns, - vec![ - OrderType::Ascending, // c ASC - OrderType::Ascending, // b ASC - OrderType::Ascending, // _row_id ASC - ], - vec![0, 1, 2], // [c, b, _row_id] - ); - - let mut managed_state = ManagedArrayAggState::new( - &agg_call, - Some(&Row::new(vec![Some(8.into())])), - &input_pk_indices, - state_table_col_mapping, - 0, - ); - - let epoch = EpochPair::new_test_epoch(1); - state_table.init_epoch(epoch); - epoch.inc(); - - { - let chunk = StreamChunk::from_pretty( - " T i i I - + a 1 8 123 - + b 5 8 128 - + c 1 3 130 D // hide this row", - ); - let (ops, columns, visibility) = chunk.into_inner(); - let column_refs: Vec<_> = columns.iter().map(|col| col.array_ref()).collect(); - managed_state - .apply_chunk(&ops, visibility.as_ref(), &column_refs, &mut state_table) - .await?; - - state_table.commit_for_test(epoch).await.unwrap(); - epoch.inc(); - - let res = managed_state.get_output(&state_table).await?; - match res { - Some(ScalarImpl::List(res)) => { - let res = res - .values() - .iter() - .map(|v| v.as_ref().map(ScalarImpl::as_utf8).cloned()) - .collect_vec(); - assert_eq!(res, vec![Some("a".to_string()), Some("b".to_string())]); - } - _ => panic!("unexpected output"), - } - } - - { - let chunk = StreamChunk::from_pretty( - " T i i I - + d 0 2 134 D // hide this row - + e 2 8 137", - ); - let (ops, columns, visibility) = chunk.into_inner(); - let column_refs: Vec<_> = columns.iter().map(|col| col.array_ref()).collect(); - managed_state - .apply_chunk(&ops, visibility.as_ref(), &column_refs, &mut state_table) - .await?; - - state_table.commit_for_test(epoch).await.unwrap(); - - let res = managed_state.get_output(&state_table).await?; - match res { - Some(ScalarImpl::List(res)) => { - let res = res - .values() - .iter() - .map(|v| v.as_ref().map(ScalarImpl::as_utf8).cloned()) - .collect_vec(); - assert_eq!( - res, - vec![ - Some("a".to_string()), - Some("e".to_string()), - Some("b".to_string()) - ] - ); - } - _ => panic!("unexpected output"), - } - } - - Ok(()) - } -} diff --git a/src/stream/src/executor/aggregation/table_state/extreme.rs b/src/stream/src/executor/aggregation/table_state/extreme.rs deleted file mode 100644 index 83b27e813445a..0000000000000 --- a/src/stream/src/executor/aggregation/table_state/extreme.rs +++ /dev/null @@ -1,1003 +0,0 @@ -// Copyright 2022 Singularity Data -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::marker::PhantomData; - -use async_trait::async_trait; -use futures::{pin_mut, StreamExt}; -use futures_async_stream::for_await; -use risingwave_common::array::stream_chunk::{Op, Ops}; -use risingwave_common::array::{ArrayImpl, Row}; -use risingwave_common::buffer::Bitmap; -use risingwave_common::catalog::Schema; -use risingwave_common::types::*; -use risingwave_common::util::ordered::OrderedRowSerde; -use risingwave_common::util::sort_util::OrderType; -use risingwave_expr::expr::AggKind; -use risingwave_storage::table::streaming_table::state_table::StateTable; -use risingwave_storage::StateStore; - -use super::{Cache, ManagedTableState}; -use crate::common::{iter_state_table, StateTableColumnMapping}; -use crate::executor::aggregation::AggCall; -use crate::executor::error::StreamExecutorResult; -use crate::executor::PkIndices; - -/// Memcomparable row. -type CacheKey = Vec; - -/// Generic managed agg state for min/max. -/// It maintains a top N cache internally, using `HashSet`, and the sort key -/// is composed of (agg input value, upstream pk). -pub struct GenericExtremeState { - _phantom_data: PhantomData, - - /// Group key to aggregate with group. - /// None for simple agg, Some for group key of hash agg. - group_key: Option, - - // TODO(yuchao): remove this after we move state table insertion out. - /// Contains the column mapping between upstream schema and state table. - state_table_col_mapping: StateTableColumnMapping, - - // The column to aggregate in input chunk. - upstream_agg_col_idx: usize, - - /// The column to aggregate in state table. - state_table_agg_col_idx: usize, - - /// The columns to order by in state table. - state_table_order_col_indices: Vec, - - /// Number of all items in the state store. - total_count: usize, - - /// Cache for the top N elements in the state. Note that the cache - /// won't store group_key so the column indices should be offsetted - /// by group_key.len(), which is handled by `state_row_to_cache_row`. - cache: Cache, - - /// Whether the cache is synced to state table. The cache is synced iff: - /// - the cache is empty and `total_count` is 0, or - /// - the cache is not empty and elements in it are the top ones in the state table. - cache_synced: bool, - - /// Serializer for cache key. - cache_key_serializer: OrderedRowSerde, -} - -impl GenericExtremeState { - /// Create a managed extreme state. If `cache_capacity` is `None`, the cache will be - /// fully synced, otherwise it will only retain top entries. - pub fn new( - agg_call: &AggCall, - group_key: Option<&Row>, - pk_indices: &PkIndices, - col_mapping: StateTableColumnMapping, - row_count: usize, - cache_capacity: usize, - input_schema: &Schema, - ) -> Self { - let upstream_agg_col_idx = agg_call.args.val_indices()[0]; - // map agg column to state table column index - let state_table_agg_col_idx = col_mapping - .upstream_to_state_table(agg_call.args.val_indices()[0]) - .expect("the column to be aggregate must appear in the state table"); - // map order by columns to state table column indices - let (state_table_order_col_indices, state_table_order_types): (Vec<_>, Vec<_>) = - std::iter::once(( - state_table_agg_col_idx, - match agg_call.kind { - AggKind::Min => OrderType::Ascending, - AggKind::Max => OrderType::Descending, - _ => unreachable!(), - }, - )) - .chain(pk_indices.iter().map(|idx| { - ( - col_mapping - .upstream_to_state_table(*idx) - .expect("the pk columns must appear in the state table"), - OrderType::Ascending, - ) - })) - .unzip(); - - // the key written into cache is from the state table, and cache_key_serializer need to know - // its schema(data_types) - let cache_key_data_types = state_table_order_col_indices - .iter() - .map(|i| { - input_schema[col_mapping.upstream_columns()[*i]] - .data_type - .clone() - }) - .collect(); - let cache_key_serializer = - OrderedRowSerde::new(cache_key_data_types, state_table_order_types); - - Self { - _phantom_data: PhantomData, - group_key: group_key.cloned(), - state_table_col_mapping: col_mapping, - upstream_agg_col_idx, - state_table_agg_col_idx, - state_table_order_col_indices, - total_count: row_count, - cache: Cache::new(cache_capacity), - cache_synced: row_count == 0, // if there is no row, the cache is synced initially - cache_key_serializer, - } - } - - fn state_row_to_cache_entry(&self, state_row: &Row) -> (CacheKey, Datum) { - let mut cache_key = Vec::new(); - self.cache_key_serializer.serialize_datums( - self.state_table_order_col_indices - .iter() - .map(|col_idx| &(state_row.0)[*col_idx]), - &mut cache_key, - ); - let cache_data = state_row[self.state_table_agg_col_idx].clone(); - (cache_key, cache_data) - } - - /// Apply a chunk of data to the state. - fn apply_chunk_inner( - &mut self, - ops: Ops<'_>, - visibility: Option<&Bitmap>, - columns: &[&ArrayImpl], - state_table: &mut StateTable, - ) -> StreamExecutorResult<()> { - for (i, op) in ops - .iter() - .enumerate() - .filter(|(i, _)| visibility.map(|x| x.is_set(*i)).unwrap_or(true)) - .filter(|(i, _)| columns[self.upstream_agg_col_idx].null_bitmap().is_set(*i)) - { - let state_row = Row::new( - self.state_table_col_mapping - .upstream_columns() - .iter() - .map(|col_idx| columns[*col_idx].datum_at(i)) - .collect(), - ); - let (cache_key, cache_data) = self.state_row_to_cache_entry(&state_row); - match op { - Op::Insert | Op::UpdateInsert => { - if self.cache_synced - && (self.cache.len() == self.total_count - || &cache_key < self.cache.last_key().unwrap()) - { - self.cache.insert(cache_key, cache_data); - } - state_table.insert(state_row); - self.total_count += 1; - } - Op::Delete | Op::UpdateDelete => { - if self.cache_synced { - self.cache.remove(cache_key); - if self.total_count > 1 /* still has rows after deletion */ && self.cache.is_empty() - { - self.cache_synced = false; - } - } - state_table.delete(state_row); - self.total_count -= 1; - } - } - } - - Ok(()) - } - - fn get_output_from_cache(&self) -> Option { - if self.cache_synced { - self.cache.first_value().cloned() - } else { - None - } - } - - async fn get_output_inner( - &mut self, - state_table: &StateTable, - ) -> StreamExecutorResult { - // try to get the result from cache - if let Some(datum) = self.get_output_from_cache() { - Ok(datum) - } else { - // read from state table and fill in the cache - let all_data_iter = iter_state_table(state_table, self.group_key.as_ref()).await?; - pin_mut!(all_data_iter); - - self.cache.clear(); - #[for_await] - for state_row in all_data_iter.take(self.cache.capacity()) { - let state_row = state_row?; - let (cache_key, cache_data) = self.state_row_to_cache_entry(state_row.as_ref()); - self.cache.insert(cache_key, cache_data); - } - self.cache_synced = true; - - // try to get the result from cache again - Ok(self.get_output_from_cache().unwrap_or(None)) - } - } -} - -#[async_trait] -impl ManagedTableState for GenericExtremeState { - async fn apply_chunk( - &mut self, - ops: Ops<'_>, - visibility: Option<&Bitmap>, - columns: &[&ArrayImpl], - state_table: &mut StateTable, - ) -> StreamExecutorResult<()> { - self.apply_chunk_inner(ops, visibility, columns, state_table) - } - - async fn get_output(&mut self, state_table: &StateTable) -> StreamExecutorResult { - self.get_output_inner(state_table).await - } -} - -#[cfg(test)] -mod tests { - use std::collections::HashSet; - - use itertools::Itertools; - use rand::prelude::*; - use risingwave_common::array::StreamChunk; - use risingwave_common::catalog::{ColumnDesc, ColumnId, Field, TableId}; - use risingwave_common::test_prelude::*; - use risingwave_common::types::ScalarImpl; - use risingwave_common::util::epoch::EpochPair; - use risingwave_common::util::sort_util::OrderType; - use risingwave_storage::memory::MemoryStateStore; - - use super::*; - use crate::executor::aggregation::AggArgs; - - fn create_agg_call(kind: AggKind, arg_type: DataType, arg_idx: usize) -> AggCall { - AggCall { - kind, - args: AggArgs::Unary(arg_type.clone(), arg_idx), - return_type: arg_type, - order_pairs: vec![], - append_only: false, - filter: None, - } - } - - #[tokio::test] - async fn test_extreme_agg_state_basic_min() -> StreamExecutorResult<()> { - // Assumption of input schema: - // (a: varchar, b: int32, c: int32, _row_id: int64) - - let input_pk_indices = vec![3]; // _row_id - let field1 = Field::unnamed(DataType::Int32); - let field2 = Field::unnamed(DataType::Int32); - let field3 = Field::unnamed(DataType::Int32); - let field4 = Field::unnamed(DataType::Int64); - let input_schema = Schema::new(vec![field1, field2, field3, field4]); - - let agg_call = create_agg_call(AggKind::Min, DataType::Int32, 2); // min(c) - - // see `LogicalAgg::infer_stream_agg_state` for the construction of state table - let table_id = TableId::new(0x2333); - let columns = vec![ - ColumnDesc::unnamed(ColumnId::new(0), DataType::Int32), // c - ColumnDesc::unnamed(ColumnId::new(1), DataType::Int64), // _row_id - ]; - let state_table_col_mapping = StateTableColumnMapping::new(vec![2, 3]); - let mut state_table = StateTable::new_without_distribution( - MemoryStateStore::new(), - table_id, - columns, - vec![ - OrderType::Ascending, // for AggKind::Min - OrderType::Ascending, - ], - vec![0, 1], // [c, _row_id] - ); - - let mut managed_state = GenericExtremeState::new( - &agg_call, - None, - &input_pk_indices, - state_table_col_mapping.clone(), - 0, - usize::MAX, - &input_schema, - ); - - let epoch = EpochPair::new_test_epoch(1); - state_table.init_epoch(epoch); - epoch.inc(); - - { - let chunk = StreamChunk::from_pretty( - " T i i I - + a 1 8 123 - + b 5 2 128 - - b 5 2 128 - + c 1 3 130", - ); - let (ops, columns, visibility) = chunk.into_inner(); - let column_refs: Vec<_> = columns.iter().map(|col| col.array_ref()).collect(); - managed_state - .apply_chunk(&ops, visibility.as_ref(), &column_refs, &mut state_table) - .await?; - - state_table.commit_for_test(epoch).await.unwrap(); - epoch.inc(); - - let res = managed_state.get_output(&state_table).await?; - match res { - Some(ScalarImpl::Int32(s)) => { - assert_eq!(s, 3); - } - _ => panic!("unexpected output"), - } - } - - { - let chunk = StreamChunk::from_pretty( - " T i i I - + d 0 8 134 - + e 2 2 137", - ); - let (ops, columns, visibility) = chunk.into_inner(); - let column_refs: Vec<_> = columns.iter().map(|col| col.array_ref()).collect(); - managed_state - .apply_chunk(&ops, visibility.as_ref(), &column_refs, &mut state_table) - .await?; - - state_table.commit_for_test(epoch).await.unwrap(); - - let res = managed_state.get_output(&state_table).await?; - match res { - Some(ScalarImpl::Int32(s)) => { - assert_eq!(s, 2); - } - _ => panic!("unexpected output"), - } - } - - { - // test recovery (cold start) - let row_count = managed_state.total_count; - let mut managed_state = GenericExtremeState::new( - &agg_call, - None, - &input_pk_indices, - state_table_col_mapping, - row_count, - usize::MAX, - &input_schema, - ); - let res = managed_state.get_output(&state_table).await?; - match res { - Some(ScalarImpl::Int32(s)) => { - assert_eq!(s, 2); - } - _ => panic!("unexpected output"), - } - } - - Ok(()) - } - - #[tokio::test] - async fn test_extreme_agg_state_basic_max() -> StreamExecutorResult<()> { - // Assumption of input schema: - // (a: varchar, b: int32, c: int32, _row_id: int64) - - let input_pk_indices = vec![3]; // _row_id - let field1 = Field::unnamed(DataType::Int32); - let field2 = Field::unnamed(DataType::Int32); - let field3 = Field::unnamed(DataType::Int32); - let field4 = Field::unnamed(DataType::Int64); - let input_schema = Schema::new(vec![field1, field2, field3, field4]); - let agg_call = create_agg_call(AggKind::Max, DataType::Int32, 2); // max(c) - - // see `LogicalAgg::infer_stream_agg_state` for the construction of state table - let table_id = TableId::new(0x2333); - let columns = vec![ - ColumnDesc::unnamed(ColumnId::new(0), DataType::Int32), // c - ColumnDesc::unnamed(ColumnId::new(1), DataType::Int64), // _row_id - ]; - let state_table_col_mapping = StateTableColumnMapping::new(vec![2, 3]); - let mut state_table = StateTable::new_without_distribution( - MemoryStateStore::new(), - table_id, - columns, - vec![ - OrderType::Descending, // for AggKind::Max - OrderType::Ascending, - ], - vec![0, 1], // [c, _row_id] - ); - - let mut managed_state = GenericExtremeState::new( - &agg_call, - None, - &input_pk_indices, - state_table_col_mapping.clone(), - 0, - usize::MAX, - &input_schema, - ); - - let epoch = EpochPair::new_test_epoch(1); - state_table.init_epoch(epoch); - epoch.inc(); - - { - let chunk = StreamChunk::from_pretty( - " T i i I - + a 1 8 123 - + b 5 2 128 - - b 5 2 128 - + c 1 3 130", - ); - let (ops, columns, visibility) = chunk.into_inner(); - let column_refs: Vec<_> = columns.iter().map(|col| col.array_ref()).collect(); - managed_state - .apply_chunk(&ops, visibility.as_ref(), &column_refs, &mut state_table) - .await?; - - state_table.commit_for_test(epoch).await.unwrap(); - epoch.inc(); - - let res = managed_state.get_output(&state_table).await?; - match res { - Some(ScalarImpl::Int32(s)) => { - assert_eq!(s, 8); - } - _ => panic!("unexpected output"), - } - } - - { - let chunk = StreamChunk::from_pretty( - " T i i I - + d 0 9 134 - + e 2 2 137", - ); - let (ops, columns, visibility) = chunk.into_inner(); - let column_refs: Vec<_> = columns.iter().map(|col| col.array_ref()).collect(); - managed_state - .apply_chunk(&ops, visibility.as_ref(), &column_refs, &mut state_table) - .await?; - - state_table.commit_for_test(epoch).await.unwrap(); - - let res = managed_state.get_output(&state_table).await?; - match res { - Some(ScalarImpl::Int32(s)) => { - assert_eq!(s, 9); - } - _ => panic!("unexpected output"), - } - } - - { - // test recovery (cold start) - let row_count = managed_state.total_count; - let mut managed_state = GenericExtremeState::new( - &agg_call, - None, - &input_pk_indices, - state_table_col_mapping, - row_count, - usize::MAX, - &input_schema, - ); - let res = managed_state.get_output(&state_table).await?; - match res { - Some(ScalarImpl::Int32(s)) => { - assert_eq!(s, 9); - } - _ => panic!("unexpected output"), - } - } - - Ok(()) - } - - #[tokio::test] - async fn test_extreme_agg_state_with_null_value() -> StreamExecutorResult<()> { - // Assumption of input schema: - // (a: varchar, b: int32, c: int32, _row_id: int64) - - let input_pk_indices = vec![3]; // _row_id - let field1 = Field::unnamed(DataType::Int32); - let field2 = Field::unnamed(DataType::Int32); - let field3 = Field::unnamed(DataType::Int32); - let field4 = Field::unnamed(DataType::Int64); - let input_schema = Schema::new(vec![field1, field2, field3, field4]); - let agg_call_1 = create_agg_call(AggKind::Min, DataType::Varchar, 0); // min(a) - let agg_call_2 = create_agg_call(AggKind::Max, DataType::Varchar, 1); // max(b) - - // see `LogicalAgg::infer_stream_agg_state` for the construction of state table - let table_id = TableId::new(0x6666); - let columns = vec![ - ColumnDesc::unnamed(ColumnId::new(0), DataType::Varchar), // a - ColumnDesc::unnamed(ColumnId::new(1), DataType::Int64), // _row_id - ]; - let state_table_col_mapping_1 = StateTableColumnMapping::new(vec![0, 3]); - let mut state_table_1 = StateTable::new_without_distribution( - MemoryStateStore::new(), - table_id, - columns, - vec![ - OrderType::Ascending, // for AggKind::Min - OrderType::Ascending, - ], - vec![0, 1], // [b, _row_id] - ); - let table_id = TableId::new(0x2333); - let columns = vec![ - ColumnDesc::unnamed(ColumnId::new(0), DataType::Int32), // b - ColumnDesc::unnamed(ColumnId::new(1), DataType::Int64), // _row_id - ]; - let state_table_col_mapping_2 = StateTableColumnMapping::new(vec![1, 3]); - let mut state_table_2 = StateTable::new_without_distribution( - MemoryStateStore::new(), - table_id, - columns, - vec![ - OrderType::Descending, // for AggKind::Max - OrderType::Ascending, - ], - vec![0, 1], // [b, _row_id] - ); - - let epoch = EpochPair::new_test_epoch(1); - state_table_1.init_epoch(epoch); - state_table_2.init_epoch(epoch); - epoch.inc(); - - let mut managed_state_1 = GenericExtremeState::new( - &agg_call_1, - None, - &input_pk_indices, - state_table_col_mapping_1, - 0, - usize::MAX, - &input_schema, - ); - let mut managed_state_2 = GenericExtremeState::new( - &agg_call_2, - None, - &input_pk_indices, - state_table_col_mapping_2, - 0, - usize::MAX, - &input_schema, - ); - - { - let chunk = StreamChunk::from_pretty( - " T i i I - + a 1 8 123 - + b 5 2 128 - - b 5 2 128 - + c 1 3 130 - + . 9 4 131 - + . 6 5 132 - + c . 3 133", - ); - let (ops, columns, visibility) = chunk.into_inner(); - let column_refs: Vec<_> = columns.iter().map(|col| col.array_ref()).collect(); - managed_state_1 - .apply_chunk(&ops, visibility.as_ref(), &column_refs, &mut state_table_1) - .await?; - managed_state_2 - .apply_chunk(&ops, visibility.as_ref(), &column_refs, &mut state_table_2) - .await?; - - state_table_1.commit_for_test(epoch).await.unwrap(); - state_table_2.commit_for_test(epoch).await.unwrap(); - - match managed_state_1.get_output(&state_table_1).await? { - Some(ScalarImpl::Utf8(s)) => { - assert_eq!(&s, "a"); - } - _ => panic!("unexpected output"), - } - match managed_state_2.get_output(&state_table_2).await? { - Some(ScalarImpl::Int32(s)) => { - assert_eq!(s, 9); - } - _ => panic!("unexpected output"), - } - } - - Ok(()) - } - - #[tokio::test] - async fn test_extreme_agg_state_grouped() -> StreamExecutorResult<()> { - // Assumption of input schema: - // (a: varchar, b: int32, c: int32, _row_id: int64) - - let input_pk_indices = vec![3]; - let field1 = Field::unnamed(DataType::Int32); - let field2 = Field::unnamed(DataType::Int32); - let field3 = Field::unnamed(DataType::Int32); - let field4 = Field::unnamed(DataType::Int64); - let input_schema = Schema::new(vec![field1, field2, field3, field4]); - let agg_call = create_agg_call(AggKind::Max, DataType::Int32, 1); // max(b) - - let table_id = TableId::new(6666); - let columns = vec![ - ColumnDesc::unnamed(ColumnId::new(0), DataType::Int32), // group by c - ColumnDesc::unnamed(ColumnId::new(1), DataType::Int32), // b - ColumnDesc::unnamed(ColumnId::new(2), DataType::Int64), // _row_id - ]; - let state_table_col_mapping = StateTableColumnMapping::new(vec![2, 1, 3]); - let mut state_table = StateTable::new_without_distribution( - MemoryStateStore::new(), - table_id, - columns, - vec![ - OrderType::Ascending, // c ASC - OrderType::Descending, // b DESC for AggKind::Max - OrderType::Ascending, // _row_id ASC - ], - vec![0, 1, 2], // [c, b, _row_id] - ); - let group_key = Row::new(vec![Some(8.into())]); - - let mut managed_state = GenericExtremeState::new( - &agg_call, - Some(&group_key), - &input_pk_indices, - state_table_col_mapping.clone(), - 0, - usize::MAX, - &input_schema, - ); - - let epoch = EpochPair::new_test_epoch(1); - state_table.init_epoch(epoch); - epoch.inc(); - - { - let chunk = StreamChunk::from_pretty( - " T i i I - + a 1 8 123 - + b 5 8 128 - + c 7 3 130 D // hide this row", - ); - let (ops, columns, visibility) = chunk.into_inner(); - let column_refs: Vec<_> = columns.iter().map(|col| col.array_ref()).collect(); - managed_state - .apply_chunk(&ops, visibility.as_ref(), &column_refs, &mut state_table) - .await?; - - state_table.commit_for_test(epoch).await.unwrap(); - epoch.inc(); - - let res = managed_state.get_output(&state_table).await?; - match res { - Some(ScalarImpl::Int32(s)) => { - assert_eq!(s, 5); - } - _ => panic!("unexpected output"), - } - } - - { - let chunk = StreamChunk::from_pretty( - " T i i I - + d 9 2 134 D // hide this row - + e 8 8 137", - ); - let (ops, columns, visibility) = chunk.into_inner(); - let column_refs: Vec<_> = columns.iter().map(|col| col.array_ref()).collect(); - managed_state - .apply_chunk(&ops, visibility.as_ref(), &column_refs, &mut state_table) - .await?; - - state_table.commit_for_test(epoch).await.unwrap(); - - let res = managed_state.get_output(&state_table).await?; - match res { - Some(ScalarImpl::Int32(s)) => { - assert_eq!(s, 8); - } - _ => panic!("unexpected output"), - } - } - - { - // test recovery (cold start) - let row_count = managed_state.total_count; - let mut managed_state = GenericExtremeState::new( - &agg_call, - Some(&group_key), - &input_pk_indices, - state_table_col_mapping, - row_count, - usize::MAX, - &input_schema, - ); - let res = managed_state.get_output(&state_table).await?; - match res { - Some(ScalarImpl::Int32(s)) => { - assert_eq!(s, 8); - } - _ => panic!("unexpected output"), - } - } - - Ok(()) - } - - #[tokio::test] - async fn test_extreme_agg_state_with_random_values() -> StreamExecutorResult<()> { - // Assumption of input schema: - // (a: int32, _row_id: int64) - - let input_pk_indices = vec![1]; // _row_id - let field1 = Field::unnamed(DataType::Int32); - let field2 = Field::unnamed(DataType::Int64); - let input_schema = Schema::new(vec![field1, field2]); - let agg_call = create_agg_call(AggKind::Min, DataType::Int32, 0); // min(a) - - // see `LogicalAgg::infer_stream_agg_state` for the construction of state table - let table_id = TableId::new(0x2333); - let columns = vec![ - ColumnDesc::unnamed(ColumnId::new(0), DataType::Int32), // a - ColumnDesc::unnamed(ColumnId::new(1), DataType::Int64), // _row_id - ]; - let state_table_col_mapping = StateTableColumnMapping::new(vec![0, 1]); - let mut state_table = StateTable::new_without_distribution( - MemoryStateStore::new(), - table_id, - columns, - vec![ - OrderType::Ascending, // for AggKind::Min - OrderType::Ascending, - ], - vec![0, 1], // [a, _row_id] - ); - let epoch = EpochPair::new_test_epoch(1); - state_table.init_epoch(epoch); - epoch.inc(); - let mut managed_state = GenericExtremeState::new( - &agg_call, - None, - &input_pk_indices, - state_table_col_mapping, - 0, - 1024, - &input_schema, - ); - - let mut rng = thread_rng(); - let insert_values: Vec = (0..10000).map(|_| rng.gen()).collect_vec(); - let delete_values: HashSet<_> = insert_values - .iter() - .choose_multiple(&mut rng, 1000) - .into_iter() - .collect(); - let mut min_value = i32::MAX; - - { - let mut pretty_lines = vec!["i I".to_string()]; - for (row_id, value) in insert_values - .iter() - .enumerate() - .take(insert_values.len() / 2) - { - pretty_lines.push(format!("+ {} {}", value, row_id)); - if delete_values.contains(&value) { - pretty_lines.push(format!("- {} {}", value, row_id)); - continue; - } - if *value < min_value { - min_value = *value; - } - } - - let chunk = StreamChunk::from_pretty(&pretty_lines.join("\n")); - let (ops, columns, visibility) = chunk.into_inner(); - let column_refs: Vec<_> = columns.iter().map(|col| col.array_ref()).collect(); - managed_state - .apply_chunk(&ops, visibility.as_ref(), &column_refs, &mut state_table) - .await?; - - state_table.commit_for_test(epoch).await.unwrap(); - epoch.inc(); - - let res = managed_state.get_output(&state_table).await?; - match res { - Some(ScalarImpl::Int32(s)) => { - assert_eq!(s, min_value); - } - _ => panic!("unexpected output"), - } - } - - { - let mut pretty_lines = vec!["i I".to_string()]; - for (row_id, value) in insert_values - .iter() - .enumerate() - .skip(insert_values.len() / 2) - { - pretty_lines.push(format!("+ {} {}", value, row_id)); - if delete_values.contains(&value) { - pretty_lines.push(format!("- {} {}", value, row_id)); - continue; - } - if *value < min_value { - min_value = *value; - } - } - - let chunk = StreamChunk::from_pretty(&pretty_lines.join("\n")); - let (ops, columns, visibility) = chunk.into_inner(); - let column_refs: Vec<_> = columns.iter().map(|col| col.array_ref()).collect(); - managed_state - .apply_chunk(&ops, visibility.as_ref(), &column_refs, &mut state_table) - .await?; - - state_table.commit_for_test(epoch).await.unwrap(); - - let res = managed_state.get_output(&state_table).await?; - match res { - Some(ScalarImpl::Int32(s)) => { - assert_eq!(s, min_value); - } - _ => panic!("unexpected output"), - } - } - - Ok(()) - } - - #[tokio::test] - async fn test_extreme_agg_state_cache_maintenance() -> StreamExecutorResult<()> { - // Assumption of input schema: - // (a: int32, _row_id: int64) - - let input_pk_indices = vec![1]; // _row_id - let field1 = Field::unnamed(DataType::Int32); - let field2 = Field::unnamed(DataType::Int64); - let input_schema = Schema::new(vec![field1, field2]); - let agg_call = create_agg_call(AggKind::Min, DataType::Int32, 0); // min(a) - - // see `LogicalAgg::infer_stream_agg_state` for the construction of state table - let table_id = TableId::new(0x2333); - let columns = vec![ - ColumnDesc::unnamed(ColumnId::new(0), DataType::Int32), // a - ColumnDesc::unnamed(ColumnId::new(1), DataType::Int64), // _row_id - ]; - let state_table_col_mapping = StateTableColumnMapping::new(vec![0, 1]); - let mut state_table = StateTable::new_without_distribution( - MemoryStateStore::new(), - table_id, - columns, - vec![ - OrderType::Ascending, // for AggKind::Min - OrderType::Ascending, - ], - vec![0, 1], // [a, _row_id] - ); - - let mut managed_state = GenericExtremeState::new( - &agg_call, - None, - &input_pk_indices, - state_table_col_mapping, - 0, - 3, // cache capacity = 3 for easy testing - &input_schema, - ); - - let epoch = EpochPair::new_test_epoch(1); - state_table.init_epoch(epoch); - epoch.inc(); - - { - let chunk = StreamChunk::from_pretty( - " i I - + 4 123 - + 8 128 - + 12 129", - ); - let (ops, columns, visibility) = chunk.into_inner(); - let column_refs: Vec<_> = columns.iter().map(|col| col.array_ref()).collect(); - managed_state - .apply_chunk(&ops, visibility.as_ref(), &column_refs, &mut state_table) - .await?; - - state_table.commit_for_test(epoch).await.unwrap(); - epoch.inc(); - - let res = managed_state.get_output(&state_table).await?; - match res { - Some(ScalarImpl::Int32(s)) => { - assert_eq!(s, 4); - } - _ => panic!("unexpected output"), - } - } - - { - let chunk = StreamChunk::from_pretty( - " i I - + 9 130 // this will evict 12 - - 9 130 - + 13 128 - - 4 123 - - 8 128", - ); - let (ops, columns, visibility) = chunk.into_inner(); - let column_refs: Vec<_> = columns.iter().map(|col| col.array_ref()).collect(); - managed_state - .apply_chunk(&ops, visibility.as_ref(), &column_refs, &mut state_table) - .await?; - - state_table.commit_for_test(epoch).await.unwrap(); - epoch.inc(); - - let res = managed_state.get_output(&state_table).await?; - match res { - Some(ScalarImpl::Int32(s)) => { - assert_eq!(s, 12); - } - _ => panic!("unexpected output"), - } - } - - { - let chunk = StreamChunk::from_pretty( - " i I - + 1 131 - + 2 132 - + 3 133 // evict all from cache - - 1 131 - - 2 132 - - 3 133 - + 14 134", - ); - let (ops, columns, visibility) = chunk.into_inner(); - let column_refs: Vec<_> = columns.iter().map(|col| col.array_ref()).collect(); - managed_state - .apply_chunk(&ops, visibility.as_ref(), &column_refs, &mut state_table) - .await?; - - state_table.commit_for_test(epoch).await.unwrap(); - - let res = managed_state.get_output(&state_table).await?; - match res { - Some(ScalarImpl::Int32(s)) => { - assert_eq!(s, 12); - } - _ => panic!("unexpected output"), - } - } - - Ok(()) - } -} diff --git a/src/stream/src/executor/aggregation/table_state/string_agg.rs b/src/stream/src/executor/aggregation/table_state/string_agg.rs deleted file mode 100644 index 0db3635587a41..0000000000000 --- a/src/stream/src/executor/aggregation/table_state/string_agg.rs +++ /dev/null @@ -1,603 +0,0 @@ -// Copyright 2022 Singularity Data -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::marker::PhantomData; - -use async_trait::async_trait; -use futures::pin_mut; -use futures_async_stream::for_await; -use risingwave_common::array::stream_chunk::Ops; -use risingwave_common::array::Op::{Delete, Insert, UpdateDelete, UpdateInsert}; -use risingwave_common::array::{ArrayImpl, Row}; -use risingwave_common::buffer::Bitmap; -use risingwave_common::types::{Datum, ScalarImpl}; -use risingwave_common::util::ordered::OrderedRow; -use risingwave_common::util::sort_util::OrderType; -use risingwave_storage::table::streaming_table::state_table::StateTable; -use risingwave_storage::StateStore; - -use super::{Cache, ManagedTableState}; -use crate::common::{iter_state_table, StateTableColumnMapping}; -use crate::executor::aggregation::AggCall; -use crate::executor::error::StreamExecutorResult; -use crate::executor::PkIndices; - -#[derive(Clone)] -struct StringAggData { - delim: String, - value: String, -} - -pub struct ManagedStringAggState { - _phantom_data: PhantomData, - - /// Group key to aggregate with group. - /// None for simple agg, Some for group key of hash agg. - group_key: Option, - - // TODO(yuchao): remove this after we move state table insertion out. - /// Contains the column mapping between upstream schema and state table. - state_table_col_mapping: StateTableColumnMapping, - - // The column to aggregate in input chunk. - upstream_agg_col_idx: usize, - - /// The column to aggregate in state table. - state_table_agg_col_idx: usize, - - /// The column as delimiter in state table. - state_table_delim_col_idx: usize, - - /// The columns to order by in state table. - state_table_order_col_indices: Vec, - - /// The order types of `state_table_order_col_indices`. - state_table_order_types: Vec, - - /// In-memory all-or-nothing cache. - cache: Cache, - - /// Whether the cache is fully synced to state table. - cache_synced: bool, -} - -impl ManagedStringAggState { - pub fn new( - agg_call: &AggCall, - group_key: Option<&Row>, - pk_indices: &PkIndices, - col_mapping: StateTableColumnMapping, - row_count: usize, - ) -> Self { - let upstream_agg_col_idx = agg_call.args.val_indices()[0]; - // map agg column to state table column index - let state_table_agg_col_idx = col_mapping - .upstream_to_state_table(agg_call.args.val_indices()[0]) - .expect("the column to be aggregate must appear in the state table"); - let state_table_delim_col_idx = col_mapping - .upstream_to_state_table(agg_call.args.val_indices()[1]) - .expect("the column as delimiter must appear in the state table"); - // map order by columns to state table column indices - let (state_table_order_col_indices, state_table_order_types) = agg_call - .order_pairs - .iter() - .map(|o| { - ( - col_mapping - .upstream_to_state_table(o.column_idx) - .expect("the column to be order by must appear in the state table"), - o.order_type, - ) - }) - .chain(pk_indices.iter().map(|idx| { - ( - col_mapping - .upstream_to_state_table(*idx) - .expect("the pk columns must appear in the state table"), - OrderType::Ascending, - ) - })) - .unzip(); - Self { - _phantom_data: PhantomData, - group_key: group_key.cloned(), - state_table_col_mapping: col_mapping, - upstream_agg_col_idx, - state_table_agg_col_idx, - state_table_delim_col_idx, - state_table_order_col_indices, - state_table_order_types, - cache: Cache::new(usize::MAX), - cache_synced: row_count == 0, // if there is no row, the cache is synced initially - } - } - - fn state_row_to_cache_entry(&self, state_row: &Row) -> (OrderedRow, StringAggData) { - let cache_key = OrderedRow::new( - state_row.by_indices(&self.state_table_order_col_indices), - &self.state_table_order_types, - ); - let cache_data = StringAggData { - delim: state_row[self.state_table_delim_col_idx] - .clone() - .map(ScalarImpl::into_utf8) - .unwrap_or_default(), - value: state_row[self.state_table_agg_col_idx] - .clone() - .map(ScalarImpl::into_utf8) - .expect("NULL values should be filtered out"), - }; - (cache_key, cache_data) - } - - fn apply_chunk_inner( - &mut self, - ops: Ops<'_>, - visibility: Option<&Bitmap>, - columns: &[&ArrayImpl], - state_table: &mut StateTable, - ) -> StreamExecutorResult<()> { - for (i, op) in ops - .iter() - .enumerate() - // skip invisible - .filter(|(i, _)| visibility.map(|x| x.is_set(*i)).unwrap_or(true)) - // skip null input - .filter(|(i, _)| columns[self.upstream_agg_col_idx].datum_at(*i).is_some()) - { - let state_row = Row::new( - self.state_table_col_mapping - .upstream_columns() - .iter() - .map(|col_idx| columns[*col_idx].datum_at(i)) - .collect(), - ); - let (cache_key, cache_data) = self.state_row_to_cache_entry(&state_row); - - match op { - Insert | UpdateInsert => { - if self.cache_synced { - self.cache.insert(cache_key, cache_data); - } - state_table.insert(state_row); - } - Delete | UpdateDelete => { - if self.cache_synced { - self.cache.remove(cache_key); - } - state_table.delete(state_row); - } - } - } - - Ok(()) - } - - async fn get_output_inner( - &mut self, - state_table: &StateTable, - ) -> StreamExecutorResult { - if !self.cache_synced { - let all_data_iter = iter_state_table(state_table, self.group_key.as_ref()).await?; - pin_mut!(all_data_iter); - - self.cache.clear(); - #[for_await] - for state_row in all_data_iter { - let state_row = state_row?; - let (cache_key, cache_data) = self.state_row_to_cache_entry(&state_row); - self.cache.insert(cache_key, cache_data.clone()); - } - self.cache_synced = true; - } - - let mut result = match self.cache.first_value() { - Some(data) => data.value.clone(), - None => return Ok(None), // return NULL if no rows to aggregate - }; - for StringAggData { value, delim } in self.cache.iter_values().skip(1) { - result.push_str(delim); - result.push_str(value); - } - Ok(Some(result.into())) - } -} - -#[async_trait] -impl ManagedTableState for ManagedStringAggState { - async fn apply_chunk( - &mut self, - ops: Ops<'_>, - visibility: Option<&Bitmap>, - columns: &[&ArrayImpl], // contains all upstream columns - state_table: &mut StateTable, - ) -> StreamExecutorResult<()> { - self.apply_chunk_inner(ops, visibility, columns, state_table) - } - - async fn get_output(&mut self, state_table: &StateTable) -> StreamExecutorResult { - self.get_output_inner(state_table).await - } -} - -#[cfg(test)] -mod tests { - use risingwave_common::array::StreamChunk; - use risingwave_common::catalog::{ColumnDesc, ColumnId, TableId}; - use risingwave_common::test_prelude::*; - use risingwave_common::types::DataType; - use risingwave_common::util::epoch::EpochPair; - use risingwave_common::util::sort_util::OrderPair; - use risingwave_expr::expr::AggKind; - use risingwave_storage::memory::MemoryStateStore; - - use super::*; - use crate::executor::aggregation::AggArgs; - - #[tokio::test] - async fn test_string_agg_state_simple_agg_without_order() -> StreamExecutorResult<()> { - // Assumption of input schema: - // (a: varchar, _delim: varchar, b: int32, c: int32, _row_id: int64) - // where `a` is the column to aggregate - - let input_pk_indices = vec![4]; - let agg_call = AggCall { - kind: AggKind::StringAgg, - args: AggArgs::Binary([DataType::Varchar, DataType::Varchar], [0, 1]), - return_type: DataType::Varchar, - order_pairs: vec![], - append_only: false, - filter: None, - }; - - // see `LogicalAgg::infer_stream_agg_state` for the construction of state table - let table_id = TableId::new(6666); - let columns = vec![ - ColumnDesc::unnamed(ColumnId::new(0), DataType::Int64), // _row_id - ColumnDesc::unnamed(ColumnId::new(1), DataType::Varchar), // a - ColumnDesc::unnamed(ColumnId::new(2), DataType::Varchar), // _delim - ]; - let state_table_col_mapping = StateTableColumnMapping::new(vec![4, 0, 1]); - let mut state_table = StateTable::new_without_distribution( - MemoryStateStore::new(), - table_id, - columns, - vec![OrderType::Ascending], - vec![0], // [_row_id] - ); - - let mut managed_state = ManagedStringAggState::new( - &agg_call, - None, - &input_pk_indices, - state_table_col_mapping, - 0, - ); - - let epoch = EpochPair::new_test_epoch(1); - state_table.init_epoch(epoch); - epoch.inc(); - - let chunk = StreamChunk::from_pretty( - " T T i i I - + a , 1 8 123 - + b , 5 2 128 - - b , 5 2 128 - + c , 1 3 130", - ); - let (ops, columns, visibility) = chunk.into_inner(); - let column_refs: Vec<_> = columns.iter().map(|col| col.array_ref()).collect(); - managed_state - .apply_chunk(&ops, visibility.as_ref(), &column_refs, &mut state_table) - .await?; - - epoch.inc(); - state_table.commit_for_test(epoch).await.unwrap(); - - let res = managed_state.get_output(&state_table).await?; - match res { - Some(ScalarImpl::Utf8(s)) => { - // should be "a,c" or "c,a" - assert_eq!(s.len(), 3); - assert!(s.contains('a')); - assert!(s.contains('c')); - assert_eq!(&s[1..2], ","); - } - _ => panic!("unexpected output"), - } - - Ok(()) - } - - #[tokio::test] - async fn test_string_agg_state_simple_agg_null_value() -> StreamExecutorResult<()> { - // Assumption of input schema: - // (a: varchar, _delim: varchar, _row_id: int64) - // where `a` is the column to aggregate - - let input_pk_indices = vec![2]; - let agg_call = AggCall { - kind: AggKind::StringAgg, - args: AggArgs::Binary([DataType::Varchar, DataType::Varchar], [0, 1]), - return_type: DataType::Varchar, - order_pairs: vec![], - append_only: false, - filter: None, - }; - - // see `LogicalAgg::infer_stream_agg_state` for the construction of state table - let table_id = TableId::new(6666); - let columns = vec![ - ColumnDesc::unnamed(ColumnId::new(0), DataType::Int64), // _row_id - ColumnDesc::unnamed(ColumnId::new(1), DataType::Varchar), // a - ColumnDesc::unnamed(ColumnId::new(2), DataType::Varchar), // _delim - ]; - let state_table_col_mapping = StateTableColumnMapping::new(vec![2, 0, 1]); - let mut state_table = StateTable::new_without_distribution( - MemoryStateStore::new(), - table_id, - columns, - vec![OrderType::Ascending], - vec![0], // [_row_id] - ); - - let mut managed_state = ManagedStringAggState::new( - &agg_call, - None, - &input_pk_indices, - state_table_col_mapping, - 0, - ); - - let epoch = EpochPair::new_test_epoch(1); - state_table.init_epoch(epoch); - epoch.inc(); - - let chunk = StreamChunk::from_pretty( - " T T I - + a 1 123 - + . 2 128 - + c . 129 - + d 4 130", - ); - let (ops, columns, visibility) = chunk.into_inner(); - let column_refs: Vec<_> = columns.iter().map(|col| col.array_ref()).collect(); - managed_state - .apply_chunk(&ops, visibility.as_ref(), &column_refs, &mut state_table) - .await?; - - epoch.inc(); - state_table.commit_for_test(epoch).await.unwrap(); - - let res = managed_state.get_output(&state_table).await?; - match res { - Some(ScalarImpl::Utf8(s)) => { - // should be something like "ac4d" - assert_eq!(s.len(), 4); - assert!(s.contains('a')); - assert!(s.contains('c')); - assert!(s.contains('d')); - assert!(s.contains('4')); - assert!(!s.contains('2')); - } - _ => panic!("unexpected output"), - } - - Ok(()) - } - - #[tokio::test] - async fn test_string_agg_state_simple_agg_with_order() -> StreamExecutorResult<()> { - // Assumption of input schema: - // (a: varchar, _delim: varchar, b: int32, c: int32, _row_id: int64) - // where `a` is the column to aggregate - - let input_pk_indices = vec![4]; - let agg_call = AggCall { - kind: AggKind::StringAgg, - args: AggArgs::Binary([DataType::Varchar, DataType::Varchar], [0, 1]), - return_type: DataType::Varchar, - order_pairs: vec![ - OrderPair::new(2, OrderType::Ascending), // b ASC - OrderPair::new(0, OrderType::Descending), // a DESC - ], - append_only: false, - filter: None, - }; - - let table_id = TableId::new(6666); - let columns = vec![ - ColumnDesc::unnamed(ColumnId::new(0), DataType::Int32), // b - ColumnDesc::unnamed(ColumnId::new(1), DataType::Varchar), // a - ColumnDesc::unnamed(ColumnId::new(2), DataType::Int64), // _row_id - ColumnDesc::unnamed(ColumnId::new(3), DataType::Varchar), // _delim - ]; - let state_table_col_mapping = StateTableColumnMapping::new(vec![2, 0, 4, 1]); - let mut state_table = StateTable::new_without_distribution( - MemoryStateStore::new(), - table_id, - columns, - vec![ - OrderType::Ascending, // b ASC - OrderType::Descending, // a DESC - OrderType::Ascending, // _row_id ASC - ], - vec![0, 1, 2], // [b, a, _row_id] - ); - - let mut managed_state = ManagedStringAggState::new( - &agg_call, - None, - &input_pk_indices, - state_table_col_mapping, - 0, - ); - - let epoch = EpochPair::new_test_epoch(1); - state_table.init_epoch(epoch); - epoch.inc(); - - { - let chunk = StreamChunk::from_pretty( - " T T i i I - + a , 1 8 123 - + b / 5 2 128 - - b / 5 2 128 - + c _ 1 3 130", - ); - let (ops, columns, visibility) = chunk.into_inner(); - let column_refs: Vec<_> = columns.iter().map(|col| col.array_ref()).collect(); - managed_state - .apply_chunk(&ops, visibility.as_ref(), &column_refs, &mut state_table) - .await?; - - state_table.commit_for_test(epoch).await.unwrap(); - epoch.inc(); - - let res = managed_state.get_output(&state_table).await?; - match res { - Some(ScalarImpl::Utf8(s)) => { - assert_eq!(s, "c,a".to_string()); - } - _ => panic!("unexpected output"), - } - } - - { - let chunk = StreamChunk::from_pretty( - " T T i i I - + d - 0 8 134 - + e + 2 2 137", - ); - let (ops, columns, visibility) = chunk.into_inner(); - let column_refs: Vec<_> = columns.iter().map(|col| col.array_ref()).collect(); - managed_state - .apply_chunk(&ops, visibility.as_ref(), &column_refs, &mut state_table) - .await?; - - state_table.commit_for_test(epoch).await.unwrap(); - - let res = managed_state.get_output(&state_table).await?; - match res { - Some(ScalarImpl::Utf8(s)) => { - assert_eq!(s, "d_c,a+e".to_string()); - } - _ => panic!("unexpected output"), - } - } - - Ok(()) - } - - #[tokio::test] - async fn test_string_agg_state_grouped_agg_with_order() -> StreamExecutorResult<()> { - // Assumption of input schema: - // (a: varchar, _delim: varchar, b: int32, c: int32, _row_id: int64) - // where `a` is the column to aggregate - - let input_pk_indices = vec![4]; - let agg_call = AggCall { - kind: AggKind::StringAgg, - args: AggArgs::Binary([DataType::Varchar, DataType::Varchar], [0, 1]), - return_type: DataType::Varchar, - order_pairs: vec![ - OrderPair::new(2, OrderType::Ascending), // b ASC - ], - append_only: false, - filter: None, - }; - - let table_id = TableId::new(6666); - let columns = vec![ - ColumnDesc::unnamed(ColumnId::new(0), DataType::Int32), // group by c - ColumnDesc::unnamed(ColumnId::new(1), DataType::Int32), // order by b - ColumnDesc::unnamed(ColumnId::new(2), DataType::Int64), // _row_id - ColumnDesc::unnamed(ColumnId::new(3), DataType::Varchar), // a - ColumnDesc::unnamed(ColumnId::new(4), DataType::Varchar), // _delim - ]; - let state_table_col_mapping = StateTableColumnMapping::new(vec![3, 2, 4, 0, 1]); - let mut state_table = StateTable::new_without_distribution( - MemoryStateStore::new(), - table_id, - columns, - vec![ - OrderType::Ascending, // c ASC - OrderType::Ascending, // b ASC - OrderType::Ascending, // _row_id ASC - ], - vec![0, 1, 2], // [c, b, _row_id] - ); - - let mut managed_state = ManagedStringAggState::new( - &agg_call, - Some(&Row::new(vec![Some(8.into())])), - &input_pk_indices, - state_table_col_mapping, - 0, - ); - - let epoch = EpochPair::new_test_epoch(1); - state_table.init_epoch(epoch); - epoch.inc(); - { - let chunk = StreamChunk::from_pretty( - " T T i i I - + a _ 1 8 123 - + b _ 5 8 128 - + c _ 1 3 130 D // hide this row", - ); - let (ops, columns, visibility) = chunk.into_inner(); - let column_refs: Vec<_> = columns.iter().map(|col| col.array_ref()).collect(); - managed_state - .apply_chunk(&ops, visibility.as_ref(), &column_refs, &mut state_table) - .await?; - - state_table.commit_for_test(epoch).await.unwrap(); - epoch.inc(); - - let res = managed_state.get_output(&state_table).await?; - match res { - Some(ScalarImpl::Utf8(s)) => { - assert_eq!(s, "a_b".to_string()); - } - _ => panic!("unexpected output"), - } - } - - { - let chunk = StreamChunk::from_pretty( - " T T i i I - + d , 0 2 134 D // hide this row - + e , 2 8 137", - ); - let (ops, columns, visibility) = chunk.into_inner(); - let column_refs: Vec<_> = columns.iter().map(|col| col.array_ref()).collect(); - managed_state - .apply_chunk(&ops, visibility.as_ref(), &column_refs, &mut state_table) - .await?; - - state_table.commit_for_test(epoch).await.unwrap(); - - let res = managed_state.get_output(&state_table).await?; - match res { - Some(ScalarImpl::Utf8(s)) => { - assert_eq!(s, "a,e_b".to_string()); - } - _ => panic!("unexpected output"), - } - } - - Ok(()) - } -} diff --git a/src/stream/src/executor/dispatch.rs b/src/stream/src/executor/dispatch.rs index b8b99409d88f4..607b2afb49e8e 100644 --- a/src/stream/src/executor/dispatch.rs +++ b/src/stream/src/executor/dispatch.rs @@ -266,7 +266,7 @@ impl StreamConsumer for DispatchExecutor { let barrier = msg.as_barrier().cloned(); self.inner .dispatch(msg) - .stack_trace(if barrier.is_some() { + .verbose_stack_trace(if barrier.is_some() { "dispatch_barrier" } else { "dispatch_chunk" diff --git a/src/stream/src/executor/dynamic_filter.rs b/src/stream/src/executor/dynamic_filter.rs index 9eaa19112bfb3..5a254cdcfa1d9 100644 --- a/src/stream/src/executor/dynamic_filter.rs +++ b/src/stream/src/executor/dynamic_filter.rs @@ -237,6 +237,8 @@ impl DynamicFilterExecutor { async fn into_stream(mut self) { let input_l = self.source_l.take().unwrap(); let input_r = self.source_r.take().unwrap(); + + let left_len = input_l.schema().len(); // Derive the dynamic expression let l_data_type = input_l.schema().data_types()[self.key_l].clone(); let r_data_type = input_r.schema().data_types()[0].clone(); @@ -271,8 +273,14 @@ impl DynamicFilterExecutor { // The first barrier message should be propagated. yield Message::Barrier(barrier); - let mut stream_chunk_builder = - StreamChunkBuilder::new(self.chunk_size, &self.schema.data_types(), 0, 0)?; + let (left_to_output, _) = + StreamChunkBuilder::get_i2o_mapping(0..self.schema.len(), left_len, 0); + let mut stream_chunk_builder = StreamChunkBuilder::new( + self.chunk_size, + &self.schema.data_types(), + vec![], + left_to_output, + )?; #[for_await] for msg in aligned_stream { diff --git a/src/stream/src/executor/exchange/input.rs b/src/stream/src/executor/exchange/input.rs index 6970e217a4a92..40a7170ba5853 100644 --- a/src/stream/src/executor/exchange/input.rs +++ b/src/stream/src/executor/exchange/input.rs @@ -82,7 +82,7 @@ impl LocalInput { #[try_stream(ok = Message, error = StreamExecutorError)] async fn run(mut channel: Receiver, actor_id: ActorId) { let span: SpanValue = format!("LocalInput (actor {actor_id})").into(); - while let Some(msg) = channel.recv().stack_trace(span.clone()).await { + while let Some(msg) = channel.recv().verbose_stack_trace(span.clone()).await { yield msg; } } @@ -161,7 +161,7 @@ impl RemoteInput { let span: SpanValue = format!("RemoteInput (actor {up_actor_id})").into(); pin_mut!(stream); - while let Some(data_res) = stream.next().stack_trace(span.clone()).await { + while let Some(data_res) = stream.next().verbose_stack_trace(span.clone()).await { match data_res { Ok(stream_msg) => { let bytes = Message::get_encoded_len(&stream_msg); diff --git a/src/stream/src/executor/exchange/output.rs b/src/stream/src/executor/exchange/output.rs index 3368505967a47..5e1f445ecca1c 100644 --- a/src/stream/src/executor/exchange/output.rs +++ b/src/stream/src/executor/exchange/output.rs @@ -75,7 +75,7 @@ impl Output for LocalOutput { async fn send(&mut self, message: Message) -> StreamResult<()> { self.ch .send(message) - .stack_trace(self.span.clone()) + .verbose_stack_trace(self.span.clone()) .await .map_err(|SendError(message)| { anyhow!( @@ -133,7 +133,7 @@ impl Output for RemoteOutput { self.ch .send(message) - .stack_trace(self.span.clone()) + .verbose_stack_trace(self.span.clone()) .await .map_err(|SendError(message)| { anyhow!( diff --git a/src/stream/src/executor/global_simple_agg.rs b/src/stream/src/executor/global_simple_agg.rs index 810020a07f08b..0c42a1b3e261c 100644 --- a/src/stream/src/executor/global_simple_agg.rs +++ b/src/stream/src/executor/global_simple_agg.rs @@ -160,7 +160,7 @@ impl GlobalSimpleAggExecutor { let capacity = chunk.capacity(); let (ops, columns, visibility) = chunk.into_inner(); - // Apply chunk to each of the state (per agg_call) + // Calculate the row visibility for every agg call. let visibilities: Vec<_> = agg_calls .iter() .map(|agg_call| { @@ -174,9 +174,28 @@ impl GlobalSimpleAggExecutor { ) }) .try_collect()?; - agg_group - .apply_chunk(storages, &ops, &columns, &visibilities) - .await?; + + // Materialize input chunk if needed. + storages + .iter_mut() + .zip_eq(visibilities.iter().map(Option::as_ref)) + .for_each(|(storage, visibility)| { + if let AggStateStorage::MaterializedInput { table, mapping } = storage { + let needed_columns = mapping + .upstream_columns() + .iter() + .map(|col_idx| columns[*col_idx].clone()) + .collect(); + table.write_chunk(StreamChunk::new( + ops.clone(), + needed_columns, + visibility.cloned(), + )); + } + }); + + // Apply chunk to each of the state (per agg_call) + agg_group.apply_chunk(storages, &ops, &columns, visibilities)?; Ok(()) } diff --git a/src/stream/src/executor/hash_agg.rs b/src/stream/src/executor/hash_agg.rs index 6de33d502bd25..1e5ff49d1858a 100644 --- a/src/stream/src/executor/hash_agg.rs +++ b/src/stream/src/executor/hash_agg.rs @@ -315,20 +315,52 @@ impl HashAggExecutor { // Decompose the input chunk. let capacity = chunk.capacity(); - let (ops, columns, _) = chunk.into_inner(); + let (ops, columns, visibility) = chunk.into_inner(); + + // Calculate the row visibility for every agg call. + let visibilities: Vec<_> = agg_calls + .iter() + .map(|agg_call| { + agg_call_filter_res( + ctx, + identity, + agg_call, + &columns, + visibility.as_ref(), + capacity, + ) + }) + .try_collect()?; + + // Materialize input chunk if needed. + storages + .iter_mut() + .zip_eq(visibilities.iter().map(Option::as_ref)) + .for_each(|(storage, visibility)| { + if let AggStateStorage::MaterializedInput { table, mapping } = storage { + let needed_columns = mapping + .upstream_columns() + .iter() + .map(|col_idx| columns[*col_idx].clone()) + .collect(); + table.write_chunk(StreamChunk::new( + ops.clone(), + needed_columns, + visibility.cloned(), + )); + } + }); // Apply chunk to each of the state (per agg_call), for each group. for (key, _, vis_map) in &unique_keys { let agg_group = agg_groups.get_mut(key).unwrap().as_mut().unwrap(); - let visibilities: Vec<_> = agg_calls + let visibilities = visibilities .iter() - .map(|agg_call| { - agg_call_filter_res(ctx, identity, agg_call, &columns, Some(vis_map), capacity) - }) - .try_collect()?; - agg_group - .apply_chunk(storages, &ops, &columns, &visibilities) - .await?; + .map(Option::as_ref) + .map(|v| v.map_or_else(|| vis_map.clone(), |v| v & vis_map)) + .map(Some) + .collect(); + agg_group.apply_chunk(storages, &ops, &columns, visibilities)?; } Ok(()) @@ -405,6 +437,11 @@ impl HashAggExecutor { storages, ) .await?; + + if n_appended_ops == 0 { + continue; + } + for _ in 0..n_appended_ops { key.clone().deserialize_to_builders( &mut builders[..group_key_indices.len()], diff --git a/src/stream/src/executor/hash_join.rs b/src/stream/src/executor/hash_join.rs index 1c6c82d6a2137..d8e7ef7d7b67f 100644 --- a/src/stream/src/executor/hash_join.rs +++ b/src/stream/src/executor/hash_join.rs @@ -107,8 +107,12 @@ const fn is_anti(join_type: JoinTypePrimitive) -> bool { join_type == JoinType::LeftAnti || join_type == JoinType::RightAnti } -const fn is_semi_or_anti(join_type: JoinTypePrimitive) -> bool { - is_semi(join_type) || is_anti(join_type) +const fn is_left_semi_or_anti(join_type: JoinTypePrimitive) -> bool { + join_type == JoinType::LeftSemi || join_type == JoinType::LeftAnti +} + +const fn is_right_semi_or_anti(join_type: JoinTypePrimitive) -> bool { + join_type == JoinType::RightSemi || join_type == JoinType::RightAnti } const fn need_update_side_matched_degree( @@ -166,6 +170,8 @@ struct JoinSide { all_data_types: Vec, /// The start position for the side in output new columns start_pos: usize, + /// The mapping from input indices of a side to output columes. + i2o_mapping: Vec<(usize, usize)>, } impl std::fmt::Debug for JoinSide { @@ -175,6 +181,7 @@ impl std::fmt::Debug for JoinSide { .field("pk_indices", &self.pk_indices) .field("col_types", &self.all_data_types) .field("start_pos", &self.start_pos) + .field("i2o_mapping", &self.i2o_mapping) .finish() } } @@ -212,9 +219,7 @@ pub struct HashJoinExecutor, /// The data types of the formed new columns - output_data_types: Vec, - /// The output indices of the join executor - output_indices: Vec, + actual_output_data_types: Vec, /// The schema of the hash join executor schema: Schema, /// The primary key indices of the schema @@ -252,7 +257,7 @@ impl std::fmt::Debug .field("side_r", &self.side_r) .field("pk_indices", &self.pk_indices) .field("schema", &self.schema) - .field("output_data_types", &self.output_data_types) + .field("actual_output_data_types", &self.actual_output_data_types) .finish() } } @@ -446,10 +451,14 @@ impl HashJoinExecutor = schema_fields + let original_output_data_types = schema_fields .iter() .map(|field| field.data_type.clone()) - .collect(); + .collect_vec(); + let actual_output_data_types = output_indices + .iter() + .map(|&idx| original_output_data_types[idx].clone()) + .collect_vec(); // Data types of of hash join state. let state_all_data_types_l = input_l @@ -540,11 +549,22 @@ impl HashJoinExecutor HashJoinExecutor HashJoinExecutor HashJoinExecutor { - Message::Chunk(chunk.reorder_columns(&self.output_indices)) - } + Message::Chunk(chunk) => Message::Chunk(chunk), barrier @ Message::Barrier(_) => barrier, })?; } @@ -660,16 +679,14 @@ impl HashJoinExecutor { - Message::Chunk(chunk.reorder_columns(&self.output_indices)) - } + Message::Chunk(chunk) => Message::Chunk(chunk), barrier @ Message::Barrier(_) => barrier, })?; } @@ -764,7 +781,7 @@ impl HashJoinExecutor, mut side_r: &'a mut JoinSide, - output_data_types: &'a [DataType], + actual_output_data_types: &'a [DataType], cond: &'a mut Option, chunk: StreamChunk, append_only_optimize: bool, @@ -778,20 +795,13 @@ impl HashJoinExecutor { - stream_chunk_builder: StreamChunkBuilder::new( - chunk_size, - output_data_types, - update_start_pos, - matched_start_pos, - )?, - } + let mut hashjoin_chunk_builder = HashJoinChunkBuilder:: { + stream_chunk_builder: StreamChunkBuilder::new( + chunk_size, + actual_output_data_types, + side_update.i2o_mapping.clone(), + side_match.i2o_mapping.clone(), + )?, }; let mut check_join_condition = diff --git a/src/stream/src/executor/integration_tests.rs b/src/stream/src/executor/integration_tests.rs index a8aeaf90a706c..90f06608d5de6 100644 --- a/src/stream/src/executor/integration_tests.rs +++ b/src/stream/src/executor/integration_tests.rs @@ -83,7 +83,6 @@ async fn test_merger_sum_aggr() { let actor = Actor::new( consumer, vec![], - 0, context, StreamingMetrics::unused().into(), actor_ctx.clone(), @@ -129,7 +128,6 @@ async fn test_merger_sum_aggr() { let actor = Actor::new( dispatcher, vec![], - 0, context, StreamingMetrics::unused().into(), actor_ctx.clone(), @@ -187,7 +185,6 @@ async fn test_merger_sum_aggr() { let actor = Actor::new( consumer, vec![], - 0, context, StreamingMetrics::unused().into(), actor_ctx.clone(), diff --git a/src/stream/src/executor/lookup/impl_.rs b/src/stream/src/executor/lookup/impl_.rs index b6b7e5f2b769e..f8454959843e6 100644 --- a/src/stream/src/executor/lookup/impl_.rs +++ b/src/stream/src/executor/lookup/impl_.rs @@ -248,6 +248,12 @@ impl LookupExecutor { .boxed() }; + let (stream_to_output, arrange_to_output) = StreamChunkBuilder::get_i2o_mapping( + self.column_mapping.iter().cloned(), + self.stream.col_types.len(), + self.arrangement.col_types.len(), + ); + #[for_await] for msg in input { let msg = msg?; @@ -297,8 +303,8 @@ impl LookupExecutor { let mut builder = StreamChunkBuilder::new( self.chunk_size, &self.chunk_data_types, - 0, - self.stream.col_types.len(), + stream_to_output.clone(), + arrange_to_output.clone(), )?; for (op, row) in ops.iter().zip_eq(chunk.rows()) { @@ -306,14 +312,14 @@ impl LookupExecutor { tracing::trace!(target: "events::stream::lookup::put", "{:?} {:?}", row, matched_row); if let Some(chunk) = builder.append_row(*op, &row, &matched_row)? { - yield Message::Chunk(chunk.reorder_columns(&self.column_mapping)); + yield Message::Chunk(chunk); } } // TODO: support outer join (return null if no rows are matched) } if let Some(chunk) = builder.take()? { - yield Message::Chunk(chunk.reorder_columns(&self.column_mapping)); + yield Message::Chunk(chunk); } } } diff --git a/src/stream/src/executor/mview/materialize.rs b/src/stream/src/executor/mview/materialize.rs index 613e2bc9674c6..351a2386056bb 100644 --- a/src/stream/src/executor/mview/materialize.rs +++ b/src/stream/src/executor/mview/materialize.rs @@ -26,8 +26,8 @@ use risingwave_storage::StateStore; use crate::executor::error::StreamExecutorError; use crate::executor::{ - expect_first_barrier, ActorContextRef, BoxedExecutor, BoxedMessageStream, Executor, - ExecutorInfo, Message, PkIndicesRef, + expect_first_barrier, ActorContext, ActorContextRef, BoxedExecutor, BoxedMessageStream, + Executor, ExecutorInfo, Message, PkIndicesRef, }; /// `MaterializeExecutor` materializes changes in stream into a materialized view on storage. @@ -42,12 +42,15 @@ pub struct MaterializeExecutor { actor_context: ActorContextRef, info: ExecutorInfo, + + _ignore_on_conflict: bool, } impl MaterializeExecutor { /// Create a new `MaterializeExecutor` with distribution specified with `distribution_keys` and /// `vnodes`. For singleton distribution, `distribution_keys` should be empty and `vnodes` /// should be `None`. + #[allow(clippy::too_many_arguments)] pub fn new( input: BoxedExecutor, store: S, @@ -56,6 +59,7 @@ impl MaterializeExecutor { actor_context: ActorContextRef, vnodes: Option>, table_catalog: &Table, + _ignore_on_conflict: bool, ) -> Self { let arrange_columns: Vec = key.iter().map(|k| k.column_idx).collect(); @@ -73,6 +77,7 @@ impl MaterializeExecutor { pk_indices: arrange_columns, identity: format!("MaterializeExecutor {:X}", executor_id), }, + _ignore_on_conflict, } } @@ -105,12 +110,13 @@ impl MaterializeExecutor { input, state_table, arrange_columns: arrange_columns.clone(), - actor_context: Default::default(), + actor_context: ActorContext::create(0), info: ExecutorInfo { schema, pk_indices: arrange_columns, identity: format!("MaterializeExecutor {:X}", executor_id), }, + _ignore_on_conflict: true, } } diff --git a/src/stream/src/executor/source/source_executor.rs b/src/stream/src/executor/source/source_executor.rs index 695ebef04424a..b9170fce950ec 100644 --- a/src/stream/src/executor/source/source_executor.rs +++ b/src/stream/src/executor/source/source_executor.rs @@ -43,7 +43,7 @@ pub struct SourceExecutor { ctx: ActorContextRef, source_id: TableId, - source_builder: SourceDescBuilder, + source_desc_builder: SourceDescBuilder, /// Row id generator for this source executor. row_id_generator: RowIdGenerator, @@ -79,7 +79,7 @@ impl SourceExecutor { #[allow(clippy::too_many_arguments)] pub fn new( ctx: ActorContextRef, - source_builder: SourceDescBuilder, + source_desc_builder: SourceDescBuilder, source_id: TableId, vnodes: Bitmap, state_table: SourceStateTableHandler, @@ -98,7 +98,7 @@ impl SourceExecutor { Ok(Self { ctx, source_id, - source_builder, + source_desc_builder, row_id_generator: RowIdGenerator::with_epoch( vnode_id as u32, *UNIX_SINGULARITY_DATE_EPOCH, @@ -254,10 +254,10 @@ impl SourceExecutor { .unwrap(); let source_desc = self - .source_builder + .source_desc_builder .build() .await - .context("build source desc failed")?; + .map_err(StreamExecutorError::connector_error)?; // source_desc's row_id_index is based on its columns, and it is possible // that we prune some columns when generating column_ids. So this index // can not be directly used. @@ -503,7 +503,7 @@ mod tests { RowFormatType as ProstRowFormatType, }; use risingwave_pb::stream_plan::source_node::Info as ProstSourceInfo; - use risingwave_source::table_test_utils::create_table_info; + use risingwave_source::table_test_utils::create_table_source_desc_builder; use risingwave_source::*; use risingwave_storage::memory::MemoryStateStore; use tokio::sync::mpsc::unbounded_channel; @@ -523,9 +523,14 @@ mod tests { }; let row_id_index = Some(0); let pk_column_ids = vec![0]; - let info = create_table_info(&schema, row_id_index, pk_column_ids); let source_manager: TableSourceManagerRef = Arc::new(TableSourceManager::default()); - let source_builder = SourceDescBuilder::new(table_id, &info, &source_manager); + let source_builder = create_table_source_desc_builder( + &schema, + table_id, + row_id_index, + pk_column_ids, + source_manager, + ); let source_desc = source_builder.build().await.unwrap(); let chunk1 = StreamChunk::from_pretty( @@ -625,9 +630,14 @@ mod tests { }; let row_id_index = Some(0); let pk_column_ids = vec![0]; - let info = create_table_info(&schema, row_id_index, pk_column_ids); let source_manager: TableSourceManagerRef = Arc::new(TableSourceManager::default()); - let source_builder = SourceDescBuilder::new(table_id, &info, &source_manager); + let source_builder = create_table_source_desc_builder( + &schema, + table_id, + row_id_index, + pk_column_ids, + source_manager, + ); let source_desc = source_builder.build().await.unwrap(); // Prepare test data chunks @@ -682,14 +692,25 @@ mod tests { write_chunk(chunk); } - fn mock_stream_source_info() -> StreamSourceInfo { + trait StreamChunkExt { + fn drop_row_id(self) -> Self; + } + + impl StreamChunkExt for StreamChunk { + fn drop_row_id(self) -> StreamChunk { + let (ops, mut columns, bitmap) = self.into_inner(); + columns.remove(0); + StreamChunk::new(ops, columns, bitmap) + } + } + + fn mock_source_desc_builder(source_id: TableId) -> SourceDescBuilder { let properties = convert_args!(hashmap!( "connector" => "datagen", "fields.v1.min" => "1", "fields.v1.max" => "1000", "fields.v1.seed" => "12345", )); - let columns = vec![ ProstColumnCatalog { column_desc: Some(ProstColumnDesc { @@ -715,34 +736,27 @@ mod tests { is_hidden: false, }, ]; - - StreamSourceInfo { - properties, + let row_id_index = Some(ProstColumnIndex { index: 0 }); + let pk_column_ids = vec![0]; + let stream_source_info = StreamSourceInfo { row_format: ProstRowFormatType::Json as i32, row_schema_location: "".to_string(), - row_id_index: Some(ProstColumnIndex { index: 0 }), + }; + let source_manager = Arc::new(TableSourceManager::default()); + SourceDescBuilder::new( + source_id, + row_id_index, columns, - pk_column_ids: vec![0], - } - } - - trait StreamChunkExt { - fn drop_row_id(self) -> Self; - } - - impl StreamChunkExt for StreamChunk { - fn drop_row_id(self) -> StreamChunk { - let (ops, mut columns, bitmap) = self.into_inner(); - columns.remove(0); - StreamChunk::new(ops, columns, bitmap) - } + pk_column_ids, + properties, + ProstSourceInfo::StreamSource(stream_source_info), + source_manager, + ) } #[tokio::test] async fn test_split_change_mutation() { - let stream_source_info = mock_stream_source_info(); let source_table_id = TableId::default(); - let source_manager: TableSourceManagerRef = Arc::new(TableSourceManager::default()); let get_schema = |column_ids: &[ColumnId], source_desc: &SourceDescRef| { let mut fields = Vec::with_capacity(column_ids.len()); @@ -757,12 +771,8 @@ mod tests { Schema::new(fields) }; - let source_builder = SourceDescBuilder::new( - source_table_id, - &ProstSourceInfo::StreamSource(stream_source_info), - &source_manager, - ); - let source_desc = source_builder.clone().build().await.unwrap(); + let source_builder = mock_source_desc_builder(source_table_id); + let source_desc = source_builder.build().await.unwrap(); let mem_state_store = MemoryStateStore::new(); let column_ids = vec![ColumnId::from(0), ColumnId::from(1)]; diff --git a/src/stream/src/executor/subtask.rs b/src/stream/src/executor/subtask.rs index 710ab40666881..a0ac4a8146d6c 100644 --- a/src/stream/src/executor/subtask.rs +++ b/src/stream/src/executor/subtask.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use async_stack_trace::StackTrace; use futures::{Future, StreamExt}; use tokio::sync::mpsc; use tokio::sync::mpsc::error::SendError; @@ -80,7 +81,8 @@ pub fn wrap(input: BoxedExecutor) -> (SubtaskHandle, SubtaskRxExecutor) { break; } } - }; + } + .stack_trace("Subtask"); (handle, rx_executor) } diff --git a/src/stream/src/executor/wrapper/trace.rs b/src/stream/src/executor/wrapper/trace.rs index 9727af1561677..fa5ecfb18f09f 100644 --- a/src/stream/src/executor/wrapper/trace.rs +++ b/src/stream/src/executor/wrapper/trace.rs @@ -96,6 +96,15 @@ pub async fn metrics( } } +fn pretty_identity(identity: &str, actor_id: ActorId, executor_id: u64) -> String { + format!( + "{} (actor {}, executor {})", + identity, + actor_id, + executor_id as u32 // Use the lower 32 bit to match the dashboard. + ) +} + /// Streams wrapped by `stack_trace` will print the async stack trace of the executors. #[try_stream(ok = Message, error = StreamExecutorError)] pub async fn stack_trace( @@ -106,13 +115,7 @@ pub async fn stack_trace( ) { pin_mut!(input); - let span: SpanValue = format!( - "{} (actor {}, executor {})", - info.identity, - actor_id, - executor_id as u32 // Use the lower 32 bit to match the dashboard. - ) - .into(); + let span: SpanValue = pretty_identity(&info.identity, actor_id, executor_id).into(); while let Some(message) = input.next().stack_trace(span.clone()).await.transpose()? { yield message; diff --git a/src/stream/src/from_proto/mview.rs b/src/stream/src/from_proto/mview.rs index dec711416d41a..fcaac0fa64313 100644 --- a/src/stream/src/from_proto/mview.rs +++ b/src/stream/src/from_proto/mview.rs @@ -38,6 +38,7 @@ impl ExecutorBuilder for MaterializeExecutorBuilder { .collect(); let table = node.get_table()?; + let do_sanity_check = node.get_ignore_on_conflict(); let executor = MaterializeExecutor::new( input, store, @@ -46,6 +47,7 @@ impl ExecutorBuilder for MaterializeExecutorBuilder { params.actor_context, params.vnode_bitmap.map(Arc::new), table, + do_sanity_check, ); Ok(executor.boxed()) @@ -76,7 +78,7 @@ impl ExecutorBuilder for ArrangeExecutorBuilder { // FIXME: Lookup is now implemented without cell-based table API and relies on all vnodes // being `DEFAULT_VNODE`, so we need to make the Arrange a singleton. let vnodes = params.vnode_bitmap.map(Arc::new); - + let ignore_on_conflict = arrange_node.get_ignore_on_conflict(); let executor = MaterializeExecutor::new( input, store, @@ -85,6 +87,7 @@ impl ExecutorBuilder for ArrangeExecutorBuilder { params.actor_context, vnodes, table, + ignore_on_conflict, ); Ok(executor.boxed()) diff --git a/src/stream/src/from_proto/source.rs b/src/stream/src/from_proto/source.rs index 200c06f3674ad..a8d9d9bb8f8c9 100644 --- a/src/stream/src/from_proto/source.rs +++ b/src/stream/src/from_proto/source.rs @@ -12,10 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -use anyhow::anyhow; use risingwave_common::catalog::{ColumnId, Field, Schema, TableId}; use risingwave_common::types::DataType; -use risingwave_pb::stream_plan::source_node::Info as SourceNodeInfo; use risingwave_source::SourceDescBuilder; use tokio::sync::mpsc::unbounded_channel; @@ -42,22 +40,19 @@ impl ExecutorBuilder for SourceExecutorBuilder { let source_id = TableId::new(node.source_id); let source_builder = SourceDescBuilder::new( source_id, - node.get_info()?, - ¶ms.env.source_manager_ref(), + node.row_id_index.clone(), + node.columns.clone(), + node.pk_column_ids.clone(), + node.properties.clone(), + node.get_info()?.clone(), + params.env.source_manager_ref(), ); - let column_ids: Vec<_> = node - .get_column_ids() + let columns = node.columns.clone(); + let column_ids: Vec<_> = columns .iter() - .map(|i| ColumnId::from(*i)) + .map(|column| ColumnId::from(column.get_column_desc().unwrap().column_id)) .collect(); - let columns = node - .get_info() - .map(|info| match info { - SourceNodeInfo::StreamSource(stream) => &stream.columns, - SourceNodeInfo::TableSource(table) => &table.columns, - }) - .map_err(|_| anyhow!("source_info not found"))?; let fields = columns .iter() .map(|prost| { diff --git a/src/stream/src/task/stream_manager.rs b/src/stream/src/task/stream_manager.rs index 33600c4ba6b56..11ed7138dbf01 100644 --- a/src/stream/src/task/stream_manager.rs +++ b/src/stream/src/task/stream_manager.rs @@ -18,7 +18,7 @@ use std::fmt::Debug; use std::sync::Arc; use anyhow::{anyhow, Context}; -use async_stack_trace::{StackTraceManager, StackTraceReport}; +use async_stack_trace::{StackTraceManager, StackTraceReport, TraceConfig}; use itertools::Itertools; use parking_lot::Mutex; use risingwave_common::bail; @@ -78,7 +78,7 @@ pub struct LocalStreamManagerCore { pub(crate) config: StreamingConfig, /// Manages the stack traces of all actors. - stack_trace_manager: Option>, + stack_trace_manager: Option<(StackTraceManager, TraceConfig)>, } /// `LocalStreamManager` manages all stream executors in this project. @@ -142,7 +142,7 @@ impl LocalStreamManager { state_store: StateStoreImpl, streaming_metrics: Arc, config: StreamingConfig, - enable_async_stack_trace: bool, + async_stack_trace_config: Option, enable_managed_cache: bool, ) -> Self { Self::with_core(LocalStreamManagerCore::new( @@ -150,7 +150,7 @@ impl LocalStreamManager { state_store, streaming_metrics, config, - enable_async_stack_trace, + async_stack_trace_config, enable_managed_cache, )) } @@ -171,6 +171,7 @@ impl LocalStreamManager { .stack_trace_manager .as_mut() .expect("async stack trace not enabled") + .0 .get_all() { println!(">> Actor {}\n\n{}", k, &*trace); @@ -183,7 +184,7 @@ impl LocalStreamManager { pub fn get_actor_traces(&self) -> HashMap { let mut core = self.core.lock(); match &mut core.stack_trace_manager { - Some(mgr) => mgr.get_all().map(|(k, v)| (*k, v.clone())).collect(), + Some((mgr, _)) => mgr.get_all().map(|(k, v)| (*k, v.clone())).collect(), None => Default::default(), } } @@ -360,7 +361,7 @@ impl LocalStreamManagerCore { state_store: StateStoreImpl, streaming_metrics: Arc, config: StreamingConfig, - enable_async_stack_trace: bool, + async_stack_trace_config: Option, enable_managed_cache: bool, ) -> Self { let context = SharedContext::new(addr, state_store.clone(), &config, enable_managed_cache); @@ -369,7 +370,7 @@ impl LocalStreamManagerCore { context, streaming_metrics, config, - enable_async_stack_trace, + async_stack_trace_config, ) } @@ -378,7 +379,7 @@ impl LocalStreamManagerCore { context: SharedContext, streaming_metrics: Arc, config: StreamingConfig, - enable_async_stack_trace: bool, + async_stack_trace_config: Option, ) -> Self { let mut builder = tokio::runtime::Builder::new_multi_thread(); if let Some(worker_threads_num) = config.actor_runtime_worker_threads_num { @@ -402,7 +403,8 @@ impl LocalStreamManagerCore { state_store, streaming_metrics, config, - stack_trace_manager: enable_async_stack_trace.then(Default::default), + stack_trace_manager: async_stack_trace_config + .map(|c| (StackTraceManager::default(), c)), } } @@ -417,7 +419,7 @@ impl LocalStreamManagerCore { SharedContext::for_test(), streaming_metrics, StreamingConfig::default(), - false, + None, ) } @@ -573,6 +575,7 @@ impl LocalStreamManagerCore { fn build_actors(&mut self, actors: &[ActorId], env: StreamEnvironment) -> StreamResult<()> { for &actor_id in actors { let actor = self.actors.remove(&actor_id).unwrap(); + let mview_definition = &actor.mview_definition; let actor_context = ActorContext::create(actor_id); let vnode_bitmap = actor .vnode_bitmap @@ -593,7 +596,6 @@ impl LocalStreamManagerCore { let actor = Actor::new( dispatcher, subtasks, - actor_id, self.context.clone(), self.streaming_metrics.clone(), actor_context, @@ -603,7 +605,7 @@ impl LocalStreamManagerCore { let trace_reporter = self .stack_trace_manager .as_mut() - .map(|m| m.register(actor_id)); + .map(|(m, _)| m.register(actor_id)); let handle = { let actor = async move { @@ -616,9 +618,12 @@ impl LocalStreamManagerCore { let traced = match trace_reporter { Some(trace_reporter) => trace_reporter.trace( actor, - format!("Actor {actor_id}"), - true, - Duration::from_millis(1000), + format!("Actor {actor_id}: `{}`", mview_definition), + TraceConfig { + report_detached: true, + verbose: true, + interval: Duration::from_secs(1), + }, ), None => actor, }; @@ -739,7 +744,7 @@ impl LocalStreamManagerCore { } self.actors.clear(); self.context.clear_channels(); - if let Some(stack_trace_manager) = self.stack_trace_manager.as_mut() { + if let Some((stack_trace_manager, _)) = self.stack_trace_manager.as_mut() { std::mem::take(stack_trace_manager); } self.actor_monitor_tasks.clear(); diff --git a/src/test_runner/Cargo.toml b/src/test_runner/Cargo.toml index 7b7e285be9419..816437734a575 100644 --- a/src/test_runner/Cargo.toml +++ b/src/test_runner/Cargo.toml @@ -6,4 +6,5 @@ edition = "2021" [dependencies] fail = "0.5" +sync-point = { path = "../utils/sync-point" } workspace-hack = { version = "0.2.0-alpha", path = "../workspace-hack" } diff --git a/src/test_runner/src/test_runner.rs b/src/test_runner/src/test_runner.rs index c54d0ae3c1880..d74218231dee8 100644 --- a/src/test_runner/src/test_runner.rs +++ b/src/test_runner/src/test_runner.rs @@ -100,21 +100,43 @@ impl TestHook for FailPointHook { }) } } + +#[derive(Clone)] +struct SyncPointHook; + +impl TestHook for SyncPointHook { + fn setup(&mut self) { + sync_point::reset(); + } + + fn teardown(&mut self) { + sync_point::reset(); + } +} + // End Copyright 2016 TiKV Project Authors. Licensed under Apache-2.0. pub fn run_failpont_tests(cases: &[&TestDescAndFn]) { let mut cases1 = vec![]; let mut cases2 = vec![]; + let mut cases3 = vec![]; cases.iter().for_each(|case| { - if case.desc.name.as_slice().contains("test_failpoints") { + if case.desc.name.as_slice().contains("test_syncpoints") { + // sync_point tests should specify #[serial], because sync_point lib doesn't implement + // an implicit global lock to order tests like fail-rs. cases1.push(*case); - } else { + } else if case.desc.name.as_slice().contains("test_failpoints") { cases2.push(*case); + } else { + cases3.push(*case); } }); if !cases1.is_empty() { - run_test_inner(cases1.as_slice(), FailPointHook); + run_test_inner(cases1.as_slice(), SyncPointHook); } if !cases2.is_empty() { - run_general_test(cases2.as_slice()); + run_test_inner(cases2.as_slice(), FailPointHook); + } + if !cases3.is_empty() { + run_general_test(cases3.as_slice()); } } diff --git a/src/tests/sync_point/Cargo.toml b/src/tests/sync_point/Cargo.toml deleted file mode 100644 index e0b0f41c8a019..0000000000000 --- a/src/tests/sync_point/Cargo.toml +++ /dev/null @@ -1,28 +0,0 @@ -[package] -name = "risingwave_sync_point_test" -version = "0.1.0" -edition = "2021" -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -bytes = "1" -itertools = "0.10" -risingwave_cmd_all = { path = "../../cmd_all" } -risingwave_common = { path = "../../common" } -risingwave_object_store = { path = "../../object_store" } -risingwave_pb = { path = "../../prost" } -risingwave_rpc_client = { path = "../../rpc_client" } -serial_test = "0.9" -sync-point = { path = "../../utils/sync-point" } -tokio = { version = "0.2", package = "madsim-tokio", features = [ - "rt", - "rt-multi-thread", - "sync", - "macros", - "time", - "signal", - "fs", -] } - -[features] -sync_point = ["sync-point/sync_point"] diff --git a/src/tests/sync_point/README.md b/src/tests/sync_point/README.md deleted file mode 100644 index 4ba8752b69431..0000000000000 --- a/src/tests/sync_point/README.md +++ /dev/null @@ -1,9 +0,0 @@ -Integration tests based on risedev playground and sync point support. - -# How to run - -bash run_sync_point_test.sh - -# How to write a test - -Refer to test_gc_watermark in src/tests/sync/src/tests.rs. \ No newline at end of file diff --git a/src/tests/sync_point/run_sync_point_test.sh b/src/tests/sync_point/run_sync_point_test.sh deleted file mode 100644 index 09bcb6112bf0c..0000000000000 --- a/src/tests/sync_point/run_sync_point_test.sh +++ /dev/null @@ -1,14 +0,0 @@ -#!/bin/bash -# On macOS: brew install coreutils - -set -e - -SCRIPT_PATH="$( cd -- "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P )" -RW_WORKSPACE=$(realpath "${SCRIPT_PATH}"/../../../) -export RW_WORKSPACE=${RW_WORKSPACE} -RUSTFLAGS='--cfg tokio_unstable --cfg sync_point_integration_test' -export RUSTFLAGS=${RUSTFLAGS} - -cp -R -n "${RW_WORKSPACE}"/e2e_test "${SCRIPT_PATH}"/slt/ || : - -cargo test -p risingwave_sync_point_test --features sync_point diff --git a/src/tests/sync_point/slt/tpch_snapshot_no_drop.slt b/src/tests/sync_point/slt/tpch_snapshot_no_drop.slt deleted file mode 100644 index 4b09585087603..0000000000000 --- a/src/tests/sync_point/slt/tpch_snapshot_no_drop.slt +++ /dev/null @@ -1,33 +0,0 @@ -include e2e_test/tpch/create_tables.slt.part - -include e2e_test/tpch/insert_customer.slt.part -include e2e_test/tpch/insert_lineitem.slt.part -include e2e_test/tpch/insert_nation.slt.part -include e2e_test/tpch/insert_orders.slt.part -include e2e_test/tpch/insert_part.slt.part -include e2e_test/tpch/insert_partsupp.slt.part -include e2e_test/tpch/insert_region.slt.part -include e2e_test/tpch/insert_supplier.slt.part - -include e2e_test/streaming/tpch/create_views.slt.part - -include e2e_test/streaming/tpch/q1.slt.part -include e2e_test/streaming/tpch/q2.slt.part -include e2e_test/streaming/tpch/q3.slt.part -include e2e_test/streaming/tpch/q4.slt.part -include e2e_test/streaming/tpch/q5.slt.part -include e2e_test/streaming/tpch/q6.slt.part -include e2e_test/streaming/tpch/q7.slt.part -include e2e_test/streaming/tpch/q8.slt.part -include e2e_test/streaming/tpch/q9.slt.part -include e2e_test/streaming/tpch/q10.slt.part -include e2e_test/streaming/tpch/q11.slt.part -include e2e_test/streaming/tpch/q12.slt.part -include e2e_test/streaming/tpch/q13.slt.part -include e2e_test/streaming/tpch/q14.slt.part -include e2e_test/streaming/tpch/q17.slt.part -include e2e_test/streaming/tpch/q18.slt.part -include e2e_test/streaming/tpch/q19.slt.part -include e2e_test/streaming/tpch/q20.slt.part -include e2e_test/streaming/tpch/q22.slt.part - diff --git a/src/tests/sync_point/src/lib.rs b/src/tests/sync_point/src/lib.rs deleted file mode 100644 index 4b601e6b924ea..0000000000000 --- a/src/tests/sync_point/src/lib.rs +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright 2022 Singularity Data -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#[cfg(all(test, sync_point_integration_test, feature = "sync_point"))] -mod test_utils; -#[cfg(all(test, sync_point_integration_test, feature = "sync_point"))] -mod tests; diff --git a/src/tests/sync_point/src/test_utils.rs b/src/tests/sync_point/src/test_utils.rs deleted file mode 100644 index 98d1bc088baf8..0000000000000 --- a/src/tests/sync_point/src/test_utils.rs +++ /dev/null @@ -1,130 +0,0 @@ -// Copyright 2022 Singularity Data -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::process::{Command, Output}; -use std::sync::Arc; -use std::thread::JoinHandle; -use std::time::Duration; - -use risingwave_cmd_all::playground; -use risingwave_object_store::object::object_metrics::ObjectStoreMetrics; -use risingwave_object_store::object::{parse_remote_object_store, ObjectStoreImpl}; -use risingwave_pb::common::WorkerType; -use risingwave_rpc_client::MetaClient; -use serial_test::serial; - -pub fn setup_env() { - sync_point::reset(); - let current_dir = - std::env::var("RW_WORKSPACE").expect("set env RW_WORKSPACE to project root path"); - std::env::set_current_dir(current_dir).unwrap(); - std::env::set_var("RW_META_ADDR", "http://127.0.0.1:5690"); - std::env::set_var("FORCE_SHARED_HUMMOCK_IN_MEM", "1"); - std::env::set_var("PLAYGROUND_PROFILE", "playground-test"); - std::env::set_var("OBJECT_STORE_URL", "memory-shared"); - std::env::set_var("OBJECT_STORE_BUCKET", "hummock_001"); -} - -pub async fn start_cluster() -> (JoinHandle<()>, std::sync::mpsc::Sender<()>) { - wipe_object_store().await; - - let (tx, rx) = std::sync::mpsc::channel(); - let join_handle = std::thread::spawn(move || { - let runtime = tokio::runtime::Builder::new_multi_thread() - .worker_threads(1) - .enable_all() - .build() - .unwrap(); - runtime.block_on(async { - tokio::spawn(async { playground().await }); - rx.recv().unwrap(); - }); - }); - // It will find "CLUSTER_READY" even when it is reached after "CLUSTER_READY" has been emitted. - sync_point::wait_timeout("CLUSTER_READY", Duration::from_secs(30)) - .await - .unwrap(); - (join_handle, tx) -} - -pub async fn stop_cluster(join_handle: JoinHandle<()>, shutdown_tx: std::sync::mpsc::Sender<()>) { - shutdown_tx.send(()).unwrap(); - join_handle.join().unwrap(); - wipe_object_store().await; -} - -pub fn run_slt() -> Output { - Command::new("./risedev") - .args([ - "slt", - "-d", - "dev", - "-p", - "4566", - "src/tests/sync_point/slt/tpch_snapshot_no_drop.slt", - ]) - .spawn() - .unwrap() - .wait_with_output() - .unwrap() -} - -pub async fn get_object_store_client() -> ObjectStoreImpl { - let url = std::env::var("OBJECT_STORE_URL").unwrap(); - parse_remote_object_store(&url, Arc::new(ObjectStoreMetrics::unused())).await -} - -pub fn get_object_store_bucket() -> String { - std::env::var("OBJECT_STORE_BUCKET").unwrap() -} - -async fn wipe_object_store() { - let url = std::env::var("OBJECT_STORE_URL").unwrap(); - assert_eq!(url, "memory-shared"); - risingwave_object_store::object::InMemObjectStore::reset_shared(); -} - -pub async fn get_meta_client() -> MetaClient { - let meta_addr = std::env::var("RW_META_ADDR").unwrap(); - MetaClient::register_new( - &meta_addr, - WorkerType::RiseCtl, - &"127.0.0.1:2333".parse().unwrap(), - 0, - ) - .await - .unwrap() -} - -#[tokio::test] -#[serial] -#[should_panic] -async fn test_wait_timeout() { - sync_point::hook("TEST_SETUP_TIMEOUT", || async { - sync_point::wait_timeout("SIG_NEVER_EMIT", Duration::from_secs(1)) - .await - .unwrap(); - }); - - // timeout - sync_point::on("TEST_SETUP_TIMEOUT").await; -} - -#[tokio::test] -#[serial] -async fn test_launch_cluster() { - setup_env(); - let (join_handle, tx) = start_cluster().await; - stop_cluster(join_handle, tx).await; -} diff --git a/src/tests/sync_point/src/tests.rs b/src/tests/sync_point/src/tests.rs deleted file mode 100644 index 7d588b4ff2857..0000000000000 --- a/src/tests/sync_point/src/tests.rs +++ /dev/null @@ -1,208 +0,0 @@ -// Copyright 2022 Singularity Data -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::time::Duration; - -use itertools::Itertools; -use risingwave_rpc_client::HummockMetaClient; -use serial_test::serial; - -use crate::test_utils::*; - -/// With support of sync point, this test is executed in the following sequential order: -/// 1. Block compaction scheduler, thus no compaction will be automatically scheduled. -/// 2. Import data with risedev slt. The slt is modified so that it will not drop MVs in the end. -/// 3. Schedule exactly one compaction. -/// 4. Wait until compactor has uploaded its output to object store. It doesn't report task result -/// to meta, until we tell it to do so in step 6. -/// 5. Verify GC logic. -/// 6. Compactor reports task result to meta. -/// 7. Verify GC logic. -#[tokio::test] -#[serial] -async fn test_gc_watermark() { - setup_env(); - - let (join_handle, tx) = start_cluster().await; - let object_store_client = get_object_store_client().await; - let meta_client = get_meta_client().await; - - sync_point::hook("BEFORE_COMPACT_REPORT", || async { - sync_point::on("SIG_DONE_COMPACT_UPLOAD").await; - sync_point::wait_timeout("START_COMPACT_REPORT", Duration::from_secs(3600)) - .await - .unwrap(); - }); - // Block compaction scheduler so that we can control scheduling explicitly - sync_point::hook("BEFORE_SCHEDULE_COMPACTION_TASK", || async { - sync_point::wait_timeout("SIG_SCHEDULE_COMPACTION_TASK", Duration::from_secs(3600)) - .await - .unwrap(); - }); - - // Import data - let run_slt = run_slt(); - assert!(run_slt.status.success()); - - let before_compaction = object_store_client.list("").await.unwrap(); - assert!(!before_compaction.is_empty()); - - // Schedule a compaction task - sync_point::on("SIG_SCHEDULE_COMPACTION_TASK").await; - - // Wait until SSTs have been written to object store - sync_point::wait_timeout("SIG_DONE_COMPACT_UPLOAD", Duration::from_secs(10)) - .await - .unwrap(); - - let after_compaction_upload = object_store_client.list("").await.unwrap(); - let new_objects = after_compaction_upload - .iter() - .filter(|after| { - !before_compaction - .iter() - .any(|before| before.key == after.key) - }) - .cloned() - .collect_vec(); - assert!(!new_objects.is_empty()); - - // Upload a garbage object - let sst_id = meta_client.get_new_sst_ids(1).await.unwrap().start_id; - object_store_client - .upload( - &format!("{}/{}.data", get_object_store_bucket(), sst_id), - bytes::Bytes::from(vec![1, 2, 3]), - ) - .await - .unwrap(); - let after_garbage_upload = object_store_client.list("").await.unwrap(); - assert_eq!( - after_garbage_upload.len(), - after_compaction_upload.len() + 1 - ); - - meta_client.trigger_full_gc(0).await.unwrap(); - // Wait until VACUUM is scheduled and reported - for _ in 0..2 { - sync_point::wait_timeout("AFTER_SCHEDULE_VACUUM", Duration::from_secs(10)) - .await - .unwrap(); - } - // Expect timeout aka no SST is deleted, because the garbage SST has greater id than watermark, - // which is held by the on-going compaction. - sync_point::wait_timeout("AFTER_REPORT_VACUUM", Duration::from_secs(10)) - .await - .unwrap_err(); - let after_gc = object_store_client.list("").await.unwrap(); - assert_eq!(after_gc.len(), after_compaction_upload.len() + 1); - - // Signal to continue compaction report - sync_point::on("START_COMPACT_REPORT").await; - - // Wait until SSts have been written to hummock version - sync_point::wait_timeout("AFTER_COMPACT_REPORT", Duration::from_secs(10)) - .await - .unwrap(); - - // Wait until VACUUM is scheduled and reported - for _ in 0..2 { - sync_point::wait_timeout("AFTER_SCHEDULE_VACUUM", Duration::from_secs(10)) - .await - .unwrap(); - } - // Expect some stale SSTs as the result of compaction are deleted. - sync_point::wait_timeout("AFTER_REPORT_VACUUM", Duration::from_secs(10)) - .await - .unwrap(); - let after_gc = object_store_client.list("").await.unwrap(); - assert!(after_gc.len() < after_compaction_upload.len()); - - meta_client.trigger_full_gc(0).await.unwrap(); - // Wait until VACUUM is scheduled and reported - for _ in 0..2 { - sync_point::wait_timeout("AFTER_SCHEDULE_VACUUM", Duration::from_secs(10)) - .await - .unwrap(); - } - // Expect the garbage SST is deleted. - sync_point::wait_timeout("AFTER_REPORT_VACUUM", Duration::from_secs(10)) - .await - .unwrap(); - let after_gc_2 = object_store_client.list("").await.unwrap(); - assert_eq!(after_gc.len(), after_gc_2.len() + 1); - - stop_cluster(join_handle, tx).await; -} - -#[tokio::test] -#[serial] -async fn test_gc_sst_retention_time() { - setup_env(); - - let (join_handle, tx) = start_cluster().await; - let object_store_client = get_object_store_client().await; - let meta_client = get_meta_client().await; - - let before_garbage_upload = object_store_client.list("").await.unwrap(); - assert_eq!(before_garbage_upload.len(), 0); - - // Upload a garbage object - let sst_id = meta_client.get_new_sst_ids(1).await.unwrap().start_id; - object_store_client - .upload( - &format!("{}/{}.data", get_object_store_bucket(), sst_id), - bytes::Bytes::from(vec![1, 2, 3]), - ) - .await - .unwrap(); - let after_garbage_upload = object_store_client.list("").await.unwrap(); - assert_eq!(after_garbage_upload.len(), 1); - - // With large sst_retention_time - meta_client.trigger_full_gc(3600).await.unwrap(); - // Wait until VACUUM is scheduled and reported - for _ in 0..2 { - sync_point::wait_timeout("AFTER_SCHEDULE_VACUUM", Duration::from_secs(10)) - .await - .unwrap(); - } - // Expect timeout aka no SST is deleted, because all SSTs are within sst_retention_time, even - // the garbage one. - sync_point::wait_timeout("AFTER_REPORT_VACUUM", Duration::from_secs(10)) - .await - .unwrap_err(); - let after_gc = object_store_client.list("").await.unwrap(); - // Garbage is not deleted. - assert_eq!(after_gc, after_garbage_upload); - - // Ensure SST's last modified is less than now - tokio::time::sleep(Duration::from_secs(1)).await; - // With 0 sst_retention_time - meta_client.trigger_full_gc(0).await.unwrap(); - // Wait until VACUUM is scheduled and reported - for _ in 0..2 { - sync_point::wait_timeout("AFTER_SCHEDULE_VACUUM", Duration::from_secs(10)) - .await - .unwrap(); - } - sync_point::wait_timeout("AFTER_REPORT_VACUUM", Duration::from_secs(10)) - .await - .unwrap(); - let after_gc = object_store_client.list("").await.unwrap(); - // Garbage is deleted. - assert_eq!(after_gc, before_garbage_upload); - - stop_cluster(join_handle, tx).await; -} diff --git a/src/utils/async_stack_trace/src/context.rs b/src/utils/async_stack_trace/src/context.rs index abdb238d3c38a..252607aaa5100 100644 --- a/src/utils/async_stack_trace/src/context.rs +++ b/src/utils/async_stack_trace/src/context.rs @@ -61,6 +61,9 @@ pub(crate) struct TraceContext { /// Whether to report the detached spans, that is, spans that are not able to be polled now. report_detached: bool, + /// Whether to report the "verbose" stack trace. + verbose: bool, + /// The arena for allocating span nodes in this context. arena: Arena, @@ -78,6 +81,7 @@ impl std::fmt::Display for TraceContext { arena: &Arena, node: NodeId, depth: usize, + current: NodeId, ) -> std::fmt::Result { f.write_str(&" ".repeat(depth * 2))?; @@ -85,7 +89,8 @@ impl std::fmt::Display for TraceContext { f.write_str(inner.span.as_ref())?; let elapsed: Duration = inner.start_time.elapsed().into(); - f.write_fmt(format_args!( + write!( + f, " [{}{:?}]", if depth > 0 && elapsed.as_secs() >= 1 { "!!! " @@ -93,20 +98,24 @@ impl std::fmt::Display for TraceContext { "" }, elapsed - ))?; + )?; + + if depth > 0 && node == current { + f.write_str(" <== current")?; + } f.write_char('\n')?; for child in node .children(arena) .sorted_by(|&a, &b| arena[a].get().span.cmp(&arena[b].get().span)) { - fmt_node(f, arena, child, depth + 1)?; + fmt_node(f, arena, child, depth + 1, current)?; } Ok(()) } - fmt_node(f, &self.arena, self.root, 0)?; + fmt_node(f, &self.arena, self.root, 0, self.current)?; // Print all detached spans. May hurt the performance so make it optional. if self.report_detached { @@ -119,8 +128,8 @@ impl std::fmt::Display for TraceContext { && node.next_sibling().is_none() && node.previous_sibling().is_none() { - f.write_str("[??? Detached]\n")?; - fmt_node(f, &self.arena, id, 1)?; + writeln!(f, "[Detached {}]", id)?; + fmt_node(f, &self.arena, id, 1, self.current)?; } } } @@ -131,7 +140,7 @@ impl std::fmt::Display for TraceContext { impl TraceContext { /// Create a new stack trace context with the given root span. - pub fn new(root_span: SpanValue, report_detached: bool) -> Self { + pub fn new(root_span: SpanValue, report_detached: bool, verbose: bool) -> Self { static ID: AtomicU64 = AtomicU64::new(0); let id = ID.fetch_add(1, Ordering::SeqCst); @@ -141,6 +150,7 @@ impl TraceContext { Self { id, report_detached, + verbose, arena, root, current: root, @@ -148,7 +158,7 @@ impl TraceContext { } /// Get the count of active span nodes in this context. - #[cfg_attr(not(test), expect(dead_code))] + #[cfg(test)] pub fn active_node_count(&self) -> usize { self.arena.iter().filter(|n| !n.is_removed()).count() } @@ -225,6 +235,11 @@ impl TraceContext { pub fn current(&self) -> NodeId { self.current } + + /// Whether the verbose span should be traced. + pub fn verbose(&self) -> bool { + self.verbose + } } tokio::task_local! { diff --git a/src/utils/async_stack_trace/src/lib.rs b/src/utils/async_stack_trace/src/lib.rs index 6916fffe114b4..af41d7cf17437 100644 --- a/src/utils/async_stack_trace/src/lib.rs +++ b/src/utils/async_stack_trace/src/lib.rs @@ -20,8 +20,7 @@ use std::pin::Pin; use std::task::Poll; use context::ContextId; -use futures::future::Fuse; -use futures::{Future, FutureExt}; +use futures::Future; use indextree::NodeId; use pin_project::{pin_project, pinned_drop}; use triomphe::Arc; @@ -32,7 +31,7 @@ mod context; mod manager; pub use context::current_context; -pub use manager::{StackTraceManager, StackTraceReport, TraceReporter}; +pub use manager::{StackTraceManager, StackTraceReport, TraceConfig, TraceReporter}; /// A cheaply-cloneable span string. #[derive(Debug, Clone)] @@ -91,11 +90,13 @@ enum StackTracedState { this_context: ContextId, }, Ready, + /// The stack trace is disabled due to `verbose` configuration. + Disabled, } /// The future for [`StackTrace::stack_trace`]. #[pin_project(PinnedDrop)] -pub struct StackTraced { +pub struct StackTraced { #[pin] inner: F, @@ -103,7 +104,7 @@ pub struct StackTraced { state: StackTracedState, } -impl StackTraced { +impl StackTraced { fn new(inner: F, span: impl Into) -> Self { Self { inner, @@ -112,22 +113,26 @@ impl StackTraced { } } -impl Future for StackTraced { +impl Future for StackTraced { type Output = F::Output; // TODO: may optionally enable based on the features fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { let this = self.project(); - let current_context = try_with_context(|c| c.id()); // For assertion. let old_current = try_with_context(|c| c.current()); let this_node = match this.state { StackTracedState::Initial(span) => { - match current_context { + match try_with_context(|c| (c.id(), c.verbose() >= VERBOSE)) { + // The tracing for this span is disabled according to the verbose configuration. + Some((_, false)) => { + *this.state = StackTracedState::Disabled; + return this.inner.poll(cx); + } // First polled - Some(current_context) => { + Some((current_context, true)) => { // First polled, push a new span to the context. let node = with_context(|c| c.push(std::mem::take(span))); *this.state = StackTracedState::Polled { @@ -144,7 +149,7 @@ impl Future for StackTraced { this_node, this_context, } => { - match current_context { + match try_with_context(|c| c.id()) { // Context correct Some(current_context) if current_context == *this_context => { // Polled before, just step in. @@ -164,6 +169,7 @@ impl Future for StackTraced { } } StackTracedState::Ready => unreachable!("the traced future should always be fused"), + StackTracedState::Disabled => return this.inner.poll(cx), }; // The current node must be the this_node. @@ -191,16 +197,16 @@ impl Future for StackTraced { } #[pinned_drop] -impl PinnedDrop for StackTraced { +impl PinnedDrop for StackTraced { fn drop(self: Pin<&mut Self>) { let this = self.project(); - let current_context = try_with_context(|c| c.id()); + let current_context = || try_with_context(|c| c.id()); match this.state { StackTracedState::Polled { this_node, this_context, - } => match current_context { + } => match current_context() { // Context correct Some(current_context) if current_context == *this_context => { with_context(|c| c.remove_and_detach(*this_node)); @@ -214,7 +220,8 @@ impl PinnedDrop for StackTraced { tracing::warn!("stack traced future is not in a traced context, while it was when first polled, cannot clean up!"); } }, - StackTracedState::Initial(_) | StackTracedState::Ready => {} + StackTracedState::Initial(_) | StackTracedState::Ready | StackTracedState::Disabled => { + } } } } @@ -222,8 +229,25 @@ impl PinnedDrop for StackTraced { pub trait StackTrace: Future + Sized { /// Wrap this future, so that we're able to check the stack trace and find where and why this /// future is pending, with [`StackTraceReport`] and [`StackTraceManager`]. - fn stack_trace(self, span: impl Into) -> Fuse> { - StackTraced::new(self, span).fuse() + fn stack_trace(self, span: impl Into) -> StackTraced { + StackTraced::new(self, span) + } + + /// Similar to [`stack_trace`], but the span is a verbose one, which means it will be traced + /// only if the verbose configuration is enabled. + #[cfg(not(debug_assertions))] + fn verbose_stack_trace(self, span: impl Into) -> StackTraced { + StackTraced::new(self, span) + } + + /// Similar to [`stack_trace`], but the span is a verbose one, which means it will be traced + /// only if the verbose configuration is enabled. + /// + /// With `debug_assertions` on, this span will be disabled statically to avoid affecting + /// performance too much. Therefore, `verbose` mode in [`TraceConfig`] is ignored. + #[cfg(debug_assertions)] + fn verbose_stack_trace(self, _span: impl Into) -> Self { + self } } impl StackTrace for F where F: Future {} diff --git a/src/utils/async_stack_trace/src/manager.rs b/src/utils/async_stack_trace/src/manager.rs index c07fc46bfe96c..6d30a13bf3810 100644 --- a/src/utils/async_stack_trace/src/manager.rs +++ b/src/utils/async_stack_trace/src/manager.rs @@ -54,6 +54,19 @@ impl std::fmt::Display for StackTraceReport { } } +/// Configuration for a traced context. +#[derive(Debug, Clone)] +pub struct TraceConfig { + /// Whether to report the futures that are not able to be polled now. + pub report_detached: bool, + + /// Whether to report the "verbose" stack trace. + pub verbose: bool, + + /// The interval to report the stack trace. + pub interval: Duration, +} + /// Used to start a reporter along with the traced future. pub struct TraceReporter { /// Used to send the report periodically to the manager. @@ -63,23 +76,24 @@ pub struct TraceReporter { impl TraceReporter { /// Provide a stack tracing context with the `root_span` for the given future. The reporter will /// be started along with this future in the current task and update the captured stack trace - /// report every `interval` time. - /// - /// If `report_detached` is true, the reporter will also report the futures that are not able to - /// be polled now. + /// report periodically. pub async fn trace( self, future: F, root_span: impl Into, - report_detached: bool, - interval: Duration, + TraceConfig { + report_detached, + verbose, + interval, + }: TraceConfig, ) -> F::Output { TRACE_CONTEXT .scope( - TraceContext::new(root_span.into(), report_detached).into(), + TraceContext::new(root_span.into(), report_detached, verbose).into(), async move { let reporter = async move { let mut interval = tokio::time::interval(interval); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay); loop { interval.tick().await; let new_trace = with_context(|c| c.to_report()); @@ -105,24 +119,6 @@ impl TraceReporter { ) .await } - - /// Optionally provide a stack tracing context. Check [`TraceReporter::trace`] for more details. - pub async fn optional_trace( - self, - future: F, - root_span: impl Into, - report_detached: bool, - interval: Duration, - enabled: bool, - ) -> F::Output { - if enabled { - self.trace(future, root_span, report_detached, interval) - .await - } else { - drop(self); // drop self so that the manager will find that the reporter is closed. - future.await - } - } } /// Manages the stack traces of multiple tasks. diff --git a/src/utils/async_stack_trace/src/tests.rs b/src/utils/async_stack_trace/src/tests.rs index e162f3667efc9..87f71f9dbeb19 100644 --- a/src/utils/async_stack_trace/src/tests.rs +++ b/src/utils/async_stack_trace/src/tests.rs @@ -15,11 +15,13 @@ use std::time::Duration; use futures::future::{join_all, select_all}; -use futures::StreamExt; +use futures::{FutureExt, StreamExt}; use futures_async_stream::stream; use tokio::sync::watch; -use super::*; +use crate::context::with_context; +use crate::manager::TraceConfig; +use crate::{StackTrace, TraceReporter}; async fn sleep(time: u64) { tokio::time::sleep(std::time::Duration::from_millis(time)).await; @@ -132,7 +134,15 @@ async fn test_stack_trace_display() { }); TraceReporter { tx: watch_tx } - .trace(hello(), "actor 233", true, Duration::from_millis(50)) + .trace( + hello(), + "actor 233", + TraceConfig { + report_detached: true, + verbose: true, + interval: Duration::from_millis(50), + }, + ) .await; collector.await.unwrap(); diff --git a/src/utils/memcomparable/src/de.rs b/src/utils/memcomparable/src/de.rs index 4a5caf51805bc..20583d0057fdf 100644 --- a/src/utils/memcomparable/src/de.rs +++ b/src/utils/memcomparable/src/de.rs @@ -161,22 +161,6 @@ impl Deserializer { Ok(byte_array) } - /// Read u8 from Bytes input in decimal form (Do not include null tag). Used by value encoding - /// (`serialize_cell`). TODO: It is a temporal solution For value encoding. Will moved to - /// value encoding serializer in future. - pub fn read_decimal_v2(&mut self) -> Result> { - let flag = self.input.get_u8(); - let mut byte_array = vec![flag]; - loop { - let byte = self.input.get_u8(); - if byte == 100 { - break; - } - byte_array.push(byte); - } - Ok(byte_array) - } - /// Read bytes_len without copy, it will consume offset pub fn read_bytes_len(&mut self) -> Result { use core::cmp; diff --git a/src/utils/runtime/Cargo.toml b/src/utils/runtime/Cargo.toml index dadb767d3ed5e..c1d0f4e9a90a8 100644 --- a/src/utils/runtime/Cargo.toml +++ b/src/utils/runtime/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" [dependencies] async-trait = "0.1" +async_stack_trace = { path = "../async_stack_trace" } console = "0.15" console-subscriber = "0.1.8" futures = { version = "0.3", default-features = false, features = ["alloc"] } diff --git a/src/utils/runtime/src/lib.rs b/src/utils/runtime/src/lib.rs index ead42d7495f73..d20d19a07ffeb 100644 --- a/src/utils/runtime/src/lib.rs +++ b/src/utils/runtime/src/lib.rs @@ -14,6 +14,8 @@ //! Configures the RisingWave binary, including logging, locks, panic handler, etc. +#![feature(panic_update_hook)] + use std::path::PathBuf; use std::time::Duration; @@ -74,15 +76,17 @@ impl LoggerSettings { } /// Set panic hook to abort the process (without losing debug info and stack trace). -pub fn set_panic_abort() { - use std::panic; +pub fn set_panic_hook() { + std::panic::update_hook(|default_hook, info| { + default_hook(info); - let default_hook = panic::take_hook(); + if let Some(context) = async_stack_trace::current_context() { + println!("\n\n*** async stack trace context of current task ***\n"); + println!("{}\n", context); + } - panic::set_hook(Box::new(move |info| { - default_hook(info); std::process::abort(); - })); + }); } /// Init logger for RisingWave binaries. @@ -228,7 +232,7 @@ pub fn main_okk(f: F) -> F::Output where F: Future + Send + 'static, { - set_panic_abort(); + set_panic_hook(); let mut builder = tokio::runtime::Builder::new_multi_thread();