diff --git a/.github/workflows/build-engines.yml b/.github/workflows/build-engines.yml index 6147eda68d1..db1c1a42ce6 100644 --- a/.github/workflows/build-engines.yml +++ b/.github/workflows/build-engines.yml @@ -50,10 +50,13 @@ jobs: if: ${{ github.event_name == 'pull_request' }} run: | echo "Pull Request: ${{ github.event.pull_request.number }}" - echo "Repository Owner: $${{ github.repository_owner }}" + echo "Repository Owner: ${{ github.repository_owner }}" echo "Pull Request Author: ${{ github.actor }}" echo "Pull Request Author Association: ${{ github.event.pull_request.author_association }}" - echo "Commit message:${{ steps.commit-msg.outputs.commit-msg }}" + cat <;trustServerCertificate=true;socket_timeout=60;isolationLevel=READ UNCOMMITTED" + ubuntu: "20.04" - name: mssql_2019 url: "sqlserver://localhost:1433;database=master;user=SA;password=;trustServerCertificate=true;socket_timeout=60;isolationLevel=READ UNCOMMITTED" - name: mssql_2022 @@ -104,7 +105,9 @@ jobs: is_vitess: true single_threaded: true - runs-on: ubuntu-latest + runs-on: "ubuntu-20.04" + # TODO: Replace with the following once `prisma@5.20.0` is released. + # runs-on: "ubuntu-${{ matrix.database.ubuntu || 'latest' }}" steps: - uses: actions/checkout@v4 - uses: actions-rust-lang/setup-rust-toolchain@v1 diff --git a/.github/workflows/utils/constructDockerBuildCommand.sh b/.github/workflows/utils/constructDockerBuildCommand.sh index 1d7adf80577..6ec9e307266 100644 --- a/.github/workflows/utils/constructDockerBuildCommand.sh +++ b/.github/workflows/utils/constructDockerBuildCommand.sh @@ -1,13 +1,18 @@ #!/bin/bash set -eux; -# full command +DOCKER_WORKSPACE="/root/build" + +# Full command, Docker + Bash. +# In Bash, we use `git config` to avoid "fatal: detected dubious ownership in repository at /root/build" panic messages +# that can occur when Prisma Engines' `build.rs` scripts run `git rev-parse HEAD` to extract the current commit hash. +# See: https://www.kenmuse.com/blog/avoiding-dubious-ownership-in-dev-containers/. command="docker run \ -e SQLITE_MAX_VARIABLE_NUMBER=250000 \ -e SQLITE_MAX_EXPR_DEPTH=10000 \ -e LIBZ_SYS_STATIC=1 \ --w /root/build \ --v \"$(pwd)\":/root/build \ +-w ${DOCKER_WORKSPACE} \ +-v \"$(pwd)\":${DOCKER_WORKSPACE} \ -v \"$HOME\"/.cargo/bin:/root/cargo/bin \ -v \"$HOME\"/.cargo/registry/index:/root/cargo/registry/index \ -v \"$HOME\"/.cargo/registry/cache:/root/cargo/registry/cache \ @@ -15,7 +20,8 @@ command="docker run \ $IMAGE \ bash -c \ \" \ - cargo clean \ + git config --global --add safe.directory ${DOCKER_WORKSPACE} \ + && cargo clean \ && cargo build --release -p query-engine --manifest-path query-engine/query-engine/Cargo.toml $TARGET_STRING $FEATURES_STRING \ && cargo build --release -p query-engine-node-api --manifest-path query-engine/query-engine-node-api/Cargo.toml $TARGET_STRING $FEATURES_STRING \ && cargo build --release -p schema-engine-cli --manifest-path schema-engine/cli/Cargo.toml $TARGET_STRING $FEATURES_STRING \ diff --git a/Cargo.lock b/Cargo.lock index 92f9486bfea..b915dbf08a6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -30,9 +30,9 @@ dependencies = [ [[package]] name = "ahash" -version = "0.8.7" +version = "0.8.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77c3a9648d43b9cd48db467b3f87fdd6e146bcc88ab0180006cef2179fe11d01" +checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" dependencies = [ "cfg-if", "getrandom 0.2.11", @@ -41,15 +41,6 @@ dependencies = [ "zerocopy", ] -[[package]] -name = "aho-corasick" -version = "0.7.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc936419f96fa211c1b9166887b38e5e40b19958e5b895be7c1f93adec7071ac" -dependencies = [ - "memchr", -] - [[package]] name = "aho-corasick" version = "1.0.3" @@ -164,6 +155,34 @@ dependencies = [ "syn 2.0.58", ] +[[package]] +name = "async-tungstenite" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90e661b6cb0a6eb34d02c520b052daa3aa9ac0cc02495c9d066bbce13ead132b" +dependencies = [ + "futures-io", + "futures-util", + "log", + "native-tls", + "pin-project-lite", + "tokio", + "tokio-native-tls", + "tungstenite", +] + +[[package]] +name = "async_io_stream" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6d7b9decdf35d8908a7e3ef02f64c5e9b1695e230154c0e8de3969142d9b94c" +dependencies = [ + "futures", + "pharos", + "rustc_version", + "tokio", +] + [[package]] name = "asynchronous-codec" version = "0.6.2" @@ -186,15 +205,6 @@ dependencies = [ "num-traits", ] -[[package]] -name = "atomic-shim" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67cd4b51d303cf3501c301e8125df442128d3c6d7c69f71b27833d253de47e77" -dependencies = [ - "crossbeam-utils", -] - [[package]] name = "atty" version = "0.2.14" @@ -340,7 +350,7 @@ dependencies = [ "enumflags2", "indoc 2.0.3", "insta", - "query-engine-metrics", + "prisma-metrics", "query-engine-tests", "query-tests-setup", "regex", @@ -419,7 +429,7 @@ version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d8a88e82b9106923b5c4d6edfca9e7db958d4e98a478ec115022e81b9b38e2c8" dependencies = [ - "ahash 0.8.7", + "ahash 0.8.11", "base64 0.13.1", "bitvec", "chrono", @@ -444,6 +454,10 @@ dependencies = [ "memchr", ] +[[package]] +name = "build-utils" +version = "0.1.0" + [[package]] name = "bumpalo" version = "3.13.0" @@ -480,9 +494,9 @@ checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" [[package]] name = "bytes" -version = "1.4.0" +version = "1.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be" +checksum = "428d9aa8fbc0670b7b8d6030a7fadd0f86151cae55e4dbbece15f3780a3dfaf3" [[package]] name = "cast" @@ -511,11 +525,11 @@ dependencies = [ [[package]] name = "cc" -version = "1.0.83" +version = "1.1.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" +checksum = "b16803a61b81d9eabb7eae2588776c4c1e584b738ede45fdbb4c972cec1e9945" dependencies = [ - "libc", + "shlex", ] [[package]] @@ -891,9 +905,12 @@ dependencies = [ name = "crosstarget-utils" version = "0.1.0" dependencies = [ + "derive_more", + "enumflags2", "futures", "js-sys", "pin-project", + "regex", "tokio", "wasm-bindgen", "wasm-bindgen-futures", @@ -1066,7 +1083,7 @@ dependencies = [ "hashbrown 0.14.5", "lock_api", "once_cell", - "parking_lot_core 0.9.8", + "parking_lot_core", ] [[package]] @@ -1191,11 +1208,11 @@ dependencies = [ "expect-test", "futures", "js-sys", - "metrics 0.18.1", "napi", "napi-derive", "once_cell", "pin-project", + "prisma-metrics", "quaint", "serde", "serde-wasm-bindgen", @@ -1223,70 +1240,6 @@ version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" -[[package]] -name = "encoding" -version = "0.2.33" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b0d943856b990d12d3b55b359144ff341533e516d94098b1d3fc1ac666d36ec" -dependencies = [ - "encoding-index-japanese", - "encoding-index-korean", - "encoding-index-simpchinese", - "encoding-index-singlebyte", - "encoding-index-tradchinese", -] - -[[package]] -name = "encoding-index-japanese" -version = "1.20141219.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04e8b2ff42e9a05335dbf8b5c6f7567e5591d0d916ccef4e0b1710d32a0d0c91" -dependencies = [ - "encoding_index_tests", -] - -[[package]] -name = "encoding-index-korean" -version = "1.20141219.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4dc33fb8e6bcba213fe2f14275f0963fd16f0a02c878e3095ecfdf5bee529d81" -dependencies = [ - "encoding_index_tests", -] - -[[package]] -name = "encoding-index-simpchinese" -version = "1.20141219.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d87a7194909b9118fc707194baa434a4e3b0fb6a5a757c73c3adb07aa25031f7" -dependencies = [ - "encoding_index_tests", -] - -[[package]] -name = "encoding-index-singlebyte" -version = "1.20141219.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3351d5acffb224af9ca265f435b859c7c01537c0849754d3db3fdf2bfe2ae84a" -dependencies = [ - "encoding_index_tests", -] - -[[package]] -name = "encoding-index-tradchinese" -version = "1.20141219.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd0e20d5688ce3cab59eb3ef3a2083a5c77bf496cb798dc6fcdb75f323890c18" -dependencies = [ - "encoding_index_tests", -] - -[[package]] -name = "encoding_index_tests" -version = "0.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a246d82be1c9d791c5dfde9a2bd045fc3cbba3fa2b11ad558f27d01712f00569" - [[package]] name = "encoding_rs" version = "0.8.32" @@ -1596,7 +1549,7 @@ checksum = "1d930c203dd0b6ff06e0201a4a2fe9149b43c684fd4420555b26d21b1a02956f" dependencies = [ "futures-core", "lock_api", - "parking_lot 0.12.1", + "parking_lot", ] [[package]] @@ -1732,7 +1685,7 @@ dependencies = [ "futures-core", "futures-sink", "futures-util", - "http", + "http 0.2.9", "indexmap 2.2.2", "slab", "tokio", @@ -1746,15 +1699,6 @@ version = "1.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eabb4a44450da02c90444cf74558da904edde8fb4e9035a9a6a4e15445af0bd7" -[[package]] -name = "hashbrown" -version = "0.11.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab5ef0d4909ef3724cc8cce6ccc8572c5c817592e9285f5464f8e86f8bd3726e" -dependencies = [ - "ahash 0.7.8", -] - [[package]] name = "hashbrown" version = "0.12.3" @@ -1770,7 +1714,7 @@ version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" dependencies = [ - "ahash 0.8.7", + "ahash 0.8.11", "allocator-api2", ] @@ -1855,7 +1799,7 @@ dependencies = [ "ipconfig", "lru-cache", "once_cell", - "parking_lot 0.12.1", + "parking_lot", "rand 0.8.5", "resolv-conf", "smallvec", @@ -1904,6 +1848,17 @@ dependencies = [ "itoa", ] +[[package]] +name = "http" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21b9ddb458710bc376481b842f5da65cdf31522de232c1ca8146abce2a358258" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + [[package]] name = "http-body" version = "0.4.5" @@ -1911,7 +1866,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d5f38f16d184e36f2408a55281cd658ecbd3ca05cce6d6510a176eca393e26d1" dependencies = [ "bytes", - "http", + "http 0.2.9", "pin-project-lite", ] @@ -1938,7 +1893,7 @@ dependencies = [ "futures-core", "futures-util", "h2", - "http", + "http 0.2.9", "http-body", "httparse", "httpdate", @@ -2092,15 +2047,6 @@ dependencies = [ "yaml-rust", ] -[[package]] -name = "instant" -version = "0.1.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a5bbe824c507c5da5956355e86a746d82e0e1464f65d862cc5e71da70e94b2c" -dependencies = [ - "cfg-if", -] - [[package]] name = "ipconfig" version = "0.3.2" @@ -2398,15 +2344,6 @@ dependencies = [ "url", ] -[[package]] -name = "mach" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b823e83b2affd8f40a9ee8c29dbc56404c1e34cd2710921f2801e2cf29527afa" -dependencies = [ - "libc", -] - [[package]] name = "match_cfg" version = "0.1.0" @@ -2454,91 +2391,47 @@ dependencies = [ [[package]] name = "metrics" -version = "0.18.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e52eb6380b6d2a10eb3434aec0885374490f5b82c8aaf5cd487a183c98be834" -dependencies = [ - "ahash 0.7.8", - "metrics-macros", -] - -[[package]] -name = "metrics" -version = "0.19.0" +version = "0.23.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "142c53885123b68d94108295a09d4afe1a1388ed95b54d5dacd9a454753030f2" +checksum = "884adb57038347dfbaf2d5065887b6cf4312330dc8e94bc30a1a839bd79d3261" dependencies = [ - "ahash 0.7.8", - "metrics-macros", + "ahash 0.8.11", + "portable-atomic", ] [[package]] name = "metrics-exporter-prometheus" -version = "0.10.0" +version = "0.15.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "953cbbb6f9ba4b9304f4df79b98cdc9d14071ed93065a9fca11c00c5d9181b66" +checksum = "b4f0c8427b39666bf970460908b213ec09b3b350f20c0c2eabcbba51704a08e6" dependencies = [ - "hyper", - "indexmap 1.9.3", - "ipnet", - "metrics 0.19.0", - "metrics-util 0.13.0", - "parking_lot 0.11.2", + "base64 0.22.1", + "indexmap 2.2.2", + "metrics", + "metrics-util", "quanta", "thiserror", - "tokio", - "tracing", -] - -[[package]] -name = "metrics-macros" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49e30813093f757be5cf21e50389a24dc7dbb22c49f23b7e8f51d69b508a5ffa" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", ] [[package]] name = "metrics-util" -version = "0.12.1" +version = "0.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "65a9e83b833e1d2e07010a386b197c13aa199bbd0fca5cf69bfa147972db890a" +checksum = "4259040465c955f9f2f1a4a8a16dc46726169bca0f88e8fb2dbeced487c3e828" dependencies = [ - "aho-corasick 0.7.20", - "atomic-shim", + "aho-corasick", "crossbeam-epoch", "crossbeam-utils", - "hashbrown 0.11.2", - "indexmap 1.9.3", - "metrics 0.18.1", + "hashbrown 0.14.5", + "indexmap 2.2.2", + "metrics", "num_cpus", "ordered-float", - "parking_lot 0.11.2", "quanta", "radix_trie", "sketches-ddsketch", ] -[[package]] -name = "metrics-util" -version = "0.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd1f4b69bef1e2b392b2d4a12902f2af90bb438ba4a66aa222d1023fa6561b50" -dependencies = [ - "atomic-shim", - "crossbeam-epoch", - "crossbeam-utils", - "hashbrown 0.11.2", - "metrics 0.19.0", - "num_cpus", - "parking_lot 0.11.2", - "quanta", - "sketches-ddsketch", -] - [[package]] name = "mime" version = "0.3.17" @@ -2574,9 +2467,9 @@ dependencies = [ [[package]] name = "mobc" -version = "0.8.3" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90eb49dc5d193287ff80e72a86f34cfb27aae562299d22fea215e06ea1059dd3" +checksum = "316a7d198b51958a0ab57248bf5f42d8409551203cb3c821d5925819a8d5415f" dependencies = [ "async-trait", "futures-channel", @@ -2584,7 +2477,7 @@ dependencies = [ "futures-timer", "futures-util", "log", - "metrics 0.18.1", + "metrics", "thiserror", "tokio", "tracing", @@ -2675,15 +2568,16 @@ dependencies = [ "mongodb", "mongodb-client", "pretty_assertions", + "prisma-metrics", "prisma-value", "psl", "query-connector", - "query-engine-metrics", "query-structure", "rand 0.8.5", "regex", "serde", "serde_json", + "telemetry", "thiserror", "tokio", "tracing", @@ -2837,9 +2731,9 @@ dependencies = [ [[package]] name = "napi" -version = "2.15.1" +version = "2.16.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43792514b0c95c5beec42996da0c1b39265b02b75c97baa82d163d3ef55cbfa7" +checksum = "214f07a80874bb96a8433b3cdfc84980d56c7b02e1a0d7ba4ba0db5cef785e2b" dependencies = [ "bitflags 2.4.0", "ctor", @@ -2859,9 +2753,9 @@ checksum = "ebd4419172727423cf30351406c54f6cc1b354a2cfb4f1dba3e6cd07f6d5522b" [[package]] name = "napi-derive" -version = "2.15.0" +version = "2.16.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7622f0dbe0968af2dacdd64870eee6dee94f93c989c841f1ad8f300cf1abd514" +checksum = "17435f7a00bfdab20b0c27d9c56f58f6499e418252253081bfff448099da31d1" dependencies = [ "cfg-if", "convert_case 0.6.0", @@ -2873,9 +2767,9 @@ dependencies = [ [[package]] name = "napi-derive-backend" -version = "1.0.59" +version = "1.0.74" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ec514d65fce18a959be55e7f683ac89c6cb850fb59b09e25ab777fd5a4a8d9e" +checksum = "967c485e00f0bf3b1bdbe510a38a4606919cf1d34d9a37ad41f25a81aa077abe" dependencies = [ "convert_case 0.6.0", "once_cell", @@ -2888,9 +2782,9 @@ dependencies = [ [[package]] name = "napi-sys" -version = "2.3.0" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2503fa6af34dc83fb74888df8b22afe933b58d37daf7d80424b1c60c68196b8b" +checksum = "427802e8ec3a734331fec1035594a210ce1ff4dc5bc1950530920ab717964ea3" dependencies = [ "libloading 0.8.1", ] @@ -3063,9 +2957,9 @@ checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" [[package]] name = "openssl" -version = "0.10.66" +version = "0.10.60" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9529f4786b70a3e8c61e11179af17ab6188ad8d0ded78c5529441ed39d4bd9c1" +checksum = "79a4c6c3a2b158f7f8f2a2fc5a969fa3a068df6fc9dbb4a43845436e3af7c800" dependencies = [ "bitflags 2.4.0", "cfg-if", @@ -3095,18 +2989,18 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "openssl-src" -version = "300.3.1+3.3.1" +version = "300.1.6+3.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7259953d42a81bf137fbbd73bd30a8e1914d6dce43c2b90ed575783a22608b91" +checksum = "439fac53e092cd7442a3660c85dde4643ab3b5bd39040912388dcdabf6b88085" dependencies = [ "cc", ] [[package]] name = "openssl-sys" -version = "0.9.103" +version = "0.9.96" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f9e8deee91df40a943c71b917e5874b951d32a802526c85721ce3b776c929d6" +checksum = "3812c071ba60da8b5677cc12bcb1d42989a65553772897a7e0355545a819838f" dependencies = [ "cc", "libc", @@ -3146,7 +3040,7 @@ dependencies = [ "async-trait", "futures", "futures-util", - "http", + "http 0.2.9", "opentelemetry", "prost", "thiserror", @@ -3171,9 +3065,9 @@ dependencies = [ [[package]] name = "ordered-float" -version = "2.10.0" +version = "4.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7940cf2ca942593318d07fcf2596cdca60a85c9e7fab408a5e21a4f9dcd40d87" +checksum = "44d501f1a72f71d3c063a6bbc8f7271fa73aa09fe5d6283b6571e2ed176a2537" dependencies = [ "num-traits", ] @@ -3202,17 +3096,6 @@ version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bb813b8af86854136c6922af0598d719255ecb2179515e6e7730d468f05c9cae" -[[package]] -name = "parking_lot" -version = "0.11.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d17b78036a60663b797adeaee46f5c9dfebb86948d1255007a1d6be0271ff99" -dependencies = [ - "instant", - "lock_api", - "parking_lot_core 0.8.6", -] - [[package]] name = "parking_lot" version = "0.12.1" @@ -3220,21 +3103,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" dependencies = [ "lock_api", - "parking_lot_core 0.9.8", -] - -[[package]] -name = "parking_lot_core" -version = "0.8.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60a2cfe6f0ad2bfc16aefa463b497d5c7a5ecd44a23efa72aa342d90177356dc" -dependencies = [ - "cfg-if", - "instant", - "libc", - "redox_syscall 0.2.16", - "smallvec", - "winapi", + "parking_lot_core", ] [[package]] @@ -3375,6 +3244,16 @@ dependencies = [ "indexmap 1.9.3", ] +[[package]] +name = "pharos" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9567389417feee6ce15dd6527a8a1ecac205ef62c2932bcf3d9f6fc5b78b414" +dependencies = [ + "futures", + "rustc_version", +] + [[package]] name = "phf" version = "0.11.2" @@ -3459,10 +3338,16 @@ dependencies = [ "plotters-backend", ] +[[package]] +name = "portable-atomic" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc9c68a3f6da06753e9335d63e27f6b9754dd1920d941135b7ea8224f141adb2" + [[package]] name = "postgres-native-tls" version = "0.5.0" -source = "git+https://github.com/prisma/rust-postgres?branch=pgbouncer-mode#54a490bc6afa315abb9867304fb67e8b12a8fbf3" +source = "git+https://github.com/prisma/rust-postgres?branch=pgbouncer-mode#c62b9928d402685e152161907e8480603c29ef65" dependencies = [ "native-tls", "tokio", @@ -3472,10 +3357,10 @@ dependencies = [ [[package]] name = "postgres-protocol" -version = "0.6.4" -source = "git+https://github.com/prisma/rust-postgres?branch=pgbouncer-mode#54a490bc6afa315abb9867304fb67e8b12a8fbf3" +version = "0.6.7" +source = "git+https://github.com/prisma/rust-postgres?branch=pgbouncer-mode#c62b9928d402685e152161907e8480603c29ef65" dependencies = [ - "base64 0.13.1", + "base64 0.22.1", "byteorder", "bytes", "fallible-iterator 0.2.0", @@ -3489,8 +3374,8 @@ dependencies = [ [[package]] name = "postgres-types" -version = "0.2.4" -source = "git+https://github.com/prisma/rust-postgres?branch=pgbouncer-mode#54a490bc6afa315abb9867304fb67e8b12a8fbf3" +version = "0.2.8" +source = "git+https://github.com/prisma/rust-postgres?branch=pgbouncer-mode#c62b9928d402685e152161907e8480603c29ef65" dependencies = [ "bit-vec", "bytes", @@ -3544,6 +3429,7 @@ dependencies = [ name = "prisma-fmt" version = "0.1.0" dependencies = [ + "build-utils", "colored", "dissimilar", "dmmf", @@ -3559,6 +3445,27 @@ dependencies = [ "structopt", ] +[[package]] +name = "prisma-metrics" +version = "0.1.0" +dependencies = [ + "derive_more", + "expect-test", + "futures", + "metrics", + "metrics-exporter-prometheus", + "metrics-util", + "once_cell", + "parking_lot", + "pin-project", + "serde", + "serde_json", + "tokio", + "tracing", + "tracing-futures", + "tracing-subscriber", +] + [[package]] name = "prisma-schema-build" version = "0.1.0" @@ -3765,6 +3672,7 @@ name = "quaint" version = "0.2.0-alpha.13" dependencies = [ "async-trait", + "async-tungstenite", "base64 0.12.3", "bigdecimal", "bit-vec", @@ -3776,6 +3684,7 @@ dependencies = [ "connection-string", "crosstarget-utils", "either", + "enumflags2", "expect-test", "futures", "getrandom 0.2.11", @@ -3783,7 +3692,6 @@ dependencies = [ "indoc 0.3.6", "itertools 0.12.0", "lru-cache", - "metrics 0.18.1", "mobc", "mysql_async", "names 0.11.0", @@ -3794,8 +3702,10 @@ dependencies = [ "percent-encoding", "postgres-native-tls", "postgres-types", + "prisma-metrics", "quaint-test-macros", "quaint-test-setup", + "regex", "rusqlite", "serde", "serde_json", @@ -3804,11 +3714,12 @@ dependencies = [ "tiberius", "tokio", "tokio-postgres", - "tokio-util 0.6.10", + "tokio-util 0.7.8", "tracing", - "tracing-core", + "tracing-futures", "url", "uuid", + "ws_stream_tungstenite", ] [[package]] @@ -3837,16 +3748,15 @@ dependencies = [ [[package]] name = "quanta" -version = "0.9.3" +version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "20afe714292d5e879d8b12740aa223c6a88f118af41870e8b6196e39a02238a8" +checksum = "8e5167a477619228a0b284fac2674e3c388cba90631d7b7de620e6f1fcd08da5" dependencies = [ "crossbeam-utils", "libc", - "mach", "once_cell", "raw-cpuid", - "wasi 0.10.2+wasi-snapshot-preview1", + "wasi 0.11.0+wasi-snapshot-preview1", "web-sys", "winapi", ] @@ -3865,6 +3775,7 @@ dependencies = [ "query-structure", "serde", "serde_json", + "telemetry", "thiserror", "user-facing-errors", "uuid", @@ -3881,6 +3792,7 @@ dependencies = [ "crossbeam-channel", "crosstarget-utils", "cuid", + "derive_more", "enumflags2", "futures", "indexmap 2.2.2", @@ -3889,13 +3801,14 @@ dependencies = [ "once_cell", "opentelemetry", "petgraph 0.4.13", + "prisma-metrics", "psl", "query-connector", - "query-engine-metrics", "query-structure", "schema", "serde", "serde_json", + "telemetry", "thiserror", "tokio", "tracing", @@ -3913,6 +3826,7 @@ dependencies = [ "anyhow", "async-trait", "base64 0.13.1", + "build-utils", "connection-string", "enumflags2", "graphql-parser", @@ -3921,17 +3835,18 @@ dependencies = [ "mongodb-query-connector", "opentelemetry", "opentelemetry-otlp", + "prisma-metrics", "psl", "quaint", "query-connector", "query-core", - "query-engine-metrics", "request-handlers", "serde", "serde_json", "serial_test", "sql-query-connector", "structopt", + "telemetry", "thiserror", "tokio", "tracing", @@ -3947,6 +3862,7 @@ version = "0.1.0" dependencies = [ "anyhow", "async-trait", + "build-utils", "cbindgen", "chrono", "connection-string", @@ -3965,6 +3881,7 @@ dependencies = [ "serde", "serde_json", "sql-query-connector", + "telemetry", "thiserror", "tokio", "tracing", @@ -3984,12 +3901,13 @@ dependencies = [ "connection-string", "napi", "opentelemetry", + "prisma-metrics", "psl", "query-connector", "query-core", - "query-engine-metrics", "serde", "serde_json", + "telemetry", "thiserror", "tracing", "tracing-futures", @@ -4001,30 +3919,13 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "query-engine-metrics" -version = "0.1.0" -dependencies = [ - "expect-test", - "metrics 0.18.1", - "metrics-exporter-prometheus", - "metrics-util 0.12.1", - "once_cell", - "parking_lot 0.12.1", - "serde", - "serde_json", - "tokio", - "tracing", - "tracing-futures", - "tracing-subscriber", -] - [[package]] name = "query-engine-node-api" version = "0.1.0" dependencies = [ "anyhow", "async-trait", + "build-utils", "connection-string", "driver-adapters", "futures", @@ -4032,17 +3933,18 @@ dependencies = [ "napi-build", "napi-derive", "opentelemetry", + "prisma-metrics", "psl", "quaint", "query-connector", "query-core", "query-engine-common", - "query-engine-metrics", "query-structure", "request-handlers", "serde", "serde_json", "sql-query-connector", + "telemetry", "thiserror", "tokio", "tracing", @@ -4068,9 +3970,9 @@ dependencies = [ "itertools 0.12.0", "once_cell", "paste", + "prisma-metrics", "prisma-value", "psl", - "query-engine-metrics", "query-test-macros", "query-tests-setup", "serde_json", @@ -4087,6 +3989,7 @@ version = "0.1.0" dependencies = [ "anyhow", "async-trait", + "build-utils", "connection-string", "driver-adapters", "futures", @@ -4103,6 +4006,7 @@ dependencies = [ "serde-wasm-bindgen", "serde_json", "sql-query-connector", + "telemetry", "thiserror", "tokio", "tracing", @@ -4159,12 +4063,12 @@ dependencies = [ "nom", "once_cell", "parse-hyperlinks", + "prisma-metrics", "psl", "qe-setup", "quaint", "query-core", "query-engine", - "query-engine-metrics", "query-structure", "regex", "request-handlers", @@ -4172,6 +4076,7 @@ dependencies = [ "serde_json", "sql-query-connector", "strip-ansi-escapes", + "telemetry", "thiserror", "tokio", "tracing", @@ -4323,11 +4228,11 @@ dependencies = [ [[package]] name = "raw-cpuid" -version = "10.7.0" +version = "11.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c297679cb867470fa8c9f67dbba74a78d78e3e98d7cf2b08d6d71540f797332" +checksum = "1ab240315c661615f2ee9f0f2cd32d5a7343a84d5ebcccb99d46e6637565e7b0" dependencies = [ - "bitflags 1.3.2", + "bitflags 2.4.0", ] [[package]] @@ -4363,29 +4268,29 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.2.16" +version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb5a58c1855b4b6819d59012155603f0b22ad30cad752600aadfcb695265519a" +checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29" dependencies = [ "bitflags 1.3.2", ] [[package]] name = "redox_syscall" -version = "0.3.5" +version = "0.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29" +checksum = "9b6dfecf2c74bce2466cabf93f6664d6998a69eb21e39f4207930065b27b771f" dependencies = [ - "bitflags 1.3.2", + "bitflags 2.4.0", ] [[package]] name = "regex" -version = "1.10.3" +version = "1.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b62dbe01f0b06f9d8dc7d49e05a0785f153b00b2c227856282f671e0318c9b15" +checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619" dependencies = [ - "aho-corasick 1.0.3", + "aho-corasick", "memchr", "regex-automata 0.4.5", "regex-syntax 0.8.2", @@ -4406,7 +4311,7 @@ version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5bb987efffd3c6d0d8f5f89510bb458559eab11e4f869acb20bf845e016259cd" dependencies = [ - "aho-corasick 1.0.3", + "aho-corasick", "memchr", "regex-syntax 0.8.2", ] @@ -4456,6 +4361,7 @@ dependencies = [ "serde", "serde_json", "sql-query-connector", + "telemetry", "thiserror", "tracing", "url", @@ -4474,7 +4380,7 @@ dependencies = [ "futures-core", "futures-util", "h2", - "http", + "http 0.2.9", "http-body", "hyper", "hyper-tls", @@ -4795,6 +4701,7 @@ version = "0.1.0" dependencies = [ "backtrace", "base64 0.13.1", + "build-utils", "connection-string", "expect-test", "indoc 2.0.3", @@ -5003,7 +4910,7 @@ dependencies = [ "futures", "lazy_static", "log", - "parking_lot 0.12.1", + "parking_lot", "serial_test_derive", ] @@ -5121,9 +5028,9 @@ checksum = "7bd3e3206899af3f8b12af284fafc038cc1dc2b41d1b89dd17297221c5d225de" [[package]] name = "sketches-ddsketch" -version = "0.1.3" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04d2ecae5fcf33b122e2e6bd520a57ccf152d2dde3b38c71039df1a6867264ee" +checksum = "85636c14b73d81f541e525f585c0a2109e6744e1565b5c1668e31c70c10ed65c" [[package]] name = "slab" @@ -5136,9 +5043,9 @@ dependencies = [ [[package]] name = "smallvec" -version = "1.11.0" +version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62bb4feee49fdd9f707ef802e22365a35de4b7b299de4763d44bfea899442ff9" +checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" [[package]] name = "socket2" @@ -5262,6 +5169,7 @@ dependencies = [ "rand 0.8.5", "serde", "serde_json", + "telemetry", "thiserror", "tokio", "tracing", @@ -5542,6 +5450,36 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" +[[package]] +name = "telemetry" +version = "0.1.0" +dependencies = [ + "async-trait", + "crossbeam-channel", + "crosstarget-utils", + "cuid", + "derive_more", + "enumflags2", + "futures", + "indexmap 2.2.2", + "itertools 0.12.0", + "lru 0.7.8", + "once_cell", + "opentelemetry", + "prisma-metrics", + "psl", + "rand 0.8.5", + "serde", + "serde_json", + "thiserror", + "tokio", + "tracing", + "tracing-futures", + "tracing-opentelemetry", + "tracing-subscriber", + "uuid", +] + [[package]] name = "tempfile" version = "3.7.1" @@ -5570,6 +5508,7 @@ version = "0.1.0" dependencies = [ "anyhow", "async-trait", + "build-utils", "colored", "dmmf", "enumflags2", @@ -5656,9 +5595,9 @@ dependencies = [ [[package]] name = "tiberius" -version = "0.11.8" +version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "091052ba8f20c1e14f85913a5242a663a09d17ff4c0137b9b1f0735cb3c5dabc" +checksum = "a1446cb4198848d1562301a3340424b4f425ef79f35ef9ee034769a9dd92c10d" dependencies = [ "async-native-tls", "async-trait", @@ -5668,10 +5607,8 @@ dependencies = [ "bytes", "chrono", "connection-string", - "encoding", + "encoding_rs", "enumflags2", - "futures", - "futures-sink", "futures-util", "num-traits", "once_cell", @@ -5753,7 +5690,7 @@ dependencies = [ "libc", "mio", "num_cpus", - "parking_lot 0.12.1", + "parking_lot", "pin-project-lite", "signal-hook-registry", "socket2 0.5.7", @@ -5794,8 +5731,8 @@ dependencies = [ [[package]] name = "tokio-postgres" -version = "0.7.7" -source = "git+https://github.com/prisma/rust-postgres?branch=pgbouncer-mode#54a490bc6afa315abb9867304fb67e8b12a8fbf3" +version = "0.7.12" +source = "git+https://github.com/prisma/rust-postgres?branch=pgbouncer-mode#c62b9928d402685e152161907e8480603c29ef65" dependencies = [ "async-trait", "byteorder", @@ -5804,15 +5741,17 @@ dependencies = [ "futures-channel", "futures-util", "log", - "parking_lot 0.12.1", + "parking_lot", "percent-encoding", "phf", "pin-project-lite", "postgres-protocol", "postgres-types", + "rand 0.8.5", "socket2 0.5.7", "tokio", "tokio-util 0.7.8", + "whoami", ] [[package]] @@ -5855,7 +5794,6 @@ checksum = "36943ee01a6d67977dd3f84a5a1d2efeb4ada3a1ae771cadfaa535d9d9fc6507" dependencies = [ "bytes", "futures-core", - "futures-io", "futures-sink", "log", "pin-project-lite", @@ -5899,7 +5837,7 @@ dependencies = [ "futures-core", "futures-util", "h2", - "http", + "http 0.2.9", "http-body", "hyper", "hyper-timeout", @@ -6105,6 +6043,25 @@ dependencies = [ "syn 2.0.58", ] +[[package]] +name = "tungstenite" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18e5b8366ee7a95b16d32197d0b2604b43a0be89dc5fac9f8e96ccafbaedda8a" +dependencies = [ + "byteorder", + "bytes", + "data-encoding", + "http 1.1.0", + "httparse", + "log", + "native-tls", + "rand 0.8.5", + "sha1", + "thiserror", + "utf-8", +] + [[package]] name = "twox-hash" version = "1.6.3" @@ -6112,7 +6069,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "97fee6b57c6a41524a810daee9286c02d7752c4253064d0b05472833a438f675" dependencies = [ "cfg-if", - "rand 0.3.23", + "rand 0.8.5", "static_assertions", ] @@ -6246,6 +6203,12 @@ dependencies = [ "user-facing-error-macros", ] +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "utf8-width" version = "0.1.6" @@ -6357,31 +6320,32 @@ checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519" [[package]] name = "wasi" -version = "0.10.2+wasi-snapshot-preview1" +version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd6fbd9a79829dd1ad0cc20627bf1ed606756a7f77edff7b66b7064f9cb327c6" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] -name = "wasi" -version = "0.11.0+wasi-snapshot-preview1" +name = "wasite" +version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +checksum = "b8dad83b4f25e74f184f64c43b150b91efe7647395b42289f38e50566d82855b" [[package]] name = "wasm-bindgen" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8" +checksum = "a82edfc16a6c469f5f44dc7b571814045d60404b55a0ee849f9bcfa2e63dd9b5" dependencies = [ "cfg-if", + "once_cell", "wasm-bindgen-macro", ] [[package]] name = "wasm-bindgen-backend" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da" +checksum = "9de396da306523044d3302746f1208fa71d7532227f15e347e2d93e4145dd77b" dependencies = [ "bumpalo", "log", @@ -6406,9 +6370,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726" +checksum = "585c4c91a46b072c92e908d99cb1dcdf95c5218eeb6f3bf1efa991ee7a68cccf" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -6416,9 +6380,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" +checksum = "afc340c74d9005395cf9dd098506f7f44e38f2b4a21c6aaacf9a105ea5e1e836" dependencies = [ "proc-macro2", "quote", @@ -6429,9 +6393,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" +checksum = "c62a0a307cb4a311d3a07867860911ca130c3494e8c2719593806c08bc5d0484" [[package]] name = "wasm-logger" @@ -6490,6 +6454,17 @@ dependencies = [ "once_cell", ] +[[package]] +name = "whoami" +version = "1.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "372d5b87f58ec45c384ba03563b03544dc5fadc3983e434b286913f5b4a9bb6d" +dependencies = [ + "redox_syscall 0.5.7", + "wasite", + "web-sys", +] + [[package]] name = "widestring" version = "1.0.2" @@ -6766,6 +6741,26 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "ws_stream_tungstenite" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed39ff9f8b2eda91bf6390f9f49eee93d655489e15708e3bb638c1c4f07cecb4" +dependencies = [ + "async-tungstenite", + "async_io_stream", + "bitflags 2.4.0", + "futures-core", + "futures-io", + "futures-sink", + "futures-util", + "pharos", + "rustc_version", + "tokio", + "tracing", + "tungstenite", +] + [[package]] name = "wyz" version = "0.5.1" diff --git a/Cargo.toml b/Cargo.toml index 97c5162a3aa..df9f44eee29 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,7 +20,6 @@ members = [ "query-engine/black-box-tests", "query-engine/dmmf", "query-engine/driver-adapters", - "query-engine/metrics", "query-engine/query-structure", "query-engine/query-engine", "query-engine/query-engine-node-api", @@ -38,6 +37,7 @@ members = [ [workspace.dependencies] async-trait = { version = "0.1.77" } enumflags2 = { version = "0.7", features = ["serde"] } +futures = "0.3" psl = { path = "./psl/psl" } serde_json = { version = "1", features = ["float_roundtrip", "preserve_order", "raw_value"] } serde = { version = "1", features = ["derive"] } @@ -51,25 +51,29 @@ tokio = { version = "1", features = [ "time", ] } chrono = { version = "0.4.38", features = ["serde"] } +derive_more = "0.99.17" user-facing-errors = { path = "./libs/user-facing-errors" } uuid = { version = "1", features = ["serde", "v4", "v7", "js"] } indoc = "2.0.1" indexmap = { version = "2.2.2", features = ["serde"] } itertools = "0.12" connection-string = "0.2" -napi = { version = "2.15.1", default-features = false, features = [ - "napi8", +napi = { version = "2.16.13", default-features = false, features = [ + "napi9", "tokio_rt", "serde-json", ] } -napi-derive = "2.15.0" +napi-derive = "2.16.12" js-sys = { version = "0.3" } +pin-project = "1" rand = { version = "0.8" } +regex = { version = "1", features = ["std"] } serde_repr = { version = "0.1.17" } serde-wasm-bindgen = { version = "0.5" } tracing = { version = "0.1" } +tracing-futures = "0.2" tsify = { version = "0.4.5" } -wasm-bindgen = { version = "0.2.92" } +wasm-bindgen = { version = "0.2.93" } wasm-bindgen-futures = { version = "0.4" } wasm-rs-dbg = { version = "0.1.2", default-features = false, features = ["console-error"] } wasm-bindgen-test = { version = "0.3.0" } diff --git a/Makefile b/Makefile index ec16c50b9dc..7407f7d41fe 100644 --- a/Makefile +++ b/Makefile @@ -67,6 +67,10 @@ build-qe-wasm-gz: build-qe-wasm gzip -knc $$provider/query_engine_bg.wasm > $$provider.gz; \ done; +integrate-qe-wasm: + cd query-engine/query-engine-wasm && \ + ./build.sh $(QE_WASM_VERSION) ../prisma/packages/client/node_modules/@prisma/query-engine-wasm + build-schema-wasm: @printf '%s\n' "🛠️ Building the Rust crate" cargo build --profile $(PROFILE) --target=wasm32-unknown-unknown -p prisma-schema-build diff --git a/flake.lock b/flake.lock index 7c914585d8a..c20225ca22e 100644 --- a/flake.lock +++ b/flake.lock @@ -1,17 +1,12 @@ { "nodes": { "crane": { - "inputs": { - "nixpkgs": [ - "nixpkgs" - ] - }, "locked": { - "lastModified": 1722960479, - "narHash": "sha256-NhCkJJQhD5GUib8zN9JrmYGMwt4lCRp6ZVNzIiYCl0Y=", + "lastModified": 1728776144, + "narHash": "sha256-fROVjMcKRoGHofDm8dY3uDUtCMwUICh/KjBFQnuBzfg=", "owner": "ipetkov", "repo": "crane", - "rev": "4c6c77920b8d44cd6660c1621dea6b3fc4b4c4f4", + "rev": "f876e3d905b922502f031aeec1a84490122254b7", "type": "github" }, "original": { @@ -27,11 +22,11 @@ ] }, "locked": { - "lastModified": 1722555600, - "narHash": "sha256-XOQkdLafnb/p9ij77byFQjDf5m5QYl9b2REiVClC+x4=", + "lastModified": 1727826117, + "narHash": "sha256-K5ZLCyfO/Zj9mPFldf3iwS6oZStJcU4tSpiXTMYaaL0=", "owner": "hercules-ci", "repo": "flake-parts", - "rev": "8471fe90ad337a8074e957b69ca4d0089218391d", + "rev": "3d04084d54bedc3d6b8b736c70ef449225c361b1", "type": "github" }, "original": { @@ -62,11 +57,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1722813957, - "narHash": "sha256-IAoYyYnED7P8zrBFMnmp7ydaJfwTnwcnqxUElC1I26Y=", + "lastModified": 1728888510, + "narHash": "sha256-nsNdSldaAyu6PE3YUA+YQLqUDJh+gRbBooMMekZJwvI=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "cb9a96f23c491c081b38eab96d22fa958043c9fa", + "rev": "a3c0b3b21515f74fd2665903d4ce6bc4dc81c77c", "type": "github" }, "original": { @@ -92,11 +87,11 @@ ] }, "locked": { - "lastModified": 1723170066, - "narHash": "sha256-SFkQfOA+8AIYJsPlQtxNP+z5jRLfz91z/aOrV94pPmw=", + "lastModified": 1729184663, + "narHash": "sha256-uNyi5vQrzaLkt4jj6ZEOs4+4UqOAwP6jFG2s7LIDwIk=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "fecfe4d7c96fea2982c7907997b387a6b52c1093", + "rev": "16fb78d443c1970dda9a0bbb93070c9d8598a925", "type": "github" }, "original": { diff --git a/flake.nix b/flake.nix index bda6ec6a565..f55ed31a29d 100644 --- a/flake.nix +++ b/flake.nix @@ -1,9 +1,6 @@ { inputs = { - crane = { - url = "github:ipetkov/crane"; - inputs.nixpkgs.follows = "nixpkgs"; - }; + crane.url = "github:ipetkov/crane"; flake-parts = { url = "github:hercules-ci/flake-parts"; inputs.nixpkgs-lib.follows = "nixpkgs"; diff --git a/libs/build-utils/Cargo.toml b/libs/build-utils/Cargo.toml new file mode 100644 index 00000000000..715b650505d --- /dev/null +++ b/libs/build-utils/Cargo.toml @@ -0,0 +1,6 @@ +[package] +name = "build-utils" +version = "0.1.0" +edition = "2021" + +[dependencies] diff --git a/libs/build-utils/src/lib.rs b/libs/build-utils/src/lib.rs new file mode 100644 index 00000000000..03294a997a3 --- /dev/null +++ b/libs/build-utils/src/lib.rs @@ -0,0 +1,23 @@ +use std::process::Command; + +/// Store the current git commit hash in the `GIT_HASH` variable in rustc env. +/// If the `GIT_HASH` environment variable is already set, this function does nothing. +pub fn store_git_commit_hash_in_env() { + if std::env::var("GIT_HASH").is_ok() { + return; + } + + let output = Command::new("git").args(["rev-parse", "HEAD"]).output().unwrap(); + + // Sanity check on the output. + if !output.status.success() { + panic!( + "Failed to get git commit hash.\nstderr: \n{}\nstdout {}\n", + String::from_utf8(output.stderr).unwrap_or_default(), + String::from_utf8(output.stdout).unwrap_or_default(), + ); + } + + let git_hash = String::from_utf8(output.stdout).unwrap(); + println!("cargo:rustc-env=GIT_HASH={git_hash}"); +} diff --git a/libs/crosstarget-utils/Cargo.toml b/libs/crosstarget-utils/Cargo.toml index 609832a99ef..78d52dade2b 100644 --- a/libs/crosstarget-utils/Cargo.toml +++ b/libs/crosstarget-utils/Cargo.toml @@ -6,14 +6,17 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -futures = "0.3" +derive_more.workspace = true +enumflags2.workspace = true +futures.workspace = true [target.'cfg(target_arch = "wasm32")'.dependencies] js-sys.workspace = true wasm-bindgen.workspace = true wasm-bindgen-futures.workspace = true tokio = { version = "1", features = ["macros", "sync"] } -pin-project = "1" +pin-project.workspace = true [target.'cfg(not(target_arch = "wasm32"))'.dependencies] tokio.workspace = true +regex.workspace = true diff --git a/libs/crosstarget-utils/src/common.rs b/libs/crosstarget-utils/src/common.rs deleted file mode 100644 index 92a1d5094e8..00000000000 --- a/libs/crosstarget-utils/src/common.rs +++ /dev/null @@ -1,23 +0,0 @@ -use std::fmt::Display; - -#[derive(Debug)] -pub struct SpawnError; - -impl Display for SpawnError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Failed to spawn a future") - } -} - -impl std::error::Error for SpawnError {} - -#[derive(Debug)] -pub struct TimeoutError; - -impl Display for TimeoutError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Operation timed out") - } -} - -impl std::error::Error for TimeoutError {} diff --git a/libs/crosstarget-utils/src/common/mod.rs b/libs/crosstarget-utils/src/common/mod.rs new file mode 100644 index 00000000000..d1701cc3408 --- /dev/null +++ b/libs/crosstarget-utils/src/common/mod.rs @@ -0,0 +1,3 @@ +pub mod regex; +pub mod spawn; +pub mod timeout; diff --git a/libs/crosstarget-utils/src/common/regex.rs b/libs/crosstarget-utils/src/common/regex.rs new file mode 100644 index 00000000000..825d3e85d6c --- /dev/null +++ b/libs/crosstarget-utils/src/common/regex.rs @@ -0,0 +1,37 @@ +use derive_more::Display; + +#[derive(Debug, Display)] +#[display(fmt = "Regular expression error: {message}")] +pub struct RegExpError { + pub message: String, +} + +impl std::error::Error for RegExpError {} + +/// Flag modifiers for regular expressions. +#[enumflags2::bitflags] +#[derive(Copy, Clone, Debug, PartialEq)] +#[repr(u8)] +pub enum RegExpFlags { + IgnoreCase = 0b0001, + Multiline = 0b0010, +} + +impl RegExpFlags { + pub fn as_str(&self) -> &'static str { + match self { + Self::IgnoreCase => "i", + Self::Multiline => "m", + } + } +} + +pub trait RegExpCompat { + /// Searches for the first match of this regex in the haystack given, and if found, + /// returns not only the overall match but also the matches of each capture group in the regex. + /// If no match is found, then None is returned. + fn captures(&self, message: &str) -> Option>; + + /// Tests if the regex matches the input string. + fn test(&self, message: &str) -> bool; +} diff --git a/libs/crosstarget-utils/src/common/spawn.rs b/libs/crosstarget-utils/src/common/spawn.rs new file mode 100644 index 00000000000..77560452dbd --- /dev/null +++ b/libs/crosstarget-utils/src/common/spawn.rs @@ -0,0 +1,8 @@ +use derive_more::Display; + +#[derive(Debug, Display)] +#[display(fmt = "Failed to spawn a future")] + +pub struct SpawnError; + +impl std::error::Error for SpawnError {} diff --git a/libs/crosstarget-utils/src/common/timeout.rs b/libs/crosstarget-utils/src/common/timeout.rs new file mode 100644 index 00000000000..829abaf0ec0 --- /dev/null +++ b/libs/crosstarget-utils/src/common/timeout.rs @@ -0,0 +1,7 @@ +use derive_more::Display; + +#[derive(Debug, Display)] +#[display(fmt = "Operation timed out")] +pub struct TimeoutError; + +impl std::error::Error for TimeoutError {} diff --git a/libs/crosstarget-utils/src/lib.rs b/libs/crosstarget-utils/src/lib.rs index a41d8dd0f9a..1cfa25edeed 100644 --- a/libs/crosstarget-utils/src/lib.rs +++ b/libs/crosstarget-utils/src/lib.rs @@ -9,4 +9,5 @@ mod native; #[cfg(not(target_arch = "wasm32"))] pub use crate::native::*; -pub use common::SpawnError; +pub use crate::common::regex::RegExpCompat; +pub use crate::common::spawn::SpawnError; diff --git a/libs/crosstarget-utils/src/native/mod.rs b/libs/crosstarget-utils/src/native/mod.rs index b19a356ff8f..e3793a1de65 100644 --- a/libs/crosstarget-utils/src/native/mod.rs +++ b/libs/crosstarget-utils/src/native/mod.rs @@ -1,3 +1,4 @@ +pub mod regex; pub mod spawn; pub mod task; pub mod time; diff --git a/libs/crosstarget-utils/src/native/regex.rs b/libs/crosstarget-utils/src/native/regex.rs new file mode 100644 index 00000000000..59fe88899c6 --- /dev/null +++ b/libs/crosstarget-utils/src/native/regex.rs @@ -0,0 +1,41 @@ +use enumflags2::BitFlags; +use regex::{Regex as NativeRegex, RegexBuilder}; + +use crate::common::regex::{RegExpCompat, RegExpError, RegExpFlags}; + +pub struct RegExp { + inner: NativeRegex, +} + +impl RegExp { + pub fn new(pattern: &str, flags: BitFlags) -> Result { + let mut builder = RegexBuilder::new(pattern); + + if flags.contains(RegExpFlags::Multiline) { + builder.multi_line(true); + } + + if flags.contains(RegExpFlags::IgnoreCase) { + builder.case_insensitive(true); + } + + let inner = builder.build().map_err(|e| RegExpError { message: e.to_string() })?; + + Ok(Self { inner }) + } +} + +impl RegExpCompat for RegExp { + fn captures(&self, message: &str) -> Option> { + self.inner.captures(message).map(|captures| { + captures + .iter() + .flat_map(|capture| capture.map(|cap| cap.as_str().to_owned())) + .collect() + }) + } + + fn test(&self, message: &str) -> bool { + self.inner.is_match(message) + } +} diff --git a/libs/crosstarget-utils/src/native/spawn.rs b/libs/crosstarget-utils/src/native/spawn.rs index 70e4c3708f2..31971aa47c4 100644 --- a/libs/crosstarget-utils/src/native/spawn.rs +++ b/libs/crosstarget-utils/src/native/spawn.rs @@ -1,7 +1,7 @@ use futures::TryFutureExt; use std::future::Future; -use crate::common::SpawnError; +use crate::common::spawn::SpawnError; pub fn spawn_if_possible(future: F) -> impl Future> where diff --git a/libs/crosstarget-utils/src/native/time.rs b/libs/crosstarget-utils/src/native/time.rs index 3b154a27565..c17cb07c5eb 100644 --- a/libs/crosstarget-utils/src/native/time.rs +++ b/libs/crosstarget-utils/src/native/time.rs @@ -3,8 +3,9 @@ use std::{ time::{Duration, Instant}, }; -use crate::common::TimeoutError; +use crate::common::timeout::TimeoutError; +#[derive(Clone, Copy)] pub struct ElapsedTimeCounter { instant: Instant, } diff --git a/libs/crosstarget-utils/src/wasm/mod.rs b/libs/crosstarget-utils/src/wasm/mod.rs index b19a356ff8f..e3793a1de65 100644 --- a/libs/crosstarget-utils/src/wasm/mod.rs +++ b/libs/crosstarget-utils/src/wasm/mod.rs @@ -1,3 +1,4 @@ +pub mod regex; pub mod spawn; pub mod task; pub mod time; diff --git a/libs/crosstarget-utils/src/wasm/regex.rs b/libs/crosstarget-utils/src/wasm/regex.rs new file mode 100644 index 00000000000..500f631282e --- /dev/null +++ b/libs/crosstarget-utils/src/wasm/regex.rs @@ -0,0 +1,38 @@ +use enumflags2::BitFlags; +use js_sys::RegExp as JSRegExp; + +use crate::common::regex::{RegExpCompat, RegExpError, RegExpFlags}; + +pub struct RegExp { + inner: JSRegExp, +} + +impl RegExp { + pub fn new(pattern: &str, flags: BitFlags) -> Result { + let mut flags: String = flags.into_iter().map(|flag| flag.as_str()).collect(); + + // Global flag is implied in `regex::Regex`, so we match that behavior for consistency. + flags.push('g'); + + Ok(Self { + inner: JSRegExp::new(pattern, &flags), + }) + } +} + +impl RegExpCompat for RegExp { + fn captures(&self, message: &str) -> Option> { + self.inner.exec(message).map(|matches| { + // We keep the same number of captures as the number of groups in the regex pattern, + // but we guarantee that the captures are always strings. + matches + .iter() + .map(|match_value| match_value.try_into().ok().unwrap_or_default()) + .collect() + }) + } + + fn test(&self, input: &str) -> bool { + self.inner.test(input) + } +} diff --git a/libs/crosstarget-utils/src/wasm/spawn.rs b/libs/crosstarget-utils/src/wasm/spawn.rs index e27104c3b94..f9700d8a007 100644 --- a/libs/crosstarget-utils/src/wasm/spawn.rs +++ b/libs/crosstarget-utils/src/wasm/spawn.rs @@ -4,7 +4,7 @@ use futures::TryFutureExt; use tokio::sync::oneshot; use wasm_bindgen_futures::spawn_local; -use crate::common::SpawnError; +use crate::common::spawn::SpawnError; pub fn spawn_if_possible(future: F) -> impl Future> where diff --git a/libs/crosstarget-utils/src/wasm/time.rs b/libs/crosstarget-utils/src/wasm/time.rs index 18f3394b746..6c36a7b4d40 100644 --- a/libs/crosstarget-utils/src/wasm/time.rs +++ b/libs/crosstarget-utils/src/wasm/time.rs @@ -7,7 +7,7 @@ use std::time::Duration; use wasm_bindgen::prelude::*; use wasm_bindgen_futures::JsFuture; -use crate::common::TimeoutError; +use crate::common::timeout::TimeoutError; #[wasm_bindgen] extern "C" { @@ -21,6 +21,7 @@ extern "C" { } +#[derive(Clone, Copy)] pub struct ElapsedTimeCounter { start_time: f64, } diff --git a/query-engine/metrics/Cargo.toml b/libs/metrics/Cargo.toml similarity index 50% rename from query-engine/metrics/Cargo.toml rename to libs/metrics/Cargo.toml index 5593b246c09..916c464dda5 100644 --- a/query-engine/metrics/Cargo.toml +++ b/libs/metrics/Cargo.toml @@ -1,19 +1,22 @@ [package] -name = "query-engine-metrics" +name = "prisma-metrics" version = "0.1.0" edition = "2021" [dependencies] -metrics = "0.18" -metrics-util = "0.12.1" -metrics-exporter-prometheus = "0.10.0" +futures.workspace = true +derive_more.workspace = true +metrics = "0.23.0" +metrics-util = "0.17.0" +metrics-exporter-prometheus = { version = "0.15.3", default-features = false } once_cell = "1.3" serde.workspace = true serde_json.workspace = true tracing.workspace = true -tracing-futures = "0.2" +tracing-futures.workspace = true tracing-subscriber = "0.3.11" parking_lot = "0.12" +pin-project.workspace = true [dev-dependencies] expect-test = "1" diff --git a/query-engine/metrics/src/common.rs b/libs/metrics/src/common.rs similarity index 100% rename from query-engine/metrics/src/common.rs rename to libs/metrics/src/common.rs diff --git a/query-engine/metrics/src/formatters.rs b/libs/metrics/src/formatters.rs similarity index 100% rename from query-engine/metrics/src/formatters.rs rename to libs/metrics/src/formatters.rs diff --git a/libs/metrics/src/guards.rs b/libs/metrics/src/guards.rs new file mode 100644 index 00000000000..331db124990 --- /dev/null +++ b/libs/metrics/src/guards.rs @@ -0,0 +1,31 @@ +use std::sync::atomic::{AtomicBool, Ordering}; + +use crate::gauge; + +pub struct GaugeGuard { + name: &'static str, + decremented: AtomicBool, +} + +impl GaugeGuard { + pub fn increment(name: &'static str) -> Self { + gauge!(name).increment(1.0); + + Self { + name, + decremented: AtomicBool::new(false), + } + } + + pub fn decrement(&self) { + if !self.decremented.swap(true, Ordering::Relaxed) { + gauge!(self.name).decrement(1.0); + } + } +} + +impl Drop for GaugeGuard { + fn drop(&mut self) { + self.decrement(); + } +} diff --git a/libs/metrics/src/instrument.rs b/libs/metrics/src/instrument.rs new file mode 100644 index 00000000000..a2cb16de48f --- /dev/null +++ b/libs/metrics/src/instrument.rs @@ -0,0 +1,83 @@ +use std::{ + cell::RefCell, + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +use futures::future::Either; +use pin_project::pin_project; + +use crate::MetricRecorder; + +thread_local! { + /// The current metric recorder temporarily set on the current thread while polling a future. + /// + /// See the description of `GLOBAL_RECORDER` in [`crate::recorder`] module for more + /// information. + static CURRENT_RECORDER: RefCell> = const { RefCell::new(None) }; +} + +/// Instruments a type with a metrics recorder. +/// +/// The instrumentation logic is currently only implemented for futures, but it could be extended +/// to support streams, sinks, and other types later if needed. Right now we only need it to be +/// able to set the initial recorder in the Node-API engine methods and forward the recorder to +/// spawned tokio tasks; in other words, to instrument the top-level future of each task. +pub trait WithMetricsInstrumentation: Sized { + /// Instruments the type with a [`MetricRecorder`]. + fn with_recorder(self, recorder: MetricRecorder) -> WithRecorder { + WithRecorder { inner: self, recorder } + } + + /// Instruments the type with an [`MetricRecorder`] if it is a `Some` or returns `self` as is + /// if the `recorder` is a `None`. + fn with_optional_recorder(self, recorder: Option) -> Either, Self> { + match recorder { + Some(recorder) => Either::Left(self.with_recorder(recorder)), + None => Either::Right(self), + } + } + + /// Instruments the type with the current [`MetricRecorder`] from the parent context on this + /// thread, or the default global recorder otherwise. If neither is set, then `self` is + /// returned as is. + fn with_current_recorder(self) -> Either, Self> { + CURRENT_RECORDER.with_borrow(|recorder| { + let recorder = recorder.clone().or_else(crate::recorder::global_recorder); + self.with_optional_recorder(recorder) + }) + } +} + +impl WithMetricsInstrumentation for T {} + +/// A type instrumented with a metric recorder. +/// +/// If `T` is a `Future`, then `WithRecorder` is also a `Future`. When polled, it temporarily +/// sets the local metric recorder for the duration of polling the inner future, and then restores +/// the previous recorder on the stack. +/// +/// Similar logic can be implemented for cases where `T` is another async primitive like a stream +/// or a sink, or any other type where such instrumentation makes sense (e.g. a function). +#[pin_project] +pub struct WithRecorder { + #[pin] + inner: T, + recorder: MetricRecorder, +} + +impl Future for WithRecorder { + type Output = T::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.project(); + let prev_recorder = CURRENT_RECORDER.replace(Some(this.recorder.clone())); + + let poll = metrics::with_local_recorder(this.recorder, || this.inner.poll(cx)); + + CURRENT_RECORDER.set(prev_recorder); + + poll + } +} diff --git a/query-engine/metrics/src/lib.rs b/libs/metrics/src/lib.rs similarity index 78% rename from query-engine/metrics/src/lib.rs rename to libs/metrics/src/lib.rs index 1965b56cb07..43aeab77592 100644 --- a/query-engine/metrics/src/lib.rs +++ b/libs/metrics/src/lib.rs @@ -1,5 +1,8 @@ -//! Query Engine Metrics -//! This crate is responsible for capturing and recording metrics in the Query Engine. +//! # Prisma Metrics +//! +//! This crate is responsible for capturing and recording metrics in the Query Engine and its +//! dependencies. +//! //! Metrics is broken into two parts, `MetricsRecorder` and `MetricsRegistry`, and uses our tracing framework to communicate. //! An example best explains this system. //! When the engine boots up, the `MetricRegistry` is added to our tracing as a layer and The MetricRecorder is @@ -19,29 +22,23 @@ //! * At the moment, with the Histogram we only support one type of bucket which is a bucket for timings in milliseconds. //! -const METRIC_TARGET: &str = "qe_metrics"; -const METRIC_COUNTER: &str = "counter"; -const METRIC_GAUGE: &str = "gauge"; -const METRIC_HISTOGRAM: &str = "histogram"; -const METRIC_DESCRIPTION: &str = "description"; - mod common; mod formatters; +mod instrument; mod recorder; mod registry; +pub mod guards; + use once_cell::sync::Lazy; -use recorder::*; -pub use registry::MetricRegistry; use serde::Deserialize; use std::collections::HashMap; -use std::sync::Once; -pub extern crate metrics; -pub use metrics::{ - absolute_counter, decrement_gauge, describe_counter, describe_gauge, describe_histogram, gauge, histogram, - increment_counter, increment_gauge, -}; +pub use metrics::{self, counter, describe_counter, describe_gauge, describe_histogram, gauge, histogram}; + +pub use instrument::*; +pub use recorder::MetricRecorder; +pub use registry::MetricRegistry; // Metrics that we emit from the engines, third party metrics emitted by libraries and that we rename are omitted. pub const PRISMA_CLIENT_QUERIES_TOTAL: &str = "prisma_client_queries_total"; // counter @@ -94,21 +91,8 @@ static METRIC_RENAMES: Lazy> ]) }); -pub fn setup() { - set_recorder(); - initialize_metrics(); -} - -static METRIC_RECORDER: Once = Once::new(); - -fn set_recorder() { - METRIC_RECORDER.call_once(|| { - metrics::set_boxed_recorder(Box::new(MetricRecorder)).unwrap(); - }); -} - /// Initialize metrics descriptions and values -pub fn initialize_metrics() { +pub(crate) fn initialize_metrics() { initialize_metrics_descriptions(); initialize_metrics_values(); } @@ -145,15 +129,15 @@ fn initialize_metrics_descriptions() { /// Histograms are excluded, as their initialization will alter the histogram values. /// (i.e. histograms don't have a neutral value, like counters or gauges) fn initialize_metrics_values() { - absolute_counter!(PRISMA_CLIENT_QUERIES_TOTAL, 0); - absolute_counter!(PRISMA_DATASOURCE_QUERIES_TOTAL, 0); - gauge!(PRISMA_CLIENT_QUERIES_ACTIVE, 0.0); - absolute_counter!(MOBC_POOL_CONNECTIONS_OPENED_TOTAL, 0); - absolute_counter!(MOBC_POOL_CONNECTIONS_CLOSED_TOTAL, 0); - gauge!(MOBC_POOL_CONNECTIONS_OPEN, 0.0); - gauge!(MOBC_POOL_CONNECTIONS_BUSY, 0.0); - gauge!(MOBC_POOL_CONNECTIONS_IDLE, 0.0); - gauge!(MOBC_POOL_WAIT_COUNT, 0.0); + counter!(PRISMA_CLIENT_QUERIES_TOTAL).absolute(0); + counter!(PRISMA_DATASOURCE_QUERIES_TOTAL).absolute(0); + gauge!(PRISMA_CLIENT_QUERIES_ACTIVE).set(0.0); + counter!(MOBC_POOL_CONNECTIONS_OPENED_TOTAL).absolute(0); + counter!(MOBC_POOL_CONNECTIONS_CLOSED_TOTAL).absolute(0); + gauge!(MOBC_POOL_CONNECTIONS_OPEN).set(0.0); + gauge!(MOBC_POOL_CONNECTIONS_BUSY).set(0.0); + gauge!(MOBC_POOL_CONNECTIONS_IDLE).set(0.0); + gauge!(MOBC_POOL_WAIT_COUNT).set(0.0); } // At the moment the histogram is only used for timings. So the bounds are hard coded here @@ -171,24 +155,16 @@ pub enum MetricFormat { #[cfg(test)] mod tests { use super::*; - use metrics::{ - absolute_counter, decrement_gauge, describe_counter, describe_gauge, describe_histogram, gauge, histogram, - increment_counter, increment_gauge, register_counter, register_gauge, register_histogram, - }; + use metrics::{describe_counter, describe_gauge, describe_histogram, gauge, histogram}; use serde_json::json; use std::collections::HashMap; use std::time::Duration; - use tracing::instrument::WithSubscriber; - use tracing::{trace, Dispatch}; - use tracing_subscriber::layer::SubscriberExt; + use tracing::trace; use once_cell::sync::Lazy; use tokio::runtime::Runtime; - static RT: Lazy = Lazy::new(|| { - set_recorder(); - Runtime::new().unwrap() - }); + static RT: Lazy = Lazy::new(|| Runtime::new().unwrap()); const TESTING_ACCEPT_LIST: &[&str] = &[ "test_counter", @@ -209,14 +185,14 @@ mod tests { fn test_counters() { RT.block_on(async { let metrics = MetricRegistry::new_with_accept_list(TESTING_ACCEPT_LIST.to_vec()); - let dispatch = Dispatch::new(tracing_subscriber::Registry::default().with(metrics.clone())); - async { - let counter1 = register_counter!("test_counter"); + let recorder = MetricRecorder::new(metrics.clone()); + async move { + let counter1 = counter!("test_counter"); counter1.increment(1); - increment_counter!("test_counter"); - increment_counter!("test_counter"); + counter!("test_counter").increment(1); + counter!("test_counter").increment(1); - increment_counter!("another_counter"); + counter!("another_counter").increment(1); let val = metrics.counter_value("test_counter").unwrap(); assert_eq!(val, 3); @@ -224,11 +200,11 @@ mod tests { let val2 = metrics.counter_value("another_counter").unwrap(); assert_eq!(val2, 1); - absolute_counter!("test_counter", 5); + counter!("test_counter").absolute(5); let val3 = metrics.counter_value("test_counter").unwrap(); assert_eq!(val3, 5); } - .with_subscriber(dispatch) + .with_recorder(recorder) .await; }); } @@ -237,13 +213,13 @@ mod tests { fn test_gauges() { RT.block_on(async { let metrics = MetricRegistry::new_with_accept_list(TESTING_ACCEPT_LIST.to_vec()); - let dispatch = Dispatch::new(tracing_subscriber::Registry::default().with(metrics.clone())); - async { - let gauge1 = register_gauge!("test_gauge"); + let recorder = MetricRecorder::new(metrics.clone()); + async move { + let gauge1 = gauge!("test_gauge"); gauge1.increment(1.0); - increment_gauge!("test_gauge", 1.0); - increment_gauge!("test_gauge", 1.0); - increment_gauge!("another_gauge", 1.0); + gauge!("test_gauge").increment(1.0); + gauge!("test_gauge").increment(1.0); + gauge!("another_gauge").increment(1.0); let val = metrics.gauge_value("test_gauge").unwrap(); assert_eq!(val, 3.0); @@ -253,15 +229,15 @@ mod tests { assert_eq!(None, metrics.counter_value("test_gauge")); - gauge!("test_gauge", 5.0); + gauge!("test_gauge").set(5.0); let val3 = metrics.gauge_value("test_gauge").unwrap(); assert_eq!(val3, 5.0); - decrement_gauge!("test_gauge", 2.0); + gauge!("test_gauge").decrement(2.0); let val4 = metrics.gauge_value("test_gauge").unwrap(); assert_eq!(val4, 3.0); } - .with_subscriber(dispatch) + .with_recorder(recorder) .await; }); } @@ -270,19 +246,19 @@ mod tests { fn test_no_panic_and_ignore_other_traces() { RT.block_on(async { let metrics = MetricRegistry::new_with_accept_list(TESTING_ACCEPT_LIST.to_vec()); - let dispatch = Dispatch::new(tracing_subscriber::Registry::default().with(metrics.clone())); - async { + let recorder = MetricRecorder::new(metrics.clone()); + async move { trace!("a fake trace"); - increment_gauge!("test_gauge", 1.0); - increment_counter!("test_counter"); + gauge!("test_gauge").set(1.0); + counter!("test_counter").increment(1); trace!("another fake trace"); assert_eq!(1.0, metrics.gauge_value("test_gauge").unwrap()); assert_eq!(1, metrics.counter_value("test_counter").unwrap()); } - .with_subscriber(dispatch) + .with_recorder(recorder) .await; }); } @@ -291,15 +267,15 @@ mod tests { fn test_ignore_non_accepted_metrics() { RT.block_on(async { let metrics = MetricRegistry::new_with_accept_list(TESTING_ACCEPT_LIST.to_vec()); - let dispatch = Dispatch::new(tracing_subscriber::Registry::default().with(metrics.clone())); - async { - increment_gauge!("not_accepted", 1.0); - increment_gauge!("test_gauge", 1.0); + let recorder = MetricRecorder::new(metrics.clone()); + async move { + gauge!("not_accepted").set(1.0); + gauge!("test_gauge").set(1.0); assert_eq!(1.0, metrics.gauge_value("test_gauge").unwrap()); assert_eq!(None, metrics.gauge_value("not_accepted")); } - .with_subscriber(dispatch) + .with_recorder(recorder) .await; }); } @@ -308,17 +284,17 @@ mod tests { fn test_histograms() { RT.block_on(async { let metrics = MetricRegistry::new_with_accept_list(TESTING_ACCEPT_LIST.to_vec()); - let dispatch = Dispatch::new(tracing_subscriber::Registry::default().with(metrics.clone())); - async { - let hist = register_histogram!("test_histogram"); + let recorder = MetricRecorder::new(metrics.clone()); + async move { + let hist = histogram!("test_histogram"); hist.record(Duration::from_millis(9)); - histogram!("test_histogram", Duration::from_millis(100)); - histogram!("test_histogram", Duration::from_millis(1)); + histogram!("test_histogram").record(Duration::from_millis(100)); + histogram!("test_histogram").record(Duration::from_millis(1)); - histogram!("test_histogram", Duration::from_millis(1999)); - histogram!("test_histogram", Duration::from_millis(3999)); - histogram!("test_histogram", Duration::from_millis(610)); + histogram!("test_histogram").record(Duration::from_millis(1999)); + histogram!("test_histogram").record(Duration::from_millis(3999)); + histogram!("test_histogram").record(Duration::from_millis(610)); let hist = metrics.histogram_values("test_histogram").unwrap(); let expected: Vec<(f64, u64)> = Vec::from([ @@ -336,7 +312,7 @@ mod tests { assert_eq!(hist.buckets(), expected); } - .with_subscriber(dispatch) + .with_recorder(recorder) .await; }); } @@ -345,8 +321,8 @@ mod tests { fn test_set_and_read_descriptions() { RT.block_on(async { let metrics = MetricRegistry::new_with_accept_list(TESTING_ACCEPT_LIST.to_vec()); - let dispatch = Dispatch::new(tracing_subscriber::Registry::default().with(metrics.clone())); - async { + let recorder = MetricRecorder::new(metrics.clone()); + async move { describe_counter!("test_counter", "This is a counter"); let descriptions = metrics.get_descriptions(); @@ -367,7 +343,7 @@ mod tests { let description = descriptions.get("test_histogram").unwrap(); assert_eq!("This is a hist", description); } - .with_subscriber(dispatch) + .with_recorder(recorder) .await; }); } @@ -376,8 +352,8 @@ mod tests { fn test_to_json() { RT.block_on(async { let metrics = MetricRegistry::new_with_accept_list(TESTING_ACCEPT_LIST.to_vec()); - let dispatch = Dispatch::new(tracing_subscriber::Registry::default().with(metrics.clone())); - async { + let recorder = MetricRecorder::new(metrics.clone()); + async move { let empty = json!({ "counters": [], "gauges": [], @@ -386,21 +362,21 @@ mod tests { assert_eq!(metrics.to_json(Default::default()), empty); - absolute_counter!("counter_1", 4, "label" => "one"); + counter!("counter_1", "label" => "one").absolute(4); describe_counter!("counter_2", "this is a description for counter 2"); - absolute_counter!("counter_2", 2, "label" => "one", "another_label" => "two"); + counter!("counter_2", "label" => "one", "another_label" => "two").absolute(2); describe_gauge!("gauge_1", "a description for gauge 1"); - gauge!("gauge_1", 7.0); - gauge!("gauge_2", 3.0, "label" => "three"); + gauge!("gauge_1").set(7.0); + gauge!("gauge_2", "label" => "three").set(3.0); describe_histogram!("histogram_1", "a description for histogram"); - let hist = register_histogram!("histogram_1", "label" => "one", "hist_two" => "two"); + let hist = histogram!("histogram_1", "label" => "one", "hist_two" => "two"); hist.record(Duration::from_millis(9)); - histogram!("histogram_2", Duration::from_millis(9)); - histogram!("histogram_2", Duration::from_millis(1000)); - histogram!("histogram_2", Duration::from_millis(40)); + histogram!("histogram_2").record(Duration::from_millis(9)); + histogram!("histogram_2").record(Duration::from_millis(1000)); + histogram!("histogram_2").record(Duration::from_millis(40)); let json = metrics.to_json(Default::default()); let expected = json!({ @@ -448,7 +424,7 @@ mod tests { assert_eq!(json, expected); } - .with_subscriber(dispatch) + .with_recorder(recorder) .await; }); } @@ -457,12 +433,12 @@ mod tests { fn test_global_and_metric_labels() { RT.block_on(async { let metrics = MetricRegistry::new_with_accept_list(TESTING_ACCEPT_LIST.to_vec()); - let dispatch = Dispatch::new(tracing_subscriber::Registry::default().with(metrics.clone())); - async { - let hist = register_histogram!("test_histogram", "label" => "one", "two" => "another"); + let recorder = MetricRecorder::new(metrics.clone()); + async move { + let hist = histogram!("test_histogram", "label" => "one", "two" => "another"); hist.record(Duration::from_millis(9)); - absolute_counter!("counter_1", 1); + counter!("counter_1").absolute(1); let mut global_labels: HashMap = HashMap::new(); global_labels.insert("global_one".to_string(), "one".to_string()); @@ -491,7 +467,7 @@ mod tests { }); assert_eq!(expected, json); } - .with_subscriber(dispatch) + .with_recorder(recorder) .await; }); } @@ -500,21 +476,21 @@ mod tests { fn test_prometheus_format() { RT.block_on(async { let metrics = MetricRegistry::new_with_accept_list(TESTING_ACCEPT_LIST.to_vec()); - let dispatch = Dispatch::new(tracing_subscriber::Registry::default().with(metrics.clone())); - async { - absolute_counter!("counter_1", 4, "label" => "one"); + let recorder = MetricRecorder::new(metrics.clone()); + async move { + counter!("counter_1", "label" => "one").absolute(4); describe_counter!("counter_2", "this is a description for counter 2"); - absolute_counter!("counter_2", 2, "label" => "one", "another_label" => "two"); + counter!("counter_2", "label" => "one", "another_label" => "two").absolute(2); describe_gauge!("gauge_1", "a description for gauge 1"); - gauge!("gauge_1", 7.0); - gauge!("gauge_2", 3.0, "label" => "three"); + gauge!("gauge_1").set(7.0); + gauge!("gauge_2", "label" => "three").set(3.0); describe_histogram!("histogram_1", "a description for histogram"); - let hist = register_histogram!("histogram_1", "label" => "one", "hist_two" => "two"); + let hist = histogram!("histogram_1", "label" => "one", "hist_two" => "two"); hist.record(Duration::from_millis(9)); - histogram!("histogram_2", Duration::from_millis(1000)); + histogram!("histogram_2").record(Duration::from_millis(1000)); let mut global_labels: HashMap = HashMap::new(); global_labels.insert("global_two".to_string(), "two".to_string()); @@ -574,7 +550,7 @@ mod tests { snapshot.assert_eq(&prometheus); } - .with_subscriber(dispatch) + .with_recorder(recorder) .await; }); } diff --git a/libs/metrics/src/recorder.rs b/libs/metrics/src/recorder.rs new file mode 100644 index 00000000000..a911a45d67e --- /dev/null +++ b/libs/metrics/src/recorder.rs @@ -0,0 +1,190 @@ +use std::sync::{Arc, OnceLock}; + +use derive_more::Display; +use metrics::{Counter, CounterFn, Gauge, GaugeFn, Histogram, HistogramFn, Key, Recorder, Unit}; +use metrics::{KeyName, Metadata, SharedString}; + +use crate::common::{MetricAction, MetricType}; +use crate::registry::MetricVisitor; +use crate::MetricRegistry; + +/// Default global metric recorder. +/// +/// `metrics` crate has the state on its own. It allows setting the global recorder, it allows +/// overriding it for a duration of an async closure, and it allows borrowing the current recorder +/// for a short while. We, however, can't use this in our async instrumentation because we need the +/// current recorder to be `Send + 'static` to be able to store it in a future that would be usable +/// in a work-stealing runtime, especially since we need to be able to instrument the futures +/// spawned as tasks. The solution to this is to maintain our own state in parallel. +/// +/// The APIs exposed by the crate guarantee that the state we modify on our side is updated on the +/// `metrics` side as well. Using `metrics::set_global_recorder` or `metrics::with_local_recorder` +/// in user code won't be detected by us but is safe and won't lead to any issues (even if the new +/// recorder isn't the [`MetricRecorder`] from this crate), we just won't know about any new local +/// recorders on the stack, and calling +/// [`crate::WithMetricsInstrumentation::with_current_recorder`] will re-use the last +/// [`MetricRecorder`] known to us. +static GLOBAL_RECORDER: OnceLock> = const { OnceLock::new() }; + +#[derive(Display, Debug)] +#[display(fmt = "global recorder can only be installed once")] +pub struct AlreadyInstalled; + +impl std::error::Error for AlreadyInstalled {} + +fn set_global_recorder(recorder: MetricRecorder) -> Result<(), AlreadyInstalled> { + GLOBAL_RECORDER.set(Some(recorder)).map_err(|_| AlreadyInstalled) +} + +pub(crate) fn global_recorder() -> Option { + GLOBAL_RECORDER.get()?.clone() +} + +/// Receives the metrics from the macros provided by the `metrics` crate and forwards them to +/// [`MetricRegistry`]. +/// +/// To provide an analogy, `MetricRecorder` to `MetricRegistry` is what `Dispatch` is to +/// `Subscriber` in `tracing`. Just like `Dispatch`, it acts like a handle to the registry and is +/// cheaply clonable with reference-counting semantics. +#[derive(Clone)] +pub struct MetricRecorder { + registry: MetricRegistry, +} + +impl MetricRecorder { + pub fn new(registry: MetricRegistry) -> Self { + Self { registry } + } + + /// Convenience method to call [`Self::init_prisma_metrics`] immediately after creating the + /// recorder. + pub fn with_initialized_prisma_metrics(self) -> Self { + self.init_prisma_metrics(); + self + } + + /// Initializes the default Prisma metrics by dispatching their descriptions and initial values + /// to the registry. + /// + /// Query engine needs this, but the metrics can also be used without this, especially in + /// tests. + pub fn init_prisma_metrics(&self) { + metrics::with_local_recorder(self, || { + super::initialize_metrics(); + }); + } + + /// Installs the metrics recorder globally, registering it both with the `metrics` crate and + /// our own instrumentation. + pub fn install_globally(&self) -> Result<(), AlreadyInstalled> { + set_global_recorder(self.clone())?; + metrics::set_global_recorder(self.clone()).map_err(|_| AlreadyInstalled) + } + + fn register_description(&self, name: KeyName, description: &str) { + self.record_in_registry(&MetricVisitor { + metric_type: MetricType::Description, + action: MetricAction::Description(description.to_owned()), + name: Key::from_name(name), + }); + } + + fn record_in_registry(&self, visitor: &MetricVisitor) { + self.registry.record(visitor); + } +} + +impl Recorder for MetricRecorder { + fn describe_counter(&self, key_name: KeyName, _unit: Option, description: SharedString) { + self.register_description(key_name, &description); + } + + fn describe_gauge(&self, key_name: KeyName, _unit: Option, description: SharedString) { + self.register_description(key_name, &description); + } + + fn describe_histogram(&self, key_name: KeyName, _unit: Option, description: SharedString) { + self.register_description(key_name, &description); + } + + fn register_counter(&self, key: &Key, _metadata: &Metadata<'_>) -> Counter { + Counter::from_arc(Arc::new(MetricHandle::new(key.clone(), self.registry.clone()))) + } + + fn register_gauge(&self, key: &Key, _metadata: &Metadata<'_>) -> Gauge { + Gauge::from_arc(Arc::new(MetricHandle::new(key.clone(), self.registry.clone()))) + } + + fn register_histogram(&self, key: &Key, _metadata: &Metadata<'_>) -> Histogram { + Histogram::from_arc(Arc::new(MetricHandle::new(key.clone(), self.registry.clone()))) + } +} + +pub(crate) struct MetricHandle { + key: Key, + registry: MetricRegistry, +} + +impl MetricHandle { + pub fn new(key: Key, registry: MetricRegistry) -> Self { + Self { key, registry } + } + + fn record_in_registry(&self, visitor: &MetricVisitor) { + self.registry.record(visitor); + } +} + +impl CounterFn for MetricHandle { + fn increment(&self, value: u64) { + self.record_in_registry(&MetricVisitor { + metric_type: MetricType::Counter, + action: MetricAction::Increment(value), + name: self.key.clone(), + }); + } + + fn absolute(&self, value: u64) { + self.record_in_registry(&MetricVisitor { + metric_type: MetricType::Counter, + action: MetricAction::Absolute(value), + name: self.key.clone(), + }); + } +} + +impl GaugeFn for MetricHandle { + fn increment(&self, value: f64) { + self.record_in_registry(&MetricVisitor { + metric_type: MetricType::Gauge, + action: MetricAction::GaugeInc(value), + name: self.key.clone(), + }); + } + + fn decrement(&self, value: f64) { + self.record_in_registry(&MetricVisitor { + metric_type: MetricType::Gauge, + action: MetricAction::GaugeDec(value), + name: self.key.clone(), + }); + } + + fn set(&self, value: f64) { + self.record_in_registry(&MetricVisitor { + metric_type: MetricType::Gauge, + action: MetricAction::GaugeSet(value), + name: self.key.clone(), + }); + } +} + +impl HistogramFn for MetricHandle { + fn record(&self, value: f64) { + self.record_in_registry(&MetricVisitor { + metric_type: MetricType::Histogram, + action: MetricAction::HistRecord(value), + name: self.key.clone(), + }); + } +} diff --git a/query-engine/metrics/src/registry.rs b/libs/metrics/src/registry.rs similarity index 66% rename from query-engine/metrics/src/registry.rs rename to libs/metrics/src/registry.rs index 6530edbe876..f6b217fcda5 100644 --- a/query-engine/metrics/src/registry.rs +++ b/libs/metrics/src/registry.rs @@ -1,11 +1,7 @@ -use super::formatters::metrics_to_json; -use super::{ - common::{KeyLabels, Metric, MetricAction, MetricType, MetricValue, Snapshot}, - formatters::metrics_to_prometheus, -}; -use super::{ - ACCEPT_LIST, HISTOGRAM_BOUNDS, METRIC_COUNTER, METRIC_DESCRIPTION, METRIC_GAUGE, METRIC_HISTOGRAM, METRIC_TARGET, -}; +use std::collections::HashMap; +use std::fmt; +use std::sync::{atomic::Ordering, Arc}; + use metrics::{CounterFn, GaugeFn, HistogramFn, Key}; use metrics_util::{ registry::{GenerationalAtomicStorage, GenerationalStorage, Registry}, @@ -13,15 +9,13 @@ use metrics_util::{ }; use parking_lot::RwLock; use serde_json::Value; -use std::collections::HashMap; -use std::fmt; -use std::sync::atomic::Ordering; -use std::sync::Arc; -use tracing::{ - field::{Field, Visit}, - Subscriber, + +use super::formatters::metrics_to_json; +use super::{ + common::{Metric, MetricAction, MetricType, MetricValue, Snapshot}, + formatters::metrics_to_prometheus, }; -use tracing_subscriber::Layer; +use super::{ACCEPT_LIST, HISTOGRAM_BOUNDS}; struct Inner { descriptions: RwLock>, @@ -46,7 +40,7 @@ pub struct MetricRegistry { impl fmt::Debug for MetricRegistry { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "Metric Registry") + write!(f, "MetricRegistry {{ .. }}") } } @@ -68,12 +62,14 @@ impl MetricRegistry { } } - fn record(&self, metric: &MetricVisitor) { - match metric.metric_type { - MetricType::Counter => self.handle_counter(metric), - MetricType::Gauge => self.handle_gauge(metric), - MetricType::Histogram => self.handle_histogram(metric), - MetricType::Description => self.handle_description(metric), + pub(crate) fn record(&self, metric: &MetricVisitor) { + if self.is_accepted_metric(metric) { + match metric.metric_type { + MetricType::Counter => self.handle_counter(metric), + MetricType::Gauge => self.handle_gauge(metric), + MetricType::Histogram => self.handle_histogram(metric), + MetricType::Description => self.handle_description(metric), + } } } @@ -223,80 +219,8 @@ impl MetricRegistry { } #[derive(Debug)] -struct MetricVisitor { - metric_type: MetricType, - action: MetricAction, - name: Key, -} - -impl MetricVisitor { - pub fn new() -> Self { - Self { - metric_type: MetricType::Description, - action: MetricAction::Absolute(0), - name: Key::from_name(""), - } - } -} - -impl Visit for MetricVisitor { - fn record_debug(&mut self, _field: &Field, _value: &dyn std::fmt::Debug) {} - - fn record_f64(&mut self, field: &Field, value: f64) { - match field.name() { - "gauge_inc" => self.action = MetricAction::GaugeInc(value), - "gauge_dec" => self.action = MetricAction::GaugeDec(value), - "gauge_set" => self.action = MetricAction::GaugeSet(value), - "hist_record" => self.action = MetricAction::HistRecord(value), - _ => (), - } - } - - fn record_i64(&mut self, field: &Field, value: i64) { - match field.name() { - "increment" => self.action = MetricAction::Increment(value as u64), - "absolute" => self.action = MetricAction::Absolute(value as u64), - _ => (), - } - } - - fn record_u64(&mut self, field: &Field, value: u64) { - match field.name() { - "increment" => self.action = MetricAction::Increment(value), - "absolute" => self.action = MetricAction::Absolute(value), - _ => (), - } - } - - fn record_str(&mut self, field: &Field, value: &str) { - match (field.name(), value) { - ("metric_type", METRIC_COUNTER) => self.metric_type = MetricType::Counter, - ("metric_type", METRIC_GAUGE) => self.metric_type = MetricType::Gauge, - ("metric_type", METRIC_HISTOGRAM) => self.metric_type = MetricType::Histogram, - ("metric_type", METRIC_DESCRIPTION) => self.metric_type = MetricType::Description, - ("name", _) => self.name = Key::from_name(value.to_string()), - ("key_labels", _) => { - let key_labels: KeyLabels = serde_json::from_str(value).unwrap(); - self.name = key_labels.into(); - } - (METRIC_DESCRIPTION, _) => self.action = MetricAction::Description(value.to_string()), - _ => (), - } - } -} - -// A tracing layer for receiving metric trace events and storing them in the registry. -impl Layer for MetricRegistry { - fn on_event(&self, event: &tracing::Event<'_>, _ctx: tracing_subscriber::layer::Context<'_, S>) { - if event.metadata().target() != METRIC_TARGET { - return; - } - - let mut visitor = MetricVisitor::new(); - event.record(&mut visitor); - - if self.is_accepted_metric(&visitor) { - self.record(&visitor); - } - } +pub(crate) struct MetricVisitor { + pub(crate) metric_type: MetricType, + pub(crate) action: MetricAction, + pub(crate) name: Key, } diff --git a/libs/prisma-value/Cargo.toml b/libs/prisma-value/Cargo.toml index 1a0d28e06db..9833b6ee104 100644 --- a/libs/prisma-value/Cargo.toml +++ b/libs/prisma-value/Cargo.toml @@ -7,7 +7,7 @@ version = "0.1.0" base64 = "0.13" chrono.workspace = true once_cell = "1.3" -regex = "1.2" +regex.workspace = true bigdecimal = "0.3" serde.workspace = true serde_json.workspace = true diff --git a/libs/query-engine-common/Cargo.toml b/libs/query-engine-common/Cargo.toml index daf41ba50f6..258639d2d94 100644 --- a/libs/query-engine-common/Cargo.toml +++ b/libs/query-engine-common/Cargo.toml @@ -8,6 +8,7 @@ thiserror = "1" url.workspace = true query-connector = { path = "../../query-engine/connectors/query-connector" } query-core = { path = "../../query-engine/core" } +telemetry = { path = "../telemetry" } user-facing-errors = { path = "../user-facing-errors" } serde_json.workspace = true serde.workspace = true @@ -16,12 +17,12 @@ psl.workspace = true async-trait.workspace = true tracing.workspace = true tracing-subscriber = { version = "0.3" } -tracing-futures = "0.2" +tracing-futures.workspace = true tracing-opentelemetry = "0.17.3" opentelemetry = { version = "0.17" } [target.'cfg(all(not(target_arch = "wasm32")))'.dependencies] -query-engine-metrics = { path = "../../query-engine/metrics" } +prisma-metrics.path = "../metrics" napi.workspace = true [target.'cfg(target_arch = "wasm32")'.dependencies] diff --git a/libs/query-engine-common/src/engine.rs b/libs/query-engine-common/src/engine.rs index 5129fca185c..91ddc2acb6d 100644 --- a/libs/query-engine-common/src/engine.rs +++ b/libs/query-engine-common/src/engine.rs @@ -59,7 +59,7 @@ pub struct EngineBuilder { pub struct ConnectedEngineNative { pub config_dir: PathBuf, pub env: HashMap, - pub metrics: Option, + pub metrics: Option, } /// Internal structure for querying and reconnecting with the engine. diff --git a/libs/query-engine-common/src/tracer.rs b/libs/query-engine-common/src/tracer.rs index 19d17cf13a0..256ec95c172 100644 --- a/libs/query-engine-common/src/tracer.rs +++ b/libs/query-engine-common/src/tracer.rs @@ -8,7 +8,6 @@ use opentelemetry::{ }, trace::{TraceError, TracerProvider}, }; -use query_core::telemetry; use std::fmt::{self, Debug}; /// Pipeline builder diff --git a/libs/telemetry/Cargo.toml b/libs/telemetry/Cargo.toml new file mode 100644 index 00000000000..4b9f9b79dac --- /dev/null +++ b/libs/telemetry/Cargo.toml @@ -0,0 +1,33 @@ +[package] +edition = "2021" +name = "telemetry" +version = "0.1.0" + +[features] +metrics = ["dep:prisma-metrics"] + +[dependencies] +async-trait.workspace = true +crossbeam-channel = "0.5.6" +psl.workspace = true +futures = "0.3" +indexmap.workspace = true +itertools.workspace = true +once_cell = "1" +opentelemetry = { version = "0.17.0", features = ["rt-tokio", "serialize"] } +rand.workspace = true +serde.workspace = true +serde_json.workspace = true +thiserror = "1.0" +tokio = { version = "1.0", features = ["macros", "time"] } +tracing = { workspace = true, features = ["attributes"] } +tracing-futures = "0.2" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } +tracing-opentelemetry = "0.17.4" +uuid.workspace = true +cuid = { git = "https://github.com/prisma/cuid-rust", branch = "wasm32-support" } +crosstarget-utils = { path = "../crosstarget-utils" } +lru = "0.7.7" +enumflags2.workspace = true +derive_more = "0.99.17" +prisma-metrics = { path = "../metrics", optional = true } diff --git a/query-engine/core/src/telemetry/capturing/capturer.rs b/libs/telemetry/src/capturing/capturer.rs similarity index 95% rename from query-engine/core/src/telemetry/capturing/capturer.rs rename to libs/telemetry/src/capturing/capturer.rs index d0d9886acd2..a978b0766c8 100644 --- a/query-engine/core/src/telemetry/capturing/capturer.rs +++ b/libs/telemetry/src/capturing/capturer.rs @@ -29,6 +29,20 @@ impl Capturer { Self::Disabled } + + pub async fn try_start_capturing(&self) { + if let Capturer::Enabled(capturer) = self { + capturer.start_capturing().await + } + } + + pub async fn try_fetch_captures(&self) -> Option { + if let Capturer::Enabled(capturer) = self { + capturer.fetch_captures().await + } else { + None + } + } } #[derive(Debug, Clone)] @@ -92,7 +106,7 @@ impl SpanProcessor for Processor { /// mongo / relational, the information to build this kind of log event is logged diffeerently in /// the server. /// - /// In the case of the of relational databaes --queried through sql_query_connector and eventually + /// In the case of the of relational database --queried through sql_query_connector and eventually /// through quaint, a trace span describes the query-- `TraceSpan::represents_query_event` /// determines if a span represents a query event. /// diff --git a/query-engine/core/src/telemetry/capturing/helpers.rs b/libs/telemetry/src/capturing/helpers.rs similarity index 100% rename from query-engine/core/src/telemetry/capturing/helpers.rs rename to libs/telemetry/src/capturing/helpers.rs diff --git a/query-engine/core/src/telemetry/capturing/mod.rs b/libs/telemetry/src/capturing/mod.rs similarity index 76% rename from query-engine/core/src/telemetry/capturing/mod.rs rename to libs/telemetry/src/capturing/mod.rs index fc1219d5fe0..0fdf711afb4 100644 --- a/query-engine/core/src/telemetry/capturing/mod.rs +++ b/libs/telemetry/src/capturing/mod.rs @@ -4,97 +4,97 @@ //! The interaction diagram below (soorry width!) shows the different roles at play during telemetry //! capturing. A textual explanatation follows it. For the sake of example a server environment //! --the query-engine crate-- is assumed. -//! # ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ -//! # -//! # │ <> │ -//! # -//! # ╔═══════════════════════╗ │╔═══════════════╗ │ -//! # ║<>║ ║ <> ║ ╔════════════════╗ ╔═══════════════════╗ -//! # ┌───────────────────┐ ║ PROCESSOR ║ │║ Sender ║ ║ Storage ║│ ║ TRACER ║ -//! # │ Server │ ╚═══════════╦═══════════╝ ╚══════╦════════╝ ╚═══════╦════════╝ ╚═════════╦═════════╝ -//! # └─────────┬─────────┘ │ │ │ │ │ │ -//! # │ │ │ │ │ -//! # │ │ │ │ │ │ │ -//! # POST │ │ │ │ │ -//! # (body, headers)│ │ │ │ │ │ │ -//! # ──────────▶┌┴┐ │ │ │ │ -//! # ┌─┐ │ │new(headers)╔════════════╗ │ │ │ │ │ │ -//! # │1│ │ ├───────────▶║s: Settings ║ │ │ │ │ -//! # └─┘ │ │ ╚════════════╝ │ │ │ │ │ │ -//! # │ │ │ │ │ │ -//! # │ │ ╔═══════════════════╗ │ │ │ │ │ │ -//! # │ │ ║ Capturer::Enabled ║ │ │ │ │ ┌────────────┐ -//! # │ │ ╚═══════════════════╝ │ │ │ │ │ │ │<│ -//! # │ │ │ │ │ │ │ └──────┬─────┘ -//! # │ │ ┌─┐ new(trace_id, s) │ │ │ │ │ │ │ │ -//! # │ ├───┤2├───────────────────────▶│ │ │ │ │ │ -//! # │ │ └─┘ │ │ │ │ │ │ │ │ -//! # │ │ │ │ │ │ │ │ -//! # │ │ ┌─┐ start_capturing() │ start_capturing │ │ │ │ │ │ │ -//! # │ ├───┤3├───────────────────────▶│ (trace_id, s) │ │ │ │ │ -//! # │ │ └─┘ │ │ │ │ │ │ │ │ -//! # │ │ ├─────────────────────▶│ send(StartCapturing, │ │ │ │ -//! # │ │ │ │ trace_id)│ │ │ │ │ │ -//! # │ │ │ │── ── ── ── ── ── ── ─▶│ │ │ │ -//! # │ │ │ │ ┌─┐ │ │insert(trace_id, s) │ │ │ │ -//! # │ │ │ │ │4│ │────────────────────▶│ │ │ -//! # │ │ │ │ └─┘ │ │ │ │ ┌─┐ │ process_query │ -//! # │ │──────────────────────────────┼──────────────────────┼───────────────────────┼─────────────────────┼────────────┤5├──────┼──────────────────────────▶┌┴┐ -//! # │ │ │ │ │ │ │ │ └─┘ │ │ │ -//! # │ │ │ │ │ │ │ │ │ +//! # ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ +//! # +//! # │ <> │ +//! # +//! # ╔═══════════════════════╗ │╔═══════════════╗ │ +//! # ║<>║ ║ <> ║ ╔════════════════╗ ╔═══════════════════╗ +//! # ┌───────────────────┐ ║ PROCESSOR ║ │║ Sender ║ ║ Storage ║│ ║ TRACER ║ +//! # │ Server │ ╚═══════════╦═══════════╝ ╚══════╦════════╝ ╚═══════╦════════╝ ╚═════════╦═════════╝ +//! # └─────────┬─────────┘ │ │ │ │ │ │ +//! # │ │ │ │ │ +//! # │ │ │ │ │ │ │ +//! # POST │ │ │ │ │ +//! # (body, headers)│ │ │ │ │ │ │ +//! # ──────────▶┌┴┐ │ │ │ │ +//! # ┌─┐ │ │new(headers)╔════════════╗ │ │ │ │ │ │ +//! # │1│ │ ├───────────▶║s: Settings ║ │ │ │ │ +//! # └─┘ │ │ ╚════════════╝ │ │ │ │ │ │ +//! # │ │ │ │ │ │ +//! # │ │ ╔═══════════════════╗ │ │ │ │ │ │ +//! # │ │ ║ Capturer::Enabled ║ │ │ │ │ ┌────────────┐ +//! # │ │ ╚═══════════════════╝ │ │ │ │ │ │ │<│ +//! # │ │ │ │ │ │ │ └──────┬─────┘ +//! # │ │ ┌─┐ new(trace_id, s) │ │ │ │ │ │ │ │ +//! # │ ├───┤2├───────────────────────▶│ │ │ │ │ │ +//! # │ │ └─┘ │ │ │ │ │ │ │ │ +//! # │ │ │ │ │ │ │ │ +//! # │ │ ┌─┐ start_capturing() │ start_capturing │ │ │ │ │ │ │ +//! # │ ├───┤3├───────────────────────▶│ (trace_id, s) │ │ │ │ │ +//! # │ │ └─┘ │ │ │ │ │ │ │ │ +//! # │ │ ├─────────────────────▶│ send(StartCapturing, │ │ │ │ +//! # │ │ │ │ trace_id)│ │ │ │ │ │ +//! # │ │ │ │── ── ── ── ── ── ── ─▶│ │ │ │ +//! # │ │ │ │ ┌─┐ │ │insert(trace_id, s) │ │ │ │ +//! # │ │ │ │ │4│ │────────────────────▶│ │ │ +//! # │ │ │ │ └─┘ │ │ │ │ ┌─┐ │ process_query │ +//! # │ │──────────────────────────────┼──────────────────────┼───────────────────────┼─────────────────────┼────────────┤5├──────┼──────────────────────────▶┌┴┐ +//! # │ │ │ │ │ │ │ │ └─┘ │ │ │ +//! # │ │ │ │ │ │ │ │ │ //! # │ │ │ │ │ │ │ │ │ │ │ ┌─────────────────────┐ //! # │ │ │ │ │ │ │ log! / span! ┌─┐ │ │ │ res: PrismaResponse │ //! # │ │ │ │ │ │ │ │ │◀─────────────────────┤6├──│ │ └──────────┬──────────┘ -//! # │ │ │ │ │ on_end(span_data)│ ┌─┐ │ └─┘ │ │ new │ -//! # │ │ │ │◀──────────────┼───────┼─────────────────────┼─────────┼──┤7├──────┤ │ │────────────▶│ -//! # │ │ │ │ send(SpanDataProcessed│ │ └─┘ │ │ │ │ -//! # │ │ │ │ , trace_id) │ append(trace_id, │ │ │ │ │ │ -//! # │ │ │ │── ── ── ── ── ── ── ─▶│ logs, traces) │ │ │ │ │ -//! # │ │ │ │ ┌─┐ │ ├────────────────────▶│ │ │ │ │ │ -//! # │ │ │ │ │8│ │ │ │ │ │ │ -//! # │ │ res: PrismaResponse │ ┌─┐ │ └─┘ │ │ │ │ │ │ │ │ -//! # │ │─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┼ ┤9├ ─ ─ ─ ─ ─ ─ ─ ─ ─│─ ─ ─ ─ ─ ─return ─ ─ ─│─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─│─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─│─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─└┬┘ │ -//! # │ │ ┌────┐ fetch_captures() │ └─┘ │ │ │ │ │ │ │ │ -//! # │ ├─┤ 10 ├──────────────────────▶│ fetch_captures │ │ │ │ │ │ -//! # │ │ └────┘ │ (trace_id) │ │ │ │ │ │ │ │ -//! # │ │ ├─────────────────────▶│ send(FetchCaptures, │ │ │ x │ -//! # │ │ │ │ trace_id) │ │ │ │ │ -//! # │ │ │ │── ── ── ── ── ── ── ─▶│ get logs/traces │ │ │ -//! # │ │ │ │ ┌────┐ │ ├─────────────────────▶ │ │ │ -//! # │ │ │ │ │ 11 │ │ │ │ │ -//! # │ │ │ │ └────┘ │ │◁ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─│ │ │ │ -//! # │ │ │ │ │ │ │ │ -//! # │ │ ◁ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─│ │ │ │ │ │ │ -//! # │ │ logs, traces │ │ │ │ │ │ -//! # │ │◁─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─│ │ │ │ │ │ │ │ -//! # │ │ x ┌────┐ │ │ │ │ res.set_extension(logs) │ -//! # │ ├───────────────────────────────────────┤ 12 ├────────┼───────────────┼───────┼─────────────────────┼─────────┼───────────┼──────────────────────────────────────────▶│ -//! # │ │ └────┘ │ │ │ │ res.set_extension(traces) │ -//! # │ ├─────────────────────────────────────────────────────┼───────────────┼───────┼─────────────────────┼─────────┼───────────┼──────────────────────────────────────────▶│ -//! # ◀ ─ ─ ─└┬┘ │ │ │ │ x -//! # json!(res) │ │ │ -//! # ┌────┐ │ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ -//! # │ 13 │ │ -//! # └────┘ -//! # -//! # ◀─────── call (pseudo-signatures) -//! # -//! # ◀─ ── ── async message passing (channels) -//! # -//! # ◁─ ─ ─ ─ return -//! # -//! +//! # │ │ │ │ │ on_end(span_data)│ ┌─┐ │ └─┘ │ │ new │ +//! # │ │ │ │◀──────────────┼───────┼─────────────────────┼─────────┼──┤7├──────┤ │ │────────────▶│ +//! # │ │ │ │ send(SpanDataProcessed│ │ └─┘ │ │ │ │ +//! # │ │ │ │ , trace_id) │ append(trace_id, │ │ │ │ │ │ +//! # │ │ │ │── ── ── ── ── ── ── ─▶│ logs, traces) │ │ │ │ │ +//! # │ │ │ │ ┌─┐ │ ├────────────────────▶│ │ │ │ │ │ +//! # │ │ │ │ │8│ │ │ │ │ │ │ +//! # │ │ res: PrismaResponse │ ┌─┐ │ └─┘ │ │ │ │ │ │ │ │ +//! # │ │─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┼ ┤9├ ─ ─ ─ ─ ─ ─ ─ ─ ─│─ ─ ─ ─ ─ ─return ─ ─ ─│─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─│─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─│─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─└┬┘ │ +//! # │ │ ┌────┐ fetch_captures() │ └─┘ │ │ │ │ │ │ │ │ +//! # │ ├─┤ 10 ├──────────────────────▶│ fetch_captures │ │ │ │ │ │ +//! # │ │ └────┘ │ (trace_id) │ │ │ │ │ │ │ │ +//! # │ │ ├─────────────────────▶│ send(FetchCaptures, │ │ │ x │ +//! # │ │ │ │ trace_id) │ │ │ │ │ +//! # │ │ │ │── ── ── ── ── ── ── ─▶│ get logs/traces │ │ │ +//! # │ │ │ │ ┌────┐ │ ├─────────────────────▶ │ │ │ +//! # │ │ │ │ │ 11 │ │ │ │ │ +//! # │ │ │ │ └────┘ │ │◁ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─│ │ │ │ +//! # │ │ │ │ │ │ │ │ +//! # │ │ ◁ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─│ │ │ │ │ │ │ +//! # │ │ logs, traces │ │ │ │ │ │ +//! # │ │◁─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─│ │ │ │ │ │ │ │ +//! # │ │ x ┌────┐ │ │ │ │ res.set_extension(logs) │ +//! # │ ├───────────────────────────────────────┤ 12 ├────────┼───────────────┼───────┼─────────────────────┼─────────┼───────────┼──────────────────────────────────────────▶│ +//! # │ │ └────┘ │ │ │ │ res.set_extension(traces) │ +//! # │ ├─────────────────────────────────────────────────────┼───────────────┼───────┼─────────────────────┼─────────┼───────────┼──────────────────────────────────────────▶│ +//! # ◀ ─ ─ ─└┬┘ │ │ │ │ x +//! # json!(res) │ │ │ +//! # ┌────┐ │ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ +//! # │ 13 │ │ +//! # └────┘ +//! # +//! # ◀─────── call (pseudo-signatures) +//! # +//! # ◀─ ── ── async message passing (channels) +//! # +//! # ◁─ ─ ─ ─ return +//! # +//! //! In the diagram, you will see objects whose lifetime is static. The boxes for those have a double //! width margin. These are: -//! +//! //! - The `server` itself //! - The global `TRACER`, which handles `log!` and `span!` and uses the global `PROCESSOR` to //! process the data constituting a trace `Span`s and log `Event`s //! - The global `PROCESSOR`, which manages the `Storage` set of data structures, holding logs, //! traces (and capture settings) per request. -//! +//! //! Then, through the request lifecycle, different objects are created and dropped: -//! +//! //! - When a request comes in, its headers are processed and a [`Settings`] object is built, this //! object determines, for the request, how logging and tracing are going to be captured: if only //! traces, logs, or both, and which log levels are going to be captured. @@ -105,9 +105,9 @@ //! part of the channel is kept in a global, so it can be cloned and used by a) the Capturer //! (to start capturing / fetch the captures) or by the tracer's SpanProcessor, to extract //! tracing and logging information that's eventually displayed to the user. -//! +//! //! Then the capturing process works in this way: -//! +//! //! - The server receives a query **[1]** //! - It grabs the HTTP headers and builds a `Capture` object **[2]**, which is configured with the settings //! denoted by the `X-capture-telemetry` @@ -138,14 +138,16 @@ #![allow(unused_imports, dead_code)] pub use self::capturer::Capturer; pub use self::settings::Settings; -pub use tx_ext::TxTraceExt; use self::capturer::Processor; use once_cell::sync::Lazy; use opentelemetry::{global, sdk, trace}; use tracing::subscriber; use tracing_subscriber::{ - filter::filter_fn, layer::Layered, prelude::__tracing_subscriber_SubscriberExt, Layer, Registry, + filter::filter_fn, + layer::{Layered, SubscriberExt}, + registry::LookupSpan, + Layer, Registry, }; static PROCESSOR: Lazy = Lazy::new(Processor::default); @@ -159,12 +161,8 @@ pub fn capturer(trace_id: trace::TraceId, settings: Settings) -> Capturer { /// Adds a capturing layer to the given subscriber and installs the transformed subscriber as the /// global, default subscriber #[cfg(feature = "metrics")] -#[allow(clippy::type_complexity)] pub fn install_capturing_layer( - subscriber: Layered< - Option, - Layered + Send + Sync>, Registry>, - >, + subscriber: impl SubscriberExt + for<'a> LookupSpan<'a> + Send + Sync + 'static, log_queries: bool, ) { // set a trace context propagator, so that the trace context is propagated via the @@ -198,4 +196,3 @@ mod capturer; mod helpers; mod settings; pub mod storage; -mod tx_ext; diff --git a/query-engine/core/src/telemetry/capturing/settings.rs b/libs/telemetry/src/capturing/settings.rs similarity index 100% rename from query-engine/core/src/telemetry/capturing/settings.rs rename to libs/telemetry/src/capturing/settings.rs diff --git a/query-engine/core/src/telemetry/capturing/storage.rs b/libs/telemetry/src/capturing/storage.rs similarity index 92% rename from query-engine/core/src/telemetry/capturing/storage.rs rename to libs/telemetry/src/capturing/storage.rs index 9c276716911..5c83affc85e 100644 --- a/query-engine/core/src/telemetry/capturing/storage.rs +++ b/libs/telemetry/src/capturing/storage.rs @@ -1,5 +1,5 @@ use super::settings::Settings; -use crate::telemetry::models; +use crate::models; #[derive(Debug, Default)] pub struct Storage { diff --git a/libs/telemetry/src/helpers.rs b/libs/telemetry/src/helpers.rs new file mode 100644 index 00000000000..4a332e86af6 --- /dev/null +++ b/libs/telemetry/src/helpers.rs @@ -0,0 +1,178 @@ +use super::models::TraceSpan; +use derive_more::Display; +use once_cell::sync::Lazy; +use opentelemetry::propagation::Extractor; +use opentelemetry::sdk::export::trace::SpanData; +use opentelemetry::trace::{SpanId, TraceContextExt, TraceFlags, TraceId}; +use serde_json::{json, Value}; +use std::collections::HashMap; +use tracing::Metadata; +use tracing_subscriber::EnvFilter; + +pub static SHOW_ALL_TRACES: Lazy = Lazy::new(|| match std::env::var("PRISMA_SHOW_ALL_TRACES") { + Ok(enabled) => enabled.eq_ignore_ascii_case("true"), + Err(_) => false, +}); + +/// `TraceParent` is a remote span. It is identified by `trace_id` and `span_id`. +/// +/// By "remote" we mean that this span was not emitted in the current process. In real life, it is +/// either: +/// - Emitted by the JS part of the Prisma ORM. This is true both for Accelerate (where the Rust +/// part is deployed as a server) and for the ORM (where the Rust part is a shared library) +/// - Never emitted at all. This happens when the `TraceParent` is created artificially from `TxId` +/// (see `TxId::as_traceparent`). In this case, `TraceParent` is used only to correlate logs +/// from different transaction operations - it is never used as a part of the trace +#[derive(Display, Copy, Clone)] +// This conforms with https://www.w3.org/TR/trace-context/#traceparent-header-field-values. Accelerate +// relies on this behaviour. +#[display(fmt = "00-{trace_id:032x}-{span_id:016x}-{flags:02x}")] +pub struct TraceParent { + trace_id: TraceId, + span_id: SpanId, + flags: TraceFlags, +} + +impl TraceParent { + pub fn from_remote_context(context: &opentelemetry::Context) -> Option { + let span = context.span(); + let span_context = span.span_context(); + + if span_context.is_valid() { + Some(Self { + trace_id: span_context.trace_id(), + span_id: span_context.span_id(), + flags: span_context.trace_flags(), + }) + } else { + None + } + } + + // TODO(aqrln): remove this method once the log capturing doesn't rely on trace IDs anymore + #[deprecated = "this must only be used to create an artificial traceparent for log capturing when tracing is disabled on the client"] + pub fn new_random() -> Self { + Self { + trace_id: TraceId::from_bytes(rand::random()), + span_id: SpanId::from_bytes(rand::random()), + flags: TraceFlags::SAMPLED, + } + } + + pub fn trace_id(&self) -> TraceId { + self.trace_id + } + + pub fn sampled(&self) -> bool { + self.flags.is_sampled() + } + + /// Returns a remote `opentelemetry::Context`. By "remote" we mean that it wasn't emitted in the + /// current process. + pub fn to_remote_context(&self) -> opentelemetry::Context { + // This relies on the fact that global text map propagator was installed that + // can handle `traceparent` field (for example, `TraceContextPropagator`). + opentelemetry::global::get_text_map_propagator(|propagator| { + propagator.extract(&TraceParentExtractor::new(self)) + }) + } +} + +/// An extractor to use with `TraceContextPropagator`. It allows to avoid creating a full `HashMap` +/// to convert a `TraceParent` to a `Context`. +pub struct TraceParentExtractor(String); + +impl TraceParentExtractor { + pub fn new(traceparent: &TraceParent) -> Self { + Self(traceparent.to_string()) + } +} + +impl Extractor for TraceParentExtractor { + fn get(&self, key: &str) -> Option<&str> { + if key == "traceparent" { + Some(&self.0) + } else { + None + } + } + + fn keys(&self) -> Vec<&str> { + vec!["traceparent"] + } +} + +pub fn spans_to_json(spans: Vec) -> String { + let json_spans: Vec = spans.into_iter().map(|span| json!(TraceSpan::from(span))).collect(); + let span_result = json!({ + "span": true, + "spans": json_spans + }); + serde_json::to_string(&span_result).unwrap_or_default() +} + +pub fn restore_remote_context_from_json_str(serialized: &str) -> opentelemetry::Context { + // This relies on the fact that global text map propagator was installed that + // can handle `traceparent` field (for example, `TraceContextPropagator`). + let trace: HashMap = serde_json::from_str(serialized).unwrap_or_default(); + opentelemetry::global::get_text_map_propagator(|propagator| propagator.extract(&trace)) +} + +pub enum QueryEngineLogLevel { + FromEnv, + Override(String), +} + +impl QueryEngineLogLevel { + fn level(self) -> Option { + match self { + Self::FromEnv => std::env::var("QE_LOG_LEVEL").ok(), + Self::Override(l) => Some(l), + } + } +} + +#[rustfmt::skip] +pub fn env_filter(log_queries: bool, qe_log_level: QueryEngineLogLevel) -> EnvFilter { + let mut filter = EnvFilter::from_default_env() + .add_directive("tide=error".parse().unwrap()) + .add_directive("tonic=error".parse().unwrap()) + .add_directive("h2=error".parse().unwrap()) + .add_directive("hyper=error".parse().unwrap()) + .add_directive("tower=error".parse().unwrap()); + + if let Some(ref level) = qe_log_level.level() { + filter = filter + .add_directive(format!("query_engine={}", level).parse().unwrap()) + .add_directive(format!("query_core={}", level).parse().unwrap()) + .add_directive(format!("query_connector={}", level).parse().unwrap()) + .add_directive(format!("sql_query_connector={}", level).parse().unwrap()) + .add_directive(format!("mongodb_query_connector={}", level).parse().unwrap()); + } + + if log_queries { + filter = filter + .add_directive("quaint[{is_query}]=trace".parse().unwrap()) + .add_directive("mongodb_query_connector=debug".parse().unwrap()); + } + + filter +} + +pub fn user_facing_span_only_filter(meta: &Metadata<'_>) -> bool { + if !meta.is_span() { + return false; + } + + if *SHOW_ALL_TRACES { + return true; + } + + if meta.fields().iter().any(|f| f.name() == "user_facing") { + return true; + } + + // spans describing a quaint query. + // TODO: should this span be made user_facing in quaint? + meta.target() == "quaint::connector::metrics" && meta.name() == "quaint:query" +} diff --git a/query-engine/core/src/telemetry/mod.rs b/libs/telemetry/src/lib.rs similarity index 100% rename from query-engine/core/src/telemetry/mod.rs rename to libs/telemetry/src/lib.rs diff --git a/query-engine/core/src/telemetry/models.rs b/libs/telemetry/src/models.rs similarity index 93% rename from query-engine/core/src/telemetry/models.rs rename to libs/telemetry/src/models.rs index c1e9ff0158b..275ec5e5693 100644 --- a/query-engine/core/src/telemetry/models.rs +++ b/libs/telemetry/src/models.rs @@ -7,7 +7,21 @@ use std::{ time::{Duration, SystemTime}, }; -const ACCEPT_ATTRIBUTES: &[&str] = &["db.statement", "itx_id", "db.type"]; +const ACCEPT_ATTRIBUTES: &[&str] = &[ + "db.system", + "db.statement", + "db.collection.name", + "db.operation.name", + "itx_id", +]; + +#[derive(Serialize, Debug, Clone, PartialEq, Eq)] +pub enum SpanKind { + #[serde(rename = "client")] + Client, + #[serde(rename = "internal")] + Internal, +} #[derive(Serialize, Debug, Clone, PartialEq, Eq)] pub struct TraceSpan { @@ -23,6 +37,7 @@ pub struct TraceSpan { pub(super) events: Vec, #[serde(skip_serializing_if = "Vec::is_empty")] pub(super) links: Vec, + pub(super) kind: SpanKind, } #[derive(Serialize, Debug, Clone, PartialEq, Eq)] @@ -39,6 +54,11 @@ impl TraceSpan { impl From for TraceSpan { fn from(span: SpanData) -> Self { + let kind = match span.span_kind { + opentelemetry::trace::SpanKind::Client => SpanKind::Client, + _ => SpanKind::Internal, + }; + let attributes: HashMap = span.attributes .iter() @@ -105,6 +125,7 @@ impl From for TraceSpan { attributes, links, events, + kind, } } } diff --git a/libs/test-cli/Cargo.toml b/libs/test-cli/Cargo.toml index 936ff3d9ee4..48c1f317067 100644 --- a/libs/test-cli/Cargo.toml +++ b/libs/test-cli/Cargo.toml @@ -18,3 +18,6 @@ tracing.workspace = true tracing-subscriber = { version = "0.3", features = ["env-filter"] } tracing-error = "0.2" async-trait.workspace = true + +[build-dependencies] +build-utils.path = "../build-utils" diff --git a/libs/test-cli/build.rs b/libs/test-cli/build.rs index 9bd10ecb9c5..33aded23a4a 100644 --- a/libs/test-cli/build.rs +++ b/libs/test-cli/build.rs @@ -1,7 +1,3 @@ -use std::process::Command; - fn main() { - let output = Command::new("git").args(["rev-parse", "HEAD"]).output().unwrap(); - let git_hash = String::from_utf8(output.stdout).unwrap(); - println!("cargo:rustc-env=GIT_HASH={git_hash}"); + build_utils::store_git_commit_hash_in_env(); } diff --git a/libs/test-macros/Cargo.toml b/libs/test-macros/Cargo.toml index 1d13b8029c0..eaeedc45a9b 100644 --- a/libs/test-macros/Cargo.toml +++ b/libs/test-macros/Cargo.toml @@ -9,4 +9,4 @@ proc-macro = true [dependencies] proc-macro2 = "1.0.26" quote = "1.0.2" -syn = "1.0.5" +syn = { version = "1.0.5", features = ["full"] } diff --git a/libs/user-facing-errors/src/lib.rs b/libs/user-facing-errors/src/lib.rs index 7d785683163..a1916e55162 100644 --- a/libs/user-facing-errors/src/lib.rs +++ b/libs/user-facing-errors/src/lib.rs @@ -119,9 +119,9 @@ impl Error { } } - /// Construct a new UnknownError from a `PanicInfo` in a panic hook. `UnknownError`s created - /// with this constructor will have a proper, useful backtrace. - pub fn new_in_panic_hook(panic_info: &std::panic::PanicInfo<'_>) -> Self { + /// Construct a new UnknownError from a [`PanicHookInfo`] in a panic hook. [`UnknownError`]s + /// created with this constructor will have a proper, useful backtrace. + pub fn new_in_panic_hook(panic_info: &std::panic::PanicHookInfo<'_>) -> Self { let message = panic_info .payload() .downcast_ref::<&str>() diff --git a/libs/user-facing-errors/src/query_engine/mod.rs b/libs/user-facing-errors/src/query_engine/mod.rs index e42fbcb03f5..804ab240653 100644 --- a/libs/user-facing-errors/src/query_engine/mod.rs +++ b/libs/user-facing-errors/src/query_engine/mod.rs @@ -68,10 +68,7 @@ pub struct UniqueKeyViolation { } #[derive(Debug, UserFacingError, Serialize)] -#[user_facing( - code = "P2003", - message = "Foreign key constraint failed on the field: `{field_name}`" -)] +#[user_facing(code = "P2003", message = "Foreign key constraint violated: `{field_name}`")] pub struct ForeignKeyViolation { /// Field name from one model from Prisma schema pub field_name: String, diff --git a/libs/user-facing-errors/src/schema_engine.rs b/libs/user-facing-errors/src/schema_engine.rs index 7329461ff2b..a3a81211c12 100644 --- a/libs/user-facing-errors/src/schema_engine.rs +++ b/libs/user-facing-errors/src/schema_engine.rs @@ -15,25 +15,29 @@ pub struct DatabaseCreationFailed { code = "P3001", message = "Migration possible with destructive changes and possible data loss: {destructive_details}" )] +#[allow(dead_code)] pub struct DestructiveMigrationDetected { pub destructive_details: String, } +/// No longer used. #[derive(Debug, UserFacingError, Serialize)] #[user_facing( code = "P3002", message = "The attempted migration was rolled back: {database_error}" )] +#[allow(dead_code)] struct MigrationRollback { pub database_error: String, } -// No longer used. +/// No longer used. #[derive(Debug, SimpleUserFacingError)] #[user_facing( code = "P3003", message = "The format of migrations changed, the saved migrations are no longer valid. To solve this problem, please follow the steps at: https://pris.ly/d/migrate" )] +#[allow(dead_code)] pub struct DatabaseMigrationFormatChanged; #[derive(Debug, UserFacingError, Serialize)] diff --git a/nix/publish-engine-size.nix b/nix/publish-engine-size.nix index 11a63d7de7e..7fe34f36d6c 100644 --- a/nix/publish-engine-size.nix +++ b/nix/publish-engine-size.nix @@ -22,12 +22,15 @@ let craneLib = (flakeInputs.crane.mkLib pkgs).overrideToolchain rustToolchain; deps = craneLib.vendorCargoDeps { inherit src; }; libSuffix = stdenv.hostPlatform.extensions.sharedLibrary; + fakeGitHash = "0000000000000000000000000000000000000000"; in { packages.prisma-engines = stdenv.mkDerivation { name = "prisma-engines"; inherit src; + GIT_HASH = "${fakeGitHash}"; + buildInputs = [ pkgs.openssl.out ]; nativeBuildInputs = with pkgs; [ rustToolchain @@ -38,6 +41,7 @@ in ] ++ lib.optionals stdenv.isDarwin [ perl # required to build openssl darwin.apple_sdk.frameworks.Security + darwin.apple_sdk.frameworks.SystemConfiguration iconv ]; @@ -68,6 +72,8 @@ in inherit src; inherit (self'.packages.prisma-engines) buildInputs nativeBuildInputs configurePhase dontStrip; + GIT_HASH = "${fakeGitHash}"; + buildPhase = "cargo build --profile=${profile} --bin=test-cli"; installPhase = '' @@ -85,6 +91,8 @@ in inherit src; inherit (self'.packages.prisma-engines) buildInputs nativeBuildInputs configurePhase dontStrip; + GIT_HASH = "${fakeGitHash}"; + buildPhase = "cargo build --profile=${profile} --bin=query-engine"; installPhase = '' @@ -105,6 +113,8 @@ in inherit src; inherit (self'.packages.prisma-engines) buildInputs nativeBuildInputs configurePhase dontStrip; + GIT_HASH = "${fakeGitHash}"; + buildPhase = '' cargo build --profile=${profile} --bin=query-engine cargo build --profile=${profile} -p query-engine-node-api @@ -134,6 +144,8 @@ in inherit src; buildInputs = with pkgs; [ iconv ]; + GIT_HASH = "${fakeGitHash}"; + buildPhase = '' export HOME=$(mktemp -dt wasm-engine-home-XXXX) diff --git a/nix/shell.nix b/nix/shell.nix index 14d33f64abf..f2767cbcf2c 100644 --- a/nix/shell.nix +++ b/nix/shell.nix @@ -13,7 +13,7 @@ in nodejs_20 nodejs_20.pkgs.typescript-language-server - nodejs_20.pkgs.pnpm + pnpm_8 binaryen cargo-insta diff --git a/prisma-fmt/Cargo.toml b/prisma-fmt/Cargo.toml index 6778573f3a6..18b9a9042c6 100644 --- a/prisma-fmt/Cargo.toml +++ b/prisma-fmt/Cargo.toml @@ -22,6 +22,9 @@ dissimilar = "1.0.3" once_cell = "1.9.0" expect-test = "1" +[build-dependencies] +build-utils.path = "../libs/build-utils" + [features] # sigh please don't ask :( vendored-openssl = [] diff --git a/prisma-fmt/build.rs b/prisma-fmt/build.rs index 2e8fe20c050..33aded23a4a 100644 --- a/prisma-fmt/build.rs +++ b/prisma-fmt/build.rs @@ -1,11 +1,3 @@ -use std::process::Command; - -fn store_git_commit_hash() { - let output = Command::new("git").args(["rev-parse", "HEAD"]).output().unwrap(); - let git_hash = String::from_utf8(output.stdout).unwrap(); - println!("cargo:rustc-env=GIT_HASH={git_hash}"); -} - fn main() { - store_git_commit_hash(); + build_utils::store_git_commit_hash_in_env(); } diff --git a/prisma-fmt/src/get_dmmf.rs b/prisma-fmt/src/get_dmmf.rs index 6f3f03aa4f1..d398b3f131b 100644 --- a/prisma-fmt/src/get_dmmf.rs +++ b/prisma-fmt/src/get_dmmf.rs @@ -606,7 +606,7 @@ mod tests { "isNullable": false, "inputTypes": [ { - "type": "BRelationFilter", + "type": "BScalarRelationFilter", "namespace": "prisma", "location": "inputObjectTypes", "isList": false @@ -764,7 +764,7 @@ mod tests { "isNullable": false, "inputTypes": [ { - "type": "BRelationFilter", + "type": "BScalarRelationFilter", "namespace": "prisma", "location": "inputObjectTypes", "isList": false @@ -1037,7 +1037,7 @@ mod tests { "isNullable": true, "inputTypes": [ { - "type": "ANullableRelationFilter", + "type": "ANullableScalarRelationFilter", "namespace": "prisma", "location": "inputObjectTypes", "isList": false @@ -1174,7 +1174,7 @@ mod tests { "isNullable": true, "inputTypes": [ { - "type": "ANullableRelationFilter", + "type": "ANullableScalarRelationFilter", "namespace": "prisma", "location": "inputObjectTypes", "isList": false @@ -2037,7 +2037,7 @@ mod tests { ] }, { - "name": "BRelationFilter", + "name": "BScalarRelationFilter", "constraints": { "maxNumFields": null, "minNumFields": null @@ -2436,7 +2436,7 @@ mod tests { ] }, { - "name": "ANullableRelationFilter", + "name": "ANullableScalarRelationFilter", "constraints": { "maxNumFields": null, "minNumFields": null diff --git a/prisma-fmt/src/validate.rs b/prisma-fmt/src/validate.rs index 67b12c45ce2..d458d389793 100644 --- a/prisma-fmt/src/validate.rs +++ b/prisma-fmt/src/validate.rs @@ -59,6 +59,28 @@ mod tests { use super::*; use expect_test::expect; + #[test] + fn validate_non_ascii_identifiers() { + let schema = r#" + datasource db { + provider = "postgresql" + url = env("DBURL") + } + + model Lööps { + id Int @id + läderlappen Boolean + } + "#; + + let request = json!({ + "prismaSchema": schema, + }); + + let response = validate(&request.to_string()); + assert!(response.is_ok()) + } + #[test] fn validate_invalid_schema_with_colors() { let schema = r#" diff --git a/psl/psl-core/Cargo.toml b/psl/psl-core/Cargo.toml index ca108793968..cd069d9bce3 100644 --- a/psl/psl-core/Cargo.toml +++ b/psl/psl-core/Cargo.toml @@ -22,7 +22,7 @@ chrono = { workspace = true } connection-string.workspace = true itertools.workspace = true once_cell = "1.3.1" -regex = "1.3.7" +regex.workspace = true serde.workspace = true serde_json.workspace = true enumflags2.workspace = true diff --git a/psl/psl-core/src/builtin_connectors/postgres_datamodel_connector.rs b/psl/psl-core/src/builtin_connectors/postgres_datamodel_connector.rs index f14a6b9bf1b..65a0d929995 100644 --- a/psl/psl-core/src/builtin_connectors/postgres_datamodel_connector.rs +++ b/psl/psl-core/src/builtin_connectors/postgres_datamodel_connector.rs @@ -463,7 +463,10 @@ impl Connector for PostgresDatamodelConnector { } fn validate_url(&self, url: &str) -> Result<(), String> { - if !url.starts_with("postgres://") && !url.starts_with("postgresql://") { + if !url.starts_with("postgres://") + && !url.starts_with("postgresql://") + && !url.starts_with("prisma+postgres://") + { return Err("must start with the protocol `postgresql://` or `postgres://`.".to_owned()); } diff --git a/psl/psl-core/src/common/preview_features.rs b/psl/psl-core/src/common/preview_features.rs index 49e86057551..ea9b0eceea8 100644 --- a/psl/psl-core/src/common/preview_features.rs +++ b/psl/psl-core/src/common/preview_features.rs @@ -81,7 +81,8 @@ features!( ReactNative, PrismaSchemaFolder, OmitApi, - TypedSql + TypedSql, + StrictUndefinedChecks ); /// Generator preview features (alphabetically sorted) @@ -100,6 +101,7 @@ pub const ALL_PREVIEW_FEATURES: FeatureMap = FeatureMap { | RelationJoins | OmitApi | PrismaSchemaFolder + | StrictUndefinedChecks }), deprecated: enumflags2::make_bitflags!(PreviewFeature::{ AtomicNumberOperations diff --git a/psl/psl/tests/config/generators.rs b/psl/psl/tests/config/generators.rs index 273de14c744..20e8f886440 100644 --- a/psl/psl/tests/config/generators.rs +++ b/psl/psl/tests/config/generators.rs @@ -258,7 +258,7 @@ fn nice_error_for_unknown_generator_preview_feature() { .unwrap_err(); let expectation = expect![[r#" - error: The preview feature "foo" is not known. Expected one of: deno, driverAdapters, fullTextIndex, fullTextSearch, metrics, multiSchema, nativeDistinct, postgresqlExtensions, tracing, views, relationJoins, prismaSchemaFolder, omitApi + error: The preview feature "foo" is not known. Expected one of: deno, driverAdapters, fullTextIndex, fullTextSearch, metrics, multiSchema, nativeDistinct, postgresqlExtensions, tracing, views, relationJoins, prismaSchemaFolder, omitApi, strictUndefinedChecks --> schema.prisma:3  |   2 |  provider = "prisma-client-js" diff --git a/psl/psl/tests/validation/enums/value_with_non_ascii_ident_should_not_error.prisma b/psl/psl/tests/validation/enums/value_with_non_ascii_ident_should_not_error.prisma new file mode 100644 index 00000000000..9adb651bde5 --- /dev/null +++ b/psl/psl/tests/validation/enums/value_with_non_ascii_ident_should_not_error.prisma @@ -0,0 +1,12 @@ +generator client { + provider = "prisma-client-js" +} + +datasource db { + provider = "postgresql" + url = env("DATABASE_URL") +} + +enum CatVariant { + lööps +} diff --git a/psl/psl/tests/validation/models/field_with_non_ascii_ident_should_not_error.prisma b/psl/psl/tests/validation/models/field_with_non_ascii_ident_should_not_error.prisma new file mode 100644 index 00000000000..f5752f9f2bf --- /dev/null +++ b/psl/psl/tests/validation/models/field_with_non_ascii_ident_should_not_error.prisma @@ -0,0 +1,14 @@ +generator client { + provider = "prisma-client-js" +} + +datasource db { + provider = "postgresql" + url = env("DATABASE_URL") +} + +model A { + id Int @id @map("_id") + lööps String[] +} + diff --git a/psl/psl/tests/validation/models/model_with_non_ascii_ident_should_not_error.prisma b/psl/psl/tests/validation/models/model_with_non_ascii_ident_should_not_error.prisma new file mode 100644 index 00000000000..a2e4ca0c3af --- /dev/null +++ b/psl/psl/tests/validation/models/model_with_non_ascii_ident_should_not_error.prisma @@ -0,0 +1,12 @@ +generator client { + provider = "prisma-client-js" +} + +datasource db { + provider = "postgresql" + url = env("DATABASE_URL") +} + +model Lööp { + id Int @id +} diff --git a/psl/schema-ast/src/parser/datamodel.pest b/psl/schema-ast/src/parser/datamodel.pest index 62bf8459e00..a0b84bd0d06 100644 --- a/psl/schema-ast/src/parser/datamodel.pest +++ b/psl/schema-ast/src/parser/datamodel.pest @@ -130,9 +130,11 @@ doc_content = @{ (!NEWLINE ~ ANY)* } // ###################################### // shared building blocks // ###################################### -identifier = @{ ASCII_ALPHANUMERIC ~ ( "_" | "-" | ASCII_ALPHANUMERIC)* } +unicode_alphanumeric = { LETTER | ASCII_DIGIT } +identifier = @{ unicode_alphanumeric ~ ( "_" | "-" | unicode_alphanumeric)* } path = @{ identifier ~ ("." ~ path?)* } + WHITESPACE = _{ SPACE_SEPARATOR | "\t" } // tabs are also whitespace NEWLINE = _{ "\n" | "\r\n" | "\r" } empty_lines = @{ (WHITESPACE* ~ NEWLINE)+ } diff --git a/quaint/Cargo.toml b/quaint/Cargo.toml index 89507de2063..482f6f0a762 100644 --- a/quaint/Cargo.toml +++ b/quaint/Cargo.toml @@ -51,6 +51,8 @@ postgresql-native = [ "bit-vec", "lru-cache", "byteorder", + "dep:ws_stream_tungstenite", + "dep:async-tungstenite" ] postgresql = [] @@ -70,15 +72,17 @@ fmt-sql = ["sqlformat"] connection-string = "0.2" percent-encoding = "2" tracing.workspace = true -tracing-core = "0.1" +tracing-futures.workspace = true async-trait.workspace = true thiserror = "1.0" num_cpus = "1.12" -metrics = "0.18" -futures = "0.3" +prisma-metrics.path = "../libs/metrics" +futures.workspace = true url.workspace = true hex = "0.4" itertools.workspace = true +regex.workspace = true +enumflags2.workspace = true either = { version = "1.6" } base64 = { version = "0.12.3" } @@ -88,7 +92,7 @@ serde_json.workspace = true native-tls = { version = "0.2", optional = true } bit-vec = { version = "0.6.1", optional = true } bytes = { version = "1.0", optional = true } -mobc = { version = "0.8", optional = true } +mobc = { version = "0.8.5", optional = true } serde = { version = "1.0" } sqlformat = { version = "0.2.3", optional = true } uuid.workspace = true @@ -110,6 +114,16 @@ expect-test = "1" version = "0.2" features = ["js"] +[dependencies.ws_stream_tungstenite] +version = "0.14.0" +features = ["tokio_io"] +optional = true + +[dependencies.async-tungstenite] +version = "0.28.0" +features = ["tokio-runtime", "tokio-native-tls"] +optional = true + [dependencies.byteorder] default-features = false optional = true @@ -126,12 +140,12 @@ features = ["chrono", "column_decltype"] optional = true [target.'cfg(not(any(target_os = "macos", target_os = "ios")))'.dependencies.tiberius] -version = "0.11.8" +version = "0.12.3" optional = true features = ["sql-browser-tokio", "chrono", "bigdecimal"] [target.'cfg(any(target_os = "macos", target_os = "ios"))'.dependencies.tiberius] -version = "0.11.8" +version = "0.12.3" optional = true default-features = false features = [ @@ -179,7 +193,7 @@ features = ["rt-multi-thread", "macros", "sync"] optional = true [dependencies.tokio-util] -version = "0.6" +version = "0.7" features = ["compat"] optional = true diff --git a/quaint/src/connector/connection_info.rs b/quaint/src/connector/connection_info.rs index 7dd8a5b5825..80d63089e48 100644 --- a/quaint/src/connector/connection_info.rs +++ b/quaint/src/connector/connection_info.rs @@ -84,7 +84,7 @@ impl ConnectionInfo { } #[cfg(feature = "postgresql")] SqlFamily::Postgres => Ok(ConnectionInfo::Native(NativeConnectionInfo::Postgres( - PostgresUrl::new(url)?, + super::PostgresUrl::new_native(url)?, ))), #[allow(unreachable_patterns)] _ => unreachable!(), @@ -243,7 +243,7 @@ impl ConnectionInfo { pub fn pg_bouncer(&self) -> bool { match self { #[cfg(all(not(target_arch = "wasm32"), feature = "postgresql"))] - ConnectionInfo::Native(NativeConnectionInfo::Postgres(url)) => url.pg_bouncer(), + ConnectionInfo::Native(NativeConnectionInfo::Postgres(PostgresUrl::Native(url))) => url.pg_bouncer(), _ => false, } } diff --git a/quaint/src/connector/metrics.rs b/quaint/src/connector/metrics.rs index a0c4ef42698..78fc7f99c72 100644 --- a/quaint/src/connector/metrics.rs +++ b/quaint/src/connector/metrics.rs @@ -1,15 +1,23 @@ +use prisma_metrics::{counter, histogram}; use tracing::{info_span, Instrument}; use crate::ast::{Params, Value}; use crosstarget_utils::time::ElapsedTimeCounter; use std::future::Future; -pub async fn query<'a, F, T, U>(tag: &'static str, query: &'a str, params: &'a [Value<'_>], f: F) -> crate::Result +pub async fn query<'a, F, T, U>( + tag: &'static str, + db_system_name: &'static str, + query: &'a str, + params: &'a [Value<'_>], + f: F, +) -> crate::Result where F: FnOnce() -> U + 'a, U: Future>, { - let span = info_span!("quaint:query", "db.statement" = %query); + let span = + info_span!("quaint:query", "db.system" = db_system_name, "db.statement" = %query, "otel.kind" = "client"); do_query(tag, query, params, f).instrument(span).await } @@ -46,9 +54,9 @@ where trace_query(query, params, result, &start); } - histogram!(format!("{tag}.query.time"), start.elapsed_time()); - histogram!("prisma_datasource_queries_duration_histogram_ms", start.elapsed_time()); - increment_counter!("prisma_datasource_queries_total"); + histogram!(format!("{tag}.query.time")).record(start.elapsed_time()); + histogram!("prisma_datasource_queries_duration_histogram_ms").record(start.elapsed_time()); + counter!("prisma_datasource_queries_total").increment(1); res } @@ -74,7 +82,7 @@ where result, ); - histogram!("pool.check_out", start.elapsed_time()); + histogram!("pool.check_out").record(start.elapsed_time()); res } diff --git a/quaint/src/connector/mssql/native/mod.rs b/quaint/src/connector/mssql/native/mod.rs index 7383e503d0a..fe7751ddf37 100644 --- a/quaint/src/connector/mssql/native/mod.rs +++ b/quaint/src/connector/mssql/native/mod.rs @@ -30,6 +30,7 @@ use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt}; pub use tiberius; static SQL_SERVER_DEFAULT_ISOLATION: IsolationLevel = IsolationLevel::ReadCommitted; +const DB_SYSTEM_NAME: &str = "mssql"; #[async_trait] impl TransactionCapable for Mssql { @@ -130,7 +131,7 @@ impl Queryable for Mssql { } async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - metrics::query("mssql.query_raw", sql, params, move || async move { + metrics::query("mssql.query_raw", DB_SYSTEM_NAME, sql, params, move || async move { let mut client = self.client.lock().await; let mut query = tiberius::Query::new(sql); @@ -193,7 +194,7 @@ impl Queryable for Mssql { } async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - metrics::query("mssql.execute_raw", sql, params, move || async move { + metrics::query("mssql.execute_raw", DB_SYSTEM_NAME, sql, params, move || async move { let mut query = tiberius::Query::new(sql); for param in params { @@ -213,7 +214,7 @@ impl Queryable for Mssql { } async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { - metrics::query("mssql.raw_cmd", cmd, &[], move || async move { + metrics::query("mssql.raw_cmd", DB_SYSTEM_NAME, cmd, &[], move || async move { let mut client = self.client.lock().await; self.perform_io(client.simple_query(cmd)).await?.into_results().await?; Ok(()) diff --git a/quaint/src/connector/mysql/native/mod.rs b/quaint/src/connector/mysql/native/mod.rs index b4b23ab94cb..2c8a757a48e 100644 --- a/quaint/src/connector/mysql/native/mod.rs +++ b/quaint/src/connector/mysql/native/mod.rs @@ -68,6 +68,8 @@ impl MysqlUrl { } } +const DB_SYSTEM_NAME: &str = "mysql"; + /// A connector interface for the MySQL database. #[derive(Debug)] pub struct Mysql { @@ -195,7 +197,7 @@ impl Queryable for Mysql { } async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - metrics::query("mysql.query_raw", sql, params, move || async move { + metrics::query("mysql.query_raw", DB_SYSTEM_NAME, sql, params, move || async move { self.prepared(sql, |stmt| async move { let mut conn = self.conn.lock().await; let rows: Vec = conn.exec(&stmt, conversion::conv_params(params)?).await?; @@ -280,7 +282,7 @@ impl Queryable for Mysql { } async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - metrics::query("mysql.execute_raw", sql, params, move || async move { + metrics::query("mysql.execute_raw", DB_SYSTEM_NAME, sql, params, move || async move { self.prepared(sql, |stmt| async move { let mut conn = self.conn.lock().await; conn.exec_drop(stmt, conversion::conv_params(params)?).await?; @@ -297,7 +299,7 @@ impl Queryable for Mysql { } async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { - metrics::query("mysql.raw_cmd", cmd, &[], move || async move { + metrics::query("mysql.raw_cmd", DB_SYSTEM_NAME, cmd, &[], move || async move { self.perform_io(|| async move { let mut conn = self.conn.lock().await; let mut result = cmd.run(&mut *conn).await?; diff --git a/quaint/src/connector/postgres/error.rs b/quaint/src/connector/postgres/error.rs index 3dcc481eccb..0f19cf99e13 100644 --- a/quaint/src/connector/postgres/error.rs +++ b/quaint/src/connector/postgres/error.rs @@ -1,3 +1,5 @@ +use crosstarget_utils::{regex::RegExp, RegExpCompat}; +use enumflags2::BitFlags; use std::fmt::{Display, Formatter}; use crate::error::{DatabaseConstraint, Error, ErrorKind, Name}; @@ -28,6 +30,11 @@ impl Display for PostgresError { } } +fn extract_fk_constraint_name(message: &str) -> Option { + let re = RegExp::new(r#"foreign key constraint "([^"]+)""#, BitFlags::empty()).unwrap(); + re.captures(message).and_then(|caps| caps.get(1).cloned()) +} + impl From for Error { fn from(value: PostgresError) -> Self { match value.code.as_str() { @@ -89,12 +96,8 @@ impl From for Error { builder.build() } None => { - let constraint = value - .message - .split_whitespace() - .nth(10) - .and_then(|s| s.split('"').nth(1)) - .map(ToString::to_string) + // `value.message` looks like `update on table "Child" violates foreign key constraint "Child_parent_id_fkey"` + let constraint = extract_fk_constraint_name(value.message.as_str()) .map(DatabaseConstraint::Index) .unwrap_or(DatabaseConstraint::CannotParse); diff --git a/quaint/src/connector/postgres/native/mod.rs b/quaint/src/connector/postgres/native/mod.rs index 805ba13a602..eb6618ce9dc 100644 --- a/quaint/src/connector/postgres/native/mod.rs +++ b/quaint/src/connector/postgres/native/mod.rs @@ -5,8 +5,9 @@ pub(crate) mod column_type; mod conversion; mod error; mod explain; +mod websocket; -pub(crate) use crate::connector::postgres::url::PostgresUrl; +pub(crate) use crate::connector::postgres::url::PostgresNativeUrl; use crate::connector::postgres::url::{Hidden, SslAcceptMode, SslParams}; use crate::connector::{ timeout, ColumnType, DescribedColumn, DescribedParameter, DescribedQuery, IsolationLevel, Transaction, @@ -27,22 +28,27 @@ use lru_cache::LruCache; use native_tls::{Certificate, Identity, TlsConnector}; use postgres_native_tls::MakeTlsConnector; use postgres_types::{Kind as PostgresKind, Type as PostgresType}; +use prisma_metrics::WithMetricsInstrumentation; use std::hash::{DefaultHasher, Hash, Hasher}; use std::{ - borrow::Borrow, fmt::{Debug, Display}, fs, future::Future, sync::atomic::{AtomicBool, Ordering}, time::Duration, }; +use tokio::sync::OnceCell; use tokio_postgres::{config::ChannelBinding, Client, Config, Statement}; +use tracing_futures::WithSubscriber; +use websocket::connect_via_websocket; /// The underlying postgres driver. Only available with the `expose-drivers` /// Cargo feature. #[cfg(feature = "expose-drivers")] pub use tokio_postgres; +use super::PostgresWebSocketUrl; + struct PostgresClient(Client); impl Debug for PostgresClient { @@ -51,6 +57,9 @@ impl Debug for PostgresClient { } } +const DB_SYSTEM_NAME_POSTGRESQL: &str = "postgresql"; +const DB_SYSTEM_NAME_COCKROACHDB: &str = "cockroachdb"; + /// A connector interface for the PostgreSQL database. #[derive(Debug)] pub struct PostgreSql { @@ -61,6 +70,7 @@ pub struct PostgreSql { is_healthy: AtomicBool, is_cockroachdb: bool, is_materialize: bool, + db_system_name: &'static str, } /// Key uniquely representing an SQL statement in the prepared statements cache. @@ -160,7 +170,7 @@ impl SslParams { } } -impl PostgresUrl { +impl PostgresNativeUrl { pub(crate) fn cache(&self) -> StatementCache { if self.query_params.pg_bouncer { StatementCache::new(0) @@ -197,8 +207,8 @@ impl PostgresUrl { pub(crate) fn to_config(&self) -> Config { let mut config = Config::new(); - config.user(self.username().borrow()); - config.password(self.password().borrow() as &str); + config.user(self.username().as_ref()); + config.password(self.password().as_ref()); config.host(self.host()); config.port(self.port()); config.dbname(self.dbname()); @@ -228,38 +238,24 @@ impl PostgresUrl { impl PostgreSql { /// Create a new connection to the database. - pub async fn new(url: PostgresUrl) -> crate::Result { + pub async fn new(url: PostgresNativeUrl, tls_manager: &MakeTlsConnectorManager) -> crate::Result { let config = url.to_config(); - let mut tls_builder = TlsConnector::builder(); - - { - let ssl_params = url.ssl_params(); - let auth = ssl_params.to_owned().into_auth().await?; - - if let Some(certificate) = auth.certificate.0 { - tls_builder.add_root_certificate(certificate); - } - - tls_builder.danger_accept_invalid_certs(auth.ssl_accept_mode == SslAcceptMode::AcceptInvalidCerts); - - if let Some(identity) = auth.identity.0 { - tls_builder.identity(identity); - } - } - - let tls = MakeTlsConnector::new(tls_builder.build()?); + let tls = tls_manager.get_connector().await?; let (client, conn) = timeout::connect(url.connect_timeout(), config.connect(tls)).await?; let is_cockroachdb = conn.parameter("crdb_version").is_some(); let is_materialize = conn.parameter("mz_version").is_some(); - tokio::spawn(conn.map(|r| match r { - Ok(_) => (), - Err(e) => { - tracing::error!("Error in PostgreSQL connection: {:?}", e); - } - })); + tokio::spawn( + conn.map(|r| { + if let Err(e) = r { + tracing::error!("Error in PostgreSQL connection: {e:?}"); + } + }) + .with_current_subscriber() + .with_current_recorder(), + ); // On Postgres, we set the SEARCH_PATH and client-encoding through client connection parameters to save a network roundtrip on connection. // We can't always do it for CockroachDB because it does not expect quotes for unsafe identifiers (https://github.com/cockroachdb/cockroach/issues/101328), which might change once the issue is fixed. @@ -281,6 +277,12 @@ impl PostgreSql { } } + let db_system_name = if is_cockroachdb { + DB_SYSTEM_NAME_COCKROACHDB + } else { + DB_SYSTEM_NAME_POSTGRESQL + }; + Ok(Self { client: PostgresClient(client), socket_timeout: url.query_params.socket_timeout, @@ -289,6 +291,23 @@ impl PostgreSql { is_healthy: AtomicBool::new(true), is_cockroachdb, is_materialize, + db_system_name, + }) + } + + /// Create a new websocket connection to managed database + pub async fn new_with_websocket(url: PostgresWebSocketUrl) -> crate::Result { + let client = connect_via_websocket(url).await?; + + Ok(Self { + client: PostgresClient(client), + socket_timeout: None, + pg_bouncer: false, + statement_cache: Mutex::new(StatementCache::new(0)), + is_healthy: AtomicBool::new(true), + is_cockroachdb: false, + is_materialize: false, + db_system_name: DB_SYSTEM_NAME_POSTGRESQL, }) } @@ -520,72 +539,84 @@ impl Queryable for PostgreSql { async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { self.check_bind_variables_len(params)?; - metrics::query("postgres.query_raw", sql, params, move || async move { - let stmt = self.fetch_cached(sql, &[]).await?; - - if stmt.params().len() != params.len() { - let kind = ErrorKind::IncorrectNumberOfParameters { - expected: stmt.params().len(), - actual: params.len(), - }; - - return Err(Error::builder(kind).build()); - } + metrics::query( + "postgres.query_raw", + self.db_system_name, + sql, + params, + move || async move { + let stmt = self.fetch_cached(sql, &[]).await?; + + if stmt.params().len() != params.len() { + let kind = ErrorKind::IncorrectNumberOfParameters { + expected: stmt.params().len(), + actual: params.len(), + }; + + return Err(Error::builder(kind).build()); + } - let rows = self - .perform_io(self.client.0.query(&stmt, conversion::conv_params(params).as_slice())) - .await?; + let rows = self + .perform_io(self.client.0.query(&stmt, conversion::conv_params(params).as_slice())) + .await?; - let col_types = stmt - .columns() - .iter() - .map(|c| PGColumnType::from_pg_type(c.type_())) - .map(ColumnType::from) - .collect::>(); - let mut result = ResultSet::new(stmt.to_column_names(), col_types, Vec::new()); + let col_types = stmt + .columns() + .iter() + .map(|c| PGColumnType::from_pg_type(c.type_())) + .map(ColumnType::from) + .collect::>(); + let mut result = ResultSet::new(stmt.to_column_names(), col_types, Vec::new()); - for row in rows { - result.rows.push(row.get_result_row()?); - } + for row in rows { + result.rows.push(row.get_result_row()?); + } - Ok(result) - }) + Ok(result) + }, + ) .await } async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { self.check_bind_variables_len(params)?; - metrics::query("postgres.query_raw", sql, params, move || async move { - let stmt = self.fetch_cached(sql, params).await?; - - if stmt.params().len() != params.len() { - let kind = ErrorKind::IncorrectNumberOfParameters { - expected: stmt.params().len(), - actual: params.len(), - }; - - return Err(Error::builder(kind).build()); - } + metrics::query( + "postgres.query_raw", + self.db_system_name, + sql, + params, + move || async move { + let stmt = self.fetch_cached(sql, params).await?; + + if stmt.params().len() != params.len() { + let kind = ErrorKind::IncorrectNumberOfParameters { + expected: stmt.params().len(), + actual: params.len(), + }; + + return Err(Error::builder(kind).build()); + } - let col_types = stmt - .columns() - .iter() - .map(|c| PGColumnType::from_pg_type(c.type_())) - .map(ColumnType::from) - .collect::>(); - let rows = self - .perform_io(self.client.0.query(&stmt, conversion::conv_params(params).as_slice())) - .await?; + let col_types = stmt + .columns() + .iter() + .map(|c| PGColumnType::from_pg_type(c.type_())) + .map(ColumnType::from) + .collect::>(); + let rows = self + .perform_io(self.client.0.query(&stmt, conversion::conv_params(params).as_slice())) + .await?; - let mut result = ResultSet::new(stmt.to_column_names(), col_types, Vec::new()); + let mut result = ResultSet::new(stmt.to_column_names(), col_types, Vec::new()); - for row in rows { - result.rows.push(row.get_result_row()?); - } + for row in rows { + result.rows.push(row.get_result_row()?); + } - Ok(result) - }) + Ok(result) + }, + ) .await } @@ -673,53 +704,65 @@ impl Queryable for PostgreSql { async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { self.check_bind_variables_len(params)?; - metrics::query("postgres.execute_raw", sql, params, move || async move { - let stmt = self.fetch_cached(sql, &[]).await?; - - if stmt.params().len() != params.len() { - let kind = ErrorKind::IncorrectNumberOfParameters { - expected: stmt.params().len(), - actual: params.len(), - }; - - return Err(Error::builder(kind).build()); - } + metrics::query( + "postgres.execute_raw", + self.db_system_name, + sql, + params, + move || async move { + let stmt = self.fetch_cached(sql, &[]).await?; + + if stmt.params().len() != params.len() { + let kind = ErrorKind::IncorrectNumberOfParameters { + expected: stmt.params().len(), + actual: params.len(), + }; + + return Err(Error::builder(kind).build()); + } - let changes = self - .perform_io(self.client.0.execute(&stmt, conversion::conv_params(params).as_slice())) - .await?; + let changes = self + .perform_io(self.client.0.execute(&stmt, conversion::conv_params(params).as_slice())) + .await?; - Ok(changes) - }) + Ok(changes) + }, + ) .await } async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { self.check_bind_variables_len(params)?; - metrics::query("postgres.execute_raw", sql, params, move || async move { - let stmt = self.fetch_cached(sql, params).await?; - - if stmt.params().len() != params.len() { - let kind = ErrorKind::IncorrectNumberOfParameters { - expected: stmt.params().len(), - actual: params.len(), - }; - - return Err(Error::builder(kind).build()); - } + metrics::query( + "postgres.execute_raw", + self.db_system_name, + sql, + params, + move || async move { + let stmt = self.fetch_cached(sql, params).await?; + + if stmt.params().len() != params.len() { + let kind = ErrorKind::IncorrectNumberOfParameters { + expected: stmt.params().len(), + actual: params.len(), + }; + + return Err(Error::builder(kind).build()); + } - let changes = self - .perform_io(self.client.0.execute(&stmt, conversion::conv_params(params).as_slice())) - .await?; + let changes = self + .perform_io(self.client.0.execute(&stmt, conversion::conv_params(params).as_slice())) + .await?; - Ok(changes) - }) + Ok(changes) + }, + ) .await } async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { - metrics::query("postgres.raw_cmd", cmd, &[], move || async move { + metrics::query("postgres.raw_cmd", self.db_system_name, cmd, &[], move || async move { self.perform_io(self.client.0.simple_query(cmd)).await?; Ok(()) }) @@ -907,6 +950,48 @@ fn is_safe_identifier(ident: &str) -> bool { true } +pub struct MakeTlsConnectorManager { + url: PostgresNativeUrl, + connector: OnceCell, +} + +impl MakeTlsConnectorManager { + pub fn new(url: PostgresNativeUrl) -> Self { + MakeTlsConnectorManager { + url, + connector: OnceCell::new(), + } + } + + pub async fn get_connector(&self) -> crate::Result { + self.connector + .get_or_try_init(|| async { + let mut tls_builder = TlsConnector::builder(); + + { + let ssl_params = self.url.ssl_params(); + let auth = ssl_params.to_owned().into_auth().await?; + + if let Some(certificate) = auth.certificate.0 { + tls_builder.add_root_certificate(certificate); + } + + tls_builder.danger_accept_invalid_certs(auth.ssl_accept_mode == SslAcceptMode::AcceptInvalidCerts); + + if let Some(identity) = auth.identity.0 { + tls_builder.identity(identity); + } + } + + let tls_connector = MakeTlsConnector::new(tls_builder.build()?); + + Ok(tls_connector) + }) + .await + .cloned() + } +} + #[cfg(test)] mod tests { use super::*; @@ -922,10 +1007,12 @@ mod tests { let mut url = Url::parse(&CONN_STR).unwrap(); url.query_pairs_mut().append_pair("schema", schema_name); - let mut pg_url = PostgresUrl::new(url).unwrap(); + let mut pg_url = PostgresNativeUrl::new(url).unwrap(); pg_url.set_flavour(PostgresFlavour::Postgres); - let client = PostgreSql::new(pg_url).await.unwrap(); + let tls_manager = MakeTlsConnectorManager::new(pg_url.clone()); + + let client = PostgreSql::new(pg_url, &tls_manager).await.unwrap(); let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); let row = result_set.first().unwrap(); @@ -974,10 +1061,12 @@ mod tests { url.query_pairs_mut().append_pair("schema", schema_name); url.query_pairs_mut().append_pair("pbbouncer", "true"); - let mut pg_url = PostgresUrl::new(url).unwrap(); + let mut pg_url = PostgresNativeUrl::new(url).unwrap(); pg_url.set_flavour(PostgresFlavour::Postgres); - let client = PostgreSql::new(pg_url).await.unwrap(); + let tls_manager = MakeTlsConnectorManager::new(pg_url.clone()); + + let client = PostgreSql::new(pg_url, &tls_manager).await.unwrap(); let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); let row = result_set.first().unwrap(); @@ -1025,10 +1114,12 @@ mod tests { let mut url = Url::parse(&CRDB_CONN_STR).unwrap(); url.query_pairs_mut().append_pair("schema", schema_name); - let mut pg_url = PostgresUrl::new(url).unwrap(); + let mut pg_url = PostgresNativeUrl::new(url).unwrap(); pg_url.set_flavour(PostgresFlavour::Cockroach); - let client = PostgreSql::new(pg_url).await.unwrap(); + let tls_manager = MakeTlsConnectorManager::new(pg_url.clone()); + + let client = PostgreSql::new(pg_url, &tls_manager).await.unwrap(); let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); let row = result_set.first().unwrap(); @@ -1076,10 +1167,12 @@ mod tests { let mut url = Url::parse(&CONN_STR).unwrap(); url.query_pairs_mut().append_pair("schema", schema_name); - let mut pg_url = PostgresUrl::new(url).unwrap(); + let mut pg_url = PostgresNativeUrl::new(url).unwrap(); pg_url.set_flavour(PostgresFlavour::Unknown); - let client = PostgreSql::new(pg_url).await.unwrap(); + let tls_manager = MakeTlsConnectorManager::new(pg_url.clone()); + + let client = PostgreSql::new(pg_url, &tls_manager).await.unwrap(); let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); let row = result_set.first().unwrap(); @@ -1127,10 +1220,12 @@ mod tests { let mut url = Url::parse(&CONN_STR).unwrap(); url.query_pairs_mut().append_pair("schema", schema_name); - let mut pg_url = PostgresUrl::new(url).unwrap(); + let mut pg_url = PostgresNativeUrl::new(url).unwrap(); pg_url.set_flavour(PostgresFlavour::Unknown); - let client = PostgreSql::new(pg_url).await.unwrap(); + let tls_manager = MakeTlsConnectorManager::new(pg_url.clone()); + + let client = PostgreSql::new(pg_url, &tls_manager).await.unwrap(); let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); let row = result_set.first().unwrap(); diff --git a/quaint/src/connector/postgres/native/websocket.rs b/quaint/src/connector/postgres/native/websocket.rs new file mode 100644 index 00000000000..f278c9f099b --- /dev/null +++ b/quaint/src/connector/postgres/native/websocket.rs @@ -0,0 +1,97 @@ +use std::str::FromStr; + +use async_tungstenite::{ + tokio::connect_async, + tungstenite::{ + self, + client::IntoClientRequest, + http::{HeaderMap, HeaderValue, StatusCode}, + Error as TungsteniteError, + }, +}; +use futures::FutureExt; +use postgres_native_tls::TlsConnector; +use prisma_metrics::WithMetricsInstrumentation; +use tokio_postgres::{Client, Config}; +use tracing_futures::WithSubscriber; +use ws_stream_tungstenite::WsStream; + +use crate::{ + connector::PostgresWebSocketUrl, + error::{self, Error, ErrorKind, Name, NativeErrorKind}, +}; + +const CONNECTION_PARAMS_HEADER: &str = "Prisma-Connection-Parameters"; +const HOST_HEADER: &str = "Prisma-Db-Host"; + +pub(crate) async fn connect_via_websocket(url: PostgresWebSocketUrl) -> crate::Result { + let db_name = url.overriden_db_name().map(ToOwned::to_owned); + let (ws_stream, response) = connect_async(url).await?; + + let connection_params = require_header_value(response.headers(), CONNECTION_PARAMS_HEADER)?; + let db_host = require_header_value(response.headers(), HOST_HEADER)?; + + let mut config = Config::from_str(connection_params)?; + if let Some(db_name) = db_name { + config.dbname(&db_name); + } + let ws_byte_stream = WsStream::new(ws_stream); + + let tls = TlsConnector::new(native_tls::TlsConnector::new()?, db_host); + let (client, connection) = config.connect_raw(ws_byte_stream, tls).await?; + tokio::spawn( + connection + .map(|r| { + if let Err(e) = r { + tracing::error!("Error in PostgreSQL WebSocket connection: {e:?}"); + } + }) + .with_current_subscriber() + .with_current_recorder(), + ); + Ok(client) +} + +fn require_header_value<'a>(headers: &'a HeaderMap, name: &str) -> crate::Result<&'a str> { + let Some(header) = headers.get(name) else { + let message = format!("Missing response header {name}"); + let error = Error::builder(ErrorKind::Native(NativeErrorKind::ConnectionError(message.into()))).build(); + return Err(error); + }; + + let value = header.to_str().map_err(|inner| { + Error::builder(ErrorKind::Native(NativeErrorKind::ConnectionError(Box::new(inner)))).build() + })?; + + Ok(value) +} + +impl IntoClientRequest for PostgresWebSocketUrl { + fn into_client_request(self) -> tungstenite::Result { + let mut request = self.url.to_string().into_client_request()?; + let bearer = format!("Bearer {}", self.api_key()); + let auth_header = HeaderValue::from_str(&bearer)?; + request.headers_mut().insert("Authorization", auth_header); + Ok(request) + } +} + +impl From for error::Error { + fn from(value: TungsteniteError) -> Self { + let builder = match value { + TungsteniteError::Tls(tls_error) => Error::builder(ErrorKind::Native(NativeErrorKind::TlsError { + message: tls_error.to_string(), + })), + + TungsteniteError::Http(response) if response.status() == StatusCode::UNAUTHORIZED => { + Error::builder(ErrorKind::DatabaseAccessDenied { + db_name: Name::Unavailable, + }) + } + + _ => Error::builder(ErrorKind::Native(NativeErrorKind::ConnectionError(Box::new(value)))), + }; + + builder.build() + } +} diff --git a/quaint/src/connector/postgres/url.rs b/quaint/src/connector/postgres/url.rs index 844da48c8d6..096484cdc87 100644 --- a/quaint/src/connector/postgres/url.rs +++ b/quaint/src/connector/postgres/url.rs @@ -63,16 +63,74 @@ impl PostgresFlavour { } } +#[derive(Debug, Clone)] +pub enum PostgresUrl { + Native(Box), + WebSocket(PostgresWebSocketUrl), +} + +impl PostgresUrl { + pub fn new_native(url: Url) -> Result { + Ok(Self::Native(Box::new(PostgresNativeUrl::new(url)?))) + } + + pub fn new_websocket(url: Url, api_key: String) -> Result { + Ok(Self::WebSocket(PostgresWebSocketUrl::new(url, api_key))) + } + + pub fn dbname(&self) -> &str { + match self { + Self::Native(url) => url.dbname(), + Self::WebSocket(url) => url.dbname(), + } + } + + pub fn host(&self) -> &str { + match self { + Self::Native(native_url) => native_url.host(), + Self::WebSocket(ws_url) => ws_url.host(), + } + } + + pub fn port(&self) -> u16 { + match self { + Self::Native(native_url) => native_url.port(), + Self::WebSocket(ws_url) => ws_url.port(), + } + } + + pub fn username(&self) -> Cow<'_, str> { + match self { + Self::Native(native_url) => native_url.username(), + Self::WebSocket(_) => Cow::Borrowed(""), + } + } + + pub fn schema(&self) -> &str { + match self { + Self::Native(native_url) => native_url.schema(), + Self::WebSocket(_) => "public", + } + } + + pub fn socket_timeout(&self) -> Option { + match self { + Self::Native(native_url) => native_url.socket_timeout(), + Self::WebSocket(_) => None, + } + } +} + /// Wraps a connection url and exposes the parsing logic used by Quaint, /// including default values. #[derive(Debug, Clone)] -pub struct PostgresUrl { +pub struct PostgresNativeUrl { pub(crate) url: Url, pub(crate) query_params: PostgresUrlQueryParams, pub(crate) flavour: PostgresFlavour, } -impl PostgresUrl { +impl PostgresNativeUrl { /// Parse `Url` to `PostgresUrl`. Returns error for mistyped connection /// parameters. pub fn new(url: Url) -> Result { @@ -431,6 +489,47 @@ pub(crate) struct PostgresUrlQueryParams { pub(crate) ssl_mode: SslMode, } +#[derive(Debug, Clone)] +pub struct PostgresWebSocketUrl { + pub(crate) url: Url, + pub(crate) api_key: String, + pub(crate) db_name: Option, +} + +impl PostgresWebSocketUrl { + pub fn new(url: Url, api_key: String) -> Self { + Self { + url, + api_key, + db_name: None, + } + } + + pub fn override_db_name(&mut self, name: String) { + self.db_name = Some(name) + } + + pub fn api_key(&self) -> &str { + &self.api_key + } + + pub fn dbname(&self) -> &str { + self.overriden_db_name().unwrap_or("postgres") + } + + pub fn overriden_db_name(&self) -> Option<&str> { + self.db_name.as_deref() + } + + pub fn host(&self) -> &str { + self.url.host_str().unwrap_or("localhost") + } + + pub fn port(&self) -> u16 { + self.url.port().unwrap_or(80) + } +} + #[cfg(test)] mod tests { use super::*; @@ -442,14 +541,15 @@ mod tests { #[test] fn should_parse_socket_url() { - let url = PostgresUrl::new(Url::parse("postgresql:///dbname?host=/var/run/psql.sock").unwrap()).unwrap(); + let url = PostgresNativeUrl::new(Url::parse("postgresql:///dbname?host=/var/run/psql.sock").unwrap()).unwrap(); assert_eq!("dbname", url.dbname()); assert_eq!("/var/run/psql.sock", url.host()); } #[test] fn should_parse_escaped_url() { - let url = PostgresUrl::new(Url::parse("postgresql:///dbname?host=%2Fvar%2Frun%2Fpostgresql").unwrap()).unwrap(); + let url = + PostgresNativeUrl::new(Url::parse("postgresql:///dbname?host=%2Fvar%2Frun%2Fpostgresql").unwrap()).unwrap(); assert_eq!("dbname", url.dbname()); assert_eq!("/var/run/postgresql", url.host()); } @@ -457,63 +557,69 @@ mod tests { #[test] fn should_allow_changing_of_cache_size() { let url = - PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?statement_cache_size=420").unwrap()).unwrap(); + PostgresNativeUrl::new(Url::parse("postgresql:///localhost:5432/foo?statement_cache_size=420").unwrap()) + .unwrap(); assert_eq!(420, url.cache().capacity()); } #[test] fn should_have_default_cache_size() { - let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo").unwrap()).unwrap(); + let url = PostgresNativeUrl::new(Url::parse("postgresql:///localhost:5432/foo").unwrap()).unwrap(); assert_eq!(100, url.cache().capacity()); } #[test] fn should_have_application_name() { - let url = - PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?application_name=test").unwrap()).unwrap(); + let url = PostgresNativeUrl::new(Url::parse("postgresql:///localhost:5432/foo?application_name=test").unwrap()) + .unwrap(); assert_eq!(Some("test"), url.application_name()); } #[test] fn should_have_channel_binding() { let url = - PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?channel_binding=require").unwrap()).unwrap(); + PostgresNativeUrl::new(Url::parse("postgresql:///localhost:5432/foo?channel_binding=require").unwrap()) + .unwrap(); assert_eq!(ChannelBinding::Require, url.channel_binding()); } #[test] fn should_have_default_channel_binding() { let url = - PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?channel_binding=invalid").unwrap()).unwrap(); + PostgresNativeUrl::new(Url::parse("postgresql:///localhost:5432/foo?channel_binding=invalid").unwrap()) + .unwrap(); assert_eq!(ChannelBinding::Prefer, url.channel_binding()); - let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo").unwrap()).unwrap(); + let url = PostgresNativeUrl::new(Url::parse("postgresql:///localhost:5432/foo").unwrap()).unwrap(); assert_eq!(ChannelBinding::Prefer, url.channel_binding()); } #[test] fn should_not_enable_caching_with_pgbouncer() { - let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?pgbouncer=true").unwrap()).unwrap(); + let url = + PostgresNativeUrl::new(Url::parse("postgresql:///localhost:5432/foo?pgbouncer=true").unwrap()).unwrap(); assert_eq!(0, url.cache().capacity()); } #[test] fn should_parse_default_host() { - let url = PostgresUrl::new(Url::parse("postgresql:///dbname").unwrap()).unwrap(); + let url = PostgresNativeUrl::new(Url::parse("postgresql:///dbname").unwrap()).unwrap(); assert_eq!("dbname", url.dbname()); assert_eq!("localhost", url.host()); } #[test] fn should_parse_ipv6_host() { - let url = PostgresUrl::new(Url::parse("postgresql://[2001:db8:1234::ffff]:5432/dbname").unwrap()).unwrap(); + let url = + PostgresNativeUrl::new(Url::parse("postgresql://[2001:db8:1234::ffff]:5432/dbname").unwrap()).unwrap(); assert_eq!("2001:db8:1234::ffff", url.host()); } #[test] fn should_handle_options_field() { - let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432?options=--cluster%3Dmy_cluster").unwrap()) - .unwrap(); + let url = + PostgresNativeUrl::new(Url::parse("postgresql:///localhost:5432?options=--cluster%3Dmy_cluster").unwrap()) + .unwrap(); assert_eq!("--cluster=my_cluster", url.options().unwrap()); } @@ -600,7 +706,7 @@ mod tests { url.query_pairs_mut().append_pair("schema", "hello"); url.query_pairs_mut().append_pair("pgbouncer", "true"); - let mut pg_url = PostgresUrl::new(url).unwrap(); + let mut pg_url = PostgresNativeUrl::new(url).unwrap(); pg_url.set_flavour(PostgresFlavour::Postgres); let config = pg_url.to_config(); @@ -616,7 +722,7 @@ mod tests { let mut url = Url::parse(&CONN_STR).unwrap(); url.query_pairs_mut().append_pair("schema", "hello"); - let mut pg_url = PostgresUrl::new(url).unwrap(); + let mut pg_url = PostgresNativeUrl::new(url).unwrap(); pg_url.set_flavour(PostgresFlavour::Postgres); let config = pg_url.to_config(); @@ -630,7 +736,7 @@ mod tests { let mut url = Url::parse(&CONN_STR).unwrap(); url.query_pairs_mut().append_pair("schema", "hello"); - let mut pg_url = PostgresUrl::new(url).unwrap(); + let mut pg_url = PostgresNativeUrl::new(url).unwrap(); pg_url.set_flavour(PostgresFlavour::Cockroach); let config = pg_url.to_config(); @@ -644,7 +750,7 @@ mod tests { let mut url = Url::parse(&CONN_STR).unwrap(); url.query_pairs_mut().append_pair("schema", "HeLLo"); - let mut pg_url = PostgresUrl::new(url).unwrap(); + let mut pg_url = PostgresNativeUrl::new(url).unwrap(); pg_url.set_flavour(PostgresFlavour::Cockroach); let config = pg_url.to_config(); diff --git a/quaint/src/connector/sqlite/native/mod.rs b/quaint/src/connector/sqlite/native/mod.rs index abcec7410a6..2d738a7f087 100644 --- a/quaint/src/connector/sqlite/native/mod.rs +++ b/quaint/src/connector/sqlite/native/mod.rs @@ -29,6 +29,8 @@ pub struct Sqlite { pub(crate) client: Mutex, } +const DB_SYSTEM_NAME: &str = "sqlite"; + impl TryFrom<&str> for Sqlite { type Error = Error; @@ -100,7 +102,7 @@ impl Queryable for Sqlite { } async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - metrics::query("sqlite.query_raw", sql, params, move || async move { + metrics::query("sqlite.query_raw", DB_SYSTEM_NAME, sql, params, move || async move { let client = self.client.lock().await; let mut stmt = client.prepare_cached(sql)?; @@ -134,7 +136,7 @@ impl Queryable for Sqlite { } async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - metrics::query("sqlite.query_raw", sql, params, move || async move { + metrics::query("sqlite.query_raw", DB_SYSTEM_NAME, sql, params, move || async move { let client = self.client.lock().await; let mut stmt = client.prepare_cached(sql)?; let res = u64::try_from(stmt.execute(params_from_iter(params.iter()))?)?; @@ -149,7 +151,7 @@ impl Queryable for Sqlite { } async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { - metrics::query("sqlite.raw_cmd", cmd, &[], move || async move { + metrics::query("sqlite.raw_cmd", DB_SYSTEM_NAME, cmd, &[], move || async move { let client = self.client.lock().await; client.execute_batch(cmd)?; Ok(()) diff --git a/quaint/src/connector/transaction.rs b/quaint/src/connector/transaction.rs index df4084883e8..599efe1d99f 100644 --- a/quaint/src/connector/transaction.rs +++ b/quaint/src/connector/transaction.rs @@ -1,13 +1,13 @@ +use std::{fmt, str::FromStr}; + +use async_trait::async_trait; +use prisma_metrics::guards::GaugeGuard; + use super::*; use crate::{ ast::*, error::{Error, ErrorKind}, }; -use async_trait::async_trait; -use metrics::{decrement_gauge, increment_gauge}; -use std::{fmt, str::FromStr}; - -extern crate metrics as metrics; #[async_trait] pub trait Transaction: Queryable { @@ -36,6 +36,7 @@ pub(crate) struct TransactionOptions { /// transaction object will panic. pub struct DefaultTransaction<'a> { pub inner: &'a dyn Queryable, + gauge: GaugeGuard, } impl<'a> DefaultTransaction<'a> { @@ -44,7 +45,10 @@ impl<'a> DefaultTransaction<'a> { begin_stmt: &str, tx_opts: TransactionOptions, ) -> crate::Result> { - let this = Self { inner }; + let this = Self { + inner, + gauge: GaugeGuard::increment("prisma_client_queries_active"), + }; if tx_opts.isolation_first { if let Some(isolation) = tx_opts.isolation_level { @@ -62,7 +66,6 @@ impl<'a> DefaultTransaction<'a> { inner.server_reset_query(&this).await?; - increment_gauge!("prisma_client_queries_active", 1.0); Ok(this) } } @@ -71,7 +74,7 @@ impl<'a> DefaultTransaction<'a> { impl<'a> Transaction for DefaultTransaction<'a> { /// Commit the changes to the database and consume the transaction. async fn commit(&self) -> crate::Result<()> { - decrement_gauge!("prisma_client_queries_active", 1.0); + self.gauge.decrement(); self.inner.raw_cmd("COMMIT").await?; Ok(()) @@ -79,7 +82,7 @@ impl<'a> Transaction for DefaultTransaction<'a> { /// Rolls back the changes to the database. async fn rollback(&self) -> crate::Result<()> { - decrement_gauge!("prisma_client_queries_active", 1.0); + self.gauge.decrement(); self.inner.raw_cmd("ROLLBACK").await?; Ok(()) diff --git a/quaint/src/lib.rs b/quaint/src/lib.rs index 45c2a10a169..ab73ef7e66a 100644 --- a/quaint/src/lib.rs +++ b/quaint/src/lib.rs @@ -110,11 +110,8 @@ compile_error!("one of 'sqlite', 'postgresql', 'mysql' or 'mssql' features must #[macro_use] mod macros; -#[macro_use] -extern crate metrics; - -pub extern crate bigdecimal; -pub extern crate chrono; +pub use bigdecimal; +pub use chrono; pub mod ast; pub mod connector; diff --git a/quaint/src/pooled.rs b/quaint/src/pooled.rs index 381f0c82414..389005ab7bd 100644 --- a/quaint/src/pooled.rs +++ b/quaint/src/pooled.rs @@ -307,12 +307,14 @@ impl Builder { /// - Defaults to `PostgresFlavour::Unknown`. #[cfg(feature = "postgresql-native")] pub fn set_postgres_flavour(&mut self, flavour: crate::connector::PostgresFlavour) { - use crate::connector::NativeConnectionInfo; - if let ConnectionInfo::Native(NativeConnectionInfo::Postgres(ref mut url)) = self.connection_info { + use crate::connector::{NativeConnectionInfo, PostgresUrl}; + if let ConnectionInfo::Native(NativeConnectionInfo::Postgres(PostgresUrl::Native(ref mut url))) = + self.connection_info + { url.set_flavour(flavour); } - if let QuaintManager::Postgres { ref mut url } = self.manager { + if let QuaintManager::Postgres { ref mut url, .. } = self.manager { url.set_flavour(flavour); } } @@ -415,13 +417,14 @@ impl Quaint { } #[cfg(feature = "postgresql")] s if s.starts_with("postgres") || s.starts_with("postgresql") => { - let url = crate::connector::PostgresUrl::new(url::Url::parse(s)?)?; + let url = crate::connector::PostgresNativeUrl::new(url::Url::parse(s)?)?; let connection_limit = url.connection_limit(); let pool_timeout = url.pool_timeout(); let max_connection_lifetime = url.max_connection_lifetime(); let max_idle_connection_lifetime = url.max_idle_connection_lifetime(); - let manager = QuaintManager::Postgres { url }; + let tls_manager = crate::connector::MakeTlsConnectorManager::new(url.clone()); + let manager = QuaintManager::Postgres { url, tls_manager }; let mut builder = Builder::new(s, manager)?; if let Some(limit) = connection_limit { diff --git a/quaint/src/pooled/manager.rs b/quaint/src/pooled/manager.rs index 7533dffcfcc..bf4d50eeea8 100644 --- a/quaint/src/pooled/manager.rs +++ b/quaint/src/pooled/manager.rs @@ -1,16 +1,21 @@ +use std::future::Future; + +use async_trait::async_trait; +use mobc::{Connection as MobcPooled, Manager}; +use prisma_metrics::WithMetricsInstrumentation; +use tracing_futures::WithSubscriber; + #[cfg(feature = "mssql-native")] use crate::connector::MssqlUrl; #[cfg(feature = "mysql-native")] use crate::connector::MysqlUrl; #[cfg(feature = "postgresql-native")] -use crate::connector::PostgresUrl; +use crate::connector::{MakeTlsConnectorManager, PostgresNativeUrl}; use crate::{ ast, connector::{self, impl_default_TransactionCapable, IsolationLevel, Queryable, Transaction, TransactionCapable}, error::Error, }; -use async_trait::async_trait; -use mobc::{Connection as MobcPooled, Manager}; /// A connection from the pool. Implements /// [Queryable](connector/trait.Queryable.html). @@ -85,7 +90,10 @@ pub enum QuaintManager { Mysql { url: MysqlUrl }, #[cfg(feature = "postgresql")] - Postgres { url: PostgresUrl }, + Postgres { + url: PostgresNativeUrl, + tls_manager: MakeTlsConnectorManager, + }, #[cfg(feature = "sqlite")] Sqlite { url: String, db_name: String }, @@ -117,9 +125,9 @@ impl Manager for QuaintManager { } #[cfg(feature = "postgresql-native")] - QuaintManager::Postgres { url } => { + QuaintManager::Postgres { url, tls_manager } => { use crate::connector::PostgreSql; - Ok(Box::new(PostgreSql::new(url.clone()).await?) as Self::Connection) + Ok(Box::new(PostgreSql::new(url.clone(), tls_manager).await?) as Self::Connection) } #[cfg(feature = "mssql-native")] @@ -143,6 +151,14 @@ impl Manager for QuaintManager { fn validate(&self, conn: &mut Self::Connection) -> bool { conn.is_healthy() } + + fn spawn_task(&self, task: T) + where + T: Future + Send + 'static, + T::Output: Send + 'static, + { + tokio::spawn(task.with_current_subscriber().with_current_recorder()); + } } #[cfg(test)] diff --git a/quaint/src/single.rs b/quaint/src/single.rs index cbf460c4150..13be8c4bc85 100644 --- a/quaint/src/single.rs +++ b/quaint/src/single.rs @@ -148,8 +148,9 @@ impl Quaint { } #[cfg(feature = "postgresql-native")] s if s.starts_with("postgres") || s.starts_with("postgresql") => { - let url = connector::PostgresUrl::new(url::Url::parse(s)?)?; - let psql = connector::PostgreSql::new(url).await?; + let url = connector::PostgresNativeUrl::new(url::Url::parse(s)?)?; + let tls_manager = connector::MakeTlsConnectorManager::new(url.clone()); + let psql = connector::PostgreSql::new(url, &tls_manager).await?; Arc::new(psql) as Arc } #[cfg(feature = "mssql-native")] diff --git a/query-engine/black-box-tests/Cargo.toml b/query-engine/black-box-tests/Cargo.toml index c5f88c844dc..e08ebb962e1 100644 --- a/query-engine/black-box-tests/Cargo.toml +++ b/query-engine/black-box-tests/Cargo.toml @@ -14,5 +14,5 @@ tokio.workspace = true user-facing-errors.workspace = true insta = "1.7.1" enumflags2.workspace = true -query-engine-metrics = {path = "../metrics"} -regex = "1.9.3" +prisma-metrics.path = "../../libs/metrics" +regex.workspace = true diff --git a/query-engine/black-box-tests/tests/metrics/smoke_tests.rs b/query-engine/black-box-tests/tests/metrics/smoke_tests.rs index b21f22265f9..91c3e719456 100644 --- a/query-engine/black-box-tests/tests/metrics/smoke_tests.rs +++ b/query-engine/black-box-tests/tests/metrics/smoke_tests.rs @@ -122,7 +122,7 @@ mod smoke_tests { assert_eq!(metrics.matches("# TYPE prisma_datasource_queries_duration_histogram_ms histogram").count(), 1); // Check that exist as many metrics as being accepted - let accepted_metric_count = query_engine_metrics::ACCEPT_LIST.len(); + let accepted_metric_count = prisma_metrics::ACCEPT_LIST.len(); let displayed_metric_count = metrics.matches("# TYPE").count(); let non_prisma_metric_count = displayed_metric_count - metrics.matches("# TYPE prisma").count(); diff --git a/query-engine/connector-test-kit-rs/README.md b/query-engine/connector-test-kit-rs/README.md index d896358d06e..ef8396f4804 100644 --- a/query-engine/connector-test-kit-rs/README.md +++ b/query-engine/connector-test-kit-rs/README.md @@ -83,7 +83,7 @@ drivers the code that actually communicates with the databases. See [`adapter-*` To run tests through a driver adapters, you should also configure the following environment variables: * `DRIVER_ADAPTER`: tells the test executor to use a particular driver adapter. Set to `neon`, `planetscale` or any other supported adapter. -* `DRIVER_ADAPTER_CONFIG`: a json string with the configuration for the driver adapter. This is adapter specific. See the [github workflow for driver adapter tests](.github/workflows/query-engine-driver-adapters.yml) for examples on how to configure the driver adapters. +* `DRIVER_ADAPTER_CONFIG`: a json string with the configuration for the driver adapter. This is adapter specific. See the [GitHub workflow for driver adapter tests](.github/workflows/query-engine-driver-adapters.yml) for examples on how to configure the driver adapters. * `ENGINE`: can be used to run either `wasm` or `napi` or `c-abi` version of the engine. Example: @@ -339,7 +339,7 @@ run_query!( **Accepting a snapshot update will replace, directly in your code, the expected output in the assertion.** -If you dislike the interactive view, you can also run `cargo insta accept` to automatically accept all snapshots and then use your git diff to check if everything is as intented. +If you dislike the interactive view, you can also run `cargo insta accept` to automatically accept all snapshots and then use your git diff to check if everything is as intended. ##### Without `cargo-insta` diff --git a/query-engine/connector-test-kit-rs/qe-setup/src/cockroachdb.rs b/query-engine/connector-test-kit-rs/qe-setup/src/cockroachdb.rs index 0d876d6b4dc..c901b9ef388 100644 --- a/query-engine/connector-test-kit-rs/qe-setup/src/cockroachdb.rs +++ b/query-engine/connector-test-kit-rs/qe-setup/src/cockroachdb.rs @@ -5,7 +5,7 @@ use url::Url; pub(crate) async fn cockroach_setup(url: String, prisma_schema: &str) -> ConnectorResult<()> { let mut parsed_url = Url::parse(&url).map_err(ConnectorError::url_parse_error)?; - let mut quaint_url = quaint::connector::PostgresUrl::new(parsed_url.clone()).unwrap(); + let mut quaint_url = quaint::connector::PostgresNativeUrl::new(parsed_url.clone()).unwrap(); quaint_url.set_flavour(PostgresFlavour::Cockroach); let db_name = quaint_url.dbname(); diff --git a/query-engine/connector-test-kit-rs/qe-setup/src/lib.rs b/query-engine/connector-test-kit-rs/qe-setup/src/lib.rs index 530717cc94d..17d9ec06ab5 100644 --- a/query-engine/connector-test-kit-rs/qe-setup/src/lib.rs +++ b/query-engine/connector-test-kit-rs/qe-setup/src/lib.rs @@ -65,14 +65,11 @@ fn parse_configuration(datamodel: &str) -> ConnectorResult<(Datasource, String, /// (rather than just the Schema Engine), this function will call [`ExternalInitializer::init_with_migration`]. /// Otherwise, it will call [`ExternalInitializer::init`], and then proceed with the standard /// setup based on the Schema Engine. -pub async fn setup_external<'a, EI>( +pub async fn setup_external<'a>( driver_adapter: DriverAdapter, - initializer: EI, + initializer: impl ExternalInitializer<'a>, db_schemas: &[&str], -) -> ConnectorResult -where - EI: ExternalInitializer<'a> + ?Sized, -{ +) -> ConnectorResult { let prisma_schema = initializer.datamodel(); let (source, url, _preview_features) = parse_configuration(prisma_schema)?; diff --git a/query-engine/connector-test-kit-rs/qe-setup/src/postgres.rs b/query-engine/connector-test-kit-rs/qe-setup/src/postgres.rs index 6bbba8564ca..536f51eb483 100644 --- a/query-engine/connector-test-kit-rs/qe-setup/src/postgres.rs +++ b/query-engine/connector-test-kit-rs/qe-setup/src/postgres.rs @@ -5,7 +5,7 @@ use url::Url; pub(crate) async fn postgres_setup(url: String, prisma_schema: &str, db_schemas: &[&str]) -> ConnectorResult<()> { let mut parsed_url = Url::parse(&url).map_err(ConnectorError::url_parse_error)?; - let mut quaint_url = quaint::connector::PostgresUrl::new(parsed_url.clone()).unwrap(); + let mut quaint_url = quaint::connector::PostgresNativeUrl::new(parsed_url.clone()).unwrap(); quaint_url.set_flavour(PostgresFlavour::Postgres); let (db_name, schema) = (quaint_url.dbname(), quaint_url.schema()); diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/Cargo.toml b/query-engine/connector-test-kit-rs/query-engine-tests/Cargo.toml index 46d1d4b845f..60cfbca7ca1 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/Cargo.toml +++ b/query-engine/connector-test-kit-rs/query-engine-tests/Cargo.toml @@ -11,7 +11,7 @@ query-test-macros = { path = "../query-test-macros" } query-tests-setup = { path = "../query-tests-setup" } indoc.workspace = true tracing.workspace = true -tracing-futures = "0.2" +tracing-futures.workspace = true colored = "2" chrono.workspace = true psl.workspace = true @@ -20,9 +20,9 @@ uuid.workspace = true tokio.workspace = true user-facing-errors.workspace = true prisma-value = { path = "../../../libs/prisma-value" } -query-engine-metrics = { path = "../../metrics"} +prisma-metrics.path = "../../../libs/metrics" once_cell = "1.15.0" -futures = "0.3" +futures.workspace = true paste = "1.0.14" [dev-dependencies] diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/src/utils/querying.rs b/query-engine/connector-test-kit-rs/query-engine-tests/src/utils/querying.rs index 268fddef976..ce64623645b 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/src/utils/querying.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/src/utils/querying.rs @@ -34,6 +34,28 @@ macro_rules! match_connector_result { }; } +#[macro_export] +macro_rules! assert_connector_error { + ($runner:expr, $q:expr, $code:expr, $( $($matcher:pat_param)|+ $( if $pred:expr )? => $msg:expr ),*) => { + use query_tests_setup::*; + use query_tests_setup::ConnectorVersion::*; + + let connector = $runner.connector_version(); + + let mut results = match &connector { + $( + $( $matcher )|+ $( if $pred )? => $msg.to_string() + ),* + }; + + if results.len() == 0 { + panic!("No assertion failure defined for connector {connector}."); + } + + $runner.query($q).await?.assert_failure($code, Some(results)); + }; +} + #[macro_export] macro_rules! is_one_of { ($result:expr, $potential_results:expr) => { diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/interactive_tx.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/interactive_tx.rs index da0db0a0e70..10082869704 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/interactive_tx.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/interactive_tx.rs @@ -3,6 +3,8 @@ use std::borrow::Cow; #[test_suite(schema(generic), exclude(Sqlite("cfd1")))] mod interactive_tx { + use std::time::{Duration, Instant}; + use query_engine_tests::*; use tokio::time; @@ -213,7 +215,7 @@ mod interactive_tx { Ok(()) } - #[connector_test(exclude(Vitess("planetscale.js.wasm"), Sqlite("cfd1")))] + #[connector_test(exclude(Sqlite("cfd1")))] async fn batch_queries_failure(mut runner: Runner) -> TestResult<()> { // Tx expires after five second. let tx_id = runner.start_tx(5000, 5000, None).await?; @@ -231,7 +233,9 @@ mod interactive_tx { let batch_results = runner.batch(queries, false, None).await?; batch_results.assert_failure(2002, None); + let now = Instant::now(); let res = runner.commit_tx(tx_id.clone()).await?; + assert!(now.elapsed() <= Duration::from_millis(5000)); if matches!(runner.connector_version(), ConnectorVersion::MongoDb(_)) { assert!(res.is_err()); @@ -256,7 +260,7 @@ mod interactive_tx { Ok(()) } - #[connector_test(exclude(Vitess("planetscale.js.wasm")))] + #[connector_test] async fn tx_expiration_failure_cycle(mut runner: Runner) -> TestResult<()> { // Tx expires after one seconds. let tx_id = runner.start_tx(5000, 1000, None).await?; @@ -570,13 +574,13 @@ mod interactive_tx { #[test_suite(schema(generic), exclude(Sqlite("cfd1")))] mod itx_isolation { + use std::sync::Arc; + use query_engine_tests::*; + use tokio::task::JoinSet; // All (SQL) connectors support serializable. - // However, there's a bug in the PlanetScale driver adapter: - // "Transaction characteristics can't be changed while a transaction is in progress - // (errno 1568) (sqlstate 25001) during query: SET TRANSACTION ISOLATION LEVEL SERIALIZABLE" - #[connector_test(exclude(MongoDb, Vitess("planetscale.js", "planetscale.js.wasm"), Sqlite("cfd1")))] + #[connector_test(exclude(MongoDb, Sqlite("cfd1")))] async fn basic_serializable(mut runner: Runner) -> TestResult<()> { let tx_id = runner.start_tx(5000, 5000, Some("Serializable".to_owned())).await?; runner.set_active_tx(tx_id.clone()); @@ -598,9 +602,7 @@ mod itx_isolation { Ok(()) } - // On PlanetScale, this fails with: - // `InteractiveTransactionError("Error in connector: Error querying the database: Server error: `ERROR 25001 (1568): Transaction characteristics can't be changed while a transaction is in progress'")` - #[connector_test(exclude(MongoDb, Vitess("planetscale.js", "planetscale.js.wasm"), Sqlite("cfd1")))] + #[connector_test(exclude(MongoDb, Sqlite("cfd1")))] async fn casing_doesnt_matter(mut runner: Runner) -> TestResult<()> { let tx_id = runner.start_tx(5000, 5000, Some("sErIaLiZaBlE".to_owned())).await?; runner.set_active_tx(tx_id.clone()); @@ -654,4 +656,45 @@ mod itx_isolation { Ok(()) } + + #[connector_test(exclude(Sqlite))] + async fn high_concurrency(runner: Runner) -> TestResult<()> { + let runner = Arc::new(runner); + let mut set = JoinSet::>::new(); + + for i in 1..=20 { + set.spawn({ + let runner = Arc::clone(&runner); + async move { + let tx_id = runner.start_tx(5000, 5000, None).await?; + + runner + .query_in_tx( + &tx_id, + format!( + r#"mutation {{ + createOneTestModel( + data: {{ + id: {i} + }} + ) {{ id }} + }}"# + ), + ) + .await? + .assert_success(); + + runner.commit_tx(tx_id).await?.expect("commit must succeed"); + + Ok(()) + } + }); + } + + while let Some(handle) = set.join_next().await { + handle.expect("task panicked or canceled")?; + } + + Ok(()) + } } diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/metrics.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/metrics.rs index 2a1cf89e9d3..323f162a211 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/metrics.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/metrics.rs @@ -9,9 +9,7 @@ use query_engine_tests::test_suite; ) )] mod metrics { - use query_engine_metrics::{ - PRISMA_CLIENT_QUERIES_ACTIVE, PRISMA_CLIENT_QUERIES_TOTAL, PRISMA_DATASOURCE_QUERIES_TOTAL, - }; + use prisma_metrics::{PRISMA_CLIENT_QUERIES_ACTIVE, PRISMA_CLIENT_QUERIES_TOTAL, PRISMA_DATASOURCE_QUERIES_TOTAL}; use query_engine_tests::ConnectorVersion::*; use query_engine_tests::*; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/ref_actions/on_delete/set_default.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/ref_actions/on_delete/set_default.rs index 131dbcf8959..9e347d47605 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/ref_actions/on_delete/set_default.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/ref_actions/on_delete/set_default.rs @@ -80,7 +80,7 @@ mod one2one_req { runner, "mutation { deleteOneParent(where: { id: 1 }) { id }}", 2003, - "Foreign key constraint failed on the field" + "Foreign key constraint violated" ); Ok(()) @@ -187,7 +187,7 @@ mod one2one_opt { runner, "mutation { deleteOneParent(where: { id: 1 }) { id }}", 2003, - "Foreign key constraint failed on the field" + "Foreign key constraint violated" ); Ok(()) @@ -293,7 +293,7 @@ mod one2many_req { runner, "mutation { deleteOneParent(where: { id: 1 }) { id }}", 2003, - "Foreign key constraint failed on the field" + "Foreign key constraint violated" ); Ok(()) @@ -397,7 +397,7 @@ mod one2many_opt { runner, "mutation { deleteOneParent(where: { id: 1 }) { id }}", 2003, - "Foreign key constraint failed on the field" + "Foreign key constraint violated" ); Ok(()) diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/ref_actions/on_update/set_default.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/ref_actions/on_update/set_default.rs index 99c2ffb63a5..73ee612bbfc 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/ref_actions/on_update/set_default.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/ref_actions/on_update/set_default.rs @@ -79,7 +79,7 @@ mod one2one_req { &runner, r#"mutation { updateOneParent(where: { id: 1 }, data: { uniq: "u1" }) { id }}"#, 2003, - "Foreign key constraint failed on the field" + "Foreign key constraint violated" ); Ok(()) @@ -182,7 +182,7 @@ mod one2one_opt { &runner, r#"mutation { updateOneParent(where: { id: 1 } data: { uniq: "u1" }) { id }}"#, 2003, - "Foreign key constraint failed on the field" + "Foreign key constraint violated" ); Ok(()) @@ -287,7 +287,7 @@ mod one2many_req { &runner, r#"mutation { updateOneParent(where: { id: 1 }, data: { uniq: "u1" }) { id }}"#, 2003, - "Foreign key constraint failed on the field" + "Foreign key constraint violated" ); Ok(()) @@ -390,7 +390,7 @@ mod one2many_opt { &runner, r#"mutation { updateOneParent(where: { id: 1 }, data: { uniq: "u1" }) { id }}"#, 2003, - "Foreign key constraint failed on the field" + "Foreign key constraint violated" ); Ok(()) diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/mod.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/mod.rs index 4b014fa53f6..6ab70f2975d 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/mod.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/mod.rs @@ -1,6 +1,7 @@ mod max_integer; mod prisma_10098; mod prisma_10935; +mod prisma_11750; mod prisma_11789; mod prisma_12572; mod prisma_12929; @@ -27,6 +28,7 @@ mod prisma_21901; mod prisma_22007; mod prisma_22298; mod prisma_22971; +mod prisma_24072; mod prisma_5952; mod prisma_6173; mod prisma_7010; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_11750.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_11750.rs new file mode 100644 index 00000000000..907aae408bf --- /dev/null +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_11750.rs @@ -0,0 +1,127 @@ +use query_engine_tests::*; + +/// Regression test for . +/// +/// See also and +/// . +/// +/// This is a port of the TypeScript test from the client test suite. +/// +/// The test creates a user and then tries to update the same row in multiple concurrent +/// transactions. We don't assert that most operations succeed and merely log the errors happening +/// during update or commit, as those are expected to happen. We do fail the test if creating the +/// user fails, or if we fail to start a transaction, as those operations are expected to succeed. +/// +/// What we really test here, though, is that the query engine must not deadlock (leading to the +/// test never finishing). +/// +/// Some providers are skipped because these concurrent conflicting transactions cause troubles on +/// the database side and failures to start new transactions. +/// +/// For an example of an equivalent test that passes on all databases where the transactions don't +/// conflict and don't cause issues on the database side, see the `high_concurrency` test in the +/// `new::interactive_tx::interactive_tx` test suite. +#[test_suite(schema(user), exclude(Sqlite, MySql(8), SqlServer))] +mod prisma_11750 { + use std::sync::Arc; + use tokio::task::JoinSet; + + #[connector_test] + async fn test_itx_concurrent_updates_single_thread(runner: Runner) -> TestResult<()> { + let runner = Arc::new(runner); + + create_user(&runner, 1, "x").await?; + + for _ in 0..10 { + tokio::try_join!( + update_user(Arc::clone(&runner), "a"), + update_user(Arc::clone(&runner), "b"), + update_user(Arc::clone(&runner), "c"), + update_user(Arc::clone(&runner), "d"), + update_user(Arc::clone(&runner), "e"), + update_user(Arc::clone(&runner), "f"), + update_user(Arc::clone(&runner), "g"), + update_user(Arc::clone(&runner), "h"), + update_user(Arc::clone(&runner), "i"), + update_user(Arc::clone(&runner), "j"), + )?; + } + + create_user(&runner, 2, "y").await?; + + Ok(()) + } + + #[connector_test] + async fn test_itx_concurrent_updates_multi_thread(runner: Runner) -> TestResult<()> { + let runner = Arc::new(runner); + + create_user(&runner, 1, "x").await?; + + for _ in 0..10 { + let mut set = JoinSet::new(); + + for email in ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"] { + set.spawn(update_user(Arc::clone(&runner), email)); + } + + while let Some(handle) = set.join_next().await { + handle.expect("task panicked or canceled")?; + } + } + + create_user(&runner, 2, "y").await?; + + Ok(()) + } + + async fn create_user(runner: &Runner, id: u32, email: &str) -> TestResult<()> { + run_query!( + &runner, + format!( + r#"mutation {{ + createOneUser( + data: {{ + id: {id}, + first_name: "{email}", + last_name: "{email}", + email: "{email}" + }} + ) {{ id }} + }}"# + ) + ); + + Ok(()) + } + + async fn update_user(runner: Arc, new_email: &str) -> TestResult<()> { + let tx_id = runner.start_tx(2000, 25, None).await?; + + let result = runner + .query_in_tx( + &tx_id, + format!( + r#"mutation {{ + updateOneUser( + where: {{ id: 1 }}, + data: {{ email: "{new_email}" }} + ) {{ id }} + }}"# + ), + ) + .await; + + if let Err(err) = result { + tracing::error!(%err, "query error"); + } + + let result = runner.commit_tx(tx_id).await?; + + if let Err(err) = result { + tracing::error!(?err, "commit error"); + } + + Ok(()) + } +} diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_15607.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_15607.rs index e026a90016b..7ed3cb9a859 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_15607.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_15607.rs @@ -4,6 +4,7 @@ //! actors to allow test to continue even if one query is blocking. use indoc::indoc; +use prisma_metrics::{MetricRecorder, WithMetricsInstrumentation}; use query_engine_tests::{ query_core::TxId, render_test_datamodel, setup_metrics, test_tracing_subscriber, LogEmit, QueryResult, Runner, TestError, TestLogCapture, TestResult, WithSubscriber, CONFIG, ENV_LOG_LEVEL, @@ -50,13 +51,12 @@ impl Actor { /// Spawns a new query engine to the runtime. pub async fn spawn() -> TestResult { let (log_capture, log_tx) = TestLogCapture::new(); - async fn with_logs(fut: impl Future, log_tx: LogEmit) -> T { - fut.with_subscriber(test_tracing_subscriber( - ENV_LOG_LEVEL.to_string(), - setup_metrics(), - log_tx, - )) - .await + let (metrics, recorder) = setup_metrics(); + + async fn with_observability(fut: impl Future, log_tx: LogEmit, recorder: MetricRecorder) -> T { + fut.with_subscriber(test_tracing_subscriber(ENV_LOG_LEVEL.to_string(), log_tx)) + .with_recorder(recorder) + .await } let (query_sender, mut query_receiver) = mpsc::channel(100); @@ -73,21 +73,24 @@ impl Actor { Some("READ COMMITTED"), ); - let mut runner = Runner::load(datamodel, &[], version, tag, None, setup_metrics(), log_capture).await?; + let mut runner = Runner::load(datamodel, &[], version, tag, None, metrics, log_capture).await?; tokio::spawn(async move { while let Some(message) = query_receiver.recv().await { match message { Message::Query(query) => { - let result = with_logs(runner.query(query), log_tx.clone()).await; + let result = with_observability(runner.query(query), log_tx.clone(), recorder.clone()).await; response_sender.send(Response::Query(result)).await.unwrap(); } Message::BeginTransaction => { - let response = with_logs(runner.start_tx(10000, 10000, None), log_tx.clone()).await; + let response = + with_observability(runner.start_tx(10000, 10000, None), log_tx.clone(), recorder.clone()) + .await; response_sender.send(Response::Tx(response)).await.unwrap(); } Message::RollbackTransaction(tx_id) => { - let response = with_logs(runner.rollback_tx(tx_id), log_tx.clone()).await?; + let response = + with_observability(runner.rollback_tx(tx_id), log_tx.clone(), recorder.clone()).await?; response_sender.send(Response::Rollback(response)).await.unwrap(); } Message::SetActiveTx(tx_id) => { diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_24072.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_24072.rs new file mode 100644 index 00000000000..14d402ee6af --- /dev/null +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/regressions/prisma_24072.rs @@ -0,0 +1,54 @@ +use indoc::indoc; +use query_engine_tests::*; + +// Skip databases that don't support `onDelete: SetDefault` +#[test_suite( + schema(schema), + exclude( + MongoDb, + MySql(5.6), + MySql(5.7), + Vitess("planetscale.js"), + Vitess("planetscale.js.wasm") + ) +)] +mod prisma_24072 { + fn schema() -> String { + let schema = indoc! { + r#"model Parent { + #id(id, Int, @id) + child Child? + } + + model Child { + #id(id, Int, @id) + parent_id Int? @default(2) @unique + parent Parent? @relation(fields: [parent_id], references: [id], onDelete: NoAction) + }"# + }; + + schema.to_owned() + } + + // Deleting the parent without cascading to the child should fail with an explicitly named constraint violation, + // without any "(not available)" names. + #[connector_test] + async fn test_24072(runner: Runner) -> TestResult<()> { + insta::assert_snapshot!( + run_query!(&runner, r#"mutation { createOneParent(data: { id: 1, child: { create: { id: 1 }}}) { id }}"#), + @r###"{"data":{"createOneParent":{"id":1}}}"### + ); + + assert_connector_error!( + &runner, + "mutation { deleteOneParent(where: { id: 1 }) { id }}", + 2003, + CockroachDb(_) | Postgres(_) | SqlServer(_) | Vitess(_) => "Foreign key constraint violated: `Child_parent_id_fkey (index)`", + MySql(_) => "Foreign key constraint violated: `parent_id`", + Sqlite(_) => "Foreign key constraint violated: `foreign key`", + _ => "Foreign key constraint violated" + ); + + Ok(()) + } +} diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/batching/transactional_batch.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/batching/transactional_batch.rs index 022f8f9e96a..cf0a769d354 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/batching/transactional_batch.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/batching/transactional_batch.rs @@ -145,9 +145,7 @@ mod transactional { Ok(()) } - // On PlanetScale, this fails with: - // "Error in connector: Error querying the database: Server error: `ERROR 25001 (1568): Transaction characteristics can't be changed while a transaction is in progress'"" - #[connector_test(exclude(MongoDb, Vitess("planetscale.js", "planetscale.js.wasm")))] + #[connector_test(exclude(MongoDb))] async fn valid_isolation_level(runner: Runner) -> TestResult<()> { let queries = vec![r#"mutation { createOneModelB(data: { id: 1 }) { id }}"#.to_string()]; diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/many_relation.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/many_relation.rs index 3e179e8a4e2..30e39078527 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/many_relation.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/filters/many_relation.rs @@ -514,6 +514,186 @@ mod many_relation { Ok(()) } + fn schema_25103() -> String { + let schema = indoc! { + r#"model Contact { + #id(id, String, @id) + identities Identity[] + } + + model Identity { + #id(id, String, @id) + contactId String + contact Contact @relation(fields: [contactId], references: [id]) + subscriptions Subscription[] + } + + model Subscription { + #id(id, String, @id) + identityId String + audienceId String + optedOutAt DateTime? + audience Audience @relation(fields: [audienceId], references: [id]) + identity Identity @relation(fields: [identityId], references: [id]) + } + + model Audience { + #id(id, String, @id) + deletedAt DateTime? + subscriptions Subscription[] + }"# + }; + + schema.to_owned() + } + + // Regression test for https://github.com/prisma/prisma/issues/25103 + // SQL Server excluded because the m2m fragment does not support onUpdate/onDelete args which are needed. + #[connector_test(schema(schema_25103), exclude(SqlServer))] + async fn prisma_25103(runner: Runner) -> TestResult<()> { + // Create some sample audiences + run_query!( + &runner, + r#"mutation { + createOneAudience(data: { + id: "audience1", + deletedAt: null + }) { + id + }}"# + ); + run_query!( + &runner, + r#"mutation { + createOneAudience(data: { + id: "audience2", + deletedAt: null + }) { + id + }}"# + ); + // Create a contact with identities and subscriptions + insta::assert_snapshot!( + run_query!( + &runner, + r#"mutation { + createOneContact(data: { + id: "contact1", + identities: { + create: [ + { + id: "identity1", + subscriptions: { + create: [ + { + id: "subscription1", + audienceId: "audience1", + optedOutAt: null + }, + { + id: "subscription2", + audienceId: "audience2", + optedOutAt: null + } + ] + } + } + ] + } + }) { + id, + identities (orderBy: { id: asc }) { + id, + subscriptions (orderBy: { id: asc }) { + id, + audienceId + } + } + }}"# + ), + @r###"{"data":{"createOneContact":{"id":"contact1","identities":[{"id":"identity1","subscriptions":[{"id":"subscription1","audienceId":"audience1"},{"id":"subscription2","audienceId":"audience2"}]}]}}}"### + ); + // Find contacts that include identities whose subscriptions have `optedOutAt = null` and include audiences with `deletedAt = null`` + insta::assert_snapshot!( + run_query!( + &runner, + r#"query { + findManyContact(orderBy: { id: asc }) { + id, + identities(orderBy: { id: asc }) { + id, + subscriptions(orderBy: { id: asc }, where: { optedOutAt: null, audience: { deletedAt: null } }) { + id, + identityId, + audience { + id, + deletedAt + } + } + } + } + }"# + ), + @r###"{"data":{"findManyContact":[{"id":"contact1","identities":[{"id":"identity1","subscriptions":[{"id":"subscription1","identityId":"identity1","audience":{"id":"audience1","deletedAt":null}},{"id":"subscription2","identityId":"identity1","audience":{"id":"audience2","deletedAt":null}}]}]}]}}"### + ); + + Ok(()) + } + + fn schema_25104() -> String { + let schema = indoc! { + r#" + model A { + #id(id, String, @id) + bs B[] + } + + model B { + #id(id, String, @id) + a A @relation(fields: [aId], references: [id]) + aId String + + cs C[] + } + + model C { + #id(id, String, @id) + name String + bs B[] + } + "# + }; + + schema.to_owned() + } + + #[connector_test(schema(schema_25104), exclude(MongoDb))] + async fn prisma_25104(runner: Runner) -> TestResult<()> { + insta::assert_snapshot!( + run_query!( + &runner, + r#" + query { + findManyA { + bs(where: { + cs: { + every: { + name: { equals: "a" } + } + } + }) { + id + } + } + } + "# + ), + @r###"{"data":{"findManyA":[]}}"### + ); + + Ok(()) + } + fn schema_23742() -> String { let schema = indoc! { r#"model Top { @@ -522,7 +702,7 @@ mod many_relation { middleId Int? middle Middle? @relation(fields: [middleId], references: [id]) - #m2m(bottoms, Bottom[], id, Int) + #m2m(bottoms, Bottom[], id, Int) } model Middle { @@ -579,6 +759,71 @@ mod many_relation { Ok(()) } + fn schema_nested_some_filter_m2m_different_pk() -> String { + let schema = indoc! { + r#" + model Top { + #id(topId, Int, @id) + + relatedMiddleId Int? + middle Middle? @relation(fields: [relatedMiddleId], references: [middleId]) + + #m2m(bottoms, Bottom[], bottomId, Int) + } + + model Middle { + #id(middleId, Int, @id) + + bottoms Bottom[] + tops Top[] + } + + model Bottom { + #id(bottomId, Int, @id) + + relatedMiddleId Int? + middle Middle? @relation(fields: [relatedMiddleId], references: [middleId]) + + #m2m(tops, Top[], topId, Int) + } + "# + }; + + schema.to_owned() + } + + #[connector_test(schema(schema_nested_some_filter_m2m_different_pk), exclude(SqlServer))] + async fn nested_some_filter_m2m_different_pk(runner: Runner) -> TestResult<()> { + run_query!( + &runner, + r#"mutation { + createOneTop(data: { + topId: 1, + middle: { create: { middleId: 1, bottoms: { create: { bottomId: 1, tops: { create: { topId: 2 } } } } } } + }) { + topId + }}"# + ); + + insta::assert_snapshot!( + run_query!(&runner, r#"{ + findUniqueTop(where: { topId: 1 }) { + middle { + bottoms( + where: { tops: { some: { topId: 2 } } } + ) { + bottomId + } + } + } + } + "#), + @r###"{"data":{"findUniqueTop":{"middle":{"bottoms":[{"bottomId":1}]}}}}"### + ); + + Ok(()) + } + async fn test_data(runner: &Runner) -> TestResult<()> { runner .query(indoc! { r#" diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/create_many.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/create_many.rs index f1f80eb93f2..61996fd993d 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/create_many.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/create_many.rs @@ -355,7 +355,7 @@ mod create_many { // LibSQL & co are ignored because they don't support metrics #[connector_test(schema(schema_7), only(Sqlite("3")))] async fn create_many_by_shape_counter_1(runner: Runner) -> TestResult<()> { - use query_engine_metrics::PRISMA_DATASOURCE_QUERIES_TOTAL; + use prisma_metrics::PRISMA_DATASOURCE_QUERIES_TOTAL; // Generated queries: // INSERT INTO `main`.`Test` (`opt`, `req`) VALUES (null, ?), (?, ?) params=[1,2,2] @@ -397,7 +397,7 @@ mod create_many { // LibSQL & co are ignored because they don't support metrics #[connector_test(schema(schema_7), only(Sqlite("3")))] async fn create_many_by_shape_counter_2(runner: Runner) -> TestResult<()> { - use query_engine_metrics::PRISMA_DATASOURCE_QUERIES_TOTAL; + use prisma_metrics::PRISMA_DATASOURCE_QUERIES_TOTAL; // Generated queries: // INSERT INTO `main`.`Test` ( `opt_default_static`, `req_default_static`, `opt`, `req` ) VALUES (?, ?, null, ?), (?, ?, null, ?), (?, ?, null, ?) params=[1,1,1,2,1,2,1,3,3] @@ -436,7 +436,7 @@ mod create_many { // LibSQL & co are ignored because they don't support metrics #[connector_test(schema(schema_7), only(Sqlite("3")))] async fn create_many_by_shape_counter_3(runner: Runner) -> TestResult<()> { - use query_engine_metrics::PRISMA_DATASOURCE_QUERIES_TOTAL; + use prisma_metrics::PRISMA_DATASOURCE_QUERIES_TOTAL; // Generated queries: // INSERT INTO `main`.`Test` ( `req_default_static`, `req`, `opt_default`, `opt_default_static` ) VALUES (?, ?, ?, ?) params=[1,6,3,1] diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/create_many_and_return.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/create_many_and_return.rs index a55efb4e0cc..9b92d99404b 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/create_many_and_return.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/writes/top_level_mutations/create_many_and_return.rs @@ -650,7 +650,7 @@ mod create_many { // LibSQL & co are ignored because they don't support metrics #[connector_test(schema(schema_7), only(Sqlite("3")))] async fn create_many_by_shape_counter_1(runner: Runner) -> TestResult<()> { - use query_engine_metrics::PRISMA_DATASOURCE_QUERIES_TOTAL; + use prisma_metrics::PRISMA_DATASOURCE_QUERIES_TOTAL; // Generated queries: // INSERT INTO `main`.`Test` (`opt`, `req`) VALUES (null, ?), (?, ?) params=[1,2,2] @@ -692,7 +692,7 @@ mod create_many { // LibSQL & co are ignored because they don't support metrics #[connector_test(schema(schema_7), only(Sqlite("3")))] async fn create_many_by_shape_counter_2(runner: Runner) -> TestResult<()> { - use query_engine_metrics::PRISMA_DATASOURCE_QUERIES_TOTAL; + use prisma_metrics::PRISMA_DATASOURCE_QUERIES_TOTAL; // Generated queries: // INSERT INTO `main`.`Test` ( `opt_default_static`, `req_default_static`, `opt`, `req` ) VALUES (?, ?, null, ?), (?, ?, null, ?), (?, ?, null, ?) params=[1,1,1,2,1,2,1,3,3] @@ -731,7 +731,7 @@ mod create_many { // LibSQL & co are ignored because they don't support metrics #[connector_test(schema(schema_7), only(Sqlite("3")))] async fn create_many_by_shape_counter_3(runner: Runner) -> TestResult<()> { - use query_engine_metrics::PRISMA_DATASOURCE_QUERIES_TOTAL; + use prisma_metrics::PRISMA_DATASOURCE_QUERIES_TOTAL; // Generated queries: // INSERT INTO `main`.`Test` ( `req_default_static`, `req`, `opt_default`, `opt_default_static` ) VALUES (?, ?, ?, ?) params=[1,6,3,1] diff --git a/query-engine/connector-test-kit-rs/query-tests-setup/Cargo.toml b/query-engine/connector-test-kit-rs/query-tests-setup/Cargo.toml index cd8abc07331..b2016e602b9 100644 --- a/query-engine/connector-test-kit-rs/query-tests-setup/Cargo.toml +++ b/query-engine/connector-test-kit-rs/query-tests-setup/Cargo.toml @@ -15,14 +15,15 @@ sql-query-connector = { path = "../../connectors/sql-query-connector" } query-engine = { path = "../../query-engine" } psl.workspace = true user-facing-errors = { path = "../../../libs/user-facing-errors" } +telemetry = { path = "../../../libs/telemetry" } thiserror = "1.0" async-trait.workspace = true nom = "7.1" itertools.workspace = true -regex = "1" +regex.workspace = true serde.workspace = true tracing.workspace = true -tracing-futures = "0.2" +tracing-futures.workspace = true tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] } tracing-error = "0.2" colored = "2" @@ -30,7 +31,7 @@ indoc.workspace = true enumflags2.workspace = true hyper = { version = "0.14", features = ["full"] } indexmap.workspace = true -query-engine-metrics = { path = "../../metrics" } +prisma-metrics.path = "../../../libs/metrics" quaint.workspace = true jsonrpc-core = "17" insta = "1.7.1" diff --git a/query-engine/connector-test-kit-rs/query-tests-setup/src/lib.rs b/query-engine/connector-test-kit-rs/query-tests-setup/src/lib.rs index e94c14c6c57..87e9241fb46 100644 --- a/query-engine/connector-test-kit-rs/query-tests-setup/src/lib.rs +++ b/query-engine/connector-test-kit-rs/query-tests-setup/src/lib.rs @@ -22,10 +22,9 @@ pub use templating::*; use colored::Colorize; use once_cell::sync::Lazy; +use prisma_metrics::{MetricRecorder, MetricRegistry, WithMetricsInstrumentation}; use psl::datamodel_connector::ConnectorCapabilities; -use query_engine_metrics::MetricRegistry; use std::future::Future; -use std::sync::Once; use tokio::runtime::Builder; use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}; use tracing_futures::WithSubscriber; @@ -61,14 +60,10 @@ fn run_with_tokio>(fut: F) -> O { .block_on(fut) } -static METRIC_RECORDER: Once = Once::new(); - -pub fn setup_metrics() -> MetricRegistry { +pub fn setup_metrics() -> (MetricRegistry, MetricRecorder) { let metrics = MetricRegistry::new(); - METRIC_RECORDER.call_once(|| { - query_engine_metrics::setup(); - }); - metrics + let recorder = MetricRecorder::new(metrics.clone()).with_initialized_prisma_metrics(); + (metrics, recorder) } /// Taken from Reddit. Enables taking an async function pointer which takes references as param @@ -161,8 +156,7 @@ fn run_relation_link_test_impl( let datamodel = render_test_datamodel(&test_db_name, template, &[], None, Default::default(), Default::default(), None); let (connector_tag, version) = CONFIG.test_connector().unwrap(); - let metrics = setup_metrics(); - let metrics_for_subscriber = metrics.clone(); + let (metrics, recorder) = setup_metrics(); let (log_capture, log_tx) = TestLogCapture::new(); run_with_tokio( @@ -176,9 +170,8 @@ fn run_relation_link_test_impl( test_fn(&runner, &dm).with_subscriber(test_tracing_subscriber( ENV_LOG_LEVEL.to_string(), - metrics_for_subscriber, log_tx, - )) + )).with_recorder(recorder) .await.unwrap(); teardown_project(&datamodel, Default::default(), runner.schema_id()) @@ -275,8 +268,7 @@ fn run_connector_test_impl( None, ); let (connector_tag, version) = CONFIG.test_connector().unwrap(); - let metrics = crate::setup_metrics(); - let metrics_for_subscriber = metrics.clone(); + let (metrics, recorder) = crate::setup_metrics(); let (log_capture, log_tx) = TestLogCapture::new(); @@ -297,11 +289,8 @@ fn run_connector_test_impl( let schema_id = runner.schema_id(); if let Err(err) = test_fn(runner) - .with_subscriber(test_tracing_subscriber( - ENV_LOG_LEVEL.to_string(), - metrics_for_subscriber, - log_tx, - )) + .with_subscriber(test_tracing_subscriber(ENV_LOG_LEVEL.to_string(), log_tx)) + .with_recorder(recorder) .await { panic!("💥 Test failed due to an error: {err:?}"); diff --git a/query-engine/connector-test-kit-rs/query-tests-setup/src/logging.rs b/query-engine/connector-test-kit-rs/query-tests-setup/src/logging.rs index 5520075e6d3..d95867c39b8 100644 --- a/query-engine/connector-test-kit-rs/query-tests-setup/src/logging.rs +++ b/query-engine/connector-test-kit-rs/query-tests-setup/src/logging.rs @@ -1,12 +1,11 @@ -use query_core::telemetry::helpers as telemetry_helpers; -use query_engine_metrics::MetricRegistry; +use telemetry::helpers as telemetry_helpers; use tracing::Subscriber; use tracing_error::ErrorLayer; use tracing_subscriber::{prelude::*, Layer}; use crate::LogEmit; -pub fn test_tracing_subscriber(log_config: String, metrics: MetricRegistry, log_tx: LogEmit) -> impl Subscriber { +pub fn test_tracing_subscriber(log_config: String, log_tx: LogEmit) -> impl Subscriber { let filter = telemetry_helpers::env_filter(true, telemetry_helpers::QueryEngineLogLevel::Override(log_config)); let fmt_layer = tracing_subscriber::fmt::layer() @@ -15,7 +14,6 @@ pub fn test_tracing_subscriber(log_config: String, metrics: MetricRegistry, log_ tracing_subscriber::registry() .with(fmt_layer.boxed()) - .with(metrics.boxed()) .with(ErrorLayer::default()) } diff --git a/query-engine/connector-test-kit-rs/query-tests-setup/src/runner/mod.rs b/query-engine/connector-test-kit-rs/query-tests-setup/src/runner/mod.rs index e5808ace7fc..de8ee9bd33b 100644 --- a/query-engine/connector-test-kit-rs/query-tests-setup/src/runner/mod.rs +++ b/query-engine/connector-test-kit-rs/query-tests-setup/src/runner/mod.rs @@ -8,13 +8,13 @@ use crate::{ ENGINE_PROTOCOL, }; use colored::Colorize; +use prisma_metrics::MetricRegistry; use query_core::{ protocol::EngineProtocol, relation_load_strategy, schema::{self, QuerySchemaRef}, QueryExecutor, TransactionOptions, TxId, }; -use query_engine_metrics::MetricRegistry; use request_handlers::{ BatchTransactionOption, ConnectorKind, GraphqlBody, JsonBatchQuery, JsonBody, JsonSingleQuery, MultiQuery, RequestBody, RequestHandler, @@ -306,7 +306,15 @@ impl Runner { }) } - pub async fn query(&self, query: T) -> TestResult + pub async fn query(&self, query: impl Into) -> TestResult { + self.query_with_maybe_tx_id(self.current_tx_id.as_ref(), query).await + } + + pub async fn query_in_tx(&self, tx_id: &TxId, query: impl Into) -> TestResult { + self.query_with_maybe_tx_id(Some(tx_id), query).await + } + + async fn query_with_maybe_tx_id(&self, tx_id: Option<&TxId>, query: T) -> TestResult where T: Into, { @@ -316,7 +324,7 @@ impl Runner { RunnerExecutor::Builtin(e) => e, RunnerExecutor::External(external) => match JsonRequest::from_graphql(&query, self.query_schema()) { Ok(json_query) => { - let mut response = external.query(json_query, self.current_tx_id.as_ref()).await?; + let mut response = external.query(json_query, tx_id).await?; response.detag(); return Ok(response); } @@ -353,7 +361,7 @@ impl Runner { } }; - let response = handler.handle(request_body, self.current_tx_id.clone(), None).await; + let response = handler.handle(request_body, tx_id.cloned(), None).await; let result: QueryResult = match self.protocol { EngineProtocol::Json => JsonResponse::from_graphql(response).into(), diff --git a/query-engine/connectors/mongodb-query-connector/Cargo.toml b/query-engine/connectors/mongodb-query-connector/Cargo.toml index 05bf5968a59..0c4cefcce84 100644 --- a/query-engine/connectors/mongodb-query-connector/Cargo.toml +++ b/query-engine/connectors/mongodb-query-connector/Cargo.toml @@ -7,22 +7,22 @@ version = "0.1.0" anyhow = "1.0" async-trait.workspace = true bigdecimal = "0.3" -futures = "0.3" +futures.workspace = true itertools.workspace = true mongodb.workspace = true bson.workspace = true rand.workspace = true -regex = "1" +regex.workspace = true serde_json.workspace = true thiserror = "1.0" tokio.workspace = true tracing.workspace = true -tracing-futures = "0.2" +tracing-futures.workspace = true uuid.workspace = true indexmap.workspace = true -query-engine-metrics = { path = "../../metrics" } +prisma-metrics.path = "../../../libs/metrics" cuid = { git = "https://github.com/prisma/cuid-rust", branch = "wasm32-support" } -derive_more = "0.99.17" +derive_more.workspace = true [dependencies.query-structure] path = "../../query-structure" @@ -37,6 +37,9 @@ path = "../query-connector" [dependencies.prisma-value] path = "../../../libs/prisma-value" +[dependencies.telemetry] +path = "../../../libs/telemetry" + [dependencies.chrono] features = ["serde"] version = "0.4" diff --git a/query-engine/connectors/mongodb-query-connector/src/interface/connection.rs b/query-engine/connectors/mongodb-query-connector/src/interface/connection.rs index 9f923251752..72b0c6a3afb 100644 --- a/query-engine/connectors/mongodb-query-connector/src/interface/connection.rs +++ b/query-engine/connectors/mongodb-query-connector/src/interface/connection.rs @@ -11,6 +11,7 @@ use connector_interface::{ use mongodb::{ClientSession, Database}; use query_structure::{prelude::*, RelationLoadStrategy, SelectionResult}; use std::collections::HashMap; +use telemetry::helpers::TraceParent; pub struct MongoDbConnection { /// The session to use for operations. @@ -57,7 +58,7 @@ impl WriteOperations for MongoDbConnection { args: WriteArgs, // The field selection on a create is never used on MongoDB as it cannot return more than the ID. _selected_fields: FieldSelection, - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result { catch(write::create_record(&self.database, &mut self.session, model, args)).await } @@ -67,7 +68,7 @@ impl WriteOperations for MongoDbConnection { model: &Model, args: Vec, skip_duplicates: bool, - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result { catch(write::create_records( &self.database, @@ -85,7 +86,7 @@ impl WriteOperations for MongoDbConnection { _args: Vec, _skip_duplicates: bool, _selected_fields: FieldSelection, - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result { unimplemented!() } @@ -95,7 +96,7 @@ impl WriteOperations for MongoDbConnection { model: &Model, record_filter: connector_interface::RecordFilter, args: WriteArgs, - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result { catch(async move { let result = write::update_records( @@ -119,7 +120,7 @@ impl WriteOperations for MongoDbConnection { record_filter: connector_interface::RecordFilter, args: WriteArgs, selected_fields: Option, - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result> { catch(async move { let result = write::update_records( @@ -150,7 +151,7 @@ impl WriteOperations for MongoDbConnection { &mut self, model: &Model, record_filter: connector_interface::RecordFilter, - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result { catch(write::delete_records( &self.database, @@ -166,7 +167,7 @@ impl WriteOperations for MongoDbConnection { model: &Model, record_filter: connector_interface::RecordFilter, selected_fields: FieldSelection, - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result { catch(write::delete_record( &self.database, @@ -183,7 +184,7 @@ impl WriteOperations for MongoDbConnection { field: &RelationFieldRef, parent_id: &SelectionResult, child_ids: &[SelectionResult], - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result<()> { catch(write::m2m_connect( &self.database, @@ -200,7 +201,7 @@ impl WriteOperations for MongoDbConnection { field: &RelationFieldRef, parent_id: &SelectionResult, child_ids: &[SelectionResult], - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result<()> { catch(write::m2m_disconnect( &self.database, @@ -235,7 +236,7 @@ impl WriteOperations for MongoDbConnection { async fn native_upsert_record( &mut self, _upsert: connector_interface::NativeUpsert, - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result { unimplemented!("Native upsert is not currently supported.") } @@ -249,7 +250,7 @@ impl ReadOperations for MongoDbConnection { filter: &query_structure::Filter, selected_fields: &FieldSelection, _relation_load_strategy: RelationLoadStrategy, - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result> { catch(read::get_single_record( &self.database, @@ -267,7 +268,7 @@ impl ReadOperations for MongoDbConnection { query_arguments: query_structure::QueryArguments, selected_fields: &FieldSelection, _relation_load_strategy: RelationLoadStrategy, - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result { catch(read::get_many_records( &self.database, @@ -283,7 +284,7 @@ impl ReadOperations for MongoDbConnection { &mut self, from_field: &RelationFieldRef, from_record_ids: &[SelectionResult], - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result> { catch(read::get_related_m2m_record_ids( &self.database, @@ -301,7 +302,7 @@ impl ReadOperations for MongoDbConnection { selections: Vec, group_by: Vec, having: Option, - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result> { catch(aggregate::aggregate( &self.database, diff --git a/query-engine/connectors/mongodb-query-connector/src/interface/transaction.rs b/query-engine/connectors/mongodb-query-connector/src/interface/transaction.rs index 2fe5d840fa1..6045d06b442 100644 --- a/query-engine/connectors/mongodb-query-connector/src/interface/transaction.rs +++ b/query-engine/connectors/mongodb-query-connector/src/interface/transaction.rs @@ -1,16 +1,20 @@ +use std::collections::HashMap; + +use connector_interface::{ConnectionLike, ReadOperations, Transaction, UpdateType, WriteOperations}; +use mongodb::options::{Acknowledgment, ReadConcern, TransactionOptions, WriteConcern}; +use prisma_metrics::{guards::GaugeGuard, PRISMA_CLIENT_QUERIES_ACTIVE}; +use query_structure::{RelationLoadStrategy, SelectionResult}; +use telemetry::helpers::TraceParent; + use super::*; use crate::{ error::MongoError, root_queries::{aggregate, read, write}, }; -use connector_interface::{ConnectionLike, ReadOperations, Transaction, UpdateType, WriteOperations}; -use mongodb::options::{Acknowledgment, ReadConcern, TransactionOptions, WriteConcern}; -use query_engine_metrics::{decrement_gauge, increment_gauge, metrics, PRISMA_CLIENT_QUERIES_ACTIVE}; -use query_structure::{RelationLoadStrategy, SelectionResult}; -use std::collections::HashMap; pub struct MongoDbTransaction<'conn> { connection: &'conn mut MongoDbConnection, + gauge: GaugeGuard, } impl<'conn> ConnectionLike for MongoDbTransaction<'conn> {} @@ -31,16 +35,17 @@ impl<'conn> MongoDbTransaction<'conn> { .await .map_err(|err| MongoError::from(err).into_connector_error())?; - increment_gauge!(PRISMA_CLIENT_QUERIES_ACTIVE, 1.0); - - Ok(Self { connection }) + Ok(Self { + connection, + gauge: GaugeGuard::increment(PRISMA_CLIENT_QUERIES_ACTIVE), + }) } } #[async_trait] impl<'conn> Transaction for MongoDbTransaction<'conn> { async fn commit(&mut self) -> connector_interface::Result<()> { - decrement_gauge!(PRISMA_CLIENT_QUERIES_ACTIVE, 1.0); + self.gauge.decrement(); utils::commit_with_retry(&mut self.connection.session) .await @@ -50,7 +55,7 @@ impl<'conn> Transaction for MongoDbTransaction<'conn> { } async fn rollback(&mut self) -> connector_interface::Result<()> { - decrement_gauge!(PRISMA_CLIENT_QUERIES_ACTIVE, 1.0); + self.gauge.decrement(); self.connection .session @@ -78,7 +83,7 @@ impl<'conn> WriteOperations for MongoDbTransaction<'conn> { args: connector_interface::WriteArgs, // The field selection on a create is never used on MongoDB as it cannot return more than the ID. _selected_fields: FieldSelection, - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result { catch(write::create_record( &self.connection.database, @@ -94,7 +99,7 @@ impl<'conn> WriteOperations for MongoDbTransaction<'conn> { model: &Model, args: Vec, skip_duplicates: bool, - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result { catch(write::create_records( &self.connection.database, @@ -112,7 +117,7 @@ impl<'conn> WriteOperations for MongoDbTransaction<'conn> { _args: Vec, _skip_duplicates: bool, _selected_fields: FieldSelection, - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result { unimplemented!() } @@ -122,7 +127,7 @@ impl<'conn> WriteOperations for MongoDbTransaction<'conn> { model: &Model, record_filter: connector_interface::RecordFilter, args: connector_interface::WriteArgs, - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result { catch(async move { let result = write::update_records( @@ -145,7 +150,7 @@ impl<'conn> WriteOperations for MongoDbTransaction<'conn> { record_filter: connector_interface::RecordFilter, args: connector_interface::WriteArgs, selected_fields: Option, - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result> { catch(async move { let result = write::update_records( @@ -175,7 +180,7 @@ impl<'conn> WriteOperations for MongoDbTransaction<'conn> { &mut self, model: &Model, record_filter: connector_interface::RecordFilter, - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result { catch(write::delete_records( &self.connection.database, @@ -191,7 +196,7 @@ impl<'conn> WriteOperations for MongoDbTransaction<'conn> { model: &Model, record_filter: connector_interface::RecordFilter, selected_fields: FieldSelection, - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result { catch(write::delete_record( &self.connection.database, @@ -206,7 +211,7 @@ impl<'conn> WriteOperations for MongoDbTransaction<'conn> { async fn native_upsert_record( &mut self, _upsert: connector_interface::NativeUpsert, - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result { unimplemented!("Native upsert is not currently supported.") } @@ -216,7 +221,7 @@ impl<'conn> WriteOperations for MongoDbTransaction<'conn> { field: &RelationFieldRef, parent_id: &SelectionResult, child_ids: &[SelectionResult], - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result<()> { catch(write::m2m_connect( &self.connection.database, @@ -233,7 +238,7 @@ impl<'conn> WriteOperations for MongoDbTransaction<'conn> { field: &RelationFieldRef, parent_id: &SelectionResult, child_ids: &[SelectionResult], - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result<()> { catch(write::m2m_disconnect( &self.connection.database, @@ -279,7 +284,7 @@ impl<'conn> ReadOperations for MongoDbTransaction<'conn> { filter: &query_structure::Filter, selected_fields: &FieldSelection, _relation_load_strategy: RelationLoadStrategy, - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result> { catch(read::get_single_record( &self.connection.database, @@ -297,7 +302,7 @@ impl<'conn> ReadOperations for MongoDbTransaction<'conn> { query_arguments: query_structure::QueryArguments, selected_fields: &FieldSelection, _relation_load_strategy: RelationLoadStrategy, - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result { catch(read::get_many_records( &self.connection.database, @@ -313,7 +318,7 @@ impl<'conn> ReadOperations for MongoDbTransaction<'conn> { &mut self, from_field: &RelationFieldRef, from_record_ids: &[SelectionResult], - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result> { catch(read::get_related_m2m_record_ids( &self.connection.database, @@ -331,7 +336,7 @@ impl<'conn> ReadOperations for MongoDbTransaction<'conn> { selections: Vec, group_by: Vec, having: Option, - _trace_id: Option, + _traceparent: Option, ) -> connector_interface::Result> { catch(aggregate::aggregate( &self.connection.database, diff --git a/query-engine/connectors/mongodb-query-connector/src/interface/utils.rs b/query-engine/connectors/mongodb-query-connector/src/interface/utils.rs index 6b089ff700f..21e35604045 100644 --- a/query-engine/connectors/mongodb-query-connector/src/interface/utils.rs +++ b/query-engine/connectors/mongodb-query-connector/src/interface/utils.rs @@ -1,7 +1,7 @@ use std::time::{Duration, Instant}; use mongodb::{ - error::{Result, TRANSIENT_TRANSACTION_ERROR, UNKNOWN_TRANSACTION_COMMIT_RESULT}, + error::{CommandError, ErrorKind, Result, TRANSIENT_TRANSACTION_ERROR, UNKNOWN_TRANSACTION_COMMIT_RESULT}, ClientSession, }; @@ -14,9 +14,15 @@ pub async fn commit_with_retry(session: &mut ClientSession) -> Result<()> { let timeout = Instant::now(); while let Err(err) = session.commit_transaction().await { - if (err.contains_label(UNKNOWN_TRANSACTION_COMMIT_RESULT) || err.contains_label(TRANSIENT_TRANSACTION_ERROR)) - && timeout.elapsed() < MAX_TX_TIMEOUT_COMMIT_RETRY_LIMIT - { + // For some reason, MongoDB adds `TRANSIENT_TRANSACTION_ERROR` to errors about aborted + // transactions. Since transaction will not become less aborted in the future, we handle + // this case separately. + let is_aborted = matches!(err.kind.as_ref(), ErrorKind::Command(CommandError { code: 251, .. })); + let is_in_unknown_state = err.contains_label(UNKNOWN_TRANSACTION_COMMIT_RESULT); + let is_transient = err.contains_label(TRANSIENT_TRANSACTION_ERROR); + let is_retryable = !is_aborted && (is_in_unknown_state || is_transient); + + if is_retryable && timeout.elapsed() < MAX_TX_TIMEOUT_COMMIT_RETRY_LIMIT { tokio::time::sleep(TX_RETRY_BACKOFF).await; continue; } else { diff --git a/query-engine/connectors/mongodb-query-connector/src/root_queries/aggregate.rs b/query-engine/connectors/mongodb-query-connector/src/root_queries/aggregate.rs index 05ff57053e9..797e34127f8 100644 --- a/query-engine/connectors/mongodb-query-connector/src/root_queries/aggregate.rs +++ b/query-engine/connectors/mongodb-query-connector/src/root_queries/aggregate.rs @@ -108,7 +108,7 @@ fn to_aggregation_rows( for field in fields { let meta = selection_meta.get(field.db_name()).unwrap(); - let bson = doc.remove(&format!("count_{}", field.db_name())).unwrap(); + let bson = doc.remove(format!("count_{}", field.db_name())).unwrap(); let field_val = value_from_bson(bson, meta)?; row.push(AggregationResult::Count(Some(field.clone()), field_val)); @@ -117,7 +117,7 @@ fn to_aggregation_rows( AggregationSelection::Average(fields) => { for field in fields { let meta = selection_meta.get(field.db_name()).unwrap(); - let bson = doc.remove(&format!("avg_{}", field.db_name())).unwrap(); + let bson = doc.remove(format!("avg_{}", field.db_name())).unwrap(); let field_val = value_from_bson(bson, meta)?; row.push(AggregationResult::Average(field.clone(), field_val)); @@ -126,7 +126,7 @@ fn to_aggregation_rows( AggregationSelection::Sum(fields) => { for field in fields { let meta = selection_meta.get(field.db_name()).unwrap(); - let bson = doc.remove(&format!("sum_{}", field.db_name())).unwrap(); + let bson = doc.remove(format!("sum_{}", field.db_name())).unwrap(); let field_val = value_from_bson(bson, meta)?; row.push(AggregationResult::Sum(field.clone(), field_val)); @@ -135,7 +135,7 @@ fn to_aggregation_rows( AggregationSelection::Min(fields) => { for field in fields { let meta = selection_meta.get(field.db_name()).unwrap(); - let bson = doc.remove(&format!("min_{}", field.db_name())).unwrap(); + let bson = doc.remove(format!("min_{}", field.db_name())).unwrap(); let field_val = value_from_bson(bson, meta)?; row.push(AggregationResult::Min(field.clone(), field_val)); @@ -144,7 +144,7 @@ fn to_aggregation_rows( AggregationSelection::Max(fields) => { for field in fields { let meta = selection_meta.get(field.db_name()).unwrap(); - let bson = doc.remove(&format!("max_{}", field.db_name())).unwrap(); + let bson = doc.remove(format!("max_{}", field.db_name())).unwrap(); let field_val = value_from_bson(bson, meta)?; row.push(AggregationResult::Max(field.clone(), field_val)); diff --git a/query-engine/connectors/mongodb-query-connector/src/root_queries/mod.rs b/query-engine/connectors/mongodb-query-connector/src/root_queries/mod.rs index 96b6d1fe73f..85099be0a69 100644 --- a/query-engine/connectors/mongodb-query-connector/src/root_queries/mod.rs +++ b/query-engine/connectors/mongodb-query-connector/src/root_queries/mod.rs @@ -13,13 +13,16 @@ use crate::{ use bson::Bson; use bson::Document; use futures::Future; -use query_engine_metrics::{ - histogram, increment_counter, metrics, PRISMA_DATASOURCE_QUERIES_DURATION_HISTOGRAM_MS, - PRISMA_DATASOURCE_QUERIES_TOTAL, +use prisma_metrics::{ + counter, histogram, PRISMA_DATASOURCE_QUERIES_DURATION_HISTOGRAM_MS, PRISMA_DATASOURCE_QUERIES_TOTAL, }; use query_structure::*; +use std::sync::Arc; use std::time::Instant; -use tracing::debug; +use tracing::{debug, info_span}; +use tracing_futures::Instrument; + +const DB_SYSTEM_NAME: &str = "mongodb"; /// Transforms a document to a `Record`, fields ordered as defined in `fields`. fn document_to_record(mut doc: Document, fields: &[String], meta_mapping: &OutputMetaMapping) -> crate::Result { @@ -59,19 +62,34 @@ where F: FnOnce() -> U + 'a, U: Future>, { + // TODO: build the string lazily in the Display impl so it doesn't have to be built if neither + // logs nor traces are enabled. This is tricky because whatever we store in the span has to be + // 'static, and all `QueryString` implementations aren't, so this requires some refactoring. + let query_string: Arc = builder.build().into(); + + let span = info_span!( + "prisma:engine:db_query", + user_facing = true, + "db.system" = DB_SYSTEM_NAME, + "db.statement" = %Arc::clone(&query_string), + "db.operation.name" = builder.query_type(), + "otel.kind" = "client" + ); + + if let Some(coll) = builder.collection() { + span.record("db.collection.name", coll); + } + let start = Instant::now(); - let res = f().await; + let res = f().instrument(span).await; let elapsed = start.elapsed().as_millis() as f64; - histogram!(PRISMA_DATASOURCE_QUERIES_DURATION_HISTOGRAM_MS, elapsed); - increment_counter!(PRISMA_DATASOURCE_QUERIES_TOTAL); + histogram!(PRISMA_DATASOURCE_QUERIES_DURATION_HISTOGRAM_MS).record(elapsed); + counter!(PRISMA_DATASOURCE_QUERIES_TOTAL).increment(1); - // TODO: emit tracing event only when "debug" level query logs are enabled. // TODO prisma/team-orm#136: fix log subscription. - let query_string = builder.build(); // NOTE: `params` is a part of the interface for query logs. - let params: Vec = vec![]; - debug!(target: "mongodb_query_connector::query", item_type = "query", is_query = true, query = %query_string, params = ?params, duration_ms = elapsed); + debug!(target: "mongodb_query_connector::query", item_type = "query", is_query = true, query = %query_string, params = %"[]", duration_ms = elapsed); res } diff --git a/query-engine/connectors/mongodb-query-connector/src/root_queries/read.rs b/query-engine/connectors/mongodb-query-connector/src/root_queries/read.rs index b27fb527249..acec7c57ead 100644 --- a/query-engine/connectors/mongodb-query-connector/src/root_queries/read.rs +++ b/query-engine/connectors/mongodb-query-connector/src/root_queries/read.rs @@ -6,7 +6,6 @@ use crate::{ use mongodb::{bson::doc, options::FindOptions, ClientSession, Database}; use query_structure::*; use std::future::IntoFuture; -use tracing::{info_span, Instrument}; /// Finds a single record. Joins are not required at the moment because the selector is always a unique one. pub async fn get_single_record<'conn>( @@ -18,12 +17,6 @@ pub async fn get_single_record<'conn>( ) -> crate::Result> { let coll = database.collection(model.db_name()); - let span = info_span!( - "prisma:engine:db_query", - user_facing = true, - "db.statement" = &format_args!("db.{}.findOne(*)", coll.name()) - ); - let meta_mapping = output_meta::from_selected_fields(selected_fields); let query_arguments: QueryArguments = (model.clone(), filter.clone()).into(); let query = MongoReadQueryBuilder::from_args(query_arguments)? @@ -31,7 +24,7 @@ pub async fn get_single_record<'conn>( .with_virtual_fields(selected_fields.virtuals())? .build()?; - let docs = query.execute(coll, session).instrument(span).await?; + let docs = query.execute(coll, session).await?; if docs.is_empty() { Ok(None) @@ -60,12 +53,6 @@ pub async fn get_many_records<'conn>( ) -> crate::Result { let coll = database.collection(model.db_name()); - let span = info_span!( - "prisma:engine:db_query", - user_facing = true, - "db.statement" = &format_args!("db.{}.findMany(*)", coll.name()) - ); - let reverse_order = query_arguments.take.map(|t| t < 0).unwrap_or(false); let field_names: Vec<_> = selected_fields.db_names().collect(); @@ -81,7 +68,7 @@ pub async fn get_many_records<'conn>( .with_virtual_fields(selected_fields.virtuals())? .build()?; - let docs = query.execute(coll, session).instrument(span).await?; + let docs = query.execute(coll, session).await?; for doc in docs { let record = document_to_record(doc, &field_names, &meta_mapping)?; records.push(record) diff --git a/query-engine/connectors/mongodb-query-connector/src/root_queries/write.rs b/query-engine/connectors/mongodb-query-connector/src/root_queries/write.rs index 76eed77e186..2564b56e371 100644 --- a/query-engine/connectors/mongodb-query-connector/src/root_queries/write.rs +++ b/query-engine/connectors/mongodb-query-connector/src/root_queries/write.rs @@ -18,7 +18,6 @@ use mongodb::{ use query_structure::{Model, PrismaValue, SelectionResult}; use std::future::IntoFuture; use std::{collections::HashMap, convert::TryInto}; -use tracing::{info_span, Instrument}; use update::IntoUpdateDocumentExtension; /// Create a single record to the database resulting in a @@ -31,12 +30,6 @@ pub async fn create_record<'conn>( ) -> crate::Result { let coll = database.collection::(model.db_name()); - let span = info_span!( - "prisma:engine:db_query", - user_facing = true, - "db.statement" = &format_args!("db.{}.insertOne(*)", coll.name()) - ); - let id_field = pick_singular_id(model); // Fields to write to the document. @@ -66,9 +59,7 @@ pub async fn create_record<'conn>( } let query_builder = InsertOne::new(&doc, coll.name()); - let insert_result = observing(&query_builder, || coll.insert_one(&doc).session(session).into_future()) - .instrument(span) - .await?; + let insert_result = observing(&query_builder, || coll.insert_one(&doc).session(session).into_future()).await?; let id_value = value_from_bson(insert_result.inserted_id, &id_meta)?; Ok(SingleRecord { @@ -86,12 +77,6 @@ pub async fn create_records<'conn>( ) -> crate::Result { let coll = database.collection::(model.db_name()); - let span = info_span!( - "prisma:engine:db_query", - user_facing = true, - "db.statement" = &format_args!("db.{}.insertMany(*)", coll.name()) - ); - let num_records = args.len(); let fields: Vec<_> = model.fields().non_relational(); @@ -128,8 +113,7 @@ pub async fn create_records<'conn>( .with_options(options) .session(session) .into_future() - }) - .instrument(span); + }); match insert.await { Ok(insert_result) => Ok(insert_result.inserted_ids.len()), @@ -184,19 +168,13 @@ pub async fn update_records<'conn>( .collect::>>()? } else { let filter = MongoFilterVisitor::new(FilterPrefix::default(), false).visit(record_filter.filter)?; - find_ids(database, coll.clone(), session, model, filter).await? + find_ids(coll.clone(), session, model, filter).await? }; if ids.is_empty() { return Ok(vec![]); } - let span = info_span!( - "prisma:engine:db_query", - user_facing = true, - "db.statement" = &format_args!("db.{}.updateMany(*)", coll.name()) - ); - let filter = doc! { id_field.db_name(): { "$in": ids.clone() } }; let fields: Vec<_> = model .fields() @@ -222,7 +200,6 @@ pub async fn update_records<'conn>( .session(session) .into_future() }) - .instrument(span) .await?; // It's important we check the `matched_count` and not the `modified_count` here. @@ -266,25 +243,18 @@ pub async fn delete_records<'conn>( .collect::>>()? } else { let filter = MongoFilterVisitor::new(FilterPrefix::default(), false).visit(record_filter.filter)?; - find_ids(database, coll.clone(), session, model, filter).await? + find_ids(coll.clone(), session, model, filter).await? }; if ids.is_empty() { return Ok(0); } - let span = info_span!( - "prisma:engine:db_query", - user_facing = true, - "db.statement" = &format_args!("db.{}.deleteMany(*)", coll.name()) - ); - let filter = doc! { id_field.db_name(): { "$in": ids } }; let query_string_builder = DeleteMany::new(&filter, coll.name()); let delete_result = observing(&query_string_builder, || { coll.delete_many(filter.clone()).session(session).into_future() }) - .instrument(span) .await?; Ok(delete_result.deleted_count as usize) @@ -312,16 +282,10 @@ pub async fn delete_record<'conn>( "$expr": filter, }; - let span = info_span!( - "prisma:engine:db_query", - user_facing = true, - "db.statement" = &format_args!("db.{}.findAndModify(*)", coll.name()) - ); let query_string_builder = DeleteOne::new(&filter, coll.name()); let document = observing(&query_string_builder, || { coll.find_one_and_delete(filter.clone()).session(session).into_future() }) - .instrument(span) .await? .ok_or(MongoError::RecordDoesNotExist { cause: "Record to delete does not exist.".to_owned(), @@ -335,20 +299,11 @@ pub async fn delete_record<'conn>( /// Retrives document ids based on the given filter. async fn find_ids( - database: &Database, collection: Collection, session: &mut ClientSession, model: &Model, filter: MongoFilter, ) -> crate::Result> { - let coll = database.collection::(model.db_name()); - - let span = info_span!( - "prisma:engine:db_query", - user_facing = true, - "db.statement" = &format_args!("db.{}.findMany(*)", coll.name()) - ); - let id_field = model.primary_identifier(); let mut builder = MongoReadQueryBuilder::new(model.clone()); @@ -363,7 +318,7 @@ async fn find_ids( let builder = builder.with_model_projection(id_field)?; let query = builder.build()?; - let docs = query.execute(collection, session).instrument(span).await?; + let docs = query.execute(collection, session).await?; let ids = docs.into_iter().map(|mut doc| doc.remove("_id").unwrap()).collect(); Ok(ids) @@ -533,13 +488,6 @@ pub async fn query_raw<'conn>( inputs: HashMap, query_type: Option, ) -> crate::Result { - let db_statement = get_raw_db_statement(&query_type, &model, database); - let span = info_span!( - "prisma:engine:db_query", - user_facing = true, - "db.statement" = &&db_statement.as_str() - ); - let mongo_command = MongoCommand::from_raw_query(model, inputs, query_type)?; async { @@ -601,17 +549,5 @@ pub async fn query_raw<'conn>( Ok(RawJson::try_new(json_result)?) } - .instrument(span) .await } - -fn get_raw_db_statement(query_type: &Option, model: &Option<&Model>, database: &Database) -> String { - match (query_type.as_deref(), model) { - (Some("findRaw"), Some(m)) => format!("db.{}.findRaw(*)", database.collection::(m.db_name()).name()), - (Some("aggregateRaw"), Some(m)) => format!( - "db.{}.aggregateRaw(*)", - database.collection::(m.db_name()).name() - ), - _ => "db.runCommandRaw(*)".to_string(), - } -} diff --git a/query-engine/connectors/query-connector/Cargo.toml b/query-engine/connectors/query-connector/Cargo.toml index 52555d256ba..125be089549 100644 --- a/query-engine/connectors/query-connector/Cargo.toml +++ b/query-engine/connectors/query-connector/Cargo.toml @@ -7,7 +7,7 @@ version = "0.1.0" anyhow = "1.0" async-trait.workspace = true chrono.workspace = true -futures = "0.3" +futures.workspace = true itertools.workspace = true query-structure = {path = "../../query-structure"} prisma-value = {path = "../../../libs/prisma-value"} @@ -17,3 +17,4 @@ thiserror = "1.0" user-facing-errors = {path = "../../../libs/user-facing-errors", features = ["sql"]} uuid.workspace = true indexmap.workspace = true +telemetry = {path = "../../../libs/telemetry"} diff --git a/query-engine/connectors/query-connector/src/interface.rs b/query-engine/connectors/query-connector/src/interface.rs index cbdafcaeeee..05e8f1e1098 100644 --- a/query-engine/connectors/query-connector/src/interface.rs +++ b/query-engine/connectors/query-connector/src/interface.rs @@ -3,13 +3,17 @@ use async_trait::async_trait; use prisma_value::PrismaValue; use query_structure::{ast::FieldArity, *}; use std::collections::HashMap; +use telemetry::helpers::TraceParent; #[async_trait] pub trait Connector { /// Returns a connection to a data source. async fn get_connection(&self) -> crate::Result>; - /// Returns the name of the connector. + /// Returns the database system name, as per the OTEL spec. + /// Reference: + /// - https://opentelemetry.io/docs/specs/semconv/database/sql/ + /// - https://opentelemetry.io/docs/specs/semconv/database/mongodb/ fn name(&self) -> &'static str; /// Returns whether a connector should retry an entire transaction when that transaction failed during its execution @@ -194,7 +198,7 @@ pub trait ReadOperations { filter: &Filter, selected_fields: &FieldSelection, relation_load_strategy: RelationLoadStrategy, - trace_id: Option, + traceparent: Option, ) -> crate::Result>; /// Gets multiple records from the database. @@ -209,7 +213,7 @@ pub trait ReadOperations { query_arguments: QueryArguments, selected_fields: &FieldSelection, relation_load_strategy: RelationLoadStrategy, - trace_id: Option, + traceparent: Option, ) -> crate::Result; /// Retrieves pairs of IDs that belong together from a intermediate join @@ -223,7 +227,7 @@ pub trait ReadOperations { &mut self, from_field: &RelationFieldRef, from_record_ids: &[SelectionResult], - trace_id: Option, + traceparent: Option, ) -> crate::Result>; /// Aggregates records for a specific model based on the given selections. @@ -238,7 +242,7 @@ pub trait ReadOperations { selections: Vec, group_by: Vec, having: Option, - trace_id: Option, + traceparent: Option, ) -> crate::Result>; } @@ -250,7 +254,7 @@ pub trait WriteOperations { model: &Model, args: WriteArgs, selected_fields: FieldSelection, - trace_id: Option, + traceparent: Option, ) -> crate::Result; /// Inserts many records at once into the database. @@ -259,7 +263,7 @@ pub trait WriteOperations { model: &Model, args: Vec, skip_duplicates: bool, - trace_id: Option, + traceparent: Option, ) -> crate::Result; /// Inserts many records at once into the database and returns their @@ -272,7 +276,7 @@ pub trait WriteOperations { args: Vec, skip_duplicates: bool, selected_fields: FieldSelection, - trace_id: Option, + traceparent: Option, ) -> crate::Result; /// Update records in the `Model` with the given `WriteArgs` filtered by the @@ -282,7 +286,7 @@ pub trait WriteOperations { model: &Model, record_filter: RecordFilter, args: WriteArgs, - trace_id: Option, + traceparent: Option, ) -> crate::Result; /// Update record in the `Model` with the given `WriteArgs` filtered by the @@ -293,7 +297,7 @@ pub trait WriteOperations { record_filter: RecordFilter, args: WriteArgs, selected_fields: Option, - trace_id: Option, + traceparent: Option, ) -> crate::Result>; /// Native upsert @@ -301,7 +305,7 @@ pub trait WriteOperations { async fn native_upsert_record( &mut self, upsert: NativeUpsert, - trace_id: Option, + traceparent: Option, ) -> crate::Result; /// Delete records in the `Model` with the given `Filter`. @@ -309,7 +313,7 @@ pub trait WriteOperations { &mut self, model: &Model, record_filter: RecordFilter, - trace_id: Option, + traceparent: Option, ) -> crate::Result; /// Delete single record in the `Model` with the given `Filter` and returns @@ -321,7 +325,7 @@ pub trait WriteOperations { model: &Model, record_filter: RecordFilter, selected_fields: FieldSelection, - trace_id: Option, + traceparent: Option, ) -> crate::Result; // We plan to remove the methods below in the future. We want emulate them with the ones above. Those should suffice. @@ -332,7 +336,7 @@ pub trait WriteOperations { field: &RelationFieldRef, parent_id: &SelectionResult, child_ids: &[SelectionResult], - trace_id: Option, + traceparent: Option, ) -> crate::Result<()>; /// Disconnect the children from the parent (m2m relation only). @@ -341,7 +345,7 @@ pub trait WriteOperations { field: &RelationFieldRef, parent_id: &SelectionResult, child_ids: &[SelectionResult], - trace_id: Option, + traceparent: Option, ) -> crate::Result<()>; /// Execute the raw query in the database as-is. diff --git a/query-engine/connectors/sql-query-connector/Cargo.toml b/query-engine/connectors/sql-query-connector/Cargo.toml index 1826cac2681..2e3e0fe2fe5 100644 --- a/query-engine/connectors/sql-query-connector/Cargo.toml +++ b/query-engine/connectors/sql-query-connector/Cargo.toml @@ -38,7 +38,7 @@ psl.workspace = true anyhow = "1.0" async-trait.workspace = true bigdecimal = "0.3" -futures = "0.3" +futures.workspace = true itertools.workspace = true once_cell = "1.3" rand.workspace = true @@ -46,7 +46,7 @@ serde_json.workspace = true thiserror = "1.0" tokio = { version = "1", features = ["macros", "time"] } tracing = { workspace = true, features = ["log"] } -tracing-futures = "0.2" +tracing-futures.workspace = true uuid.workspace = true opentelemetry = { version = "0.17", features = ["tokio"] } tracing-opentelemetry = "0.17.3" @@ -66,6 +66,9 @@ path = "../../query-structure" [dependencies.prisma-value] path = "../../../libs/prisma-value" +[dependencies.telemetry] +path = "../../../libs/telemetry" + [dependencies.chrono] features = ["serde"] version = "0.4" diff --git a/query-engine/connectors/sql-query-connector/src/context.rs b/query-engine/connectors/sql-query-connector/src/context.rs index 5b2887451ef..2519439b13f 100644 --- a/query-engine/connectors/sql-query-connector/src/context.rs +++ b/query-engine/connectors/sql-query-connector/src/context.rs @@ -1,8 +1,9 @@ use quaint::prelude::ConnectionInfo; +use telemetry::helpers::TraceParent; pub(super) struct Context<'a> { connection_info: &'a ConnectionInfo, - pub(crate) trace_id: Option<&'a str>, + pub(crate) traceparent: Option, /// Maximum rows allowed at once for an insert query. /// None is unlimited. pub(crate) max_insert_rows: Option, @@ -12,13 +13,13 @@ pub(super) struct Context<'a> { } impl<'a> Context<'a> { - pub(crate) fn new(connection_info: &'a ConnectionInfo, trace_id: Option<&'a str>) -> Self { + pub(crate) fn new(connection_info: &'a ConnectionInfo, traceparent: Option) -> Self { let max_insert_rows = connection_info.max_insert_rows(); let max_bind_values = connection_info.max_bind_values(); Context { connection_info, - trace_id, + traceparent, max_insert_rows, max_bind_values: Some(max_bind_values), } diff --git a/query-engine/connectors/sql-query-connector/src/database/connection.rs b/query-engine/connectors/sql-query-connector/src/database/connection.rs index f928fcacdfa..1222f0425ea 100644 --- a/query-engine/connectors/sql-query-connector/src/database/connection.rs +++ b/query-engine/connectors/sql-query-connector/src/database/connection.rs @@ -15,6 +15,7 @@ use quaint::{ }; use query_structure::{prelude::*, Filter, QueryArguments, RelationLoadStrategy, SelectionResult}; use std::{collections::HashMap, str::FromStr}; +use telemetry::helpers::TraceParent; pub(crate) struct SqlConnection { inner: C, @@ -89,10 +90,10 @@ where filter: &Filter, selected_fields: &FieldSelection, relation_load_strategy: RelationLoadStrategy, - trace_id: Option, + traceparent: Option, ) -> connector::Result> { // [Composites] todo: FieldSelection -> ModelProjection conversion - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); catch( &self.connection_info, read::get_single_record( @@ -113,9 +114,9 @@ where query_arguments: QueryArguments, selected_fields: &FieldSelection, relation_load_strategy: RelationLoadStrategy, - trace_id: Option, + traceparent: Option, ) -> connector::Result { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); catch( &self.connection_info, read::get_many_records( @@ -134,9 +135,9 @@ where &mut self, from_field: &RelationFieldRef, from_record_ids: &[SelectionResult], - trace_id: Option, + traceparent: Option, ) -> connector::Result> { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); catch( &self.connection_info, read::get_related_m2m_record_ids(&self.inner, from_field, from_record_ids, &ctx), @@ -151,9 +152,9 @@ where selections: Vec, group_by: Vec, having: Option, - trace_id: Option, + traceparent: Option, ) -> connector::Result> { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); catch( &self.connection_info, read::aggregate(&self.inner, model, query_arguments, selections, group_by, having, &ctx), @@ -172,9 +173,9 @@ where model: &Model, args: WriteArgs, selected_fields: FieldSelection, - trace_id: Option, + traceparent: Option, ) -> connector::Result { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); catch( &self.connection_info, write::create_record( @@ -194,9 +195,9 @@ where model: &Model, args: Vec, skip_duplicates: bool, - trace_id: Option, + traceparent: Option, ) -> connector::Result { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); catch( &self.connection_info, write::create_records_count(&self.inner, model, args, skip_duplicates, &ctx), @@ -210,9 +211,9 @@ where args: Vec, skip_duplicates: bool, selected_fields: FieldSelection, - trace_id: Option, + traceparent: Option, ) -> connector::Result { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); catch( &self.connection_info, write::create_records_returning(&self.inner, model, args, skip_duplicates, selected_fields, &ctx), @@ -225,9 +226,9 @@ where model: &Model, record_filter: RecordFilter, args: WriteArgs, - trace_id: Option, + traceparent: Option, ) -> connector::Result { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); catch( &self.connection_info, write::update_records(&self.inner, model, record_filter, args, &ctx), @@ -241,9 +242,9 @@ where record_filter: RecordFilter, args: WriteArgs, selected_fields: Option, - trace_id: Option, + traceparent: Option, ) -> connector::Result> { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); catch( &self.connection_info, write::update_record(&self.inner, model, record_filter, args, selected_fields, &ctx), @@ -255,9 +256,9 @@ where &mut self, model: &Model, record_filter: RecordFilter, - trace_id: Option, + traceparent: Option, ) -> connector::Result { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); catch( &self.connection_info, write::delete_records(&self.inner, model, record_filter, &ctx), @@ -270,9 +271,9 @@ where model: &Model, record_filter: RecordFilter, selected_fields: FieldSelection, - trace_id: Option, + traceparent: Option, ) -> connector::Result { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); catch( &self.connection_info, write::delete_record(&self.inner, model, record_filter, selected_fields, &ctx), @@ -283,9 +284,9 @@ where async fn native_upsert_record( &mut self, upsert: connector_interface::NativeUpsert, - trace_id: Option, + traceparent: Option, ) -> connector::Result { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); catch(&self.connection_info, upsert::native_upsert(&self.inner, upsert, &ctx)).await } @@ -294,9 +295,9 @@ where field: &RelationFieldRef, parent_id: &SelectionResult, child_ids: &[SelectionResult], - trace_id: Option, + traceparent: Option, ) -> connector::Result<()> { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); catch( &self.connection_info, write::m2m_connect(&self.inner, field, parent_id, child_ids, &ctx), @@ -309,9 +310,9 @@ where field: &RelationFieldRef, parent_id: &SelectionResult, child_ids: &[SelectionResult], - trace_id: Option, + traceparent: Option, ) -> connector::Result<()> { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); catch( &self.connection_info, write::m2m_disconnect(&self.inner, field, parent_id, child_ids, &ctx), diff --git a/query-engine/connectors/sql-query-connector/src/database/js.rs b/query-engine/connectors/sql-query-connector/src/database/js.rs index d771eb51e40..40ca0caa002 100644 --- a/query-engine/connectors/sql-query-connector/src/database/js.rs +++ b/query-engine/connectors/sql-query-connector/src/database/js.rs @@ -48,7 +48,16 @@ impl Connector for Js { } fn name(&self) -> &'static str { - "js" + match self.connection_info.sql_family() { + #[cfg(feature = "postgresql")] + SqlFamily::Postgres => "postgresql", + #[cfg(feature = "mysql")] + SqlFamily::Mysql => "mysql", + #[cfg(feature = "sqlite")] + SqlFamily::Sqlite => "sqlite", + #[cfg(feature = "mssql")] + SqlFamily::Mssql => "mssql", + } } fn should_retry_on_transient_error(&self) -> bool { diff --git a/query-engine/connectors/sql-query-connector/src/database/native/postgresql.rs b/query-engine/connectors/sql-query-connector/src/database/native/postgresql.rs index 701fa9a1a0c..9e59e4f232c 100644 --- a/query-engine/connectors/sql-query-connector/src/database/native/postgresql.rs +++ b/query-engine/connectors/sql-query-connector/src/database/native/postgresql.rs @@ -13,6 +13,7 @@ pub struct PostgreSql { pool: Quaint, connection_info: ConnectionInfo, features: psl::PreviewFeatures, + flavour: PostgresFlavour, } impl PostgreSql { @@ -60,6 +61,7 @@ impl FromSource for PostgreSql { pool, connection_info, features, + flavour, }) } } @@ -76,7 +78,10 @@ impl Connector for PostgreSql { } fn name(&self) -> &'static str { - "postgres" + match self.flavour { + PostgresFlavour::Postgres | PostgresFlavour::Unknown => "postgresql", + PostgresFlavour::Cockroach => "cockroachdb", + } } fn should_retry_on_transient_error(&self) -> bool { diff --git a/query-engine/connectors/sql-query-connector/src/database/operations/write.rs b/query-engine/connectors/sql-query-connector/src/database/operations/write.rs index 33cb0cfa3ac..137bff50ca5 100644 --- a/query-engine/connectors/sql-query-connector/src/database/operations/write.rs +++ b/query-engine/connectors/sql-query-connector/src/database/operations/write.rs @@ -72,7 +72,7 @@ async fn generate_id( // db generate values only if needed if need_select { - let pk_select = id_select.add_trace_id(ctx.trace_id); + let pk_select = id_select.add_traceparent(ctx.traceparent); let pk_result = conn.query(pk_select.into()).await?; let result = try_convert(&(id_field.into()), pk_result)?; diff --git a/query-engine/connectors/sql-query-connector/src/database/transaction.rs b/query-engine/connectors/sql-query-connector/src/database/transaction.rs index 263c541f6b4..387b18f63ee 100644 --- a/query-engine/connectors/sql-query-connector/src/database/transaction.rs +++ b/query-engine/connectors/sql-query-connector/src/database/transaction.rs @@ -10,6 +10,7 @@ use prisma_value::PrismaValue; use quaint::prelude::ConnectionInfo; use query_structure::{prelude::*, Filter, QueryArguments, RelationLoadStrategy, SelectionResult}; use std::collections::HashMap; +use telemetry::helpers::TraceParent; pub struct SqlConnectorTransaction<'tx> { inner: Box, @@ -73,9 +74,9 @@ impl<'tx> ReadOperations for SqlConnectorTransaction<'tx> { filter: &Filter, selected_fields: &FieldSelection, relation_load_strategy: RelationLoadStrategy, - trace_id: Option, + traceparent: Option, ) -> connector::Result> { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); catch( &self.connection_info, read::get_single_record( @@ -96,9 +97,9 @@ impl<'tx> ReadOperations for SqlConnectorTransaction<'tx> { query_arguments: QueryArguments, selected_fields: &FieldSelection, relation_load_strategy: RelationLoadStrategy, - trace_id: Option, + traceparent: Option, ) -> connector::Result { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); catch( &self.connection_info, read::get_many_records( @@ -117,9 +118,9 @@ impl<'tx> ReadOperations for SqlConnectorTransaction<'tx> { &mut self, from_field: &RelationFieldRef, from_record_ids: &[SelectionResult], - trace_id: Option, + traceparent: Option, ) -> connector::Result> { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); catch(&self.connection_info, async { read::get_related_m2m_record_ids(self.inner.as_queryable(), from_field, from_record_ids, &ctx).await }) @@ -133,9 +134,9 @@ impl<'tx> ReadOperations for SqlConnectorTransaction<'tx> { selections: Vec, group_by: Vec, having: Option, - trace_id: Option, + traceparent: Option, ) -> connector::Result> { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); catch( &self.connection_info, read::aggregate( @@ -159,9 +160,9 @@ impl<'tx> WriteOperations for SqlConnectorTransaction<'tx> { model: &Model, args: WriteArgs, selected_fields: FieldSelection, - trace_id: Option, + traceparent: Option, ) -> connector::Result { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); catch( &self.connection_info, write::create_record( @@ -181,9 +182,9 @@ impl<'tx> WriteOperations for SqlConnectorTransaction<'tx> { model: &Model, args: Vec, skip_duplicates: bool, - trace_id: Option, + traceparent: Option, ) -> connector::Result { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); catch( &self.connection_info, write::create_records_count(self.inner.as_queryable(), model, args, skip_duplicates, &ctx), @@ -197,9 +198,9 @@ impl<'tx> WriteOperations for SqlConnectorTransaction<'tx> { args: Vec, skip_duplicates: bool, selected_fields: FieldSelection, - trace_id: Option, + traceparent: Option, ) -> connector::Result { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); catch( &self.connection_info, write::create_records_returning( @@ -219,9 +220,9 @@ impl<'tx> WriteOperations for SqlConnectorTransaction<'tx> { model: &Model, record_filter: RecordFilter, args: WriteArgs, - trace_id: Option, + traceparent: Option, ) -> connector::Result { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); catch( &self.connection_info, write::update_records(self.inner.as_queryable(), model, record_filter, args, &ctx), @@ -235,9 +236,9 @@ impl<'tx> WriteOperations for SqlConnectorTransaction<'tx> { record_filter: RecordFilter, args: WriteArgs, selected_fields: Option, - trace_id: Option, + traceparent: Option, ) -> connector::Result> { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); catch( &self.connection_info, write::update_record( @@ -256,10 +257,10 @@ impl<'tx> WriteOperations for SqlConnectorTransaction<'tx> { &mut self, model: &Model, record_filter: RecordFilter, - trace_id: Option, + traceparent: Option, ) -> connector::Result { catch(&self.connection_info, async { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); write::delete_records(self.inner.as_queryable(), model, record_filter, &ctx).await }) .await @@ -270,9 +271,9 @@ impl<'tx> WriteOperations for SqlConnectorTransaction<'tx> { model: &Model, record_filter: RecordFilter, selected_fields: FieldSelection, - trace_id: Option, + traceparent: Option, ) -> connector::Result { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); catch( &self.connection_info, write::delete_record(self.inner.as_queryable(), model, record_filter, selected_fields, &ctx), @@ -283,10 +284,10 @@ impl<'tx> WriteOperations for SqlConnectorTransaction<'tx> { async fn native_upsert_record( &mut self, upsert: connector_interface::NativeUpsert, - trace_id: Option, + traceparent: Option, ) -> connector::Result { catch(&self.connection_info, async { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); upsert::native_upsert(self.inner.as_queryable(), upsert, &ctx).await }) .await @@ -297,10 +298,10 @@ impl<'tx> WriteOperations for SqlConnectorTransaction<'tx> { field: &RelationFieldRef, parent_id: &SelectionResult, child_ids: &[SelectionResult], - trace_id: Option, + traceparent: Option, ) -> connector::Result<()> { catch(&self.connection_info, async { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); write::m2m_connect(self.inner.as_queryable(), field, parent_id, child_ids, &ctx).await }) .await @@ -311,9 +312,9 @@ impl<'tx> WriteOperations for SqlConnectorTransaction<'tx> { field: &RelationFieldRef, parent_id: &SelectionResult, child_ids: &[SelectionResult], - trace_id: Option, + traceparent: Option, ) -> connector::Result<()> { - let ctx = Context::new(&self.connection_info, trace_id.as_deref()); + let ctx = Context::new(&self.connection_info, traceparent); catch( &self.connection_info, write::m2m_disconnect(self.inner.as_queryable(), field, parent_id, child_ids, &ctx), diff --git a/query-engine/connectors/sql-query-connector/src/query_builder/read.rs b/query-engine/connectors/sql-query-connector/src/query_builder/read.rs index 6a0572ecc0d..84323e0f52b 100644 --- a/query-engine/connectors/sql-query-connector/src/query_builder/read.rs +++ b/query-engine/connectors/sql-query-connector/src/query_builder/read.rs @@ -100,7 +100,7 @@ impl SelectDefinition for QueryArguments { .so_that(conditions) .offset(skip as usize) .append_trace(&Span::current()) - .add_trace_id(ctx.trace_id); + .add_traceparent(ctx.traceparent); let select_ast = order_by_definitions .iter() @@ -137,7 +137,7 @@ where let (select, additional_selection_set) = query.into_select(model, virtual_selections, ctx); let select = columns.fold(select, |acc, col| acc.column(col)); - let select = select.append_trace(&Span::current()).add_trace_id(ctx.trace_id); + let select = select.append_trace(&Span::current()).add_traceparent(ctx.traceparent); additional_selection_set .into_iter() @@ -183,7 +183,7 @@ pub(crate) fn aggregate( selections.iter().fold( Select::from_table(sub_table) .append_trace(&Span::current()) - .add_trace_id(ctx.trace_id), + .add_traceparent(ctx.traceparent), |select, next_op| match next_op { AggregationSelection::Field(field) => select.column( Column::from(field.db_name().to_owned()) @@ -269,7 +269,9 @@ pub(crate) fn group_by_aggregate( }); let grouped = group_by.into_iter().fold( - select_query.append_trace(&Span::current()).add_trace_id(ctx.trace_id), + select_query + .append_trace(&Span::current()) + .add_traceparent(ctx.traceparent), |query, field| query.group_by(field.as_column(ctx)), ); diff --git a/query-engine/connectors/sql-query-connector/src/query_builder/select/mod.rs b/query-engine/connectors/sql-query-connector/src/query_builder/select/mod.rs index c7c49ed3bc3..9c0139c6cd8 100644 --- a/query-engine/connectors/sql-query-connector/src/query_builder/select/mod.rs +++ b/query-engine/connectors/sql-query-connector/src/query_builder/select/mod.rs @@ -648,20 +648,13 @@ fn extract_filter_scalars(f: &Filter) -> Vec { Filter::Scalar(x) => x.scalar_fields().into_iter().map(ToOwned::to_owned).collect(), Filter::ScalarList(x) => vec![x.field.clone()], Filter::OneRelationIsNull(x) => join_fields(&x.field), - Filter::Relation(x) => vec![join_fields(&x.field), extract_filter_scalars(&x.nested_filter)] - .into_iter() - .flatten() - .collect(), + Filter::Relation(x) => join_fields(&x.field), _ => Vec::new(), } } fn join_fields(rf: &RelationField) -> Vec { - if rf.is_inlined_on_enclosing_model() { - rf.scalar_fields() - } else { - rf.related_field().referenced_fields() - } + rf.linking_fields().as_scalar_fields().unwrap_or_default() } fn join_alias_name(rf: &RelationField) -> String { diff --git a/query-engine/connectors/sql-query-connector/src/query_builder/write.rs b/query-engine/connectors/sql-query-connector/src/query_builder/write.rs index c089f0834dc..c07e3600e14 100644 --- a/query-engine/connectors/sql-query-connector/src/query_builder/write.rs +++ b/query-engine/connectors/sql-query-connector/src/query_builder/write.rs @@ -34,7 +34,7 @@ pub(crate) fn create_record( Insert::from(insert) .returning(selected_fields.as_columns(ctx).map(|c| c.set_is_selected(true))) .append_trace(&Span::current()) - .add_trace_id(ctx.trace_id) + .add_traceparent(ctx.traceparent) } /// `INSERT` new records into the database based on the given write arguments, @@ -84,7 +84,7 @@ pub(crate) fn create_records_nonempty( let insert = Insert::multi_into(model.as_table(ctx), columns); let insert = values.into_iter().fold(insert, |stmt, values| stmt.values(values)); let insert: Insert = insert.into(); - let mut insert = insert.append_trace(&Span::current()).add_trace_id(ctx.trace_id); + let mut insert = insert.append_trace(&Span::current()).add_traceparent(ctx.traceparent); if let Some(selected_fields) = selected_fields { insert = insert.returning(projection_into_columns(selected_fields, ctx)); @@ -105,7 +105,7 @@ pub(crate) fn create_records_empty( ctx: &Context<'_>, ) -> Insert<'static> { let insert: Insert<'static> = Insert::single_into(model.as_table(ctx)).into(); - let mut insert = insert.append_trace(&Span::current()).add_trace_id(ctx.trace_id); + let mut insert = insert.append_trace(&Span::current()).add_traceparent(ctx.traceparent); if let Some(selected_fields) = selected_fields { insert = insert.returning(projection_into_columns(selected_fields, ctx)); @@ -175,7 +175,7 @@ pub(crate) fn build_update_and_set_query( acc.set(name, value) }); - let query = query.append_trace(&Span::current()).add_trace_id(ctx.trace_id); + let query = query.append_trace(&Span::current()).add_traceparent(ctx.traceparent); let query = if let Some(selected_fields) = selected_fields { query.returning(selected_fields.as_columns(ctx).map(|c| c.set_is_selected(true))) @@ -222,7 +222,7 @@ pub(crate) fn delete_returning( .so_that(filter) .returning(projection_into_columns(selected_fields, ctx)) .append_trace(&Span::current()) - .add_trace_id(ctx.trace_id) + .add_traceparent(ctx.traceparent) .into() } @@ -234,7 +234,7 @@ pub(crate) fn delete_many_from_filter( Delete::from_table(model.as_table(ctx)) .so_that(filter_condition) .append_trace(&Span::current()) - .add_trace_id(ctx.trace_id) + .add_traceparent(ctx.traceparent) .into() } @@ -301,5 +301,5 @@ pub(crate) fn delete_relation_table_records( Delete::from_table(relation.as_table(ctx)) .so_that(parent_id_criteria.and(child_id_criteria)) .append_trace(&Span::current()) - .add_trace_id(ctx.trace_id) + .add_traceparent(ctx.traceparent) } diff --git a/query-engine/connectors/sql-query-connector/src/query_ext.rs b/query-engine/connectors/sql-query-connector/src/query_ext.rs index c0d511f9e6d..a843f4a1525 100644 --- a/query-engine/connectors/sql-query-connector/src/query_ext.rs +++ b/query-engine/connectors/sql-query-connector/src/query_ext.rs @@ -25,18 +25,18 @@ impl QueryExt for Q { idents: &[ColumnMetadata<'_>], ctx: &Context<'_>, ) -> crate::Result> { - let span = info_span!("filter read query"); + let span = info_span!("prisma:engine:filter_read_query"); let otel_ctx = span.context(); let span_ref = otel_ctx.span(); let span_ctx = span_ref.span_context(); - let q = match (q, ctx.trace_id) { + let q = match (q, ctx.traceparent) { (Query::Select(x), _) if span_ctx.trace_flags() == TraceFlags::SAMPLED => { Query::Select(Box::from(x.comment(trace_parent_to_string(span_ctx)))) } // This is part of the required changes to pass a traceid - (Query::Select(x), trace_id) => Query::Select(Box::from(x.add_trace_id(trace_id))), + (Query::Select(x), traceparent) => Query::Select(Box::from(x.add_traceparent(traceparent))), (q, _) => q, }; @@ -119,7 +119,7 @@ impl QueryExt for Q { let select = Select::from_table(model.as_table(ctx)) .columns(id_cols) .append_trace(&Span::current()) - .add_trace_id(ctx.trace_id) + .add_traceparent(ctx.traceparent) .so_that(condition); self.select_ids(select, model_id, ctx).await diff --git a/query-engine/connectors/sql-query-connector/src/sql_trace.rs b/query-engine/connectors/sql-query-connector/src/sql_trace.rs index bffaf117431..4fa88a64d2e 100644 --- a/query-engine/connectors/sql-query-connector/src/sql_trace.rs +++ b/query-engine/connectors/sql-query-connector/src/sql_trace.rs @@ -1,5 +1,6 @@ use opentelemetry::trace::{SpanContext, TraceContextExt, TraceFlags}; use quaint::ast::{Delete, Insert, Select, Update}; +use telemetry::helpers::TraceParent; use tracing::Span; use tracing_opentelemetry::OpenTelemetrySpanExt; @@ -8,12 +9,12 @@ pub fn trace_parent_to_string(context: &SpanContext) -> String { let span_id = context.span_id(); // see https://www.w3.org/TR/trace-context/#traceparent-header-field-values - format!("traceparent='00-{trace_id:032x}-{span_id:032x}-01'") + format!("traceparent='00-{trace_id:032x}-{span_id:016x}-01'") } pub trait SqlTraceComment: Sized { fn append_trace(self, span: &Span) -> Self; - fn add_trace_id(self, trace_id: Option<&str>) -> Self; + fn add_traceparent(self, traceparent: Option) -> Self; } macro_rules! sql_trace { @@ -30,14 +31,15 @@ macro_rules! sql_trace { self } } + // Temporary method to pass the traceid in an operation - fn add_trace_id(self, trace_id: Option<&str>) -> Self { - if let Some(traceparent) = trace_id { - if should_sample(&traceparent) { - self.comment(format!("traceparent='{}'", traceparent)) - } else { - self - } + fn add_traceparent(self, traceparent: Option) -> Self { + let Some(traceparent) = traceparent else { + return self; + }; + + if traceparent.sampled() { + self.comment(format!("traceparent='{traceparent}'")) } else { self } @@ -46,10 +48,6 @@ macro_rules! sql_trace { }; } -fn should_sample(traceparent: &str) -> bool { - traceparent.split('-').count() == 4 && traceparent.ends_with("-01") -} - sql_trace!(Insert<'_>); sql_trace!(Update<'_>); diff --git a/query-engine/core/Cargo.toml b/query-engine/core/Cargo.toml index a1aa9a326f7..6005b091f55 100644 --- a/query-engine/core/Cargo.toml +++ b/query-engine/core/Cargo.toml @@ -4,7 +4,7 @@ name = "query-core" version = "0.1.0" [features] -metrics = ["query-engine-metrics"] +metrics = ["prisma-metrics"] graphql-protocol = [] [dependencies] @@ -15,7 +15,7 @@ connection-string.workspace = true connector = { path = "../connectors/query-connector", package = "query-connector" } crossbeam-channel = "0.5.6" psl.workspace = true -futures = "0.3" +futures.workspace = true indexmap.workspace = true itertools.workspace = true once_cell = "1" @@ -24,13 +24,13 @@ query-structure = { path = "../query-structure", features = [ "default_generators", ] } opentelemetry = { version = "0.17.0", features = ["rt-tokio", "serialize"] } -query-engine-metrics = { path = "../metrics", optional = true } +prisma-metrics = { path = "../../libs/metrics", optional = true } serde.workspace = true serde_json.workspace = true thiserror = "1.0" tokio = { version = "1", features = ["macros", "time"] } tracing = { workspace = true, features = ["attributes"] } -tracing-futures = "0.2" +tracing-futures.workspace = true tracing-subscriber = { version = "0.3", features = ["env-filter"] } tracing-opentelemetry = "0.17.4" user-facing-errors = { path = "../../libs/user-facing-errors" } @@ -38,5 +38,7 @@ uuid.workspace = true cuid = { git = "https://github.com/prisma/cuid-rust", branch = "wasm32-support" } schema = { path = "../schema" } crosstarget-utils = { path = "../../libs/crosstarget-utils" } +telemetry = { path = "../../libs/telemetry" } lru = "0.7.7" enumflags2.workspace = true +derive_more.workspace = true diff --git a/query-engine/core/src/error.rs b/query-engine/core/src/error.rs index 3a3803bf0d6..b067a325a4a 100644 --- a/query-engine/core/src/error.rs +++ b/query-engine/core/src/error.rs @@ -4,6 +4,8 @@ use query_structure::DomainError; use thiserror::Error; use user_facing_errors::UnknownError; +use crate::response_ir::{Item, Map}; + #[derive(Debug, Error)] #[error( "Error converting field \"{field}\" of expected non-nullable type \"{expected_type}\", found incompatible value of \"{found}\"." @@ -62,6 +64,9 @@ pub enum CoreError { #[error("Error in batch request {request_idx}: {error}")] BatchError { request_idx: usize, error: Box }, + + #[error("Query timed out")] + QueryTimeout, } impl CoreError { @@ -227,3 +232,27 @@ impl From for user_facing_errors::Error { } } } + +#[derive(Debug, serde::Serialize, PartialEq)] +pub struct ExtendedUserFacingError { + #[serde(flatten)] + user_facing_error: user_facing_errors::Error, + + #[serde(skip_serializing_if = "indexmap::IndexMap::is_empty")] + extensions: Map, +} + +impl ExtendedUserFacingError { + pub fn set_extension(&mut self, key: String, val: serde_json::Value) { + self.extensions.entry(key).or_insert(Item::Json(val)); + } +} + +impl From for ExtendedUserFacingError { + fn from(error: CoreError) -> Self { + ExtendedUserFacingError { + user_facing_error: error.into(), + extensions: Default::default(), + } + } +} diff --git a/query-engine/core/src/executor/execute_operation.rs b/query-engine/core/src/executor/execute_operation.rs index 6ef445d8364..986741182b9 100644 --- a/query-engine/core/src/executor/execute_operation.rs +++ b/query-engine/core/src/executor/execute_operation.rs @@ -10,13 +10,17 @@ use connector::{Connection, ConnectionLike, Connector}; use crosstarget_utils::time::ElapsedTimeCounter; use futures::future; +#[cfg(not(feature = "metrics"))] +use crate::metrics::MetricsInstrumentationStub; #[cfg(feature = "metrics")] -use query_engine_metrics::{ - histogram, increment_counter, metrics, PRISMA_CLIENT_QUERIES_DURATION_HISTOGRAM_MS, PRISMA_CLIENT_QUERIES_TOTAL, +use prisma_metrics::{ + counter, histogram, WithMetricsInstrumentation, PRISMA_CLIENT_QUERIES_DURATION_HISTOGRAM_MS, + PRISMA_CLIENT_QUERIES_TOTAL, }; use schema::{QuerySchema, QuerySchemaRef}; use std::time::Duration; +use telemetry::helpers::TraceParent; use tracing::Instrument; use tracing_futures::WithSubscriber; @@ -24,18 +28,15 @@ pub async fn execute_single_operation( query_schema: QuerySchemaRef, conn: &mut dyn ConnectionLike, operation: &Operation, - trace_id: Option, + traceparent: Option, ) -> crate::Result { let operation_timer = ElapsedTimeCounter::start(); let (graph, serializer) = build_graph(&query_schema, operation.clone())?; - let result = execute_on(conn, graph, serializer, query_schema.as_ref(), trace_id).await; + let result = execute_on(conn, graph, serializer, query_schema.as_ref(), traceparent).await; #[cfg(feature = "metrics")] - histogram!( - PRISMA_CLIENT_QUERIES_DURATION_HISTOGRAM_MS, - operation_timer.elapsed_time() - ); + histogram!(PRISMA_CLIENT_QUERIES_DURATION_HISTOGRAM_MS).record(operation_timer.elapsed_time()); result } @@ -44,7 +45,7 @@ pub async fn execute_many_operations( query_schema: QuerySchemaRef, conn: &mut dyn ConnectionLike, operations: &[Operation], - trace_id: Option, + traceparent: Option, ) -> crate::Result>> { let queries = operations .iter() @@ -55,13 +56,10 @@ pub async fn execute_many_operations( for (i, (graph, serializer)) in queries.into_iter().enumerate() { let operation_timer = ElapsedTimeCounter::start(); - let result = execute_on(conn, graph, serializer, query_schema.as_ref(), trace_id.clone()).await; + let result = execute_on(conn, graph, serializer, query_schema.as_ref(), traceparent).await; #[cfg(feature = "metrics")] - histogram!( - PRISMA_CLIENT_QUERIES_DURATION_HISTOGRAM_MS, - operation_timer.elapsed_time() - ); + histogram!(PRISMA_CLIENT_QUERIES_DURATION_HISTOGRAM_MS).record(operation_timer.elapsed_time()); match result { Ok(result) => results.push(Ok(result)), @@ -81,13 +79,13 @@ pub async fn execute_single_self_contained( connector: &C, query_schema: QuerySchemaRef, operation: Operation, - trace_id: Option, + traceparent: Option, force_transactions: bool, ) -> crate::Result { let conn_span = info_span!( "prisma:engine:connection", user_facing = true, - "db.type" = connector.name() + "db.system" = connector.name(), ); let conn = connector.get_connection().instrument(conn_span).await?; @@ -97,7 +95,7 @@ pub async fn execute_single_self_contained( operation, force_transactions, connector.should_retry_on_transient_error(), - trace_id, + traceparent, ) .await } @@ -106,21 +104,20 @@ pub async fn execute_many_self_contained( connector: &C, query_schema: QuerySchemaRef, operations: &[Operation], - trace_id: Option, + traceparent: Option, force_transactions: bool, engine_protocol: EngineProtocol, ) -> crate::Result>> { let mut futures = Vec::with_capacity(operations.len()); - let dispatcher = crate::get_current_dispatcher(); for op in operations { #[cfg(feature = "metrics")] - increment_counter!(PRISMA_CLIENT_QUERIES_TOTAL); + counter!(PRISMA_CLIENT_QUERIES_TOTAL).increment(1); let conn_span = info_span!( "prisma:engine:connection", user_facing = true, - "db.type" = connector.name(), + "db.system" = connector.name(), ); let conn = connector.get_connection().instrument(conn_span).await?; @@ -133,10 +130,11 @@ pub async fn execute_many_self_contained( op.clone(), force_transactions, connector.should_retry_on_transient_error(), - trace_id.clone(), + traceparent, ), ) - .with_subscriber(dispatcher.clone()), + .with_current_subscriber() + .with_current_recorder(), )); } @@ -156,7 +154,7 @@ async fn execute_self_contained( operation: Operation, force_transactions: bool, retry_on_transient_error: bool, - trace_id: Option, + traceparent: Option, ) -> crate::Result { let operation_timer = ElapsedTimeCounter::start(); let result = if retry_on_transient_error { @@ -166,20 +164,18 @@ async fn execute_self_contained( operation, force_transactions, ElapsedTimeCounter::start(), - trace_id, + traceparent, ) .await } else { let (graph, serializer) = build_graph(&query_schema, operation)?; - execute_self_contained_without_retry(conn, graph, serializer, force_transactions, &query_schema, trace_id).await + execute_self_contained_without_retry(conn, graph, serializer, force_transactions, &query_schema, traceparent) + .await }; #[cfg(feature = "metrics")] - histogram!( - PRISMA_CLIENT_QUERIES_DURATION_HISTOGRAM_MS, - operation_timer.elapsed_time() - ); + histogram!(PRISMA_CLIENT_QUERIES_DURATION_HISTOGRAM_MS).record(operation_timer.elapsed_time()); result } @@ -190,13 +186,13 @@ async fn execute_self_contained_without_retry<'a>( serializer: IrSerializer<'a>, force_transactions: bool, query_schema: &'a QuerySchema, - trace_id: Option, + traceparent: Option, ) -> crate::Result { if force_transactions || graph.needs_transaction() { - return execute_in_tx(&mut conn, graph, serializer, query_schema, trace_id).await; + return execute_in_tx(&mut conn, graph, serializer, query_schema, traceparent).await; } - execute_on(conn.as_connection_like(), graph, serializer, query_schema, trace_id).await + execute_on(conn.as_connection_like(), graph, serializer, query_schema, traceparent).await } // As suggested by the MongoDB documentation @@ -212,12 +208,12 @@ async fn execute_self_contained_with_retry( operation: Operation, force_transactions: bool, retry_timeout: ElapsedTimeCounter, - trace_id: Option, + traceparent: Option, ) -> crate::Result { let (graph, serializer) = build_graph(&query_schema, operation.clone())?; if force_transactions || graph.needs_transaction() { - let res = execute_in_tx(conn, graph, serializer, query_schema.as_ref(), trace_id.clone()).await; + let res = execute_in_tx(conn, graph, serializer, query_schema.as_ref(), traceparent).await; if !is_transient_error(&res) { return res; @@ -225,7 +221,7 @@ async fn execute_self_contained_with_retry( loop { let (graph, serializer) = build_graph(&query_schema, operation.clone())?; - let res = execute_in_tx(conn, graph, serializer, query_schema.as_ref(), trace_id.clone()).await; + let res = execute_in_tx(conn, graph, serializer, query_schema.as_ref(), traceparent).await; if is_transient_error(&res) && retry_timeout.elapsed_time() < MAX_TX_TIMEOUT_RETRY_LIMIT { crosstarget_utils::time::sleep(TX_RETRY_BACKOFF).await; @@ -240,7 +236,7 @@ async fn execute_self_contained_with_retry( graph, serializer, query_schema.as_ref(), - trace_id, + traceparent, ) .await } @@ -251,17 +247,10 @@ async fn execute_in_tx<'a>( graph: QueryGraph, serializer: IrSerializer<'a>, query_schema: &'a QuerySchema, - trace_id: Option, + traceparent: Option, ) -> crate::Result { let mut tx = conn.start_transaction(None).await?; - let result = execute_on( - tx.as_connection_like(), - graph, - serializer, - query_schema, - trace_id.clone(), - ) - .await; + let result = execute_on(tx.as_connection_like(), graph, serializer, query_schema, traceparent).await; if result.is_ok() { tx.commit().await?; @@ -278,14 +267,14 @@ async fn execute_on<'a>( graph: QueryGraph, serializer: IrSerializer<'a>, query_schema: &'a QuerySchema, - trace_id: Option, + traceparent: Option, ) -> crate::Result { #[cfg(feature = "metrics")] - increment_counter!(PRISMA_CLIENT_QUERIES_TOTAL); + counter!(PRISMA_CLIENT_QUERIES_TOTAL).increment(1); let interpreter = QueryInterpreter::new(conn); QueryPipeline::new(graph, interpreter, serializer) - .execute(query_schema, trace_id) + .execute(query_schema, traceparent) .await } diff --git a/query-engine/core/src/executor/interpreting_executor.rs b/query-engine/core/src/executor/interpreting_executor.rs index 0408361b766..2e391461c71 100644 --- a/query-engine/core/src/executor/interpreting_executor.rs +++ b/query-engine/core/src/executor/interpreting_executor.rs @@ -1,13 +1,15 @@ use super::execute_operation::{execute_many_operations, execute_many_self_contained, execute_single_self_contained}; use super::request_context; +use crate::ItxManager; use crate::{ protocol::EngineProtocol, BatchDocumentTransaction, CoreError, Operation, QueryExecutor, ResponseData, - TransactionActorManager, TransactionError, TransactionManager, TransactionOptions, TxId, + TransactionError, TransactionManager, TransactionOptions, TxId, }; use async_trait::async_trait; use connector::Connector; use schema::QuerySchemaRef; +use telemetry::helpers::TraceParent; use tokio::time::Duration; use tracing_futures::Instrument; @@ -16,7 +18,7 @@ pub struct InterpretingExecutor { /// The loaded connector connector: C, - itx_manager: TransactionActorManager, + itx_manager: ItxManager, /// Flag that forces individual operations to run in a transaction. /// Does _not_ force batches to use transactions. @@ -31,7 +33,7 @@ where InterpretingExecutor { connector, force_transactions, - itx_manager: TransactionActorManager::new(), + itx_manager: ItxManager::new(), } } } @@ -48,25 +50,24 @@ where tx_id: Option, operation: Operation, query_schema: QuerySchemaRef, - trace_id: Option, + traceparent: Option, engine_protocol: EngineProtocol, ) -> crate::Result { - // If a Tx id is provided, execute on that one. Else execute normally as a single operation. - if let Some(tx_id) = tx_id { - self.itx_manager.execute(&tx_id, operation, trace_id).await - } else { - request_context::with_request_context(engine_protocol, async move { + request_context::with_request_context(engine_protocol, async move { + if let Some(tx_id) = tx_id { + self.itx_manager.execute(&tx_id, operation, traceparent).await + } else { execute_single_self_contained( &self.connector, query_schema, operation, - trace_id, + traceparent, self.force_transactions, ) .await - }) - .await - } + } + }) + .await } /// Executes a batch of operations. @@ -87,53 +88,50 @@ where operations: Vec, transaction: Option, query_schema: QuerySchemaRef, - trace_id: Option, + traceparent: Option, engine_protocol: EngineProtocol, ) -> crate::Result>> { - if let Some(tx_id) = tx_id { - let batch_isolation_level = transaction.and_then(|t| t.isolation_level()); - if batch_isolation_level.is_some() { - return Err(CoreError::UnsupportedFeatureError( - "Can not set batch isolation level within interactive transaction".into(), - )); - } - self.itx_manager.batch_execute(&tx_id, operations, trace_id).await - } else if let Some(transaction) = transaction { - let conn_span = info_span!( - "prisma:engine:connection", - user_facing = true, - "db.type" = self.connector.name(), - ); - let mut conn = self.connector.get_connection().instrument(conn_span).await?; - let mut tx = conn.start_transaction(transaction.isolation_level()).await?; - - let results = request_context::with_request_context( - engine_protocol, - execute_many_operations(query_schema, tx.as_connection_like(), &operations, trace_id), - ) - .await; - - if results.is_err() { - tx.rollback().await?; + request_context::with_request_context(engine_protocol, async move { + if let Some(tx_id) = tx_id { + let batch_isolation_level = transaction.and_then(|t| t.isolation_level()); + if batch_isolation_level.is_some() { + return Err(CoreError::UnsupportedFeatureError( + "Can not set batch isolation level within interactive transaction".into(), + )); + } + self.itx_manager.batch_execute(&tx_id, operations, traceparent).await + } else if let Some(transaction) = transaction { + let conn_span = info_span!( + "prisma:engine:connection", + user_facing = true, + "db.system" = self.connector.name(), + ); + let mut conn = self.connector.get_connection().instrument(conn_span).await?; + let mut tx = conn.start_transaction(transaction.isolation_level()).await?; + + let results = + execute_many_operations(query_schema, tx.as_connection_like(), &operations, traceparent).await; + + if results.is_err() { + tx.rollback().await?; + } else { + tx.commit().await?; + } + + results } else { - tx.commit().await?; - } - - results - } else { - request_context::with_request_context(engine_protocol, async move { execute_many_self_contained( &self.connector, query_schema, &operations, - trace_id, + traceparent, self.force_transactions, engine_protocol, ) .await - }) - .await - } + } + }) + .await } fn primary_connector(&self) -> &(dyn Connector + Send + Sync) { @@ -158,11 +156,10 @@ where let valid_for_millis = tx_opts.valid_for_millis; let id = tx_opts.new_tx_id.unwrap_or_default(); - trace!("[{}] Starting...", id); let conn_span = info_span!( "prisma:engine:connection", user_facing = true, - "db.type" = self.connector.name() + "db.system" = self.connector.name() ); let conn = crosstarget_utils::time::timeout( Duration::from_millis(tx_opts.max_acquisition_millis), @@ -180,23 +177,19 @@ where conn, isolation_level, Duration::from_millis(valid_for_millis), - engine_protocol, ) .await?; - debug!("[{}] Started.", id); Ok(id) }) .await } async fn commit_tx(&self, tx_id: TxId) -> crate::Result<()> { - trace!("[{}] Committing.", tx_id); self.itx_manager.commit_tx(&tx_id).await } async fn rollback_tx(&self, tx_id: TxId) -> crate::Result<()> { - trace!("[{}] Rolling back.", tx_id); self.itx_manager.rollback_tx(&tx_id).await } } diff --git a/query-engine/core/src/executor/mod.rs b/query-engine/core/src/executor/mod.rs index fee7bc68fe7..c7846f7ff7c 100644 --- a/query-engine/core/src/executor/mod.rs +++ b/query-engine/core/src/executor/mod.rs @@ -14,6 +14,7 @@ mod request_context; pub use self::{execute_operation::*, interpreting_executor::InterpretingExecutor}; pub(crate) use request_context::*; +use telemetry::helpers::TraceParent; use crate::{ protocol::EngineProtocol, query_document::Operation, response_ir::ResponseData, schema::QuerySchemaRef, @@ -22,7 +23,6 @@ use crate::{ use async_trait::async_trait; use connector::Connector; use serde::{Deserialize, Serialize}; -use tracing::Dispatch; #[cfg_attr(target_arch = "wasm32", async_trait(?Send))] #[cfg_attr(not(target_arch = "wasm32"), async_trait)] @@ -35,7 +35,7 @@ pub trait QueryExecutor: TransactionManager { tx_id: Option, operation: Operation, query_schema: QuerySchemaRef, - trace_id: Option, + traceparent: Option, engine_protocol: EngineProtocol, ) -> crate::Result; @@ -51,7 +51,7 @@ pub trait QueryExecutor: TransactionManager { operations: Vec, transaction: Option, query_schema: QuerySchemaRef, - trace_id: Option, + traceparent: Option, engine_protocol: EngineProtocol, ) -> crate::Result>>; @@ -89,10 +89,10 @@ impl TransactionOptions { /// Generates a new transaction id before the transaction is started and returns a modified version /// of self with the new predefined_id set. - pub fn with_new_transaction_id(&mut self) -> TxId { - let tx_id: TxId = Default::default(); + pub fn with_new_transaction_id(mut self) -> Self { + let tx_id = TxId::default(); self.new_tx_id = Some(tx_id.clone()); - tx_id + self } } #[cfg_attr(target_arch = "wasm32", async_trait(?Send))] @@ -116,20 +116,3 @@ pub trait TransactionManager { /// Rolls back a transaction. async fn rollback_tx(&self, tx_id: TxId) -> crate::Result<()>; } - -// With the node-api when a future is spawned in a new thread `tokio:spawn` it will not -// use the current dispatcher and its logs will not be captured anymore. We can use this -// method to get the current dispatcher and combine it with `with_subscriber` -// let dispatcher = get_current_dispatcher(); -// tokio::spawn(async { -// my_async_ops.await -// }.with_subscriber(dispatcher)); -// -// -// Finally, this can be replaced with with_current_collector -// https://github.com/tokio-rs/tracing/blob/master/tracing-futures/src/lib.rs#L234 -// once this is in a release - -pub fn get_current_dispatcher() -> Dispatch { - tracing::dispatcher::get_default(|current| current.clone()) -} diff --git a/query-engine/core/src/executor/pipeline.rs b/query-engine/core/src/executor/pipeline.rs index 2193410a57e..bd1ba73d5e8 100644 --- a/query-engine/core/src/executor/pipeline.rs +++ b/query-engine/core/src/executor/pipeline.rs @@ -1,5 +1,6 @@ use crate::{Env, Expressionista, IrSerializer, QueryGraph, QueryInterpreter, ResponseData}; use schema::QuerySchema; +use telemetry::helpers::TraceParent; use tracing::Instrument; #[derive(Debug)] @@ -25,7 +26,7 @@ impl<'conn, 'schema> QueryPipeline<'conn, 'schema> { pub(crate) async fn execute( mut self, query_schema: &'schema QuerySchema, - trace_id: Option, + traceparent: Option, ) -> crate::Result { let serializer = self.serializer; let expr = Expressionista::translate(self.graph)?; @@ -34,7 +35,7 @@ impl<'conn, 'schema> QueryPipeline<'conn, 'schema> { let result = self .interpreter - .interpret(expr, Env::default(), 0, trace_id) + .interpret(expr, Env::default(), 0, traceparent) .instrument(span) .await; diff --git a/query-engine/core/src/interactive_transactions/actor_manager.rs b/query-engine/core/src/interactive_transactions/actor_manager.rs deleted file mode 100644 index e6c1c7fbd1d..00000000000 --- a/query-engine/core/src/interactive_transactions/actor_manager.rs +++ /dev/null @@ -1,160 +0,0 @@ -use crate::{protocol::EngineProtocol, ClosedTx, Operation, ResponseData}; -use connector::Connection; -use crosstarget_utils::task::JoinHandle; -use lru::LruCache; -use once_cell::sync::Lazy; -use schema::QuerySchemaRef; -use std::{collections::HashMap, sync::Arc}; -use tokio::{ - sync::{ - mpsc::{channel, Sender}, - RwLock, - }, - time::Duration, -}; - -use super::{spawn_client_list_clear_actor, spawn_itx_actor, ITXClient, TransactionError, TxId}; - -pub static CLOSED_TX_CACHE_SIZE: Lazy = Lazy::new(|| match std::env::var("CLOSED_TX_CACHE_SIZE") { - Ok(size) => size.parse().unwrap_or(100), - Err(_) => 100, -}); - -static CHANNEL_SIZE: usize = 100; - -pub struct TransactionActorManager { - /// Map of active ITx clients - pub(crate) clients: Arc>>, - /// Cache of closed transactions. We keep the last N closed transactions in memory to - /// return better error messages if operations are performed on closed transactions. - pub(crate) closed_txs: Arc>>>, - /// Channel used to signal an ITx is closed and can be moved to the list of closed transactions. - send_done: Sender<(TxId, Option)>, - /// Handle to the task in charge of clearing actors. - /// Used to abort the task when the TransactionActorManager is dropped. - bg_reader_clear: JoinHandle<()>, -} - -impl Drop for TransactionActorManager { - fn drop(&mut self) { - self.bg_reader_clear.abort(); - } -} - -impl Default for TransactionActorManager { - fn default() -> Self { - Self::new() - } -} - -impl TransactionActorManager { - pub fn new() -> Self { - let clients = Arc::new(RwLock::new(HashMap::new())); - let closed_txs = Arc::new(RwLock::new(LruCache::new(*CLOSED_TX_CACHE_SIZE))); - - let (send_done, rx) = channel(CHANNEL_SIZE); - let handle = spawn_client_list_clear_actor(clients.clone(), closed_txs.clone(), rx); - - Self { - clients, - closed_txs, - send_done, - bg_reader_clear: handle, - } - } - - pub(crate) async fn create_tx( - &self, - query_schema: QuerySchemaRef, - tx_id: TxId, - conn: Box, - isolation_level: Option, - timeout: Duration, - engine_protocol: EngineProtocol, - ) -> crate::Result<()> { - let client = spawn_itx_actor( - query_schema.clone(), - tx_id.clone(), - conn, - isolation_level, - timeout, - CHANNEL_SIZE, - self.send_done.clone(), - engine_protocol, - ) - .await?; - - self.clients.write().await.insert(tx_id, client); - Ok(()) - } - - async fn get_client(&self, tx_id: &TxId, from_operation: &str) -> crate::Result { - if let Some(client) = self.clients.read().await.get(tx_id) { - Ok(client.clone()) - } else if let Some(closed_tx) = self.closed_txs.read().await.peek(tx_id) { - Err(TransactionError::Closed { - reason: match closed_tx { - Some(ClosedTx::Committed) => { - format!("A {from_operation} cannot be executed on a committed transaction") - } - Some(ClosedTx::RolledBack) => { - format!("A {from_operation} cannot be executed on a transaction that was rolled back") - } - Some(ClosedTx::Expired { start_time, timeout }) => { - format!( - "A {from_operation} cannot be executed on an expired transaction. \ - The timeout for this transaction was {} ms, however {} ms passed since the start \ - of the transaction. Consider increasing the interactive transaction timeout \ - or doing less work in the transaction", - timeout.as_millis(), - start_time.elapsed_time().as_millis(), - ) - } - None => { - error!("[{tx_id}] no details about closed transaction"); - format!("A {from_operation} cannot be executed on a closed transaction") - } - }, - } - .into()) - } else { - Err(TransactionError::NotFound.into()) - } - } - - pub async fn execute( - &self, - tx_id: &TxId, - operation: Operation, - traceparent: Option, - ) -> crate::Result { - let client = self.get_client(tx_id, "query").await?; - - client.execute(operation, traceparent).await - } - - pub async fn batch_execute( - &self, - tx_id: &TxId, - operations: Vec, - traceparent: Option, - ) -> crate::Result>> { - let client = self.get_client(tx_id, "batch query").await?; - - client.batch_execute(operations, traceparent).await - } - - pub async fn commit_tx(&self, tx_id: &TxId) -> crate::Result<()> { - let client = self.get_client(tx_id, "commit").await?; - client.commit().await?; - - Ok(()) - } - - pub async fn rollback_tx(&self, tx_id: &TxId) -> crate::Result<()> { - let client = self.get_client(tx_id, "rollback").await?; - client.rollback().await?; - - Ok(()) - } -} diff --git a/query-engine/core/src/interactive_transactions/actors.rs b/query-engine/core/src/interactive_transactions/actors.rs deleted file mode 100644 index 86ebd5c13b8..00000000000 --- a/query-engine/core/src/interactive_transactions/actors.rs +++ /dev/null @@ -1,425 +0,0 @@ -use super::{CachedTx, TransactionError, TxOpRequest, TxOpRequestMsg, TxOpResponse}; -use crate::{ - execute_many_operations, execute_single_operation, protocol::EngineProtocol, ClosedTx, Operation, ResponseData, - TxId, -}; -use connector::Connection; -use crosstarget_utils::task::{spawn, spawn_controlled, JoinHandle}; -use crosstarget_utils::time::ElapsedTimeCounter; -use schema::QuerySchemaRef; -use std::{collections::HashMap, sync::Arc}; -use tokio::{ - sync::{ - mpsc::{channel, Receiver, Sender}, - oneshot, RwLock, - }, - time::Duration, -}; -use tracing::Span; -use tracing_futures::Instrument; -use tracing_futures::WithSubscriber; - -#[cfg(feature = "metrics")] -use crate::telemetry::helpers::set_span_link_from_traceparent; - -#[derive(PartialEq)] -enum RunState { - Continue, - Finished, -} - -pub struct ITXServer<'a> { - id: TxId, - pub cached_tx: CachedTx<'a>, - pub timeout: Duration, - receive: Receiver, - query_schema: QuerySchemaRef, -} - -impl<'a> ITXServer<'a> { - pub fn new( - id: TxId, - tx: CachedTx<'a>, - timeout: Duration, - receive: Receiver, - query_schema: QuerySchemaRef, - ) -> Self { - Self { - id, - cached_tx: tx, - timeout, - receive, - query_schema, - } - } - - // RunState is used to tell if the run loop should continue - async fn process_msg(&mut self, op: TxOpRequest) -> RunState { - match op.msg { - TxOpRequestMsg::Single(ref operation, traceparent) => { - let result = self.execute_single(operation, traceparent).await; - let _ = op.respond_to.send(TxOpResponse::Single(result)); - RunState::Continue - } - TxOpRequestMsg::Batch(ref operations, traceparent) => { - let result = self.execute_batch(operations, traceparent).await; - let _ = op.respond_to.send(TxOpResponse::Batch(result)); - RunState::Continue - } - TxOpRequestMsg::Commit => { - let resp = self.commit().await; - let _ = op.respond_to.send(TxOpResponse::Committed(resp)); - RunState::Finished - } - TxOpRequestMsg::Rollback => { - let resp = self.rollback(false).await; - let _ = op.respond_to.send(TxOpResponse::RolledBack(resp)); - RunState::Finished - } - } - } - - async fn execute_single( - &mut self, - operation: &Operation, - traceparent: Option, - ) -> crate::Result { - let span = info_span!("prisma:engine:itx_query_builder", user_facing = true); - - #[cfg(feature = "metrics")] - set_span_link_from_traceparent(&span, traceparent.clone()); - - let conn = self.cached_tx.as_open()?; - execute_single_operation( - self.query_schema.clone(), - conn.as_connection_like(), - operation, - traceparent, - ) - .instrument(span) - .await - } - - async fn execute_batch( - &mut self, - operations: &[Operation], - traceparent: Option, - ) -> crate::Result>> { - let span = info_span!("prisma:engine:itx_execute", user_facing = true); - - let conn = self.cached_tx.as_open()?; - execute_many_operations( - self.query_schema.clone(), - conn.as_connection_like(), - operations, - traceparent, - ) - .instrument(span) - .await - } - - pub(crate) async fn commit(&mut self) -> crate::Result<()> { - if let CachedTx::Open(_) = self.cached_tx { - let open_tx = self.cached_tx.as_open()?; - trace!("[{}] committing.", self.id.to_string()); - open_tx.commit().await?; - self.cached_tx = CachedTx::Committed; - } - - Ok(()) - } - - pub(crate) async fn rollback(&mut self, was_timeout: bool) -> crate::Result<()> { - debug!("[{}] rolling back, was timed out = {was_timeout}", self.name()); - if let CachedTx::Open(_) = self.cached_tx { - let open_tx = self.cached_tx.as_open()?; - open_tx.rollback().await?; - if was_timeout { - trace!("[{}] Expired Rolling back", self.id.to_string()); - self.cached_tx = CachedTx::Expired; - } else { - self.cached_tx = CachedTx::RolledBack; - trace!("[{}] Rolling back", self.id.to_string()); - } - } - - Ok(()) - } - - pub(crate) fn name(&self) -> String { - format!("itx-{:?}", self.id.to_string()) - } -} - -#[derive(Clone)] -pub struct ITXClient { - send: Sender, - tx_id: TxId, -} - -impl ITXClient { - pub(crate) async fn commit(&self) -> crate::Result<()> { - let msg = self.send_and_receive(TxOpRequestMsg::Commit).await?; - - if let TxOpResponse::Committed(resp) = msg { - debug!("[{}] COMMITTED {:?}", self.tx_id, resp); - resp - } else { - Err(self.handle_error(msg).into()) - } - } - - pub(crate) async fn rollback(&self) -> crate::Result<()> { - let msg = self.send_and_receive(TxOpRequestMsg::Rollback).await?; - - if let TxOpResponse::RolledBack(resp) = msg { - resp - } else { - Err(self.handle_error(msg).into()) - } - } - - pub async fn execute(&self, operation: Operation, traceparent: Option) -> crate::Result { - let msg_req = TxOpRequestMsg::Single(operation, traceparent); - let msg = self.send_and_receive(msg_req).await?; - - if let TxOpResponse::Single(resp) = msg { - resp - } else { - Err(self.handle_error(msg).into()) - } - } - - pub(crate) async fn batch_execute( - &self, - operations: Vec, - traceparent: Option, - ) -> crate::Result>> { - let msg_req = TxOpRequestMsg::Batch(operations, traceparent); - - let msg = self.send_and_receive(msg_req).await?; - - if let TxOpResponse::Batch(resp) = msg { - resp - } else { - Err(self.handle_error(msg).into()) - } - } - - async fn send_and_receive(&self, msg: TxOpRequestMsg) -> Result { - let (receiver, req) = self.create_receive_and_req(msg); - if let Err(err) = self.send.send(req).await { - debug!("channel send error {err}"); - return Err(TransactionError::Closed { - reason: "Could not perform operation".to_string(), - } - .into()); - } - - match receiver.await { - Ok(resp) => Ok(resp), - Err(_err) => Err(TransactionError::Closed { - reason: "Could not perform operation".to_string(), - } - .into()), - } - } - - fn create_receive_and_req(&self, msg: TxOpRequestMsg) -> (oneshot::Receiver, TxOpRequest) { - let (send, rx) = oneshot::channel::(); - let request = TxOpRequest { msg, respond_to: send }; - (rx, request) - } - - fn handle_error(&self, msg: TxOpResponse) -> TransactionError { - match msg { - TxOpResponse::Committed(..) => { - let reason = "Transaction is no longer valid. Last state: 'Committed'".to_string(); - TransactionError::Closed { reason } - } - TxOpResponse::RolledBack(..) => { - let reason = "Transaction is no longer valid. Last state: 'RolledBack'".to_string(); - TransactionError::Closed { reason } - } - other => { - error!("Unexpected iTx response, {}", other); - let reason = format!("response '{other}'"); - TransactionError::Closed { reason } - } - } - } -} - -#[allow(clippy::too_many_arguments)] -pub(crate) async fn spawn_itx_actor( - query_schema: QuerySchemaRef, - tx_id: TxId, - mut conn: Box, - isolation_level: Option, - timeout: Duration, - channel_size: usize, - send_done: Sender<(TxId, Option)>, - engine_protocol: EngineProtocol, -) -> crate::Result { - let span = Span::current(); - let tx_id_str = tx_id.to_string(); - span.record("itx_id", tx_id_str.as_str()); - let dispatcher = crate::get_current_dispatcher(); - - let (tx_to_server, rx_from_client) = channel::(channel_size); - let client = ITXClient { - send: tx_to_server, - tx_id: tx_id.clone(), - }; - let (open_transaction_send, open_transaction_rcv) = oneshot::channel(); - - spawn( - crate::executor::with_request_context(engine_protocol, async move { - // We match on the result in order to send the error to the parent task and abort this - // task, on error. This is a separate task (actor), not a function where we can just bubble up the - // result. - let c_tx = match conn.start_transaction(isolation_level).await { - Ok(c_tx) => { - open_transaction_send.send(Ok(())).unwrap(); - c_tx - } - Err(err) => { - open_transaction_send.send(Err(err)).unwrap(); - return; - } - }; - - let mut server = ITXServer::new( - tx_id.clone(), - CachedTx::Open(c_tx), - timeout, - rx_from_client, - query_schema, - ); - - let start_time = ElapsedTimeCounter::start(); - let sleep = crosstarget_utils::time::sleep(timeout); - tokio::pin!(sleep); - - loop { - tokio::select! { - _ = &mut sleep => { - trace!("[{}] interactive transaction timed out", server.id.to_string()); - let _ = server.rollback(true).await; - break; - } - msg = server.receive.recv() => { - if let Some(op) = msg { - let run_state = server.process_msg(op).await; - - if run_state == RunState::Finished { - break - } - } else { - break; - } - } - } - } - - trace!("[{}] completed with {}", server.id.to_string(), server.cached_tx); - - let _ = send_done - .send(( - server.id.clone(), - server.cached_tx.to_closed(start_time, server.timeout), - )) - .await; - - trace!("[{}] has stopped with {}", server.id.to_string(), server.cached_tx); - }) - .instrument(span) - .with_subscriber(dispatcher), - ); - - open_transaction_rcv.await.unwrap()?; - - Ok(client) -} - -/// Spawn the client list clear actor -/// It waits for messages from completed ITXServers and removes -/// the ITXClient from the clients hashmap - -/* A future improvement to this would be to change this to keep a queue of - clients to remove from the list and then periodically remove them. This - would be a nice optimization because we would take less write locks on the - hashmap. - - The downside to consider is that we can introduce a race condition where the - ITXServer has stopped running but the client hasn't been removed from the hashmap - yet. When the client tries to send a message to the ITXServer there will be a - send error. This isn't a huge obstacle but something to handle correctly. - And example implementation for this would be: - - ``` - let mut queue: Vec = Vec::new(); - - let sleep_duration = Duration::from_millis(100); - let clear_sleeper = time::sleep(sleep_duration); - tokio::pin!(clear_sleeper); - - loop { - tokio::select! { - _ = &mut clear_sleeper => { - let mut list = clients.write().await; - for id in queue.drain(..) { - trace!("removing {} from client list", id); - list.remove(&id); - } - clear_sleeper.as_mut().reset(Instant::now() + sleep_duration); - } - msg = rx.recv() => { - if let Some(id) = msg { - queue.push(id); - } - } - } - } - ``` -*/ -pub(crate) fn spawn_client_list_clear_actor( - clients: Arc>>, - closed_txs: Arc>>>, - mut rx: Receiver<(TxId, Option)>, -) -> JoinHandle<()> { - // Note: tasks implemented via loops cannot be cancelled implicitly, so we need to spawn them in a - // "controlled" way, via `spawn_controlled`. - // The `rx_exit` receiver is used to signal the loop to exit, and that signal is emitted whenever - // the task is aborted (likely, due to the engine shutting down and cleaning up the allocated resources). - spawn_controlled(Box::new( - |mut rx_exit: tokio::sync::broadcast::Receiver<()>| async move { - loop { - tokio::select! { - result = rx.recv() => { - match result { - Some((id, closed_tx)) => { - trace!("removing {} from client list", id); - - let mut clients_guard = clients.write().await; - - clients_guard.remove(&id); - drop(clients_guard); - - closed_txs.write().await.put(id, closed_tx); - } - None => { - // the `rx` channel is closed. - tracing::error!("rx channel is closed!"); - break; - } - } - }, - _ = rx_exit.recv() => { - break; - }, - } - } - }, - )) -} diff --git a/query-engine/core/src/interactive_transactions/error.rs b/query-engine/core/src/interactive_transactions/error.rs index 8189e2ce742..146d69f103b 100644 --- a/query-engine/core/src/interactive_transactions/error.rs +++ b/query-engine/core/src/interactive_transactions/error.rs @@ -1,10 +1,5 @@ use thiserror::Error; -use crate::{ - response_ir::{Item, Map}, - CoreError, -}; - #[derive(Debug, Error, PartialEq)] pub enum TransactionError { #[error("Unable to start a transaction in the given time.")] @@ -22,27 +17,3 @@ pub enum TransactionError { #[error("Unexpected response: {reason}.")] Unknown { reason: String }, } - -#[derive(Debug, serde::Serialize, PartialEq)] -pub struct ExtendedTransactionUserFacingError { - #[serde(flatten)] - user_facing_error: user_facing_errors::Error, - - #[serde(skip_serializing_if = "indexmap::IndexMap::is_empty")] - extensions: Map, -} - -impl ExtendedTransactionUserFacingError { - pub fn set_extension(&mut self, key: String, val: serde_json::Value) { - self.extensions.entry(key).or_insert(Item::Json(val)); - } -} - -impl From for ExtendedTransactionUserFacingError { - fn from(error: CoreError) -> Self { - ExtendedTransactionUserFacingError { - user_facing_error: error.into(), - extensions: Default::default(), - } - } -} diff --git a/query-engine/core/src/interactive_transactions/manager.rs b/query-engine/core/src/interactive_transactions/manager.rs new file mode 100644 index 00000000000..d9873c4383a --- /dev/null +++ b/query-engine/core/src/interactive_transactions/manager.rs @@ -0,0 +1,192 @@ +use crate::{ClosedTransaction, InteractiveTransaction, Operation, ResponseData}; +use connector::Connection; +use lru::LruCache; +use once_cell::sync::Lazy; +use schema::QuerySchemaRef; +use std::{collections::HashMap, sync::Arc}; +use telemetry::helpers::TraceParent; +use tokio::{ + sync::{ + mpsc::{unbounded_channel, UnboundedSender}, + Mutex, RwLock, + }, + time::Duration, +}; +use tracing_futures::WithSubscriber; + +#[cfg(not(feature = "metrics"))] +use crate::metrics::MetricsInstrumentationStub; +#[cfg(feature = "metrics")] +use prisma_metrics::WithMetricsInstrumentation; + +use super::{TransactionError, TxId}; + +pub static CLOSED_TX_CACHE_SIZE: Lazy = Lazy::new(|| match std::env::var("CLOSED_TX_CACHE_SIZE") { + Ok(size) => size.parse().unwrap_or(100), + Err(_) => 100, +}); + +pub struct ItxManager { + /// Stores all current transactions (some of them might be already committed/expired/rolled back). + /// + /// There are two tiers of locks here: + /// 1. Lock on the entire hashmap. This *must* be taken only for short periods of time - for + /// example to insert/delete transaction or to clone transaction inside. + /// 2. Lock on the individual transactions. This one can be taken for prolonged periods of time - for + /// example to perform an I/O operation. + /// + /// The rationale behind this design is to make shared path (lock on the entire hashmap) as free + /// from contention as possible. Individual transactions are not capable of concurrency, so + /// taking a lock on them to serialise operations is acceptable. + /// + /// Note that since we clone transaction from the shared hashmap to perform operations on it, it + /// is possible to end up in a situation where we cloned the transaction, but it was then + /// immediately removed by the background task from the common hashmap. In this case, either + /// our operation will be first or the background cleanup task will be first. Both cases are + /// an acceptable outcome. + transactions: Arc>>>>, + + /// Cache of closed transactions. We keep the last N closed transactions in memory to + /// return better error messages if operations are performed on closed transactions. + closed_txs: Arc>>, + + /// Sender part of the channel to which transaction id is sent when the timeout of the + /// transaction expires. + timeout_sender: UnboundedSender, +} + +impl ItxManager { + pub fn new() -> Self { + let transactions = Arc::new(RwLock::new(HashMap::<_, Arc>>::default())); + let closed_txs = Arc::new(RwLock::new(LruCache::new(*CLOSED_TX_CACHE_SIZE))); + let (timeout_sender, mut timeout_receiver) = unbounded_channel(); + + // This task rollbacks and removes any open transactions with expired timeouts from the + // `self.transactions`. It also removes any closed transactions to avoid `self.transactions` + // growing infinitely in size over time. + // Note that this task automatically exits when all transactions finish and the `ItxManager` + // is dropped, because that causes the `timeout_receiver` to become closed. + crosstarget_utils::task::spawn({ + let transactions = Arc::clone(&transactions); + let closed_txs = Arc::clone(&closed_txs); + async move { + while let Some(tx_id) = timeout_receiver.recv().await { + let transaction_entry = match transactions.write().await.remove(&tx_id) { + Some(transaction_entry) => transaction_entry, + None => { + // Transaction was committed or rolled back already. + continue; + } + }; + let mut transaction = transaction_entry.lock().await; + + // If transaction was already committed, rollback will error. + let _ = transaction.rollback(true).await; + + let closed_tx = transaction + .as_closed() + .expect("transaction must be closed after rollback"); + + closed_txs.write().await.put(tx_id, closed_tx); + } + } + .with_current_subscriber() + .with_current_recorder() + }); + + Self { + transactions, + closed_txs, + timeout_sender, + } + } + + pub async fn create_tx( + &self, + query_schema: QuerySchemaRef, + tx_id: TxId, + conn: Box, + isolation_level: Option, + timeout: Duration, + ) -> crate::Result<()> { + // This task notifies the task spawned in `new()` method that the timeout for this + // transaction has expired. + crosstarget_utils::task::spawn({ + let timeout_sender = self.timeout_sender.clone(); + let tx_id = tx_id.clone(); + async move { + crosstarget_utils::time::sleep(timeout).await; + timeout_sender.send(tx_id).expect("receiver must exist"); + } + }); + + let transaction = + InteractiveTransaction::new(tx_id.clone(), conn, timeout, query_schema, isolation_level).await?; + + self.transactions + .write() + .await + .insert(tx_id, Arc::new(Mutex::new(transaction))); + Ok(()) + } + + async fn get_transaction( + &self, + tx_id: &TxId, + from_operation: &str, + ) -> crate::Result>> { + if let Some(transaction) = self.transactions.read().await.get(tx_id) { + Ok(Arc::clone(transaction)) + } else { + Err(if let Some(closed_tx) = self.closed_txs.read().await.peek(tx_id) { + TransactionError::Closed { + reason: closed_tx.error_message_for(from_operation), + } + .into() + } else { + TransactionError::NotFound.into() + }) + } + } + + pub async fn execute( + &self, + tx_id: &TxId, + operation: Operation, + traceparent: Option, + ) -> crate::Result { + self.get_transaction(tx_id, "query") + .await? + .lock() + .await + .execute_single(&operation, traceparent) + .await + } + + pub async fn batch_execute( + &self, + tx_id: &TxId, + operations: Vec, + traceparent: Option, + ) -> crate::Result>> { + self.get_transaction(tx_id, "batch query") + .await? + .lock() + .await + .execute_batch(&operations, traceparent) + .await + } + + pub async fn commit_tx(&self, tx_id: &TxId) -> crate::Result<()> { + self.get_transaction(tx_id, "commit").await?.lock().await.commit().await + } + + pub async fn rollback_tx(&self, tx_id: &TxId) -> crate::Result<()> { + self.get_transaction(tx_id, "rollback") + .await? + .lock() + .await + .rollback(false) + .await + } +} diff --git a/query-engine/core/src/interactive_transactions/messages.rs b/query-engine/core/src/interactive_transactions/messages.rs deleted file mode 100644 index 0dba2c096a8..00000000000 --- a/query-engine/core/src/interactive_transactions/messages.rs +++ /dev/null @@ -1,46 +0,0 @@ -use crate::{Operation, ResponseData}; -use std::fmt::Display; -use tokio::sync::oneshot; - -#[derive(Debug)] -pub enum TxOpRequestMsg { - Commit, - Rollback, - Single(Operation, Option), - Batch(Vec, Option), -} - -pub struct TxOpRequest { - pub msg: TxOpRequestMsg, - pub respond_to: oneshot::Sender, -} - -impl Display for TxOpRequest { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self.msg { - TxOpRequestMsg::Commit => write!(f, "Commit"), - TxOpRequestMsg::Rollback => write!(f, "Rollback"), - TxOpRequestMsg::Single(..) => write!(f, "Single"), - TxOpRequestMsg::Batch(..) => write!(f, "Batch"), - } - } -} - -#[derive(Debug)] -pub enum TxOpResponse { - Committed(crate::Result<()>), - RolledBack(crate::Result<()>), - Single(crate::Result), - Batch(crate::Result>>), -} - -impl Display for TxOpResponse { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Committed(..) => write!(f, "Committed"), - Self::RolledBack(..) => write!(f, "RolledBack"), - Self::Single(..) => write!(f, "Single"), - Self::Batch(..) => write!(f, "Batch"), - } - } -} diff --git a/query-engine/core/src/interactive_transactions/mod.rs b/query-engine/core/src/interactive_transactions/mod.rs index a0aed069a87..009cab37ccf 100644 --- a/query-engine/core/src/interactive_transactions/mod.rs +++ b/query-engine/core/src/interactive_transactions/mod.rs @@ -1,49 +1,19 @@ -use crate::CoreError; -use connector::Transaction; -use crosstarget_utils::time::ElapsedTimeCounter; +use derive_more::Display; use serde::Deserialize; -use std::fmt::Display; -use tokio::time::Duration; -mod actor_manager; -mod actors; mod error; -mod messages; +mod manager; +mod transaction; pub use error::*; -pub(crate) use actor_manager::*; -pub(crate) use actors::*; -pub(crate) use messages::*; +pub(crate) use manager::*; +pub(crate) use transaction::*; -/// How Interactive Transactions work -/// The Interactive Transactions (iTx) follow an actor model design. Where each iTx is created in its own process. -/// When a prisma client requests to start a new transaction, the Transaction Actor Manager spawns a new ITXServer. The ITXServer runs in its own -/// process and waits for messages to arrive via its receive channel to process. -/// The Transaction Actor Manager will also create an ITXClient and add it to hashmap managed by an RwLock. The ITXClient is the only way to communicate -/// with the ITXServer. -/// -/// Once Prisma Client receives the iTx Id it can perform database operations using that iTx id. When an operation request is received by the -/// TransactionActorManager, it looks for the client in the hashmap and passes the operation to the client. The ITXClient sends a message to the -/// ITXServer and waits for a response. The ITXServer will then perform the operation and return the result. The ITXServer will perform one -/// operation at a time. All other operations will sit in the message queue waiting to be processed. -/// -/// The ITXServer will handle all messages until: -/// - It transitions state, e.g "rollback" or "commit" -/// - It exceeds its timeout, in which case the iTx is rolledback and the connection to the database is closed. -/// -/// Once the ITXServer is done handling messages from the iTx Client, it sends a last message to the Background Client list Actor to say that it is completed and then shuts down. -/// The Background Client list Actor removes the client from the list of active clients and keeps in cache the iTx id of the closed transaction. -/// -/// We keep a list of closed transactions so that if any further messages are received for this iTx id, -/// the TransactionActorManager can reply with a helpful error message which explains that no operation can be performed on a closed transaction -/// rather than an error message stating that the transaction does not exist. - -#[derive(Debug, Clone, Hash, Eq, PartialEq, Deserialize)] +#[derive(Debug, Clone, Hash, Eq, PartialEq, Deserialize, Display)] +#[display(fmt = "{}", _0)] pub struct TxId(String); -const MINIMUM_TX_ID_LENGTH: usize = 24; - impl Default for TxId { fn default() -> Self { #[allow(deprecated)] @@ -56,9 +26,11 @@ where T: Into, { fn from(s: T) -> Self { + const MINIMUM_TX_ID_LENGTH: usize = 24; + let contents = s.into(); // This postcondition is to ensure that the TxId is long enough as to be able to derive - // a TraceId from it. + // a TraceId from it. See `TxTraceExt` trait for more details. assert!( contents.len() >= MINIMUM_TX_ID_LENGTH, "minimum length for a TxId ({}) is {}, but was {}", @@ -69,57 +41,3 @@ where Self(contents) } } - -impl Display for TxId { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - Display::fmt(&self.0, f) - } -} - -pub enum CachedTx<'a> { - Open(Box), - Committed, - RolledBack, - Expired, -} - -impl Display for CachedTx<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - CachedTx::Open(_) => f.write_str("Open"), - CachedTx::Committed => f.write_str("Committed"), - CachedTx::RolledBack => f.write_str("Rolled back"), - CachedTx::Expired => f.write_str("Expired"), - } - } -} - -impl<'a> CachedTx<'a> { - /// Requires this cached TX to be `Open`, else an error will be raised that it is no longer valid. - pub(crate) fn as_open(&mut self) -> crate::Result<&mut Box> { - if let Self::Open(ref mut otx) = self { - Ok(otx) - } else { - let reason = format!("Transaction is no longer valid. Last state: '{self}'"); - Err(CoreError::from(TransactionError::Closed { reason })) - } - } - - pub(crate) fn to_closed(&self, start_time: ElapsedTimeCounter, timeout: Duration) -> Option { - match self { - CachedTx::Open(_) => None, - CachedTx::Committed => Some(ClosedTx::Committed), - CachedTx::RolledBack => Some(ClosedTx::RolledBack), - CachedTx::Expired => Some(ClosedTx::Expired { start_time, timeout }), - } - } -} - -pub(crate) enum ClosedTx { - Committed, - RolledBack, - Expired { - start_time: ElapsedTimeCounter, - timeout: Duration, - }, -} diff --git a/query-engine/core/src/interactive_transactions/transaction.rs b/query-engine/core/src/interactive_transactions/transaction.rs new file mode 100644 index 00000000000..4e84155ad78 --- /dev/null +++ b/query-engine/core/src/interactive_transactions/transaction.rs @@ -0,0 +1,253 @@ +#![allow(unsafe_code)] + +use std::pin::Pin; + +use crate::{ + execute_many_operations, execute_single_operation, CoreError, Operation, ResponseData, TransactionError, TxId, +}; +use connector::{Connection, Transaction}; +use crosstarget_utils::time::ElapsedTimeCounter; +use schema::QuerySchemaRef; +use telemetry::helpers::TraceParent; +use tokio::time::Duration; +use tracing::Span; +use tracing_futures::Instrument; + +// Note: it's important to maintain the correct state of the transaction throughout execution. If +// the transaction is ever left in the `Open` state after rollback or commit operations, it means +// that the corresponding connection will never be returned to the connection pool. +enum TransactionState { + Open { + // Note: field order is important here because fields are dropped in the declaration order. + // First, we drop the `tx`, which may reference `_conn`. Only after that we drop `_conn`. + tx: Box, + _conn: Pin>, + }, + Committed, + RolledBack, + Expired { + start_time: ElapsedTimeCounter, + timeout: Duration, + }, +} + +pub enum ClosedTransaction { + Committed, + RolledBack, + Expired { + start_time: ElapsedTimeCounter, + timeout: Duration, + }, +} + +impl ClosedTransaction { + pub fn error_message_for(&self, operation: &str) -> String { + match self { + ClosedTransaction::Committed => { + format!("A {operation} cannot be executed on a committed transaction") + } + ClosedTransaction::RolledBack => { + format!("A {operation} cannot be executed on a transaction that was rolled back") + } + ClosedTransaction::Expired { start_time, timeout } => { + format!( + "A {operation} cannot be executed on an expired transaction. \ + The timeout for this transaction was {} ms, however {} ms passed since the start \ + of the transaction. Consider increasing the interactive transaction timeout \ + or doing less work in the transaction", + timeout.as_millis(), + start_time.elapsed_time().as_millis(), + ) + } + } + } +} + +impl TransactionState { + async fn start_transaction( + conn: Box, + isolation_level: Option, + ) -> crate::Result { + // Note: This method creates a self-referential struct, which is why we need unsafe. Field + // `tx` is referencing field `conn` in the `Self::Open` variant. + let mut conn = Box::into_pin(conn); + + // SAFETY: We do not move out of `conn`. + let conn_mut: &mut (dyn Connection + Send + Sync) = unsafe { conn.as_mut().get_unchecked_mut() }; + + // This creates a transaction, which borrows from the connection. + let tx_borrowed_from_conn: Box = conn_mut.start_transaction(isolation_level).await?; + + // SAFETY: This transmute only erases the lifetime from `conn_mut`. Normally, borrow checker + // guarantees that the borrowed value is not dropped. In this case, we guarantee ourselves + // through the use of `Pin` on the connection. + let tx_with_erased_lifetime: Box = + unsafe { std::mem::transmute(tx_borrowed_from_conn) }; + + Ok(Self::Open { + tx: tx_with_erased_lifetime, + _conn: conn, + }) + } + + fn as_open(&mut self, from_operation: &str) -> crate::Result<&mut Box> { + match self { + Self::Open { tx, .. } => Ok(tx), + tx => Err(CoreError::from(TransactionError::Closed { + reason: tx.as_closed().unwrap().error_message_for(from_operation), + })), + } + } + + fn as_closed(&self) -> Option { + match self { + Self::Open { .. } => None, + Self::Committed => Some(ClosedTransaction::Committed), + Self::RolledBack => Some(ClosedTransaction::RolledBack), + Self::Expired { start_time, timeout } => Some(ClosedTransaction::Expired { + start_time: *start_time, + timeout: *timeout, + }), + } + } +} + +pub struct InteractiveTransaction { + id: TxId, + state: TransactionState, + start_time: ElapsedTimeCounter, + timeout: Duration, + query_schema: QuerySchemaRef, +} + +/// This macro executes the future until it's ready or the transaction's timeout expires. +macro_rules! tx_timeout { + ($self:expr, $operation:expr, $fut:expr) => {{ + let remaining_time = $self + .timeout + .checked_sub($self.start_time.elapsed_time()) + .unwrap_or(Duration::ZERO); + tokio::select! { + biased; + _ = crosstarget_utils::time::sleep(remaining_time) => { + let _ = $self.rollback(true).await; + Err(TransactionError::Closed { + reason: $self.as_closed().unwrap().error_message_for($operation), + }.into()) + } + result = $fut => { + result + } + } + }}; +} + +impl InteractiveTransaction { + pub async fn new( + id: TxId, + conn: Box, + timeout: Duration, + query_schema: QuerySchemaRef, + isolation_level: Option, + ) -> crate::Result { + Span::current().record("itx_id", id.to_string()); + + Ok(Self { + id, + state: TransactionState::start_transaction(conn, isolation_level).await?, + start_time: ElapsedTimeCounter::start(), + timeout, + query_schema, + }) + } + + pub async fn execute_single( + &mut self, + operation: &Operation, + traceparent: Option, + ) -> crate::Result { + tx_timeout!(self, "query", async { + let conn = self.state.as_open("query")?; + execute_single_operation( + self.query_schema.clone(), + conn.as_connection_like(), + operation, + traceparent, + ) + .instrument(info_span!("prisma:engine:itx_execute_single", user_facing = true)) + .await + }) + } + + pub async fn execute_batch( + &mut self, + operations: &[Operation], + traceparent: Option, + ) -> crate::Result>> { + tx_timeout!(self, "batch query", async { + let conn = self.state.as_open("batch query")?; + execute_many_operations( + self.query_schema.clone(), + conn.as_connection_like(), + operations, + traceparent, + ) + .instrument(info_span!("prisma:engine:itx_execute_batch", user_facing = true)) + .await + }) + } + + pub async fn commit(&mut self) -> crate::Result<()> { + tx_timeout!(self, "commit", async { + let name = self.name(); + let conn = self.state.as_open("commit")?; + let span = info_span!("prisma:engine:itx_commit", user_facing = true); + + if let Err(err) = conn.commit().instrument(span).await { + error!(?err, ?name, "transaction failed to commit"); + // We don't know if the transaction was committed or not. Because of that, we cannot + // leave it in "open" state. We attempt to rollback to get the transaction into a + // known state. + let _ = self.rollback(false).await; + Err(err.into()) + } else { + debug!(?name, "transaction committed"); + self.state = TransactionState::Committed; + Ok(()) + } + }) + } + + pub async fn rollback(&mut self, was_timeout: bool) -> crate::Result<()> { + let name = self.name(); + let conn = self.state.as_open("rollback")?; + let span = info_span!("prisma:engine:itx_rollback", user_facing = true); + + let result = conn.rollback().instrument(span).await; + if let Err(err) = &result { + error!(?err, ?was_timeout, ?name, "transaction failed to roll back"); + } else { + debug!(?was_timeout, ?name, "transaction rolled back"); + } + + // Ensure that the transaction isn't left in the "open" state after the rollback. + if was_timeout { + self.state = TransactionState::Expired { + start_time: self.start_time, + timeout: self.timeout, + }; + } else { + self.state = TransactionState::RolledBack; + } + + result.map_err(<_>::into) + } + + pub fn as_closed(&self) -> Option { + self.state.as_closed() + } + + pub fn name(&self) -> String { + format!("itx-{}", self.id) + } +} diff --git a/query-engine/core/src/interpreter/interpreter_impl.rs b/query-engine/core/src/interpreter/interpreter_impl.rs index 012bbc953b1..e25d157b7ad 100644 --- a/query-engine/core/src/interpreter/interpreter_impl.rs +++ b/query-engine/core/src/interpreter/interpreter_impl.rs @@ -8,6 +8,7 @@ use connector::ConnectionLike; use futures::future::BoxFuture; use query_structure::prelude::*; use std::{collections::HashMap, fmt}; +use telemetry::helpers::TraceParent; use tracing::Instrument; #[derive(Debug, Clone)] @@ -178,7 +179,7 @@ impl<'conn> QueryInterpreter<'conn> { exp: Expression, env: Env, level: usize, - trace_id: Option, + traceparent: Option, ) -> BoxFuture<'_, InterpretationResult> { match exp { Expression::Func { func } => { @@ -186,7 +187,7 @@ impl<'conn> QueryInterpreter<'conn> { Box::pin(async move { self.log_line(level, || "execute {"); - let result = self.interpret(expr?, env, level + 1, trace_id).await; + let result = self.interpret(expr?, env, level + 1, traceparent).await; self.log_line(level, || "}"); result }) @@ -204,7 +205,7 @@ impl<'conn> QueryInterpreter<'conn> { let mut results = Vec::with_capacity(seq.len()); for expr in seq { - results.push(self.interpret(expr, env.clone(), level + 1, trace_id.clone()).await?); + results.push(self.interpret(expr, env.clone(), level + 1, traceparent).await?); self.log_line(level + 1, || ","); } @@ -227,7 +228,7 @@ impl<'conn> QueryInterpreter<'conn> { self.log_line(level + 1, || format!("{} = {{", &binding.name)); let result = self - .interpret(binding.expr, env.clone(), level + 2, trace_id.clone()) + .interpret(binding.expr, env.clone(), level + 2, traceparent) .await?; inner_env.insert(binding.name, result); @@ -242,7 +243,7 @@ impl<'conn> QueryInterpreter<'conn> { }; self.log_line(level, || "in {"); - let result = self.interpret(next_expression, inner_env, level + 1, trace_id).await; + let result = self.interpret(next_expression, inner_env, level + 1, traceparent).await; self.log_line(level, || "}"); result }) @@ -253,7 +254,7 @@ impl<'conn> QueryInterpreter<'conn> { Query::Read(read) => { self.log_line(level, || format!("readExecute {read}")); let span = info_span!("prisma:engine:read-execute"); - Ok(read::execute(self.conn, read, None, trace_id) + Ok(read::execute(self.conn, read, None, traceparent) .instrument(span) .await .map(ExpressionResult::Query)?) @@ -262,7 +263,7 @@ impl<'conn> QueryInterpreter<'conn> { Query::Write(write) => { self.log_line(level, || format!("writeExecute {write}")); let span = info_span!("prisma:engine:write-execute"); - Ok(write::execute(self.conn, write, trace_id) + Ok(write::execute(self.conn, write, traceparent) .instrument(span) .await .map(ExpressionResult::Query)?) @@ -297,10 +298,10 @@ impl<'conn> QueryInterpreter<'conn> { self.log_line(level, || format!("if = {predicate} {{")); let result = if predicate { - self.interpret(Expression::Sequence { seq: then }, env, level + 1, trace_id) + self.interpret(Expression::Sequence { seq: then }, env, level + 1, traceparent) .await } else { - self.interpret(Expression::Sequence { seq: elze }, env, level + 1, trace_id) + self.interpret(Expression::Sequence { seq: elze }, env, level + 1, traceparent) .await }; self.log_line(level, || "}"); diff --git a/query-engine/core/src/interpreter/query_interpreters/nested_read.rs b/query-engine/core/src/interpreter/query_interpreters/nested_read.rs index 790728104fd..95e5945c18d 100644 --- a/query-engine/core/src/interpreter/query_interpreters/nested_read.rs +++ b/query-engine/core/src/interpreter/query_interpreters/nested_read.rs @@ -3,12 +3,13 @@ use crate::{interpreter::InterpretationResult, query_ast::*}; use connector::ConnectionLike; use query_structure::*; use std::collections::HashMap; +use telemetry::helpers::TraceParent; pub(crate) async fn m2m( tx: &mut dyn ConnectionLike, query: &mut RelatedRecordsQuery, parent_result: Option<&ManyRecords>, - trace_id: Option, + traceparent: Option, ) -> InterpretationResult { let processor = InMemoryRecordProcessor::new_from_query_args(&mut query.args); @@ -31,7 +32,7 @@ pub(crate) async fn m2m( } let ids = tx - .get_related_m2m_record_ids(&query.parent_field, &parent_ids, trace_id.clone()) + .get_related_m2m_record_ids(&query.parent_field, &parent_ids, traceparent) .await?; if ids.is_empty() { return Ok(ManyRecords::empty(&query.selected_fields)); @@ -70,7 +71,7 @@ pub(crate) async fn m2m( args, &query.selected_fields, RelationLoadStrategy::Query, - trace_id.clone(), + traceparent, ) .await? }; @@ -137,7 +138,7 @@ pub async fn one2m( parent_result: Option<&ManyRecords>, mut query_args: QueryArguments, selected_fields: &FieldSelection, - trace_id: Option, + traceparent: Option, ) -> InterpretationResult { let parent_model_id = parent_field.model().primary_identifier(); let parent_link_id = parent_field.linking_fields(); @@ -208,7 +209,7 @@ pub async fn one2m( args, selected_fields, RelationLoadStrategy::Query, - trace_id, + traceparent, ) .await? }; diff --git a/query-engine/core/src/interpreter/query_interpreters/read.rs b/query-engine/core/src/interpreter/query_interpreters/read.rs index 7e194993b75..d79f4fd5c99 100644 --- a/query-engine/core/src/interpreter/query_interpreters/read.rs +++ b/query-engine/core/src/interpreter/query_interpreters/read.rs @@ -4,20 +4,21 @@ use connector::{error::ConnectorError, ConnectionLike}; use futures::future::{BoxFuture, FutureExt}; use psl::can_support_relation_load_strategy; use query_structure::{ManyRecords, RelationLoadStrategy, RelationSelection}; +use telemetry::helpers::TraceParent; use user_facing_errors::KnownError; pub(crate) fn execute<'conn>( tx: &'conn mut dyn ConnectionLike, query: ReadQuery, parent_result: Option<&'conn ManyRecords>, - trace_id: Option, + traceparent: Option, ) -> BoxFuture<'conn, InterpretationResult> { let fut = async move { match query { - ReadQuery::RecordQuery(q) => read_one(tx, q, trace_id).await, - ReadQuery::ManyRecordsQuery(q) => read_many(tx, q, trace_id).await, - ReadQuery::RelatedRecordsQuery(q) => read_related(tx, q, parent_result, trace_id).await, - ReadQuery::AggregateRecordsQuery(q) => aggregate(tx, q, trace_id).await, + ReadQuery::RecordQuery(q) => read_one(tx, q, traceparent).await, + ReadQuery::ManyRecordsQuery(q) => read_many(tx, q, traceparent).await, + ReadQuery::RelatedRecordsQuery(q) => read_related(tx, q, parent_result, traceparent).await, + ReadQuery::AggregateRecordsQuery(q) => aggregate(tx, q, traceparent).await, } }; @@ -28,7 +29,7 @@ pub(crate) fn execute<'conn>( fn read_one( tx: &mut dyn ConnectionLike, query: RecordQuery, - trace_id: Option, + traceparent: Option, ) -> BoxFuture<'_, InterpretationResult> { let fut = async move { let model = query.model; @@ -39,7 +40,7 @@ fn read_one( &filter, &query.selected_fields, query.relation_load_strategy, - trace_id, + traceparent, ) .await?; @@ -97,18 +98,18 @@ fn read_one( fn read_many( tx: &mut dyn ConnectionLike, query: ManyRecordsQuery, - trace_id: Option, + traceparent: Option, ) -> BoxFuture<'_, InterpretationResult> { match query.relation_load_strategy { - RelationLoadStrategy::Join => read_many_by_joins(tx, query, trace_id), - RelationLoadStrategy::Query => read_many_by_queries(tx, query, trace_id), + RelationLoadStrategy::Join => read_many_by_joins(tx, query, traceparent), + RelationLoadStrategy::Query => read_many_by_queries(tx, query, traceparent), } } fn read_many_by_queries( tx: &mut dyn ConnectionLike, mut query: ManyRecordsQuery, - trace_id: Option, + traceparent: Option, ) -> BoxFuture<'_, InterpretationResult> { let processor = if query.args.requires_inmemory_processing() { Some(InMemoryRecordProcessor::new_from_query_args(&mut query.args)) @@ -123,7 +124,7 @@ fn read_many_by_queries( query.args.clone(), &query.selected_fields, query.relation_load_strategy, - trace_id, + traceparent, ) .await?; @@ -156,7 +157,7 @@ fn read_many_by_queries( fn read_many_by_joins( tx: &mut dyn ConnectionLike, query: ManyRecordsQuery, - trace_id: Option, + traceparent: Option, ) -> BoxFuture<'_, InterpretationResult> { if !can_support_relation_load_strategy() { unreachable!() @@ -168,7 +169,7 @@ fn read_many_by_joins( query.args.clone(), &query.selected_fields, query.relation_load_strategy, - trace_id, + traceparent, ) .await?; @@ -209,13 +210,13 @@ fn read_related<'conn>( tx: &'conn mut dyn ConnectionLike, mut query: RelatedRecordsQuery, parent_result: Option<&'conn ManyRecords>, - trace_id: Option, + traceparent: Option, ) -> BoxFuture<'conn, InterpretationResult> { let fut = async move { let relation = query.parent_field.relation(); let records = if relation.is_many_to_many() { - nested_read::m2m(tx, &mut query, parent_result, trace_id).await? + nested_read::m2m(tx, &mut query, parent_result, traceparent).await? } else { nested_read::one2m( tx, @@ -224,7 +225,7 @@ fn read_related<'conn>( parent_result, query.args.clone(), &query.selected_fields, - trace_id, + traceparent, ) .await? }; @@ -248,7 +249,7 @@ fn read_related<'conn>( async fn aggregate( tx: &mut dyn ConnectionLike, query: AggregateRecordsQuery, - trace_id: Option, + traceparent: Option, ) -> InterpretationResult { let selection_order = query.selection_order; @@ -259,7 +260,7 @@ async fn aggregate( query.selectors, query.group_by, query.having, - trace_id, + traceparent, ) .await?; diff --git a/query-engine/core/src/interpreter/query_interpreters/write.rs b/query-engine/core/src/interpreter/query_interpreters/write.rs index d3146c38363..45396436980 100644 --- a/query-engine/core/src/interpreter/query_interpreters/write.rs +++ b/query-engine/core/src/interpreter/query_interpreters/write.rs @@ -7,24 +7,25 @@ use crate::{ }; use connector::{ConnectionLike, DatasourceFieldName, NativeUpsert, WriteArgs}; use query_structure::{ManyRecords, Model, RawJson}; +use telemetry::helpers::TraceParent; pub(crate) async fn execute( tx: &mut dyn ConnectionLike, write_query: WriteQuery, - trace_id: Option, + traceparent: Option, ) -> InterpretationResult { match write_query { - WriteQuery::CreateRecord(q) => create_one(tx, q, trace_id).await, - WriteQuery::CreateManyRecords(q) => create_many(tx, q, trace_id).await, - WriteQuery::UpdateRecord(q) => update_one(tx, q, trace_id).await, - WriteQuery::DeleteRecord(q) => delete_one(tx, q, trace_id).await, - WriteQuery::UpdateManyRecords(q) => update_many(tx, q, trace_id).await, - WriteQuery::DeleteManyRecords(q) => delete_many(tx, q, trace_id).await, - WriteQuery::ConnectRecords(q) => connect(tx, q, trace_id).await, - WriteQuery::DisconnectRecords(q) => disconnect(tx, q, trace_id).await, + WriteQuery::CreateRecord(q) => create_one(tx, q, traceparent).await, + WriteQuery::CreateManyRecords(q) => create_many(tx, q, traceparent).await, + WriteQuery::UpdateRecord(q) => update_one(tx, q, traceparent).await, + WriteQuery::DeleteRecord(q) => delete_one(tx, q, traceparent).await, + WriteQuery::UpdateManyRecords(q) => update_many(tx, q, traceparent).await, + WriteQuery::DeleteManyRecords(q) => delete_many(tx, q, traceparent).await, + WriteQuery::ConnectRecords(q) => connect(tx, q, traceparent).await, + WriteQuery::DisconnectRecords(q) => disconnect(tx, q, traceparent).await, WriteQuery::ExecuteRaw(q) => execute_raw(tx, q).await, WriteQuery::QueryRaw(q) => query_raw(tx, q).await, - WriteQuery::Upsert(q) => native_upsert(tx, q, trace_id).await, + WriteQuery::Upsert(q) => native_upsert(tx, q, traceparent).await, } } @@ -46,9 +47,11 @@ async fn execute_raw(tx: &mut dyn ConnectionLike, q: RawQuery) -> Interpretation async fn create_one( tx: &mut dyn ConnectionLike, q: CreateRecord, - trace_id: Option, + traceparent: Option, ) -> InterpretationResult { - let res = tx.create_record(&q.model, q.args, q.selected_fields, trace_id).await?; + let res = tx + .create_record(&q.model, q.args, q.selected_fields, traceparent) + .await?; Ok(QueryResult::RecordSelection(Some(Box::new(RecordSelection { name: q.name, @@ -63,15 +66,15 @@ async fn create_one( async fn create_many( tx: &mut dyn ConnectionLike, q: CreateManyRecords, - trace_id: Option, + traceparent: Option, ) -> InterpretationResult { if q.split_by_shape { - return create_many_split_by_shape(tx, q, trace_id).await; + return create_many_split_by_shape(tx, q, traceparent).await; } if let Some(selected_fields) = q.selected_fields { let records = tx - .create_records_returning(&q.model, q.args, q.skip_duplicates, selected_fields.fields, trace_id) + .create_records_returning(&q.model, q.args, q.skip_duplicates, selected_fields.fields, traceparent) .await?; let nested: Vec = super::read::process_nested(tx, selected_fields.nested, Some(&records)).await?; @@ -87,7 +90,9 @@ async fn create_many( Ok(QueryResult::RecordSelection(Some(Box::new(selection)))) } else { - let affected_records = tx.create_records(&q.model, q.args, q.skip_duplicates, trace_id).await?; + let affected_records = tx + .create_records(&q.model, q.args, q.skip_duplicates, traceparent) + .await?; Ok(QueryResult::Count(affected_records)) } @@ -100,7 +105,7 @@ async fn create_many( async fn create_many_split_by_shape( tx: &mut dyn ConnectionLike, q: CreateManyRecords, - trace_id: Option, + traceparent: Option, ) -> InterpretationResult { let mut args_by_shape: HashMap> = Default::default(); let model = &q.model; @@ -121,7 +126,7 @@ async fn create_many_split_by_shape( args, q.skip_duplicates, selected_fields.fields.clone(), - trace_id.clone(), + traceparent, ) .await?; @@ -139,7 +144,7 @@ async fn create_many_split_by_shape( result } else { // Empty result means that the list of arguments was empty as well. - tx.create_records_returning(&q.model, vec![], q.skip_duplicates, selected_fields.fields, trace_id) + tx.create_records_returning(&q.model, vec![], q.skip_duplicates, selected_fields.fields, traceparent) .await? }; @@ -161,7 +166,7 @@ async fn create_many_split_by_shape( for args in args_by_shape.into_values() { let affected_records = tx - .create_records(&q.model, args, q.skip_duplicates, trace_id.clone()) + .create_records(&q.model, args, q.skip_duplicates, traceparent) .await?; result += affected_records; } @@ -205,7 +210,7 @@ fn create_many_shape(write_args: &WriteArgs, model: &Model) -> CreateManyShape { async fn update_one( tx: &mut dyn ConnectionLike, q: UpdateRecord, - trace_id: Option, + traceparent: Option, ) -> InterpretationResult { let res = tx .update_record( @@ -213,7 +218,7 @@ async fn update_one( q.record_filter().clone(), q.args().clone(), q.selected_fields(), - trace_id, + traceparent, ) .await?; @@ -245,9 +250,9 @@ async fn update_one( async fn native_upsert( tx: &mut dyn ConnectionLike, query: NativeUpsert, - trace_id: Option, + traceparent: Option, ) -> InterpretationResult { - let scalars = tx.native_upsert_record(query.clone(), trace_id).await?; + let scalars = tx.native_upsert_record(query.clone(), traceparent).await?; Ok(RecordSelection { name: query.name().to_string(), @@ -263,7 +268,7 @@ async fn native_upsert( async fn delete_one( tx: &mut dyn ConnectionLike, q: DeleteRecord, - trace_id: Option, + traceparent: Option, ) -> InterpretationResult { // We need to ensure that we have a record finder, else we delete everything (conversion to empty filter). let filter = match q.record_filter { @@ -276,7 +281,7 @@ async fn delete_one( if let Some(selected_fields) = q.selected_fields { let record = tx - .delete_record(&q.model, filter, selected_fields.fields, trace_id) + .delete_record(&q.model, filter, selected_fields.fields, traceparent) .await?; let selection = RecordSelection { name: q.name, @@ -289,7 +294,7 @@ async fn delete_one( Ok(QueryResult::RecordSelection(Some(Box::new(selection)))) } else { - let result = tx.delete_records(&q.model, filter, trace_id).await?; + let result = tx.delete_records(&q.model, filter, traceparent).await?; Ok(QueryResult::Count(result)) } } @@ -297,9 +302,11 @@ async fn delete_one( async fn update_many( tx: &mut dyn ConnectionLike, q: UpdateManyRecords, - trace_id: Option, + traceparent: Option, ) -> InterpretationResult { - let res = tx.update_records(&q.model, q.record_filter, q.args, trace_id).await?; + let res = tx + .update_records(&q.model, q.record_filter, q.args, traceparent) + .await?; Ok(QueryResult::Count(res)) } @@ -307,9 +314,9 @@ async fn update_many( async fn delete_many( tx: &mut dyn ConnectionLike, q: DeleteManyRecords, - trace_id: Option, + traceparent: Option, ) -> InterpretationResult { - let res = tx.delete_records(&q.model, q.record_filter, trace_id).await?; + let res = tx.delete_records(&q.model, q.record_filter, traceparent).await?; Ok(QueryResult::Count(res)) } @@ -317,13 +324,13 @@ async fn delete_many( async fn connect( tx: &mut dyn ConnectionLike, q: ConnectRecords, - trace_id: Option, + traceparent: Option, ) -> InterpretationResult { tx.m2m_connect( &q.relation_field, &q.parent_id.expect("Expected parent record ID to be set for connect"), &q.child_ids, - trace_id, + traceparent, ) .await?; @@ -333,13 +340,13 @@ async fn connect( async fn disconnect( tx: &mut dyn ConnectionLike, q: DisconnectRecords, - trace_id: Option, + traceparent: Option, ) -> InterpretationResult { tx.m2m_disconnect( &q.relation_field, &q.parent_id.expect("Expected parent record ID to be set for disconnect"), &q.child_ids, - trace_id, + traceparent, ) .await?; diff --git a/query-engine/core/src/lib.rs b/query-engine/core/src/lib.rs index bf993d6bce1..7e1868cc017 100644 --- a/query-engine/core/src/lib.rs +++ b/query-engine/core/src/lib.rs @@ -10,13 +10,11 @@ pub mod query_document; pub mod query_graph_builder; pub mod relation_load_strategy; pub mod response_ir; -pub mod telemetry; -pub use self::telemetry::*; pub use self::{ - error::{CoreError, FieldConversionError}, + error::{CoreError, ExtendedUserFacingError, FieldConversionError}, executor::{QueryExecutor, TransactionOptions}, - interactive_transactions::{ExtendedTransactionUserFacingError, TransactionError, TxId}, + interactive_transactions::{TransactionError, TxId}, query_document::*, }; @@ -28,6 +26,7 @@ pub use connector::{ mod error; mod interactive_transactions; mod interpreter; +mod metrics; mod query_ast; mod query_graph; mod result_ast; diff --git a/query-engine/core/src/metrics.rs b/query-engine/core/src/metrics.rs new file mode 100644 index 00000000000..736096634ad --- /dev/null +++ b/query-engine/core/src/metrics.rs @@ -0,0 +1,13 @@ +/// When the `metrics` feature is disabled, we don't compile the `prisma-metrics` crate and +/// thus can't use the metrics instrumentation. To avoid the boilerplate of putting every +/// `with_current_recorder` call behind `#[cfg]`, we use this stub trait that does nothing but +/// allows the code that relies on `WithMetricsInstrumentation` trait to be in scope compile. +#[cfg(not(feature = "metrics"))] +pub(crate) trait MetricsInstrumentationStub: Sized { + fn with_current_recorder(self) -> Self { + self + } +} + +#[cfg(not(feature = "metrics"))] +impl MetricsInstrumentationStub for T {} diff --git a/query-engine/core/src/telemetry/capturing/tx_ext.rs b/query-engine/core/src/telemetry/capturing/tx_ext.rs deleted file mode 100644 index 6b1b4905ab5..00000000000 --- a/query-engine/core/src/telemetry/capturing/tx_ext.rs +++ /dev/null @@ -1,88 +0,0 @@ -use std::collections::HashMap; - -pub trait TxTraceExt { - fn into_trace_id(self) -> opentelemetry::trace::TraceId; - fn into_trace_context(self) -> opentelemetry::Context; - fn as_traceparent(&self) -> String; -} - -impl TxTraceExt for crate::TxId { - // in order to convert a TxId (a 48 bytes cuid) into a TraceId (16 bytes), we remove the first byte, - // (always 'c') and get the next 16 bytes, which are random enough to be used as a trace id. - // this is a typical cuid: "c-lct0q6ma-0004-rb04-h6en1roa" - // - // - first letter is always the same - // - next 7-8 byte are random a timestamp. There's more entropy in the least significative bytes - // - next 4 bytes are a counter since the server started - // - next 4 bytes are a system fingerprint, invariant for the same server instance - // - least significative 8 bytes. Totally random. - // - // We want the most entropic slice of 16 bytes that's deterministicly determined - fn into_trace_id(self) -> opentelemetry::trace::TraceId { - let mut buffer = [0; 16]; - let str = self.to_string(); - let tx_id_bytes = str.as_bytes(); - let len = tx_id_bytes.len(); - - // bytes [len-20 to len-12): least significative 4 bytes of the timestamp + 4 bytes counter - for (i, source_idx) in (len - 20..len - 12).enumerate() { - buffer[i] = tx_id_bytes[source_idx]; - } - // bytes [len-8 to len): the random blocks - for (i, source_idx) in (len - 8..len).enumerate() { - buffer[i + 8] = tx_id_bytes[source_idx]; - } - - opentelemetry::trace::TraceId::from_bytes(buffer) - } - // This is a bit of a hack, but it's the only way to have a default trace span for a whole - // transaction when no traceparent is propagated from the client. - // - // This is done so we can capture traces happening accross the different queries in a - // transaction. Otherwise, if a traceparent is not propagated from the client, each query in - // the transaction will run within a span that has already been generated at the begining of the - // transaction, and held active in the actor in charge of running the queries. Thus, making - // impossible to capture traces happening in the individual queries, as they won't be aware of - // the transaction they are part of. - // - // By generating this "fake" traceparent based on the transaction id, we can have a common - // trace_id for all transaction operations. - fn into_trace_context(self) -> opentelemetry::Context { - let extractor: HashMap = - HashMap::from_iter(vec![("traceparent".to_string(), self.as_traceparent())]); - opentelemetry::global::get_text_map_propagator(|propagator| propagator.extract(&extractor)) - } - - fn as_traceparent(&self) -> String { - let trace_id = self.clone().into_trace_id(); - format!("00-{trace_id}-0000000000000001-01") - } -} - -// tests for txid into traits -#[cfg(test)] -mod test { - use super::*; - use crate::TxId; - - #[test] - fn test_txid_into_traceid() { - let fixture = vec![ - ("clct0q6ma0000rb04768tiqbj", "71366d6130303030373638746971626a"), - // counter changed, trace id changed: - ("clct0q6ma0002rb04cpa6zkmx", "71366d6130303032637061367a6b6d78"), - // fingerprint changed, trace id did not change, as that chunk is ignored: - ("clct0q6ma00020000cpa6zkmx", "71366d6130303032637061367a6b6d78"), - // first 5 bytes changed, trace id did not change, as that chunk is ignored: - ("00000q6ma00020000cpa6zkmx", "71366d6130303032637061367a6b6d78"), - // 6 th byte changed, trace id changed, as that chunk is part of the lsb of the timestamp - ("0000006ma00020000cpa6zkmx", "30366d6130303032637061367a6b6d78"), - ]; - - for (txid, expected_trace_id) in fixture { - let txid: TxId = txid.into(); - let trace_id: opentelemetry::trace::TraceId = txid.into_trace_id(); - assert_eq!(trace_id.to_string(), expected_trace_id); - } - } -} diff --git a/query-engine/core/src/telemetry/helpers.rs b/query-engine/core/src/telemetry/helpers.rs deleted file mode 100644 index 30c63ed6693..00000000000 --- a/query-engine/core/src/telemetry/helpers.rs +++ /dev/null @@ -1,128 +0,0 @@ -use super::models::TraceSpan; -use once_cell::sync::Lazy; -use opentelemetry::sdk::export::trace::SpanData; -use opentelemetry::trace::{TraceContextExt, TraceId}; -use opentelemetry::Context; -use serde_json::{json, Value}; -use std::collections::HashMap; -use tracing::{Metadata, Span}; -use tracing_opentelemetry::OpenTelemetrySpanExt; -use tracing_subscriber::EnvFilter; - -pub static SHOW_ALL_TRACES: Lazy = Lazy::new(|| match std::env::var("PRISMA_SHOW_ALL_TRACES") { - Ok(enabled) => enabled.eq_ignore_ascii_case("true"), - Err(_) => false, -}); - -pub fn spans_to_json(spans: Vec) -> String { - let json_spans: Vec = spans.into_iter().map(|span| json!(TraceSpan::from(span))).collect(); - let span_result = json!({ - "span": true, - "spans": json_spans - }); - serde_json::to_string(&span_result).unwrap_or_default() -} - -// set the parent context and return the traceparent -pub fn set_parent_context_from_json_str(span: &Span, trace: &str) -> Option { - let trace: HashMap = serde_json::from_str(trace).unwrap_or_default(); - let trace_id = trace.get("traceparent").map(String::from); - let cx = opentelemetry::global::get_text_map_propagator(|propagator| propagator.extract(&trace)); - span.set_parent(cx); - trace_id -} - -pub fn set_span_link_from_traceparent(span: &Span, traceparent: Option) { - if let Some(traceparent) = traceparent { - let trace: HashMap = HashMap::from([("traceparent".to_string(), traceparent)]); - let cx = opentelemetry::global::get_text_map_propagator(|propagator| propagator.extract(&trace)); - let context_span = cx.span(); - span.add_link(context_span.span_context().clone()); - } -} - -pub fn get_trace_parent_from_span(span: &Span) -> String { - let cx = span.context(); - let binding = cx.span(); - let span_context = binding.span_context(); - - format!("00-{}-{}-01", span_context.trace_id(), span_context.span_id()) -} - -pub fn get_trace_id_from_span(span: &Span) -> TraceId { - let cx = span.context(); - get_trace_id_from_context(&cx) -} - -pub fn get_trace_id_from_context(context: &Context) -> TraceId { - let context_span = context.span(); - context_span.span_context().trace_id() -} - -pub fn get_trace_id_from_traceparent(traceparent: Option<&str>) -> TraceId { - traceparent - .unwrap_or("0-0-0-0") - .split('-') - .nth(1) - .map(|id| TraceId::from_hex(id).unwrap_or(TraceId::INVALID)) - .unwrap() -} - -pub enum QueryEngineLogLevel { - FromEnv, - Override(String), -} - -impl QueryEngineLogLevel { - fn level(self) -> Option { - match self { - Self::FromEnv => std::env::var("QE_LOG_LEVEL").ok(), - Self::Override(l) => Some(l), - } - } -} - -#[rustfmt::skip] -pub fn env_filter(log_queries: bool, qe_log_level: QueryEngineLogLevel) -> EnvFilter { - let mut filter = EnvFilter::from_default_env() - .add_directive("tide=error".parse().unwrap()) - .add_directive("tonic=error".parse().unwrap()) - .add_directive("h2=error".parse().unwrap()) - .add_directive("hyper=error".parse().unwrap()) - .add_directive("tower=error".parse().unwrap()); - - if let Some(level) = qe_log_level.level() { - filter = filter - .add_directive(format!("query_engine={}", &level).parse().unwrap()) - .add_directive(format!("query_core={}", &level).parse().unwrap()) - .add_directive(format!("query_connector={}", &level).parse().unwrap()) - .add_directive(format!("sql_query_connector={}", &level).parse().unwrap()) - .add_directive(format!("mongodb_query_connector={}", &level).parse().unwrap()); - } - - if log_queries { - filter = filter - .add_directive("quaint[{is_query}]=trace".parse().unwrap()) - .add_directive("mongodb_query_connector=debug".parse().unwrap()); - } - - filter -} - -pub fn user_facing_span_only_filter(meta: &Metadata<'_>) -> bool { - if !meta.is_span() { - return false; - } - - if *SHOW_ALL_TRACES { - return true; - } - - if meta.fields().iter().any(|f| f.name() == "user_facing") { - return true; - } - - // spans describing a quaint query. - // TODO: should this span be made user_facing in quaint? - meta.target() == "quaint::connector::metrics" && meta.name() == "quaint:query" -} diff --git a/query-engine/dmmf/src/tests/test-schemas/snapshots/odoo.snapshot.json.gz b/query-engine/dmmf/src/tests/test-schemas/snapshots/odoo.snapshot.json.gz index f65efcc2c5a..9190b361820 100644 Binary files a/query-engine/dmmf/src/tests/test-schemas/snapshots/odoo.snapshot.json.gz and b/query-engine/dmmf/src/tests/test-schemas/snapshots/odoo.snapshot.json.gz differ diff --git a/query-engine/driver-adapters/Cargo.toml b/query-engine/driver-adapters/Cargo.toml index 39a0314a476..606b33e9642 100644 --- a/query-engine/driver-adapters/Cargo.toml +++ b/query-engine/driver-adapters/Cargo.toml @@ -10,18 +10,17 @@ postgresql = ["quaint/postgresql"] [dependencies] async-trait.workspace = true +futures.workspace = true once_cell = "1.15" +prisma-metrics.path = "../../libs/metrics" serde.workspace = true serde_json.workspace = true tracing.workspace = true tracing-core = "0.1" -metrics = "0.18" uuid.workspace = true -pin-project = "1" +pin-project.workspace = true serde_repr.workspace = true -futures = "0.3" - [dev-dependencies] expect-test = "1" tokio = { version = "1", features = ["macros", "time", "sync"] } diff --git a/query-engine/driver-adapters/executor/src/recording.ts b/query-engine/driver-adapters/executor/src/recording.ts index 88b9d369bc2..5ac0f52b4cb 100644 --- a/query-engine/driver-adapters/executor/src/recording.ts +++ b/query-engine/driver-adapters/executor/src/recording.ts @@ -21,7 +21,7 @@ function recorder(adapter: DriverAdapter, recordings: Recordings) { return { provider: adapter.provider, adapterName: adapter.adapterName, - startTransaction: () => { + transactionContext: () => { throw new Error("Not implemented"); }, getConnectionInfo: () => { @@ -43,7 +43,7 @@ function replayer(adapter: DriverAdapter, recordings: Recordings) { provider: adapter.provider, adapterName: adapter.adapterName, recordings: recordings, - startTransaction: () => { + transactionContext: () => { throw new Error("Not implemented"); }, getConnectionInfo: () => { diff --git a/query-engine/driver-adapters/src/lib.rs b/query-engine/driver-adapters/src/lib.rs index 55c7de41eb8..137df06d731 100644 --- a/query-engine/driver-adapters/src/lib.rs +++ b/query-engine/driver-adapters/src/lib.rs @@ -117,14 +117,14 @@ mod arch { pub(crate) fn get_named_property(object: &::napi::JsObject, name: &str) -> JsResult where - T: ::napi::bindgen_prelude::FromNapiValue, + T: ::napi::bindgen_prelude::FromNapiValue + ::napi::bindgen_prelude::ValidateNapiValue, { object.get_named_property(name) } pub(crate) fn get_optional_named_property(object: &::napi::JsObject, name: &str) -> JsResult> where - T: ::napi::bindgen_prelude::FromNapiValue, + T: ::napi::bindgen_prelude::FromNapiValue + ::napi::bindgen_prelude::ValidateNapiValue, { if has_named_property(object, name)? { Ok(Some(get_named_property(object, name)?)) diff --git a/query-engine/driver-adapters/src/napi/adapter_method.rs b/query-engine/driver-adapters/src/napi/adapter_method.rs index dd7399d86fa..658c2003e9e 100644 --- a/query-engine/driver-adapters/src/napi/adapter_method.rs +++ b/query-engine/driver-adapters/src/napi/adapter_method.rs @@ -79,3 +79,24 @@ where Self::from_threadsafe_function(threadsafe_fn, env) } } + +impl ValidateNapiValue for AdapterMethod +where + ArgType: ToNapiValue + 'static, + ReturnType: FromNapiValue + 'static, +{ +} + +impl TypeName for AdapterMethod +where + ArgType: ToNapiValue + 'static, + ReturnType: FromNapiValue + 'static, +{ + fn type_name() -> &'static str { + "AdapterMethod" + } + + fn value_type() -> ValueType { + ValueType::Function + } +} diff --git a/query-engine/driver-adapters/src/napi/result.rs b/query-engine/driver-adapters/src/napi/result.rs index 529455bf9a0..466658df329 100644 --- a/query-engine/driver-adapters/src/napi/result.rs +++ b/query-engine/driver-adapters/src/napi/result.rs @@ -1,5 +1,8 @@ use crate::error::DriverAdapterError; -use napi::{bindgen_prelude::FromNapiValue, Env, JsUnknown, NapiValue}; +use napi::{ + bindgen_prelude::{FromNapiValue, TypeName, ValidateNapiValue}, + Env, JsUnknown, NapiValue, +}; impl FromNapiValue for DriverAdapterError { unsafe fn from_napi_value(napi_env: napi::sys::napi_env, napi_val: napi::sys::napi_value) -> napi::Result { @@ -9,6 +12,18 @@ impl FromNapiValue for DriverAdapterError { } } +impl ValidateNapiValue for DriverAdapterError {} + +impl TypeName for DriverAdapterError { + fn type_name() -> &'static str { + "DriverAdapterError" + } + + fn value_type() -> napi::ValueType { + napi::ValueType::Object + } +} + /// Wrapper for JS-side result type. /// This Napi-specific implementation has the same shape and API as the Wasm implementation, /// but it asks for a `FromNapiValue` bound on the generic type. diff --git a/query-engine/driver-adapters/src/proxy.rs b/query-engine/driver-adapters/src/proxy.rs index 8e1d39138cb..cf78a4cbb88 100644 --- a/query-engine/driver-adapters/src/proxy.rs +++ b/query-engine/driver-adapters/src/proxy.rs @@ -1,13 +1,13 @@ -use crate::send_future::UnsafeFuture; use crate::types::JsConnectionInfo; pub use crate::types::{JSResultSet, Query, TransactionOptions}; use crate::{ from_js_value, get_named_property, get_optional_named_property, to_rust_str, AdapterMethod, JsObject, JsResult, JsString, JsTransaction, }; +use crate::{send_future::UnsafeFuture, transaction::JsTransactionContext}; use futures::Future; -use metrics::increment_gauge; +use prisma_metrics::gauge; use std::sync::atomic::{AtomicBool, Ordering}; /// Proxy is a struct wrapping a javascript object that exhibits basic primitives for @@ -28,8 +28,19 @@ pub(crate) struct CommonProxy { /// This is a JS proxy for accessing the methods specific to top level /// JS driver objects pub(crate) struct DriverProxy { - start_transaction: AdapterMethod<(), JsTransaction>, + /// Retrieve driver-specific info, such as the maximum number of query parameters get_connection_info: Option>, + + /// Provide a transaction context, in which raw commands are guaranteed to be executed in + /// the same scope as a future transaction, which can be spawned by via + /// [`driver_adapters::transaction::JsTransactionContext::start_transaction`]. + /// This was first introduced for supporting Isolation Levels in PlanetScale. + transaction_context: AdapterMethod<(), JsTransactionContext>, +} + +/// This is a JS proxy for accessing the methods specific to JS transaction contexts. +pub(crate) struct TransactionContextProxy { + start_transaction: AdapterMethod<(), JsTransaction>, } /// This a JS proxy for accessing the methods, specific @@ -48,6 +59,7 @@ pub(crate) struct TransactionProxy { closed: AtomicBool, } +// TypeScript: Queryable impl CommonProxy { pub fn new(object: &JsObject) -> JsResult { let provider: JsString = get_named_property(object, "provider")?; @@ -68,11 +80,12 @@ impl CommonProxy { } } +// TypeScript: DriverAdapter impl DriverProxy { pub fn new(object: &JsObject) -> JsResult { Ok(Self { - start_transaction: get_named_property(object, "startTransaction")?, get_connection_info: get_optional_named_property(object, "getConnectionInfo")?, + transaction_context: get_named_property(object, "transactionContext")?, }) } @@ -87,6 +100,20 @@ impl DriverProxy { .await } + pub async fn transaction_context(&self) -> quaint::Result { + let ctx = self.transaction_context.call_as_async(()).await?; + + Ok(ctx) + } +} + +impl TransactionContextProxy { + pub fn new(object: &JsObject) -> JsResult { + let start_transaction = get_named_property(object, "startTransaction")?; + + Ok(Self { start_transaction }) + } + async fn start_transaction_inner(&self) -> quaint::Result> { let tx = self.start_transaction.call_as_async(()).await?; @@ -94,11 +121,11 @@ impl DriverProxy { // Previously, it was done in JsTransaction::new, similar to the native Transaction. // However, correct Dispatcher is lost there and increment does not register, so we moved // it here instead. - increment_gauge!("prisma_client_queries_active", 1.0); + gauge!("prisma_client_queries_active").increment(1.0); Ok(Box::new(tx)) } - pub fn start_transaction(&self) -> UnsafeFuture>> + '_> { + pub fn start_transaction(&self) -> impl Future>> + '_ { UnsafeFuture(self.start_transaction_inner()) } } @@ -184,6 +211,8 @@ macro_rules! impl_send_sync_on_wasm { // Assume the proxy object will not be sent to service workers, we can unsafe impl Send + Sync. impl_send_sync_on_wasm!(TransactionProxy); +impl_send_sync_on_wasm!(JsTransaction); +impl_send_sync_on_wasm!(TransactionContextProxy); +impl_send_sync_on_wasm!(JsTransactionContext); impl_send_sync_on_wasm!(DriverProxy); impl_send_sync_on_wasm!(CommonProxy); -impl_send_sync_on_wasm!(JsTransaction); diff --git a/query-engine/driver-adapters/src/queryable.rs b/query-engine/driver-adapters/src/queryable.rs index 4e47e9c5163..8aa7579762f 100644 --- a/query-engine/driver-adapters/src/queryable.rs +++ b/query-engine/driver-adapters/src/queryable.rs @@ -30,12 +30,18 @@ use tracing::{info_span, Instrument}; pub(crate) struct JsBaseQueryable { pub(crate) proxy: CommonProxy, pub provider: AdapterFlavour, + pub(crate) db_system_name: &'static str, } impl JsBaseQueryable { pub(crate) fn new(proxy: CommonProxy) -> Self { let provider: AdapterFlavour = proxy.provider.parse().unwrap(); - Self { proxy, provider } + let db_system_name = provider.db_system_name(); + Self { + proxy, + provider, + db_system_name, + } } /// visit a quaint query AST according to the provider of the JS connector @@ -84,7 +90,7 @@ impl QuaintQueryable for JsBaseQueryable { } async fn query_raw(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { - metrics::query("js.query_raw", sql, params, move || async move { + metrics::query("js.query_raw", self.db_system_name, sql, params, move || async move { self.do_query_raw(sql, params).await }) .await @@ -104,7 +110,7 @@ impl QuaintQueryable for JsBaseQueryable { } async fn execute_raw(&self, sql: &str, params: &[quaint::Value<'_>]) -> quaint::Result { - metrics::query("js.execute_raw", sql, params, move || async move { + metrics::query("js.execute_raw", self.db_system_name, sql, params, move || async move { self.do_execute_raw(sql, params).await }) .await @@ -116,7 +122,7 @@ impl QuaintQueryable for JsBaseQueryable { async fn raw_cmd(&self, cmd: &str) -> quaint::Result<()> { let params = &[]; - metrics::query("js.raw_cmd", cmd, params, move || async move { + metrics::query("js.raw_cmd", self.db_system_name, cmd, params, move || async move { self.do_execute_raw(cmd, params).await?; Ok(()) }) @@ -174,7 +180,7 @@ impl JsBaseQueryable { let serialization_span = info_span!("js:query:args", user_facing = true, "length" = %len); let query = self.build_query(sql, params).instrument(serialization_span).await?; - let sql_span = info_span!("js:query:sql", user_facing = true, "db.statement" = %sql); + let sql_span = info_span!("js:query:sql", user_facing = true, "db.system" = %self.db_system_name, "db.statement" = %sql, "otel.kind" = "client"); let result_set = self.proxy.query_raw(query).instrument(sql_span).await?; let len = result_set.len(); @@ -196,7 +202,7 @@ impl JsBaseQueryable { let serialization_span = info_span!("js:query:args", user_facing = true, "length" = %len); let query = self.build_query(sql, params).instrument(serialization_span).await?; - let sql_span = info_span!("js:query:sql", user_facing = true, "db.statement" = %sql); + let sql_span = info_span!("js:query:sql", user_facing = true, "db.system" = %self.db_system_name, "db.statement" = %sql, "otel.kind" = "client"); let affected_rows = self.proxy.execute_raw(query).instrument(sql_span).await?; Ok(affected_rows as u64) @@ -301,25 +307,32 @@ impl QuaintQueryable for JsQueryable { } } -#[async_trait] -impl TransactionCapable for JsQueryable { - async fn start_transaction<'a>( +impl JsQueryable { + async fn start_transaction_inner<'a>( &'a self, isolation: Option, ) -> quaint::Result> { - let tx = self.driver_proxy.start_transaction().await?; + // 1. Obtain a transaction context from the driver. + // Any command run on this context is guaranteed to be part of the same session + // as the transaction spawned from it. + let tx_ctx = self.driver_proxy.transaction_context().await?; - let isolation_first = tx.requires_isolation_first(); + let requires_isolation_first = tx_ctx.requires_isolation_first(); - if isolation_first { + // 2. Set the isolation level (if specified) if the provider requires it to be set before + // creating the transaction. + if requires_isolation_first { if let Some(isolation) = isolation { - tx.set_tx_isolation_level(isolation).await?; + tx_ctx.set_tx_isolation_level(isolation).await?; } } - let begin_stmt = tx.begin_statement(); + // 3. Spawn a transaction from the context. + let tx = tx_ctx.start_transaction().await?; + let begin_stmt = tx.begin_statement(); let tx_opts = tx.options(); + if tx_opts.use_phantom_query { let begin_stmt = JsBaseQueryable::phantom_query_message(begin_stmt); tx.raw_phantom_cmd(begin_stmt.as_str()).await?; @@ -327,7 +340,8 @@ impl TransactionCapable for JsQueryable { tx.raw_cmd(begin_stmt).await?; } - if !isolation_first { + // 4. Set the isolation level (if specified) if we didn't do it before. + if !requires_isolation_first { if let Some(isolation) = isolation { tx.set_tx_isolation_level(isolation).await?; } @@ -339,6 +353,16 @@ impl TransactionCapable for JsQueryable { } } +#[async_trait] +impl TransactionCapable for JsQueryable { + async fn start_transaction<'a>( + &'a self, + isolation: Option, + ) -> quaint::Result> { + UnsafeFuture(self.start_transaction_inner(isolation)).await + } +} + pub fn from_js(driver: JsObject) -> JsQueryable { let common = CommonProxy::new(&driver).unwrap(); let driver_proxy = DriverProxy::new(&driver).unwrap(); diff --git a/query-engine/driver-adapters/src/transaction.rs b/query-engine/driver-adapters/src/transaction.rs index b3dd6463089..3a1167159ae 100644 --- a/query-engine/driver-adapters/src/transaction.rs +++ b/query-engine/driver-adapters/src/transaction.rs @@ -1,15 +1,85 @@ +use std::future::Future; + use async_trait::async_trait; -use metrics::decrement_gauge; +use prisma_metrics::gauge; use quaint::{ connector::{DescribedQuery, IsolationLevel, Transaction as QuaintTransaction}, prelude::{Query as QuaintQuery, Queryable, ResultSet}, Value, }; -use crate::proxy::{TransactionOptions, TransactionProxy}; +use crate::proxy::{TransactionContextProxy, TransactionOptions, TransactionProxy}; use crate::{proxy::CommonProxy, queryable::JsBaseQueryable, send_future::UnsafeFuture}; use crate::{JsObject, JsResult}; +pub(crate) struct JsTransactionContext { + tx_ctx_proxy: TransactionContextProxy, + inner: JsBaseQueryable, +} + +// Wrapper around JS transaction context objects that implements Queryable. Can be used in place of quaint transaction, +// context, but delegates most operations to JS +impl JsTransactionContext { + pub(crate) fn new(inner: JsBaseQueryable, tx_ctx_proxy: TransactionContextProxy) -> Self { + Self { inner, tx_ctx_proxy } + } + + pub fn start_transaction(&self) -> impl Future>> + '_ { + UnsafeFuture(self.tx_ctx_proxy.start_transaction()) + } +} + +#[async_trait] +impl Queryable for JsTransactionContext { + async fn query(&self, q: QuaintQuery<'_>) -> quaint::Result { + self.inner.query(q).await + } + + async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> quaint::Result { + self.inner.query_raw(sql, params).await + } + + async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> quaint::Result { + self.inner.query_raw_typed(sql, params).await + } + + async fn execute(&self, q: QuaintQuery<'_>) -> quaint::Result { + self.inner.execute(q).await + } + + async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> quaint::Result { + self.inner.execute_raw(sql, params).await + } + + async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> quaint::Result { + self.inner.execute_raw_typed(sql, params).await + } + + async fn raw_cmd(&self, cmd: &str) -> quaint::Result<()> { + self.inner.raw_cmd(cmd).await + } + + async fn version(&self) -> quaint::Result> { + self.inner.version().await + } + + async fn describe_query(&self, sql: &str) -> quaint::Result { + self.inner.describe_query(sql).await + } + + fn is_healthy(&self) -> bool { + self.inner.is_healthy() + } + + async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> quaint::Result<()> { + self.inner.set_tx_isolation_level(isolation_level).await + } + + fn requires_isolation_first(&self) -> bool { + self.inner.requires_isolation_first() + } +} + // Wrapper around JS transaction objects that implements Queryable // and quaint::Transaction. Can be used in place of quaint transaction, // but delegates most operations to JS @@ -29,7 +99,14 @@ impl JsTransaction { pub async fn raw_phantom_cmd(&self, cmd: &str) -> quaint::Result<()> { let params = &[]; - quaint::connector::metrics::query("js.raw_phantom_cmd", cmd, params, move || async move { Ok(()) }).await + quaint::connector::metrics::query( + "js.raw_phantom_cmd", + self.inner.db_system_name, + cmd, + params, + move || async move { Ok(()) }, + ) + .await } } @@ -37,7 +114,7 @@ impl JsTransaction { impl QuaintTransaction for JsTransaction { async fn commit(&self) -> quaint::Result<()> { // increment of this gauge is done in DriverProxy::startTransaction - decrement_gauge!("prisma_client_queries_active", 1.0); + gauge!("prisma_client_queries_active").decrement(1.0); let commit_stmt = "COMMIT"; @@ -53,7 +130,7 @@ impl QuaintTransaction for JsTransaction { async fn rollback(&self) -> quaint::Result<()> { // increment of this gauge is done in DriverProxy::startTransaction - decrement_gauge!("prisma_client_queries_active", 1.0); + gauge!("prisma_client_queries_active").decrement(1.0); let rollback_stmt = "ROLLBACK"; @@ -149,3 +226,30 @@ impl ::napi::bindgen_prelude::FromNapiValue for JsTransaction { Ok(Self::new(JsBaseQueryable::new(common_proxy), tx_proxy)) } } + +#[cfg(target_arch = "wasm32")] +impl super::wasm::FromJsValue for JsTransactionContext { + fn from_js_value(value: wasm_bindgen::prelude::JsValue) -> JsResult { + use wasm_bindgen::JsCast; + + let object = value.dyn_into::()?; + let common_proxy = CommonProxy::new(&object)?; + let base = JsBaseQueryable::new(common_proxy); + let tx_ctx_proxy = TransactionContextProxy::new(&object)?; + + Ok(Self::new(base, tx_ctx_proxy)) + } +} + +/// Implementing unsafe `from_napi_value` allows retrieving a threadsafe `JsTransactionContext` in `DriverProxy` +/// while keeping derived futures `Send`. +#[cfg(not(target_arch = "wasm32"))] +impl ::napi::bindgen_prelude::FromNapiValue for JsTransactionContext { + unsafe fn from_napi_value(env: napi::sys::napi_env, napi_val: napi::sys::napi_value) -> JsResult { + let object = JsObject::from_napi_value(env, napi_val)?; + let common_proxy = CommonProxy::new(&object)?; + let tx_ctx_proxy = TransactionContextProxy::new(&object)?; + + Ok(Self::new(JsBaseQueryable::new(common_proxy), tx_ctx_proxy)) + } +} diff --git a/query-engine/driver-adapters/src/types.rs b/query-engine/driver-adapters/src/types.rs index 9e5f1eae149..03f9c5d6325 100644 --- a/query-engine/driver-adapters/src/types.rs +++ b/query-engine/driver-adapters/src/types.rs @@ -25,6 +25,19 @@ pub enum AdapterFlavour { Sqlite, } +impl AdapterFlavour { + pub fn db_system_name(&self) -> &'static str { + match self { + #[cfg(feature = "mysql")] + Self::Mysql => "mysql", + #[cfg(feature = "postgresql")] + Self::Postgres => "postgresql", + #[cfg(feature = "sqlite")] + Self::Sqlite => "sqlite", + } + } +} + impl FromStr for AdapterFlavour { type Err = String; diff --git a/query-engine/metrics/src/recorder.rs b/query-engine/metrics/src/recorder.rs deleted file mode 100644 index 94d2c050f60..00000000000 --- a/query-engine/metrics/src/recorder.rs +++ /dev/null @@ -1,122 +0,0 @@ -use std::sync::Arc; - -use metrics::KeyName; -use metrics::{Counter, CounterFn, Gauge, GaugeFn, Histogram, HistogramFn, Key, Recorder, Unit}; -use tracing::trace; - -use super::common::KeyLabels; -use super::{METRIC_COUNTER, METRIC_DESCRIPTION, METRIC_GAUGE, METRIC_HISTOGRAM, METRIC_TARGET}; - -#[derive(Default)] -pub(crate) struct MetricRecorder; - -impl MetricRecorder { - fn register_description(&self, name: &str, description: &str) { - trace!( - target: METRIC_TARGET, - name = name, - metric_type = METRIC_DESCRIPTION, - description = description - ); - } -} - -impl Recorder for MetricRecorder { - fn describe_counter(&self, key_name: KeyName, _unit: Option, description: &'static str) { - self.register_description(key_name.as_str(), description); - } - - fn describe_gauge(&self, key_name: KeyName, _unit: Option, description: &'static str) { - self.register_description(key_name.as_str(), description); - } - - fn describe_histogram(&self, key_name: KeyName, _unit: Option, description: &'static str) { - self.register_description(key_name.as_str(), description); - } - - fn register_counter(&self, key: &Key) -> Counter { - Counter::from_arc(Arc::new(MetricHandle(key.clone()))) - } - - fn register_gauge(&self, key: &Key) -> Gauge { - Gauge::from_arc(Arc::new(MetricHandle(key.clone()))) - } - - fn register_histogram(&self, key: &Key) -> Histogram { - Histogram::from_arc(Arc::new(MetricHandle(key.clone()))) - } -} - -pub(crate) struct MetricHandle(Key); - -impl CounterFn for MetricHandle { - fn increment(&self, value: u64) { - let keylabels: KeyLabels = self.0.clone().into(); - let json_string = serde_json::to_string(&keylabels).unwrap(); - trace!( - target: METRIC_TARGET, - key_labels = json_string.as_str(), - metric_type = METRIC_COUNTER, - increment = value, - ); - } - - fn absolute(&self, value: u64) { - let keylabels: KeyLabels = self.0.clone().into(); - let json_string = serde_json::to_string(&keylabels).unwrap(); - trace!( - target: METRIC_TARGET, - key_labels = json_string.as_str(), - metric_type = METRIC_COUNTER, - absolute = value, - ); - } -} - -impl GaugeFn for MetricHandle { - fn increment(&self, value: f64) { - let keylabels: KeyLabels = self.0.clone().into(); - let json_string = serde_json::to_string(&keylabels).unwrap(); - trace!( - target: METRIC_TARGET, - key_labels = json_string.as_str(), - metric_type = METRIC_GAUGE, - gauge_inc = value, - ); - } - - fn decrement(&self, value: f64) { - let keylabels: KeyLabels = self.0.clone().into(); - let json_string = serde_json::to_string(&keylabels).unwrap(); - trace!( - target: METRIC_TARGET, - key_labels = json_string.as_str(), - metric_type = METRIC_GAUGE, - gauge_dec = value, - ); - } - - fn set(&self, value: f64) { - let keylabels: KeyLabels = self.0.clone().into(); - let json_string = serde_json::to_string(&keylabels).unwrap(); - trace!( - target: METRIC_TARGET, - key_labels = json_string.as_str(), - metric_type = METRIC_GAUGE, - gauge_set = value, - ); - } -} - -impl HistogramFn for MetricHandle { - fn record(&self, value: f64) { - let keylabels: KeyLabels = self.0.clone().into(); - let json_string = serde_json::to_string(&keylabels).unwrap(); - trace!( - target: METRIC_TARGET, - key_labels = json_string.as_str(), - metric_type = METRIC_HISTOGRAM, - hist_record = value, - ); - } -} diff --git a/query-engine/query-engine-c-abi/Cargo.toml b/query-engine/query-engine-c-abi/Cargo.toml index d130d33be68..5542ba38a43 100644 --- a/query-engine/query-engine-c-abi/Cargo.toml +++ b/query-engine/query-engine-c-abi/Cargo.toml @@ -17,6 +17,7 @@ request-handlers = { path = "../request-handlers", features = [ ] } query-connector = { path = "../connectors/query-connector" } query-engine-common = { path = "../../libs/query-engine-common" } +telemetry = { path = "../../libs/telemetry" } user-facing-errors = { path = "../../libs/user-facing-errors" } psl = { workspace = true, features = ["sqlite"] } sql-connector = { path = "../connectors/sql-query-connector", package = "sql-query-connector" } @@ -36,13 +37,14 @@ indoc.workspace = true tracing = "0.1" tracing-subscriber = { version = "0.3" } -tracing-futures = "0.2" +tracing-futures.workspace = true tracing-opentelemetry = "0.17.3" opentelemetry = { version = "0.17" } tokio.workspace = true -futures = "0.3" +futures.workspace = true once_cell = "1.19.0" [build-dependencies] cbindgen = "0.24.0" +build-utils.path = "../../libs/build-utils" diff --git a/query-engine/query-engine-c-abi/build.rs b/query-engine/query-engine-c-abi/build.rs index 0739d31bf25..b8f3fcdbaff 100644 --- a/query-engine/query-engine-c-abi/build.rs +++ b/query-engine/query-engine-c-abi/build.rs @@ -1,13 +1,4 @@ -extern crate cbindgen; - use std::env; -use std::process::Command; - -fn store_git_commit_hash() { - let output = Command::new("git").args(["rev-parse", "HEAD"]).output().unwrap(); - let git_hash = String::from_utf8(output.stdout).unwrap(); - println!("cargo:rustc-env=GIT_HASH={git_hash}"); -} fn generate_c_headers() { let crate_dir = env::var("CARGO_MANIFEST_DIR").unwrap(); @@ -28,6 +19,6 @@ fn main() { // Tell Cargo that if the given file changes, to rerun this build script. println!("cargo:rerun-if-changed=src/engine.rs"); // println!("✅ Running build.rs"); - store_git_commit_hash(); + build_utils::store_git_commit_hash_in_env(); generate_c_headers(); } diff --git a/query-engine/query-engine-c-abi/src/engine.rs b/query-engine/query-engine-c-abi/src/engine.rs index 4d3d8163675..e0dff1ca1da 100644 --- a/query-engine/query-engine-c-abi/src/engine.rs +++ b/query-engine/query-engine-c-abi/src/engine.rs @@ -9,7 +9,7 @@ use once_cell::sync::Lazy; use query_core::{ protocol::EngineProtocol, schema::{self}, - telemetry, TransactionOptions, TxId, + TransactionOptions, TxId, }; use request_handlers::{load_executor, RequestBody, RequestHandler}; use serde_json::json; @@ -20,11 +20,13 @@ use std::{ ptr::null_mut, sync::Arc, }; +use telemetry::helpers::TraceParent; use tokio::{ runtime::{self, Runtime}, sync::RwLock, }; -use tracing::{field, instrument::WithSubscriber, level_filters::LevelFilter, Instrument}; +use tracing::{instrument::WithSubscriber, level_filters::LevelFilter, Instrument}; +use tracing_opentelemetry::OpenTelemetrySpanExt; use query_engine_common::Result; use query_engine_common::{ @@ -201,8 +203,9 @@ impl QueryEngine { let trace_string = get_cstr_safe(trace).expect("Connect trace is missing"); - let span = tracing::info_span!("prisma:engine:connect"); - let _ = telemetry::helpers::set_parent_context_from_json_str(&span, &trace_string); + let span = tracing::info_span!("prisma:engine:connect", user_facing = true); + let parent_context = telemetry::helpers::restore_remote_context_from_json_str(&trace_string); + span.set_parent(parent_context); let mut inner = self.inner.write().await; let builder = inner.as_builder()?; @@ -238,7 +241,7 @@ impl QueryEngine { let conn_span = tracing::info_span!( "prisma:engine:connection", user_facing = true, - "db.type" = connector.name(), + "db.system" = connector.name(), ); connector.get_connection().instrument(conn_span).await?; @@ -293,12 +296,14 @@ impl QueryEngine { let query = RequestBody::try_from_str(&body, engine.engine_protocol())?; - let span = tracing::info_span!("prisma:engine", user_facing = true); - let trace_id = telemetry::helpers::set_parent_context_from_json_str(&span, &trace); + let span = tracing::info_span!("prisma:engine:query", user_facing = true); + let parent_context = telemetry::helpers::restore_remote_context_from_json_str(&trace); + let traceparent = TraceParent::from_remote_context(&parent_context); + span.set_parent(parent_context); async move { let handler = RequestHandler::new(engine.executor(), engine.query_schema(), engine.engine_protocol()); - let response = handler.handle(query, tx_id.map(TxId::from), trace_id).await; + let response = handler.handle(query, tx_id.map(TxId::from), traceparent).await; let serde_span = tracing::info_span!("prisma:engine:response_json_serialization", user_facing = true); Ok(serde_span.in_scope(|| serde_json::to_string(&response))?) @@ -315,8 +320,9 @@ impl QueryEngine { let trace = get_cstr_safe(trace_str).expect("Trace is needed"); let dispatcher = self.logger.dispatcher(); async { - let span = tracing::info_span!("prisma:engine:disconnect"); - let _ = telemetry::helpers::set_parent_context_from_json_str(&span, &trace); + let span = tracing::info_span!("prisma:engine:disconnect", user_facing = true); + let parent_context = telemetry::helpers::restore_remote_context_from_json_str(&trace); + span.set_parent(parent_context); async { let mut inner = self.inner.write().await; @@ -393,8 +399,9 @@ impl QueryEngine { let dispatcher = self.logger.dispatcher(); async move { - let span = tracing::info_span!("prisma:engine:itx_runner", user_facing = true, itx_id = field::Empty); - telemetry::helpers::set_parent_context_from_json_str(&span, &trace); + let span = tracing::info_span!("prisma:engine:start_transaction", user_facing = true); + let parent_context = telemetry::helpers::restore_remote_context_from_json_str(&trace); + span.set_parent(parent_context); let tx_opts: TransactionOptions = serde_json::from_str(&input)?; match engine @@ -412,15 +419,20 @@ impl QueryEngine { } // If connected, attempts to commit a transaction with id `tx_id` in the core. - pub async fn commit_transaction(&self, tx_id_str: *const c_char, _trace: *const c_char) -> Result { + pub async fn commit_transaction(&self, tx_id_str: *const c_char, trace: *const c_char) -> Result { let tx_id = get_cstr_safe(tx_id_str).expect("Input string missing"); + let trace = get_cstr_safe(trace).expect("trace is required in transactions"); let inner = self.inner.read().await; let engine = inner.as_engine()?; let dispatcher = self.logger.dispatcher(); async move { - match engine.executor().commit_tx(TxId::from(tx_id)).await { + let span = tracing::info_span!("prisma:engine:commit_transaction", user_facing = true); + let parent_context = telemetry::helpers::restore_remote_context_from_json_str(&trace); + span.set_parent(parent_context); + + match engine.executor().commit_tx(TxId::from(tx_id)).instrument(span).await { Ok(_) => Ok("{}".to_string()), Err(err) => Ok(map_known_error(err)?), } @@ -430,15 +442,19 @@ impl QueryEngine { } // If connected, attempts to roll back a transaction with id `tx_id` in the core. - pub async fn rollback_transaction(&self, tx_id_str: *const c_char, _trace: *const c_char) -> Result { + pub async fn rollback_transaction(&self, tx_id_str: *const c_char, trace: *const c_char) -> Result { let tx_id = get_cstr_safe(tx_id_str).expect("Input string missing"); - // let trace = get_cstr_safe(trace_str).expect("trace is required in transactions"); + let trace = get_cstr_safe(trace).expect("trace is required in transactions"); let inner = self.inner.read().await; let engine = inner.as_engine()?; let dispatcher = self.logger.dispatcher(); async move { + let span = tracing::info_span!("prisma:engine:rollback_transaction", user_facing = true); + let parent_context = telemetry::helpers::restore_remote_context_from_json_str(&trace); + span.set_parent(parent_context); + match engine.executor().rollback_tx(TxId::from(tx_id)).await { Ok(_) => Ok("{}".to_string()), Err(err) => Ok(map_known_error(err)?), diff --git a/query-engine/query-engine-c-abi/src/logger.rs b/query-engine/query-engine-c-abi/src/logger.rs index 3585b94e14a..fbb69c38e67 100644 --- a/query-engine/query-engine-c-abi/src/logger.rs +++ b/query-engine/query-engine-c-abi/src/logger.rs @@ -1,7 +1,6 @@ use core::fmt; -use query_core::telemetry; + use query_engine_common::logger::StringCallback; -// use query_engine_metrics::MetricRegistry; use serde_json::Value; use std::sync::Arc; use std::{collections::BTreeMap, fmt::Display}; @@ -20,7 +19,6 @@ pub(crate) type LogCallback = Box; pub(crate) struct Logger { dispatcher: Dispatch, - // metrics: Option, } impl Logger { @@ -58,26 +56,14 @@ impl Logger { let layer = CallbackLayer::new(log_callback).with_filter(filters); - // let metrics = if enable_metrics { - // query_engine_metrics::setup(); - // Some(MetricRegistry::new()) - // } else { - // None - // }; - Self { dispatcher: Dispatch::new(Registry::default().with(telemetry).with(layer)), - // metrics, } } pub fn dispatcher(&self) -> Dispatch { self.dispatcher.clone() } - - // pub fn metrics(&self) -> Option { - // self.metrics.clone() - // } } pub struct JsonVisitor<'a> { diff --git a/query-engine/query-engine-node-api/Cargo.toml b/query-engine/query-engine-node-api/Cargo.toml index cbe4f455b58..b4ec9eb5f36 100644 --- a/query-engine/query-engine-node-api/Cargo.toml +++ b/query-engine/query-engine-node-api/Cargo.toml @@ -24,6 +24,7 @@ request-handlers = { path = "../request-handlers", features = ["all"] } query-connector = { path = "../connectors/query-connector" } query-engine-common = { path = "../../libs/query-engine-common" } user-facing-errors = { path = "../../libs/user-facing-errors" } +telemetry = { path = "../../libs/telemetry" } psl = { workspace = true, features = ["all"] } sql-connector = { path = "../connectors/sql-query-connector", package = "sql-query-connector", features = [ "all-native", @@ -45,14 +46,15 @@ serde.workspace = true tracing.workspace = true tracing-subscriber = { version = "0.3" } -tracing-futures = "0.2" +tracing-futures.workspace = true tracing-opentelemetry = "0.17.3" opentelemetry = { version = "0.17" } quaint.workspace = true tokio.workspace = true -futures = "0.3" -query-engine-metrics = { path = "../metrics" } +futures.workspace = true +prisma-metrics.path = "../../libs/metrics" [build-dependencies] napi-build = "1" +build-utils.path = "../../libs/build-utils" diff --git a/query-engine/query-engine-node-api/build.rs b/query-engine/query-engine-node-api/build.rs index 2ed42a66137..eb0c9b2fe74 100644 --- a/query-engine/query-engine-node-api/build.rs +++ b/query-engine/query-engine-node-api/build.rs @@ -1,14 +1,4 @@ -extern crate napi_build; - -use std::process::Command; - -fn store_git_commit_hash() { - let output = Command::new("git").args(["rev-parse", "HEAD"]).output().unwrap(); - let git_hash = String::from_utf8(output.stdout).unwrap(); - println!("cargo:rustc-env=GIT_HASH={git_hash}"); -} - fn main() { - store_git_commit_hash(); + build_utils::store_git_commit_hash_in_env(); napi_build::setup() } diff --git a/query-engine/query-engine-node-api/src/engine.rs b/query-engine/query-engine-node-api/src/engine.rs index 01c78b6e2c1..7e9515e4d22 100644 --- a/query-engine/query-engine-node-api/src/engine.rs +++ b/query-engine/query-engine-node-api/src/engine.rs @@ -2,20 +2,22 @@ use crate::{error::ApiError, logger::Logger}; use futures::FutureExt; use napi::{threadsafe_function::ThreadSafeCallContext, Env, JsFunction, JsObject, JsUnknown}; use napi_derive::napi; +use prisma_metrics::{MetricFormat, WithMetricsInstrumentation}; use psl::PreviewFeature; use quaint::connector::ExternalConnector; -use query_core::{protocol::EngineProtocol, relation_load_strategy, schema, telemetry, TransactionOptions, TxId}; +use query_core::{protocol::EngineProtocol, relation_load_strategy, schema, TransactionOptions, TxId}; use query_engine_common::engine::{ map_known_error, stringify_env_values, ConnectedEngine, ConnectedEngineNative, ConstructorOptions, ConstructorOptionsNative, EngineBuilder, EngineBuilderNative, Inner, }; -use query_engine_metrics::MetricFormat; use request_handlers::{load_executor, render_graphql_schema, ConnectorKind, RequestBody, RequestHandler}; use serde::Deserialize; use serde_json::json; use std::{collections::HashMap, future::Future, marker::PhantomData, panic::AssertUnwindSafe, sync::Arc}; +use telemetry::helpers::TraceParent; use tokio::sync::RwLock; -use tracing::{field, instrument::WithSubscriber, Instrument, Span}; +use tracing::{instrument::WithSubscriber, Instrument}; +use tracing_opentelemetry::OpenTelemetrySpanExt; use tracing_subscriber::filter::LevelFilter; use user_facing_errors::Error; @@ -71,12 +73,12 @@ impl QueryEngine { native, } = napi_env.from_js_value(options).expect( r###" - Failed to deserialize constructor options. - - This usually happens when the javascript object passed to the constructor is missing + Failed to deserialize constructor options. + + This usually happens when the javascript object passed to the constructor is missing properties for the ConstructorOptions fields that must have some value. - - If you set some of these in javascript trough environment variables, make sure there are + + If you set some of these in javascript through environment variables, make sure there are values for data_model, log_level, and any field that is not Option "###, ); @@ -149,21 +151,6 @@ impl QueryEngine { let log_level = log_level.parse::().unwrap(); let logger = Logger::new(log_queries, log_level, log_callback, enable_metrics, enable_tracing); - // Describe metrics adds all the descriptions and default values for our metrics - // this needs to run once our metrics pipeline has been configured and it needs to - // use the correct logging subscriber(our dispatch) so that the metrics recorder recieves - // it - if enable_metrics { - napi_env.execute_tokio_future( - async { - query_engine_metrics::initialize_metrics(); - Ok(()) - } - .with_subscriber(logger.dispatcher()), - |&mut _env, _data| Ok(()), - )?; - } - Ok(Self { connector_mode, inner: RwLock::new(Inner::Builder(builder)), @@ -175,10 +162,12 @@ impl QueryEngine { #[napi] pub async fn connect(&self, trace: String) -> napi::Result<()> { let dispatcher = self.logger.dispatcher(); + let recorder = self.logger.recorder(); async_panic_to_js_error(async { - let span = tracing::info_span!("prisma:engine:connect"); - let _ = telemetry::helpers::set_parent_context_from_json_str(&span, &trace); + let span = tracing::info_span!("prisma:engine:connect", user_facing = true); + let parent_context = telemetry::helpers::restore_remote_context_from_json_str(&trace); + span.set_parent(parent_context); let mut inner = self.inner.write().await; let builder = inner.as_builder()?; @@ -224,7 +213,7 @@ impl QueryEngine { let conn_span = tracing::info_span!( "prisma:engine:connection", user_facing = true, - "db.type" = connector.name(), + "db.system" = connector.name(), ); let conn = connector.get_connection().instrument(conn_span).await?; @@ -268,6 +257,7 @@ impl QueryEngine { Ok(()) }) .with_subscriber(dispatcher) + .with_optional_recorder(recorder) .await?; Ok(()) @@ -277,10 +267,12 @@ impl QueryEngine { #[napi] pub async fn disconnect(&self, trace: String) -> napi::Result<()> { let dispatcher = self.logger.dispatcher(); + let recorder = self.logger.recorder(); async_panic_to_js_error(async { - let span = tracing::info_span!("prisma:engine:disconnect"); - let _ = telemetry::helpers::set_parent_context_from_json_str(&span, &trace); + let span = tracing::info_span!("prisma:engine:disconnect", user_facing = true); + let parent_context = telemetry::helpers::restore_remote_context_from_json_str(&trace); + span.set_parent(parent_context); // TODO: when using Node Drivers, we need to call Driver::close() here. @@ -305,6 +297,7 @@ impl QueryEngine { .await }) .with_subscriber(dispatcher) + .with_optional_recorder(recorder) .await } @@ -312,6 +305,7 @@ impl QueryEngine { #[napi] pub async fn query(&self, body: String, trace: String, tx_id: Option) -> napi::Result { let dispatcher = self.logger.dispatcher(); + let recorder = self.logger.recorder(); async_panic_to_js_error(async { let inner = self.inner.read().await; @@ -319,17 +313,14 @@ impl QueryEngine { let query = RequestBody::try_from_str(&body, engine.engine_protocol())?; - let span = if tx_id.is_none() { - tracing::info_span!("prisma:engine", user_facing = true) - } else { - Span::none() - }; - - let trace_id = telemetry::helpers::set_parent_context_from_json_str(&span, &trace); + let span = tracing::info_span!("prisma:engine:query", user_facing = true); + let parent_context = telemetry::helpers::restore_remote_context_from_json_str(&trace); + let traceparent = TraceParent::from_remote_context(&parent_context); + span.set_parent(parent_context); async move { let handler = RequestHandler::new(engine.executor(), engine.query_schema(), engine.engine_protocol()); - let response = handler.handle(query, tx_id.map(TxId::from), trace_id).await; + let response = handler.handle(query, tx_id.map(TxId::from), traceparent).await; let serde_span = tracing::info_span!("prisma:engine:response_json_serialization", user_facing = true); Ok(serde_span.in_scope(|| serde_json::to_string(&response))?) @@ -338,55 +329,65 @@ impl QueryEngine { .await }) .with_subscriber(dispatcher) + .with_optional_recorder(recorder) .await } /// If connected, attempts to start a transaction in the core and returns its ID. #[napi] pub async fn start_transaction(&self, input: String, trace: String) -> napi::Result { + let dispatcher = self.logger.dispatcher(); + let recorder = self.logger.recorder(); + async_panic_to_js_error(async { let inner = self.inner.read().await; let engine = inner.as_engine()?; + let tx_opts: TransactionOptions = serde_json::from_str(&input)?; - let dispatcher = self.logger.dispatcher(); + let span = tracing::info_span!("prisma:engine:start_transaction", user_facing = true); + let parent_context = telemetry::helpers::restore_remote_context_from_json_str(&trace); + span.set_parent(parent_context); async move { - let span = tracing::info_span!("prisma:engine:itx_runner", user_facing = true, itx_id = field::Empty); - telemetry::helpers::set_parent_context_from_json_str(&span, &trace); - - let tx_opts: TransactionOptions = serde_json::from_str(&input)?; match engine .executor() .start_tx(engine.query_schema().clone(), engine.engine_protocol(), tx_opts) - .instrument(span) .await { Ok(tx_id) => Ok(json!({ "id": tx_id.to_string() }).to_string()), Err(err) => Ok(map_known_error(err)?), } } - .with_subscriber(dispatcher) + .instrument(span) .await }) + .with_subscriber(dispatcher) + .with_optional_recorder(recorder) .await } /// If connected, attempts to commit a transaction with id `tx_id` in the core. #[napi] - pub async fn commit_transaction(&self, tx_id: String, _trace: String) -> napi::Result { + pub async fn commit_transaction(&self, tx_id: String, trace: String) -> napi::Result { async_panic_to_js_error(async { let inner = self.inner.read().await; let engine = inner.as_engine()?; let dispatcher = self.logger.dispatcher(); + let recorder = self.logger.recorder(); async move { - match engine.executor().commit_tx(TxId::from(tx_id)).await { + let span = tracing::info_span!("prisma:engine:commit_transaction", user_facing = true); + let parent_context = telemetry::helpers::restore_remote_context_from_json_str(&trace); + span.set_parent(parent_context); + + match engine.executor().commit_tx(TxId::from(tx_id)).instrument(span).await { Ok(_) => Ok("{}".to_string()), Err(err) => Ok(map_known_error(err)?), } } .with_subscriber(dispatcher) + .with_optional_recorder(recorder) .await }) .await @@ -394,20 +395,26 @@ impl QueryEngine { /// If connected, attempts to roll back a transaction with id `tx_id` in the core. #[napi] - pub async fn rollback_transaction(&self, tx_id: String, _trace: String) -> napi::Result { + pub async fn rollback_transaction(&self, tx_id: String, trace: String) -> napi::Result { async_panic_to_js_error(async { let inner = self.inner.read().await; let engine = inner.as_engine()?; let dispatcher = self.logger.dispatcher(); + let recorder = self.logger.recorder(); async move { - match engine.executor().rollback_tx(TxId::from(tx_id)).await { + let span = tracing::info_span!("prisma:engine:rollback_transaction", user_facing = true); + let parent_context = telemetry::helpers::restore_remote_context_from_json_str(&trace); + span.set_parent(parent_context); + + match engine.executor().rollback_tx(TxId::from(tx_id)).instrument(span).await { Ok(_) => Ok("{}".to_string()), Err(err) => Ok(map_known_error(err)?), } } .with_subscriber(dispatcher) + .with_optional_recorder(recorder) .await }) .await @@ -416,17 +423,25 @@ impl QueryEngine { /// Loads the query schema. Only available when connected. #[napi] pub async fn sdl_schema(&self) -> napi::Result { + let dispatcher = self.logger.dispatcher(); + let recorder = self.logger.recorder(); + async_panic_to_js_error(async move { let inner = self.inner.read().await; let engine = inner.as_engine()?; Ok(render_graphql_schema(engine.query_schema())) }) + .with_subscriber(dispatcher) + .with_optional_recorder(recorder) .await } #[napi] pub async fn metrics(&self, json_options: String) -> napi::Result { + let dispatcher = self.logger.dispatcher(); + let recorder = self.logger.recorder(); + async_panic_to_js_error(async move { let inner = self.inner.read().await; let engine = inner.as_engine()?; @@ -447,6 +462,8 @@ impl QueryEngine { .into()) } }) + .with_subscriber(dispatcher) + .with_optional_recorder(recorder) .await } } diff --git a/query-engine/query-engine-node-api/src/logger.rs b/query-engine/query-engine-node-api/src/logger.rs index b86343bb4a9..bd0fd6dd8b3 100644 --- a/query-engine/query-engine-node-api/src/logger.rs +++ b/query-engine/query-engine-node-api/src/logger.rs @@ -1,8 +1,7 @@ use core::fmt; use napi::threadsafe_function::{ErrorStrategy, ThreadsafeFunction, ThreadsafeFunctionCallMode}; -use query_core::telemetry; +use prisma_metrics::{MetricRecorder, MetricRegistry}; use query_engine_common::logger::StringCallback; -use query_engine_metrics::MetricRegistry; use serde_json::Value; use std::{collections::BTreeMap, fmt::Display}; use tracing::{ @@ -21,6 +20,7 @@ pub(crate) type LogCallback = ThreadsafeFunction; pub(crate) struct Logger { dispatcher: Dispatch, metrics: Option, + recorder: Option, } impl Logger { @@ -63,16 +63,18 @@ impl Logger { let layer = log_callback.with_filter(filters); - let metrics = if enable_metrics { - query_engine_metrics::setup(); - Some(MetricRegistry::new()) + let (metrics, recorder) = if enable_metrics { + let registry = MetricRegistry::new(); + let recorder = MetricRecorder::new(registry.clone()).with_initialized_prisma_metrics(); + (Some(registry), Some(recorder)) } else { - None + (None, None) }; Self { - dispatcher: Dispatch::new(Registry::default().with(telemetry).with(layer).with(metrics.clone())), + dispatcher: Dispatch::new(Registry::default().with(telemetry).with(layer)), metrics, + recorder, } } @@ -83,6 +85,10 @@ impl Logger { pub fn metrics(&self) -> Option { self.metrics.clone() } + + pub fn recorder(&self) -> Option { + self.recorder.clone() + } } pub struct JsonVisitor<'a> { diff --git a/query-engine/query-engine-wasm/Cargo.toml b/query-engine/query-engine-wasm/Cargo.toml index 21ba05e0fe8..40017c0270b 100644 --- a/query-engine/query-engine-wasm/Cargo.toml +++ b/query-engine/query-engine-wasm/Cargo.toml @@ -44,6 +44,7 @@ request-handlers = { path = "../request-handlers", default-features = false, fea ] } query-core = { path = "../core" } driver-adapters = { path = "../driver-adapters" } +telemetry = { path = "../../libs/telemetry" } quaint.workspace = true connection-string.workspace = true js-sys.workspace = true @@ -58,14 +59,17 @@ thiserror = "1" url.workspace = true serde.workspace = true tokio = { version = "1", features = ["macros", "sync", "io-util", "time"] } -futures = "0.3" +futures.workspace = true tracing.workspace = true tracing-subscriber = { version = "0.3" } -tracing-futures = "0.2" +tracing-futures.workspace = true tracing-opentelemetry = "0.17.3" opentelemetry = { version = "0.17" } +[build-dependencies] +build-utils.path = "../../libs/build-utils" + [package.metadata.wasm-pack.profile.release] wasm-opt = false # use wasm-opt explicitly in `./build.sh` diff --git a/query-engine/query-engine-wasm/build.rs b/query-engine/query-engine-wasm/build.rs index 2e8fe20c050..33aded23a4a 100644 --- a/query-engine/query-engine-wasm/build.rs +++ b/query-engine/query-engine-wasm/build.rs @@ -1,11 +1,3 @@ -use std::process::Command; - -fn store_git_commit_hash() { - let output = Command::new("git").args(["rev-parse", "HEAD"]).output().unwrap(); - let git_hash = String::from_utf8(output.stdout).unwrap(); - println!("cargo:rustc-env=GIT_HASH={git_hash}"); -} - fn main() { - store_git_commit_hash(); + build_utils::store_git_commit_hash_in_env(); } diff --git a/query-engine/query-engine-wasm/pnpm-lock.yaml b/query-engine/query-engine-wasm/pnpm-lock.yaml deleted file mode 100644 index 89591aef986..00000000000 --- a/query-engine/query-engine-wasm/pnpm-lock.yaml +++ /dev/null @@ -1,130 +0,0 @@ -lockfileVersion: '6.0' - -settings: - autoInstallPeers: true - excludeLinksFromLockfile: false - -dependencies: - '@neondatabase/serverless': - specifier: 0.6.0 - version: 0.6.0 - '@prisma/adapter-neon': - specifier: 5.6.0 - version: 5.6.0(@neondatabase/serverless@0.6.0) - '@prisma/driver-adapter-utils': - specifier: 5.6.0 - version: 5.6.0 - -packages: - - /@neondatabase/serverless@0.6.0: - resolution: {integrity: sha512-qXxBRYN0m2v8kVQBfMxbzNGn2xFAhTXFibzQlE++NfJ56Shz3m7+MyBBtXDlEH+3Wfa6lToDXf1MElocY4sJ3w==} - dependencies: - '@types/pg': 8.6.6 - dev: false - - /@prisma/adapter-neon@5.6.0(@neondatabase/serverless@0.6.0): - resolution: {integrity: sha512-IUkIE5NKyP2wCXMMAByM78fizfaJl7YeWDEajvyqQafXgRwmxl+2HhxsevvHly8jT4RlELdhjK6IP1eciGvXVA==} - peerDependencies: - '@neondatabase/serverless': ^0.6.0 - dependencies: - '@neondatabase/serverless': 0.6.0 - '@prisma/driver-adapter-utils': 5.6.0 - postgres-array: 3.0.2 - transitivePeerDependencies: - - supports-color - dev: false - - /@prisma/driver-adapter-utils@5.6.0: - resolution: {integrity: sha512-/TSrfCGLAQghNf+bwg5/e8iKAgecCYU/gMN0IyNra3183/VTQJneLFgbacuSK9bBXiIRUmpbuUIrJ6dhENzfjA==} - dependencies: - debug: 4.3.4 - transitivePeerDependencies: - - supports-color - dev: false - - /@types/node@20.9.1: - resolution: {integrity: sha512-HhmzZh5LSJNS5O8jQKpJ/3ZcrrlG6L70hpGqMIAoM9YVD0YBRNWYsfwcXq8VnSjlNpCpgLzMXdiPo+dxcvSmiA==} - dependencies: - undici-types: 5.26.5 - dev: false - - /@types/pg@8.6.6: - resolution: {integrity: sha512-O2xNmXebtwVekJDD+02udOncjVcMZQuTEQEMpKJ0ZRf5E7/9JJX3izhKUcUifBkyKpljyUM6BTgy2trmviKlpw==} - dependencies: - '@types/node': 20.9.1 - pg-protocol: 1.6.0 - pg-types: 2.2.0 - dev: false - - /debug@4.3.4: - resolution: {integrity: sha512-PRWFHuSU3eDtQJPvnNY7Jcket1j0t5OuOsFzPPzsekD52Zl8qUfFIPEiswXqIvHWGVHOgX+7G/vCNNhehwxfkQ==} - engines: {node: '>=6.0'} - peerDependencies: - supports-color: '*' - peerDependenciesMeta: - supports-color: - optional: true - dependencies: - ms: 2.1.2 - dev: false - - /ms@2.1.2: - resolution: {integrity: sha512-sGkPx+VjMtmA6MX27oA4FBFELFCZZ4S4XqeGOXCv68tT+jb3vk/RyaKWP0PTKyWtmLSM0b+adUTEvbs1PEaH2w==} - dev: false - - /pg-int8@1.0.1: - resolution: {integrity: sha512-WCtabS6t3c8SkpDBUlb1kjOs7l66xsGdKpIPZsg4wR+B3+u9UAum2odSsF9tnvxg80h4ZxLWMy4pRjOsFIqQpw==} - engines: {node: '>=4.0.0'} - dev: false - - /pg-protocol@1.6.0: - resolution: {integrity: sha512-M+PDm637OY5WM307051+bsDia5Xej6d9IR4GwJse1qA1DIhiKlksvrneZOYQq42OM+spubpcNYEo2FcKQrDk+Q==} - dev: false - - /pg-types@2.2.0: - resolution: {integrity: sha512-qTAAlrEsl8s4OiEQY69wDvcMIdQN6wdz5ojQiOy6YRMuynxenON0O5oCpJI6lshc6scgAY8qvJ2On/p+CXY0GA==} - engines: {node: '>=4'} - dependencies: - pg-int8: 1.0.1 - postgres-array: 2.0.0 - postgres-bytea: 1.0.0 - postgres-date: 1.0.7 - postgres-interval: 1.2.0 - dev: false - - /postgres-array@2.0.0: - resolution: {integrity: sha512-VpZrUqU5A69eQyW2c5CA1jtLecCsN2U/bD6VilrFDWq5+5UIEVO7nazS3TEcHf1zuPYO/sqGvUvW62g86RXZuA==} - engines: {node: '>=4'} - dev: false - - /postgres-array@3.0.2: - resolution: {integrity: sha512-6faShkdFugNQCLwucjPcY5ARoW1SlbnrZjmGl0IrrqewpvxvhSLHimCVzqeuULCbG0fQv7Dtk1yDbG3xv7Veog==} - engines: {node: '>=12'} - dev: false - - /postgres-bytea@1.0.0: - resolution: {integrity: sha512-xy3pmLuQqRBZBXDULy7KbaitYqLcmxigw14Q5sj8QBVLqEwXfeybIKVWiqAXTlcvdvb0+xkOtDbfQMOf4lST1w==} - engines: {node: '>=0.10.0'} - dev: false - - /postgres-date@1.0.7: - resolution: {integrity: sha512-suDmjLVQg78nMK2UZ454hAG+OAW+HQPZ6n++TNDUX+L0+uUlLywnoxJKDou51Zm+zTCjrCl0Nq6J9C5hP9vK/Q==} - engines: {node: '>=0.10.0'} - dev: false - - /postgres-interval@1.2.0: - resolution: {integrity: sha512-9ZhXKM/rw350N1ovuWHbGxnGh/SNJ4cnxHiM0rxE4VN41wsg8P8zWn9hv/buK00RP4WvlOyr/RBDiptyxVbkZQ==} - engines: {node: '>=0.10.0'} - dependencies: - xtend: 4.0.2 - dev: false - - /undici-types@5.26.5: - resolution: {integrity: sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA==} - dev: false - - /xtend@4.0.2: - resolution: {integrity: sha512-LKYU1iAXJXUgAXn9URjiu+MWhyUXHsvfp7mcuYm9dSUKK0/CjtrUwFAxD82/mCWbtLsGjFIad0wIsod4zrTAEQ==} - engines: {node: '>=0.4'} - dev: false diff --git a/query-engine/query-engine-wasm/rust-toolchain.toml b/query-engine/query-engine-wasm/rust-toolchain.toml index 5048fd2e74a..44e38c0b870 100644 --- a/query-engine/query-engine-wasm/rust-toolchain.toml +++ b/query-engine/query-engine-wasm/rust-toolchain.toml @@ -1,5 +1,5 @@ [toolchain] -channel = "nightly-2024-05-25" +channel = "nightly-2024-09-01" components = ["clippy", "rustfmt", "rust-src"] targets = [ "wasm32-unknown-unknown", diff --git a/query-engine/query-engine-wasm/src/wasm/engine.rs b/query-engine/query-engine-wasm/src/wasm/engine.rs index ae6fe40f872..837160e1bb0 100644 --- a/query-engine/query-engine-wasm/src/wasm/engine.rs +++ b/query-engine/query-engine-wasm/src/wasm/engine.rs @@ -13,15 +13,17 @@ use query_core::{ protocol::EngineProtocol, relation_load_strategy, schema::{self}, - telemetry, TransactionOptions, TxId, + TransactionOptions, TxId, }; use query_engine_common::engine::{map_known_error, ConnectedEngine, ConstructorOptions, EngineBuilder, Inner}; use request_handlers::ConnectorKind; use request_handlers::{load_executor, RequestBody, RequestHandler}; use serde_json::json; use std::{marker::PhantomData, sync::Arc}; +use telemetry::helpers::TraceParent; use tokio::sync::RwLock; -use tracing::{field, instrument::WithSubscriber, Instrument, Level, Span}; +use tracing::{instrument::WithSubscriber, Instrument, Level}; +use tracing_opentelemetry::OpenTelemetrySpanExt; use tracing_subscriber::filter::LevelFilter; use wasm_bindgen::prelude::wasm_bindgen; @@ -89,7 +91,8 @@ impl QueryEngine { async { let span = tracing::info_span!("prisma:engine:connect"); - let _ = telemetry::helpers::set_parent_context_from_json_str(&span, &trace); + let parent_context = telemetry::helpers::restore_remote_context_from_json_str(&trace); + span.set_parent(parent_context); let mut inner = self.inner.write().await; let builder = inner.as_builder()?; @@ -111,7 +114,7 @@ impl QueryEngine { let conn_span = tracing::info_span!( "prisma:engine:connection", user_facing = true, - "db.type" = connector.name(), + "db.system" = connector.name(), ); let conn = connector.get_connection().instrument(conn_span).await?; @@ -149,8 +152,9 @@ impl QueryEngine { let dispatcher = self.logger.dispatcher(); async { - let span = tracing::info_span!("prisma:engine:disconnect"); - let _ = telemetry::helpers::set_parent_context_from_json_str(&span, &trace); + let span = tracing::info_span!("prisma:engine:disconnect", user_facing = true); + let parent_context = telemetry::helpers::restore_remote_context_from_json_str(&trace); + span.set_parent(parent_context); async { let mut inner = self.inner.write().await; @@ -189,17 +193,14 @@ impl QueryEngine { let query = RequestBody::try_from_str(&body, engine.engine_protocol())?; async move { - let span = if tx_id.is_none() { - tracing::info_span!("prisma:engine", user_facing = true) - } else { - Span::none() - }; - - let trace_id = telemetry::helpers::set_parent_context_from_json_str(&span, &trace); + let span = tracing::info_span!("prisma:engine:query", user_facing = true); + let parent_context = telemetry::helpers::restore_remote_context_from_json_str(&trace); + let traceparent = TraceParent::from_remote_context(&parent_context); + span.set_parent(parent_context); let handler = RequestHandler::new(engine.executor(), engine.query_schema(), engine.engine_protocol()); let response = handler - .handle(query, tx_id.map(TxId::from), trace_id) + .handle(query, tx_id.map(TxId::from), traceparent) .instrument(span) .await; @@ -219,13 +220,15 @@ impl QueryEngine { let dispatcher = self.logger.dispatcher(); async move { - let span = tracing::info_span!("prisma:engine:itx_runner", user_facing = true, itx_id = field::Empty); + let span = tracing::info_span!("prisma:engine:start_transaction", user_facing = true); + let parent_context = telemetry::helpers::restore_remote_context_from_json_str(&trace); + let traceparent = TraceParent::from_remote_context(&parent_context); + span.set_parent(parent_context); let tx_opts: TransactionOptions = serde_json::from_str(&input)?; match engine .executor() .start_tx(engine.query_schema().clone(), engine.engine_protocol(), tx_opts) - .instrument(span) .await { Ok(tx_id) => Ok(json!({ "id": tx_id.to_string() }).to_string()), @@ -245,6 +248,11 @@ impl QueryEngine { let dispatcher = self.logger.dispatcher(); async move { + let span = tracing::info_span!("prisma:engine:commit_transaction", user_facing = true); + let parent_context = telemetry::helpers::restore_remote_context_from_json_str(&trace); + let traceparent = TraceParent::from_remote_context(&parent_context); + span.set_parent(parent_context); + match engine.executor().commit_tx(TxId::from(tx_id)).await { Ok(_) => Ok("{}".to_string()), Err(err) => Ok(map_known_error(err)?), @@ -263,6 +271,11 @@ impl QueryEngine { let dispatcher = self.logger.dispatcher(); async move { + let span = tracing::info_span!("prisma:engine:rollback_transaction", user_facing = true); + let parent_context = telemetry::helpers::restore_remote_context_from_json_str(&trace); + let traceparent = TraceParent::from_remote_context(&parent_context); + span.set_parent(parent_context); + match engine.executor().rollback_tx(TxId::from(tx_id)).await { Ok(_) => Ok("{}".to_string()), Err(err) => Ok(map_known_error(err)?), diff --git a/query-engine/query-engine/Cargo.toml b/query-engine/query-engine/Cargo.toml index e3fd4768ed7..7d56c891972 100644 --- a/query-engine/query-engine/Cargo.toml +++ b/query-engine/query-engine/Cargo.toml @@ -32,11 +32,15 @@ tracing-opentelemetry = "0.17.3" tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] } opentelemetry = { version = "0.17.0", features = ["rt-tokio"] } opentelemetry-otlp = { version = "0.10", features = ["tls", "tls-roots"] } -query-engine-metrics = { path = "../metrics" } +prisma-metrics.path = "../../libs/metrics" user-facing-errors = { path = "../../libs/user-facing-errors" } +telemetry = { path = "../../libs/telemetry", features = ["metrics"] } [dev-dependencies] serial_test = "*" quaint.workspace = true indoc.workspace = true + +[build-dependencies] +build-utils.path = "../../libs/build-utils" diff --git a/query-engine/query-engine/build.rs b/query-engine/query-engine/build.rs index 2e8fe20c050..33aded23a4a 100644 --- a/query-engine/query-engine/build.rs +++ b/query-engine/query-engine/build.rs @@ -1,11 +1,3 @@ -use std::process::Command; - -fn store_git_commit_hash() { - let output = Command::new("git").args(["rev-parse", "HEAD"]).output().unwrap(); - let git_hash = String::from_utf8(output.stdout).unwrap(); - println!("cargo:rustc-env=GIT_HASH={git_hash}"); -} - fn main() { - store_git_commit_hash(); + build_utils::store_git_commit_hash_in_env(); } diff --git a/query-engine/query-engine/src/context.rs b/query-engine/query-engine/src/context.rs index 7a1138c411e..f6e3896c17a 100644 --- a/query-engine/query-engine/src/context.rs +++ b/query-engine/query-engine/src/context.rs @@ -1,6 +1,7 @@ use crate::features::{EnabledFeatures, Feature}; use crate::{logger::Logger, opt::PrismaOpt}; use crate::{PrismaError, PrismaResult}; +use prisma_metrics::{MetricRecorder, MetricRegistry}; use psl::PreviewFeature; use query_core::{ protocol::EngineProtocol, @@ -8,11 +9,10 @@ use query_core::{ schema::{self, QuerySchemaRef}, QueryExecutor, }; -use query_engine_metrics::setup as metric_setup; -use query_engine_metrics::MetricRegistry; use request_handlers::{load_executor, ConnectorKind}; use std::{env, fmt, sync::Arc}; use tracing::Instrument; +use tracing_opentelemetry::OpenTelemetrySpanExt; /// Prisma request context containing all immutable state of the process. /// There is usually only one context initialized per process. @@ -49,7 +49,8 @@ impl PrismaContext { // Construct query schema schema::build(arced_schema, enabled_features.contains(Feature::RawQueries)) }); - let executor_fut = tokio::spawn(async move { + + let executor_fut = async move { let config = &arced_schema_2.configuration; let preview_features = config.preview_features(); @@ -62,14 +63,22 @@ impl PrismaContext { let url = datasource.load_url(|key| env::var(key).ok())?; // Load executor let executor = load_executor(ConnectorKind::Rust { url, datasource }, preview_features).await?; - let conn = executor.primary_connector().get_connection().await?; + let connector = executor.primary_connector(); + + let conn_span = tracing::info_span!( + "prisma:engine:connection", + user_facing = true, + "db.system" = connector.name(), + ); + + let conn = connector.get_connection().instrument(conn_span).await?; let db_version = conn.version().await; PrismaResult::<_>::Ok((executor, db_version)) - }); + }; let (query_schema, executor_with_db_version) = tokio::join!(query_schema_fut, executor_fut); - let (executor, db_version) = executor_with_db_version.unwrap()?; + let (executor, db_version) = executor_with_db_version?; let query_schema = query_schema.unwrap().with_db_version_supports_join_strategy( relation_load_strategy::db_version_supports_joins_strategy(db_version)?, @@ -103,29 +112,29 @@ impl PrismaContext { } } -pub async fn setup( - opts: &PrismaOpt, - install_logger: bool, - metrics: Option, -) -> PrismaResult> { - let metrics = metrics.unwrap_or_default(); - - if install_logger { - Logger::new("prisma-engine-http", Some(metrics.clone()), opts) - .install() - .unwrap(); - } +pub async fn setup(opts: &PrismaOpt) -> PrismaResult> { + Logger::new("prisma-engine-http", opts).install().unwrap(); - if opts.enable_metrics || opts.dataproxy_metric_override { - metric_setup(); - } + let metrics = if opts.enable_metrics || opts.dataproxy_metric_override { + let metrics = MetricRegistry::new(); + let recorder = MetricRecorder::new(metrics.clone()); + recorder.install_globally().expect("setup must be called only once"); + recorder.init_prisma_metrics(); + Some(metrics) + } else { + None + }; let datamodel = opts.schema(false)?; let config = &datamodel.configuration; let protocol = opts.engine_protocol(); config.validate_that_one_datasource_is_provided()?; - let span = tracing::info_span!("prisma:engine:connect"); + let span = tracing::info_span!("prisma:engine:connect", user_facing = true); + if let Some(trace_context) = opts.trace_context.as_ref() { + let parent_context = telemetry::helpers::restore_remote_context_from_json_str(trace_context); + span.set_parent(parent_context); + } let mut features = EnabledFeatures::from(opts); @@ -133,7 +142,7 @@ pub async fn setup( features |= Feature::Metrics } - let cx = PrismaContext::new(datamodel, protocol, features, Some(metrics)) + let cx = PrismaContext::new(datamodel, protocol, features, metrics) .instrument(span) .await?; diff --git a/query-engine/query-engine/src/logger.rs b/query-engine/query-engine/src/logger.rs index 10f6ced58b8..2bf2566fe96 100644 --- a/query-engine/query-engine/src/logger.rs +++ b/query-engine/query-engine/src/logger.rs @@ -3,8 +3,6 @@ use opentelemetry::{ KeyValue, }; use opentelemetry_otlp::WithExportConfig; -use query_core::telemetry; -use query_engine_metrics::MetricRegistry; use tracing::{dispatcher::SetGlobalDefaultError, subscriber}; use tracing_subscriber::{filter::filter_fn, layer::SubscriberExt, Layer}; @@ -19,7 +17,6 @@ pub(crate) struct Logger { log_format: LogFormat, log_queries: bool, tracing_config: TracingConfig, - metrics: Option, } // TracingConfig specifies how tracing will be exposed by the logger facility @@ -38,7 +35,7 @@ enum TracingConfig { impl Logger { /// Initialize a new global logger installer. - pub fn new(service_name: &'static str, metrics: Option, opts: &PrismaOpt) -> Self { + pub fn new(service_name: &'static str, opts: &PrismaOpt) -> Self { let enable_telemetry = opts.enable_open_telemetry; let enable_capturing = opts.enable_telemetry_in_response; let endpoint = if opts.open_telemetry_endpoint.is_empty() { @@ -58,7 +55,6 @@ impl Logger { service_name, log_format: opts.log_format(), log_queries: opts.log_queries(), - metrics, tracing_config, } } @@ -81,9 +77,7 @@ impl Logger { } }; - let subscriber = tracing_subscriber::registry() - .with(fmt_layer) - .with(self.metrics.clone()); + let subscriber = tracing_subscriber::registry().with(fmt_layer); match self.tracing_config { TracingConfig::Captured => { diff --git a/query-engine/query-engine/src/main.rs b/query-engine/query-engine/src/main.rs index 7c3a6f7a1db..17900c4ad74 100644 --- a/query-engine/query-engine/src/main.rs +++ b/query-engine/query-engine/src/main.rs @@ -1,8 +1,5 @@ #![allow(clippy::upper_case_acronyms)] -#[macro_use] -extern crate tracing; - use query_engine::cli::CliCommand; use query_engine::context; use query_engine::error::PrismaError; @@ -11,14 +8,13 @@ use query_engine::server; use query_engine::LogFormat; use std::{error::Error, process}; use structopt::StructOpt; -use tracing::Instrument; type AnyError = Box; #[tokio::main] async fn main() -> Result<(), AnyError> { return main().await.map_err(|err| { - info!("Encountered error during initialization:"); + tracing::info!("Encountered error during initialization:"); err.render_as_json().expect("error rendering"); process::exit(1) }); @@ -29,8 +25,7 @@ async fn main() -> Result<(), AnyError> { match CliCommand::from_opt(&opts)? { Some(cmd) => cmd.execute().await?, None => { - let span = tracing::info_span!("prisma:engine:connect"); - let cx = context::setup(&opts, true, None).instrument(span).await?; + let cx = context::setup(&opts).await?; set_panic_hook(opts.log_format()); server::listen(cx, &opts).await?; } diff --git a/query-engine/query-engine/src/opt.rs b/query-engine/query-engine/src/opt.rs index 83ee4bb7fdc..d2d9441f87f 100644 --- a/query-engine/query-engine/src/opt.rs +++ b/query-engine/query-engine/src/opt.rs @@ -119,6 +119,11 @@ pub struct PrismaOpt { #[structopt(long, env = "PRISMA_ENGINE_PROTOCOL")] pub engine_protocol: Option, + /// The trace context (https://www.w3.org/TR/trace-context) for the engine initialization + /// as a JSON object with properties corresponding to the headers (e.g. `traceparent`). + #[structopt(long, env)] + pub trace_context: Option, + #[structopt(subcommand)] pub subcommand: Option, } diff --git a/query-engine/query-engine/src/server/mod.rs b/query-engine/query-engine/src/server/mod.rs index 01b61a07b6b..e3950996233 100644 --- a/query-engine/query-engine/src/server/mod.rs +++ b/query-engine/query-engine/src/server/mod.rs @@ -3,19 +3,19 @@ use crate::features::Feature; use crate::{opt::PrismaOpt, PrismaResult}; use hyper::service::{make_service_fn, service_fn}; use hyper::{header::CONTENT_TYPE, Body, HeaderMap, Method, Request, Response, Server, StatusCode}; -use opentelemetry::trace::TraceContextExt; +use opentelemetry::trace::{TraceContextExt, TraceId}; use opentelemetry::{global, propagation::Extractor}; -use query_core::helpers::*; -use query_core::telemetry::capturing::TxTraceExt; -use query_core::{telemetry, ExtendedTransactionUserFacingError, TransactionOptions, TxId}; +use query_core::{ExtendedUserFacingError, TransactionOptions, TxId}; use request_handlers::{dmmf, render_graphql_schema, RequestBody, RequestHandler}; use serde::Serialize; use serde_json::json; use std::collections::HashMap; use std::net::SocketAddr; use std::sync::Arc; -use std::time::Instant; -use tracing::{field, Instrument, Span}; +use std::time::{Duration, Instant}; +use telemetry::capturing::Capturer; +use telemetry::helpers::TraceParent; +use tracing::{Instrument, Span}; use tracing_opentelemetry::OpenTelemetrySpanExt; /// Starts up the graphql query engine server @@ -111,71 +111,22 @@ async fn request_handler(cx: Arc, req: Request) -> Result { let handler = RequestHandler::new(cx.executor(), cx.query_schema(), cx.engine_protocol()); let mut result = handler.handle(body, tx_id, traceparent).instrument(span).await; - if let telemetry::capturing::Capturer::Enabled(capturer) = &capture_config { + if let telemetry::capturing::Capturer::Enabled(capturer) = capturer { let telemetry = capturer.fetch_captures().await; if let Some(telemetry) = telemetry { result.set_extension("traces".to_owned(), json!(telemetry.traces)); @@ -202,7 +153,32 @@ async fn request_handler(cx: Arc, req: Request) -> Result tokio::time::sleep(timeout).await, + // Never return if timeout isn't set. + None => std::future::pending().await, + } + }; + + tokio::select! { + _ = query_timeout_fut => { + let captured_telemetry = if let telemetry::capturing::Capturer::Enabled(capturer) = capturer { + capturer.fetch_captures().await + } else { + None + }; + + // Note: this relies on the fact that client will rollback the transaction after the + // error. If the client continues using this transaction (and later commits it), data + // corruption might happen because some write queries (but not all of them) might be + // already executed by the database before the timeout is fired. + Ok(err_to_http_resp(query_core::CoreError::QueryTimeout, captured_telemetry)) + } + result = work => { + result + } + } } /// Expose the GraphQL playground if enabled. @@ -227,10 +203,7 @@ async fn metrics_handler(cx: Arc, req: Request) -> Result = match serde_json::from_slice(full_body.as_ref()) { - Ok(map) => map, - Err(_e) => HashMap::new(), - }; + let global_labels: HashMap = serde_json::from_slice(full_body.as_ref()).unwrap_or_default(); let response = if requested_json { let metrics = cx.metrics.to_json(global_labels); @@ -281,46 +254,22 @@ async fn transaction_start_handler(cx: Arc, req: Request) - let body_start = req.into_body(); let full_body = hyper::body::to_bytes(body_start).await?; - let mut tx_opts: TransactionOptions = serde_json::from_slice(full_body.as_ref()).unwrap(); - let tx_id = tx_opts.with_new_transaction_id(); - - // This is the span we use to instrument the execution of a transaction. This span will be open - // during the tx execution, and held in the ITXServer for that transaction (see ITXServer]) - let span = info_span!("prisma:engine:itx_runner", user_facing = true, itx_id = field::Empty); - - // If telemetry needs to be captured, we use the span trace_id to correlate the logs happening - // during the different operations within a transaction. The trace_id is propagated in the - // traceparent header, but if it's not present, we need to synthetically create one for the - // transaction. This is needed, in case the client is interested in capturing logs and not - // traces, because: - // - The client won't send a traceparent header - // - A transaction initial span is created here (prisma:engine:itx_runner) and stored in the - // ITXServer for that transaction - // - When a query comes in, the graphql handler process it, but we need to tell the capturer - // to start capturing logs, and for that we need a trace_id. There are two places were we - // could get that information from: - // - First, it's the traceparent, but the client didn't send it, because they are only - // interested in logs. - // - Second, it's the root span for the transaction, but it's not in scope but instead - // stored in the ITXServer, in a different tokio task. - // - // For the above reasons, we need to create a trace_id that we can predict and use accross the - // different operations happening within a transaction. So we do it by converting the tx_id - // into a trace_id, leaning on the fact that the tx_id has more entropy, and there's no - // information loss. - let capture_settings = capture_settings(&headers); - let traceparent = traceparent(&headers); - if traceparent.is_none() && capture_settings.logs_enabled() { - span.set_parent(tx_id.into_trace_context()) - } else { - span.set_parent(get_parent_span_context(&headers)) - } - let trace_id = span.context().span().span_context().trace_id(); - let capture_config = telemetry::capturing::capturer(trace_id, capture_settings); - if let telemetry::capturing::Capturer::Enabled(capturer) = &capture_config { - capturer.start_capturing().await; - } + let tx_opts = match serde_json::from_slice::(full_body.as_ref()) { + Ok(opts) => opts.with_new_transaction_id(), + Err(_) => { + return Ok(Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(Body::from("Invalid transaction options")) + .unwrap()) + } + }; + + let (span, _traceparent, capturer) = setup_telemetry( + info_span!("prisma:engine:start_transaction", user_facing = true), + &headers, + ) + .await; let result = cx .executor @@ -328,12 +277,7 @@ async fn transaction_start_handler(cx: Arc, req: Request) - .instrument(span) .await; - let telemetry = if let telemetry::capturing::Capturer::Enabled(capturer) = &capture_config { - capturer.fetch_captures().await - } else { - None - }; - + let telemetry = capturer.try_fetch_captures().await; match result { Ok(tx_id) => { let result = if let Some(telemetry) = telemetry { @@ -355,20 +299,15 @@ async fn transaction_commit_handler( req: Request, tx_id: TxId, ) -> Result, hyper::Error> { - let capture_config = capture_config(req.headers(), tx_id.clone()); - - if let telemetry::capturing::Capturer::Enabled(capturer) = &capture_config { - capturer.start_capturing().await; - } + let (span, _traceparent, capturer) = setup_telemetry( + info_span!("prisma:engine:commit_transaction", user_facing = true), + req.headers(), + ) + .await; - let result = cx.executor.commit_tx(tx_id).await; - - let telemetry = if let telemetry::capturing::Capturer::Enabled(capturer) = &capture_config { - capturer.fetch_captures().await - } else { - None - }; + let result = cx.executor.commit_tx(tx_id).instrument(span).await; + let telemetry = capturer.try_fetch_captures().await; match result { Ok(_) => Ok(empty_json_to_http_resp(telemetry)), Err(err) => Ok(err_to_http_resp(err, telemetry)), @@ -380,20 +319,15 @@ async fn transaction_rollback_handler( req: Request, tx_id: TxId, ) -> Result, hyper::Error> { - let capture_config = capture_config(req.headers(), tx_id.clone()); + let (span, _traceparent, capturer) = setup_telemetry( + info_span!("prisma:engine:rollback_transaction", user_facing = true), + req.headers(), + ) + .await; - if let telemetry::capturing::Capturer::Enabled(capturer) = &capture_config { - capturer.start_capturing().await; - } - - let result = cx.executor.rollback_tx(tx_id).await; - - let telemetry = if let telemetry::capturing::Capturer::Enabled(capturer) = &capture_config { - capturer.fetch_captures().await - } else { - None - }; + let result = cx.executor.rollback_tx(tx_id).instrument(span).await; + let telemetry = capturer.try_fetch_captures().await; match result { Ok(_) => Ok(empty_json_to_http_resp(telemetry)), Err(err) => Ok(err_to_http_resp(err, telemetry)), @@ -457,11 +391,13 @@ fn err_to_http_resp( query_core::TransactionError::Unknown { reason: _ } => StatusCode::INTERNAL_SERVER_ERROR, }, + query_core::CoreError::QueryTimeout => StatusCode::REQUEST_TIMEOUT, + // All other errors are treated as 500s, most of these paths should never be hit, only connector errors may occur. _ => StatusCode::INTERNAL_SERVER_ERROR, }; - let mut err: ExtendedTransactionUserFacingError = err.into(); + let mut err: ExtendedUserFacingError = err.into(); if let Some(telemetry) = captured_telemetry { err.set_extension("traces".to_owned(), json!(telemetry.traces)); err.set_extension("logs".to_owned(), json!(telemetry.logs)); @@ -470,57 +406,86 @@ fn err_to_http_resp( build_json_response(status, &err) } -fn capture_config(headers: &HeaderMap, tx_id: TxId) -> telemetry::capturing::Capturer { - let capture_settings = capture_settings(headers); - let mut traceparent = traceparent(headers); - - if traceparent.is_none() && capture_settings.is_enabled() { - traceparent = Some(tx_id.as_traceparent()) - } - - let trace_id = get_trace_id_from_traceparent(traceparent.as_deref()); +async fn setup_telemetry(span: Span, headers: &HeaderMap) -> (Span, Option, Capturer) { + let capture_settings = { + let settings = headers + .get("X-capture-telemetry") + .and_then(|value| value.to_str().ok()) + .unwrap_or_default(); + telemetry::capturing::Settings::from(settings) + }; - telemetry::capturing::capturer(trace_id, capture_settings) -} + // Parse parent trace_id and span_id from `traceparent` header and attach them to the current + // context. Internally, this relies on the fact that global text map propagator was installed that + // can handle `traceparent` header (for example, `TraceContextPropagator`). + let parent_context = { + let extractor = HeaderExtractor(headers); + let context = global::get_text_map_propagator(|propagator| propagator.extract(&extractor)); + if context.span().span_context().is_valid() { + Some(context) + } else { + None + } + }; -#[allow(clippy::bind_instead_of_map)] -fn capture_settings(headers: &HeaderMap) -> telemetry::capturing::Settings { - const CAPTURE_TELEMETRY_HEADER: &str = "X-capture-telemetry"; - let s = if let Some(hv) = headers.get(CAPTURE_TELEMETRY_HEADER) { - hv.to_str().unwrap_or("") + let traceparent = if let Some(parent_context) = parent_context { + let requester_traceparent = TraceParent::from_remote_context(&parent_context); + span.set_parent(parent_context); + requester_traceparent + } else if capture_settings.is_enabled() { + // If tracing is disabled on the client but capturing the logs is enabled, we construct an + // artificial traceparent. Although the span corresponding to this traceparent doesn't + // actually exist, it is okay because no spans will be returned to the client in this case + // anyway, so they don't have to be valid. The reason we need this is because capturing the + // logs requires a trace ID to correlate the events with. This is not the right design: it + // seems to be based on the wrong idea that trace ID uniquely identifies a request (which + // is not the case in reality), and it is prone to race conditions and losing spans and + // logs when there are multiple concurrent Prisma operations in a single trace. Ironically, + // this makes log capturing more reliable in the code path with the fake traceparent hack + // than when a real traceparent is present. The drawback is that the fake traceparent leaks + // to the query logs. We could of course add a custom flag to `TraceParent` to indicate + // that it is synthetic (we can't use the sampled trace flag for it as it would prevent it + // from being processed by the `SpanProcessor`) and check it when adding the traceparent + // comment if we wanted a quick fix for that, but this problem existed for as long as + // capturing was implemented, and the `DataProxyEngine` works around it by stripping the + // phony traceparent comments before emitting the logs on the `PrismaClient` instance. So + // instead, we will fix the root cause of the problem by reworking the capturer to collect + // all spans and events which have the `span` created above as an ancestor and not rely on + // trace IDs at all. This will happen in a follow-up PR as part of Tracing GA work. + let traceparent = { + #[allow(deprecated)] + TraceParent::new_random() + }; + span.set_parent(traceparent.to_remote_context()); + Some(traceparent) } else { - "" + None }; - telemetry::capturing::Settings::from(s) -} - -fn traceparent(headers: &HeaderMap) -> Option { - const TRACEPARENT_HEADER: &str = "traceparent"; + let trace_id = traceparent + .as_ref() + .map(TraceParent::trace_id) + .unwrap_or(TraceId::INVALID); - let value = headers - .get(TRACEPARENT_HEADER) - .and_then(|h| h.to_str().ok()) - .map(|s| s.to_owned()); + let capturer = telemetry::capturing::capturer(trace_id, capture_settings); + capturer.try_start_capturing().await; - let is_valid_traceparent = |s: &String| s.split_terminator('-').count() >= 4; - - value.filter(is_valid_traceparent) + (span, traceparent, capturer) } -fn transaction_id(headers: &HeaderMap) -> Option { - const TRANSACTION_ID_HEADER: &str = "X-transaction-id"; +fn try_get_transaction_id(headers: &HeaderMap) -> Option { headers - .get(TRANSACTION_ID_HEADER) + .get("X-transaction-id") .and_then(|h| h.to_str().ok()) .map(TxId::from) } -/// If the client sends us a trace and span id, extracting a new context if the -/// headers are set. If not, returns current context. -fn get_parent_span_context(headers: &HeaderMap) -> opentelemetry::Context { - let extractor = HeaderExtractor(headers); - global::get_text_map_propagator(|propagator| propagator.extract(&extractor)) +fn query_timeout(headers: &HeaderMap) -> Option { + headers + .get("X-query-timeout") + .and_then(|h| h.to_str().ok()) + .and_then(|value| value.parse::().ok()) + .map(Duration::from_millis) } fn build_json_response(status_code: StatusCode, value: &T) -> Response diff --git a/query-engine/query-engine/src/tests/dmmf.rs b/query-engine/query-engine/src/tests/dmmf.rs index 8151f25bf17..443b2c81a19 100644 --- a/query-engine/query-engine/src/tests/dmmf.rs +++ b/query-engine/query-engine/src/tests/dmmf.rs @@ -96,6 +96,7 @@ fn test_dmmf_cli_command(schema: &str) -> PrismaResult<()> { enable_telemetry_in_response: false, dataproxy_metric_override: false, engine_protocol: None, + trace_context: None, }; let cli_cmd = CliCommand::from_opt(&prisma_opt)?.unwrap(); diff --git a/query-engine/query-engine/src/tracer.rs b/query-engine/query-engine/src/tracer.rs index 8763ba892f4..75f6630931b 100644 --- a/query-engine/query-engine/src/tracer.rs +++ b/query-engine/query-engine/src/tracer.rs @@ -8,7 +8,7 @@ use opentelemetry::{ }, trace::TracerProvider, }; -use query_core::telemetry; + use std::io::{stdout, Stdout}; use std::{fmt::Debug, io::Write}; diff --git a/query-engine/request-handlers/Cargo.toml b/query-engine/request-handlers/Cargo.toml index e23d5927c55..8d7e8b4e222 100644 --- a/query-engine/request-handlers/Cargo.toml +++ b/query-engine/request-handlers/Cargo.toml @@ -8,13 +8,14 @@ psl.workspace = true query-structure = { path = "../query-structure" } query-core = { path = "../core" } user-facing-errors = { path = "../../libs/user-facing-errors" } +telemetry = { path = "../../libs/telemetry" } quaint.workspace = true dmmf_crate = { path = "../dmmf", package = "dmmf" } itertools.workspace = true graphql-parser = { git = "https://github.com/prisma/graphql-parser", optional = true } serde.workspace = true serde_json.workspace = true -futures = "0.3" +futures.workspace = true indexmap.workspace = true bigdecimal = "0.3" thiserror = "1" diff --git a/query-engine/request-handlers/src/handler.rs b/query-engine/request-handlers/src/handler.rs index 570a109a15f..123af6541c4 100644 --- a/query-engine/request-handlers/src/handler.rs +++ b/query-engine/request-handlers/src/handler.rs @@ -13,6 +13,7 @@ use query_core::{ }; use query_structure::{parse_datetime, stringify_datetime, PrismaValue}; use std::{collections::HashMap, fmt, panic::AssertUnwindSafe, str::FromStr}; +use telemetry::helpers::TraceParent; type ArgsToResult = (HashMap, IndexMap); @@ -41,24 +42,34 @@ impl<'a> RequestHandler<'a> { } } - pub async fn handle(&self, body: RequestBody, tx_id: Option, trace_id: Option) -> PrismaResponse { + pub async fn handle( + &self, + body: RequestBody, + tx_id: Option, + traceparent: Option, + ) -> PrismaResponse { tracing::debug!("Incoming GraphQL query: {:?}", &body); match body.into_doc(self.query_schema) { - Ok(QueryDocument::Single(query)) => self.handle_single(query, tx_id, trace_id).await, + Ok(QueryDocument::Single(query)) => self.handle_single(query, tx_id, traceparent).await, Ok(QueryDocument::Multi(batch)) => match batch.compact(self.query_schema) { BatchDocument::Multi(batch, transaction) => { - self.handle_batch(batch, transaction, tx_id, trace_id).await + self.handle_batch(batch, transaction, tx_id, traceparent).await } - BatchDocument::Compact(compacted) => self.handle_compacted(compacted, tx_id, trace_id).await, + BatchDocument::Compact(compacted) => self.handle_compacted(compacted, tx_id, traceparent).await, }, Err(err) => PrismaResponse::Single(GQLError::from_handler_error(err).into()), } } - async fn handle_single(&self, query: Operation, tx_id: Option, trace_id: Option) -> PrismaResponse { - let gql_response = match AssertUnwindSafe(self.handle_request(query, tx_id, trace_id)) + async fn handle_single( + &self, + query: Operation, + tx_id: Option, + traceparent: Option, + ) -> PrismaResponse { + let gql_response = match AssertUnwindSafe(self.handle_request(query, tx_id, traceparent)) .catch_unwind() .await { @@ -75,14 +86,14 @@ impl<'a> RequestHandler<'a> { queries: Vec, transaction: Option, tx_id: Option, - trace_id: Option, + traceparent: Option, ) -> PrismaResponse { match AssertUnwindSafe(self.executor.execute_all( tx_id, queries, transaction, self.query_schema.clone(), - trace_id, + traceparent, self.engine_protocol, )) .catch_unwind() @@ -108,7 +119,7 @@ impl<'a> RequestHandler<'a> { &self, document: CompactedDocument, tx_id: Option, - trace_id: Option, + traceparent: Option, ) -> PrismaResponse { let plural_name = document.plural_name(); let singular_name = document.single_name(); @@ -117,7 +128,7 @@ impl<'a> RequestHandler<'a> { let arguments = document.arguments; let nested_selection = document.nested_selection; - match AssertUnwindSafe(self.handle_request(document.operation, tx_id, trace_id)) + match AssertUnwindSafe(self.handle_request(document.operation, tx_id, traceparent)) .catch_unwind() .await { @@ -200,14 +211,14 @@ impl<'a> RequestHandler<'a> { &self, query_doc: Operation, tx_id: Option, - trace_id: Option, + traceparent: Option, ) -> query_core::Result { self.executor .execute( tx_id, query_doc, self.query_schema.clone(), - trace_id, + traceparent, self.engine_protocol, ) .await diff --git a/query-engine/schema/src/identifier_type.rs b/query-engine/schema/src/identifier_type.rs index d4cf309a299..0aa77140139 100644 --- a/query-engine/schema/src/identifier_type.rs +++ b/query-engine/schema/src/identifier_type.rs @@ -186,7 +186,12 @@ impl std::fmt::Display for IdentifierType { IdentifierType::ToOneRelationFilterInput(related_model, arity) => { let nullable = if arity.is_optional() { "Nullable" } else { "" }; - write!(f, "{}{}RelationFilter", capitalize(related_model.name()), nullable) + write!( + f, + "{}{}ScalarRelationFilter", + capitalize(related_model.name()), + nullable + ) } IdentifierType::ToOneCompositeFilterInput(ct, arity) => { let nullable = if arity.is_optional() { "Nullable" } else { "" }; diff --git a/query-engine/schema/test-schemas/odoo.prisma b/query-engine/schema/test-schemas/odoo.prisma index a7410606f4c..7da843cf2fd 100644 --- a/query-engine/schema/test-schemas/odoo.prisma +++ b/query-engine/schema/test-schemas/odoo.prisma @@ -1,5 +1,5 @@ datasource db { - provider = "postgresql" + provider = "postgres" url = env("DB_URL") } diff --git a/rust-toolchain.toml b/rust-toolchain.toml index e48263a1387..8f2d5ed3466 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,5 +1,5 @@ [toolchain] -channel = "1.80.1" +channel = "1.82.0" components = ["clippy", "rustfmt", "rust-src"] targets = [ # WASM target for serverless and edge environments. diff --git a/schema-engine/cli/Cargo.toml b/schema-engine/cli/Cargo.toml index bfb136f582d..65718f8c8cd 100644 --- a/schema-engine/cli/Cargo.toml +++ b/schema-engine/cli/Cargo.toml @@ -36,6 +36,9 @@ connection-string.workspace = true expect-test = "1.4.0" quaint = { workspace = true, features = ["all-native"] } +[build-dependencies] +build-utils.path = "../../libs/build-utils" + [[bin]] name = "schema-engine" path = "src/main.rs" diff --git a/schema-engine/cli/build.rs b/schema-engine/cli/build.rs index 2e8fe20c050..33aded23a4a 100644 --- a/schema-engine/cli/build.rs +++ b/schema-engine/cli/build.rs @@ -1,11 +1,3 @@ -use std::process::Command; - -fn store_git_commit_hash() { - let output = Command::new("git").args(["rev-parse", "HEAD"]).output().unwrap(); - let git_hash = String::from_utf8(output.stdout).unwrap(); - println!("cargo:rustc-env=GIT_HASH={git_hash}"); -} - fn main() { - store_git_commit_hash(); + build_utils::store_git_commit_hash_in_env(); } diff --git a/schema-engine/cli/tests/cli_tests.rs b/schema-engine/cli/tests/cli_tests.rs index 973345ac403..fd89b9bce74 100644 --- a/schema-engine/cli/tests/cli_tests.rs +++ b/schema-engine/cli/tests/cli_tests.rs @@ -51,7 +51,7 @@ where .or_else(|| panic_payload.downcast_ref::().map(|s| s.to_owned())) .unwrap_or_default(); - panic!("Error: '{}'", res) + panic!("Error: '{}'", res); } } } diff --git a/schema-engine/connectors/mongodb-schema-connector/Cargo.toml b/schema-engine/connectors/mongodb-schema-connector/Cargo.toml index 1978cc1f631..c5c8188b439 100644 --- a/schema-engine/connectors/mongodb-schema-connector/Cargo.toml +++ b/schema-engine/connectors/mongodb-schema-connector/Cargo.toml @@ -14,7 +14,7 @@ user-facing-errors = { path = "../../../libs/user-facing-errors", features = [ ] } enumflags2.workspace = true -futures = "0.3" +futures.workspace = true mongodb.workspace = true bson.workspace = true serde_json.workspace = true @@ -22,7 +22,7 @@ tokio.workspace = true tracing.workspace = true convert_case = "0.6.0" once_cell = "1.8.0" -regex = "1.7.3" +regex.workspace = true indoc.workspace = true [dev-dependencies] @@ -33,4 +33,4 @@ url.workspace = true expect-test = "1" names = { version = "0.12", default-features = false } itertools.workspace = true -indoc.workspace = true \ No newline at end of file +indoc.workspace = true diff --git a/schema-engine/connectors/sql-schema-connector/Cargo.toml b/schema-engine/connectors/sql-schema-connector/Cargo.toml index b5f70423a6e..a336495e743 100644 --- a/schema-engine/connectors/sql-schema-connector/Cargo.toml +++ b/schema-engine/connectors/sql-schema-connector/Cargo.toml @@ -34,10 +34,10 @@ chrono.workspace = true connection-string.workspace = true enumflags2.workspace = true once_cell = "1.3" -regex = "1" +regex.workspace = true serde_json.workspace = true tracing.workspace = true -tracing-futures = "0.2" +tracing-futures.workspace = true url.workspace = true either = "1.6" sqlformat = "0.2.1" diff --git a/schema-engine/connectors/sql-schema-connector/src/flavour/postgres.rs b/schema-engine/connectors/sql-schema-connector/src/flavour/postgres.rs index dca3b89f6f2..02752e491ee 100644 --- a/schema-engine/connectors/sql-schema-connector/src/flavour/postgres.rs +++ b/schema-engine/connectors/sql-schema-connector/src/flavour/postgres.rs @@ -5,12 +5,17 @@ use self::connection::*; use crate::SqlFlavour; use enumflags2::BitFlags; use indoc::indoc; -use quaint::{connector::PostgresUrl, Value}; +use once_cell::sync::Lazy; +use quaint::{ + connector::{PostgresUrl, PostgresWebSocketUrl}, + prelude::NativeConnectionInfo, + Value, +}; use schema_connector::{ migrations_directory::MigrationDirectory, BoxFuture, ConnectorError, ConnectorParams, ConnectorResult, Namespaces, }; use sql_schema_describer::SqlSchema; -use std::{borrow::Cow, collections::HashMap, future, time}; +use std::{borrow::Cow, collections::HashMap, future, str::FromStr, time}; use url::Url; use user_facing_errors::{ common::{DatabaseAccessDenied, DatabaseDoesNotExist}, @@ -28,9 +33,70 @@ SET enable_experimental_alter_column_type_general = true; type State = super::State, Connection)>; +#[derive(Debug, Clone)] +struct MigratePostgresUrl(PostgresUrl); + +static MIGRATE_WS_BASE_URL: Lazy> = Lazy::new(|| { + std::env::var("PRISMA_SCHEMA_ENGINE_WS_BASE_URL") + .map(Cow::Owned) + .unwrap_or_else(|_| Cow::Borrowed("wss://migrations.prisma-data.net/websocket")) +}); + +impl MigratePostgresUrl { + const WEBSOCKET_SCHEME: &'static str = "prisma+postgres"; + const API_KEY_PARAM: &'static str = "api_key"; + const DBNAME_PARAM: &'static str = "dbname"; + + fn new(url: Url) -> ConnectorResult { + let postgres_url = if url.scheme() == Self::WEBSOCKET_SCHEME { + let ws_url = Url::from_str(&MIGRATE_WS_BASE_URL).map_err(ConnectorError::url_parse_error)?; + let Some((_, api_key)) = url.query_pairs().find(|(name, _)| name == Self::API_KEY_PARAM) else { + return Err(ConnectorError::url_parse_error( + "Required `api_key` query string parameter was not provided in a connection URL", + )); + }; + + let dbname_override = url.query_pairs().find(|(name, _)| name == Self::DBNAME_PARAM); + let mut ws_url = PostgresWebSocketUrl::new(ws_url, api_key.into_owned()); + if let Some((_, dbname_override)) = dbname_override { + ws_url.override_db_name(dbname_override.into_owned()); + } + + Ok(PostgresUrl::WebSocket(ws_url)) + } else { + PostgresUrl::new_native(url) + } + .map_err(ConnectorError::url_parse_error)?; + + Ok(Self(postgres_url)) + } + + pub(super) fn host(&self) -> &str { + self.0.host() + } + + pub(super) fn port(&self) -> u16 { + self.0.port() + } + + pub(super) fn dbname(&self) -> &str { + self.0.dbname() + } + + pub(super) fn schema(&self) -> &str { + self.0.schema() + } +} + +impl From for NativeConnectionInfo { + fn from(value: MigratePostgresUrl) -> Self { + NativeConnectionInfo::Postgres(value.0) + } +} + struct Params { connector_params: ConnectorParams, - url: PostgresUrl, + url: MigratePostgresUrl, } /// The specific provider that was requested by the user. @@ -378,7 +444,7 @@ impl SqlFlavour for PostgresFlavour { .map_err(ConnectorError::url_parse_error)?; disable_postgres_statement_cache(&mut url)?; let connection_string = url.to_string(); - let url = PostgresUrl::new(url).map_err(ConnectorError::url_parse_error)?; + let url = MigratePostgresUrl::new(url)?; connector_params.connection_string = connection_string; let params = Params { connector_params, url }; self.state.set_params(params); @@ -460,7 +526,14 @@ impl SqlFlavour for PostgresFlavour { .connection_string .parse() .map_err(ConnectorError::url_parse_error)?; - shadow_database_url.set_path(&format!("/{shadow_database_name}")); + + if shadow_database_url.scheme() == MigratePostgresUrl::WEBSOCKET_SCHEME { + shadow_database_url + .query_pairs_mut() + .append_pair(MigratePostgresUrl::DBNAME_PARAM, &shadow_database_name); + } else { + shadow_database_url.set_path(&format!("/{shadow_database_name}")); + } let shadow_db_params = ConnectorParams { connection_string: shadow_database_url.to_string(), preview_features: params.connector_params.preview_features, @@ -510,7 +583,11 @@ impl SqlFlavour for PostgresFlavour { /// TL;DR, /// 1. pg >= 13 -> it works. /// 2. pg < 13 -> syntax error on WITH (FORCE), and then fail with db in use if pgbouncer is used. -async fn drop_db_try_force(conn: &mut Connection, url: &PostgresUrl, database_name: &str) -> ConnectorResult<()> { +async fn drop_db_try_force( + conn: &mut Connection, + url: &MigratePostgresUrl, + database_name: &str, +) -> ConnectorResult<()> { let drop_database = format!("DROP DATABASE IF EXISTS \"{database_name}\" WITH (FORCE)"); if let Err(err) = conn.raw_cmd(&drop_database, url).await { if let Some(msg) = err.message() { @@ -537,7 +614,7 @@ fn strip_schema_param_from_url(url: &mut Url) { /// Try to connect as an admin to a postgres database. We try to pick a default database from which /// we can create another database. -async fn create_postgres_admin_conn(mut url: Url) -> ConnectorResult<(Connection, PostgresUrl)> { +async fn create_postgres_admin_conn(mut url: Url) -> ConnectorResult<(Connection, MigratePostgresUrl)> { // "postgres" is the default database on most postgres installations, // "template1" is guaranteed to exist, and "defaultdb" is the only working // option on DigitalOcean managed postgres databases. @@ -547,7 +624,7 @@ async fn create_postgres_admin_conn(mut url: Url) -> ConnectorResult<(Connection for database_name in CANDIDATE_DEFAULT_DATABASES { url.set_path(&format!("/{database_name}")); - let postgres_url = PostgresUrl::new(url.clone()).unwrap(); + let postgres_url = MigratePostgresUrl(PostgresUrl::new_native(url.clone()).unwrap()); match Connection::new(url.clone()).await { // If the database does not exist, try the next one. Err(err) => match &err.error_code() { diff --git a/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/connection.rs b/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/connection.rs index 3ca9b673b0a..3a8f9fb6517 100644 --- a/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/connection.rs +++ b/schema-engine/connectors/sql-schema-connector/src/flavour/postgres/connection.rs @@ -4,8 +4,8 @@ use enumflags2::BitFlags; use indoc::indoc; use psl::PreviewFeature; use quaint::{ - connector::{self, tokio_postgres::error::ErrorPosition, PostgresUrl}, - prelude::{ConnectionInfo, NativeConnectionInfo, Queryable}, + connector::{self, tokio_postgres::error::ErrorPosition, MakeTlsConnectorManager, PostgresUrl}, + prelude::{ConnectionInfo, Queryable}, }; use schema_connector::{ConnectorError, ConnectorResult, Namespaces}; use sql_schema_describer::{postgres::PostgresSchemaExt, SqlSchema}; @@ -13,19 +13,22 @@ use user_facing_errors::{schema_engine::ApplyMigrationError, schema_engine::Data use crate::sql_renderer::IteratorJoin; +use super::MigratePostgresUrl; + pub(super) struct Connection(connector::PostgreSql); impl Connection { pub(super) async fn new(url: url::Url) -> ConnectorResult { - let url = PostgresUrl::new(url).map_err(|err| { - ConnectorError::user_facing(user_facing_errors::common::InvalidConnectionString { - details: err.to_string(), - }) - })?; + let url = MigratePostgresUrl::new(url)?; - let quaint = connector::PostgreSql::new(url.clone()) - .await - .map_err(quaint_err(&url))?; + let quaint = match url.0 { + PostgresUrl::Native(ref native_url) => { + let tls_manager = MakeTlsConnectorManager::new(native_url.as_ref().clone()); + connector::PostgreSql::new(native_url.as_ref().clone(), &tls_manager).await + } + PostgresUrl::WebSocket(ref ws_url) => connector::PostgreSql::new_with_websocket(ws_url.clone()).await, + } + .map_err(quaint_err(&url))?; let version = quaint.version().await.map_err(quaint_err(&url))?; @@ -116,12 +119,12 @@ impl Connection { Ok(schema) } - pub(super) async fn raw_cmd(&mut self, sql: &str, url: &PostgresUrl) -> ConnectorResult<()> { + pub(super) async fn raw_cmd(&mut self, sql: &str, url: &MigratePostgresUrl) -> ConnectorResult<()> { tracing::debug!(query_type = "raw_cmd", sql); self.0.raw_cmd(sql).await.map_err(quaint_err(url)) } - pub(super) async fn version(&mut self, url: &PostgresUrl) -> ConnectorResult> { + pub(super) async fn version(&mut self, url: &MigratePostgresUrl) -> ConnectorResult> { tracing::debug!(query_type = "version"); self.0.version().await.map_err(quaint_err(url)) } @@ -129,7 +132,7 @@ impl Connection { pub(super) async fn query( &mut self, query: quaint::ast::Query<'_>, - url: &PostgresUrl, + url: &MigratePostgresUrl, ) -> ConnectorResult { use quaint::visitor::Visitor; let (sql, params) = quaint::visitor::Postgres::build(query).unwrap(); @@ -140,7 +143,7 @@ impl Connection { &self, sql: &str, params: &[quaint::prelude::Value<'_>], - url: &PostgresUrl, + url: &MigratePostgresUrl, ) -> ConnectorResult { tracing::debug!(query_type = "query_raw", sql, ?params); self.0.query_raw(sql, params).await.map_err(quaint_err(url)) @@ -149,7 +152,7 @@ impl Connection { pub(super) async fn describe_query( &self, sql: &str, - url: &PostgresUrl, + url: &MigratePostgresUrl, ) -> ConnectorResult { tracing::debug!(query_type = "describe_query", sql); self.0.describe_query(sql).await.map_err(quaint_err(url)) @@ -237,11 +240,6 @@ fn normalize_sql_schema(schema: &mut SqlSchema, preview_features: BitFlags impl (Fn(quaint::error::Error) -> ConnectorError) + '_ { - |err| { - crate::flavour::quaint_error_to_connector_error( - err, - &ConnectionInfo::Native(NativeConnectionInfo::Postgres(url.clone())), - ) - } +fn quaint_err(url: &MigratePostgresUrl) -> impl (Fn(quaint::error::Error) -> ConnectorError) + '_ { + |err| crate::flavour::quaint_error_to_connector_error(err, &ConnectionInfo::Native(url.clone().into())) } diff --git a/schema-engine/connectors/sql-schema-connector/src/lib.rs b/schema-engine/connectors/sql-schema-connector/src/lib.rs index f78ac9b60fe..3641eeb9633 100644 --- a/schema-engine/connectors/sql-schema-connector/src/lib.rs +++ b/schema-engine/connectors/sql-schema-connector/src/lib.rs @@ -23,7 +23,7 @@ use migration_pair::MigrationPair; use psl::{datamodel_connector::NativeTypeInstance, parser_database::ScalarType, ValidatedSchema}; use quaint::connector::DescribedQuery; use schema_connector::{migrations_directory::MigrationDirectory, *}; -use sql_doc_parser::parse_sql_doc; +use sql_doc_parser::{parse_sql_doc, sanitize_sql}; use sql_migration::{DropUserDefinedType, DropView, SqlMigration, SqlMigrationStep}; use sql_schema_describer as sql; use std::{future, sync::Arc}; @@ -362,11 +362,12 @@ impl SchemaConnector for SqlSchemaConnector { input: IntrospectSqlQueryInput, ) -> BoxFuture<'_, ConnectorResult> { Box::pin(async move { + let sanitized_sql = sanitize_sql(&input.source); let DescribedQuery { parameters, columns, enum_names, - } = self.flavour.describe_query(&input.source).await?; + } = self.flavour.describe_query(&sanitized_sql).await?; let enum_names = enum_names.unwrap_or_default(); let sql_source = input.source.clone(); let parsed_doc = parse_sql_doc(&sql_source, enum_names.as_slice())?; @@ -397,7 +398,7 @@ impl SchemaConnector for SqlSchemaConnector { Ok(IntrospectSqlQueryOutput { name: input.name, - source: input.source, + source: sanitized_sql, documentation: parsed_doc.description().map(ToOwned::to_owned), parameters, result_columns: columns, diff --git a/schema-engine/connectors/sql-schema-connector/src/sql_doc_parser.rs b/schema-engine/connectors/sql-schema-connector/src/sql_doc_parser.rs index c52a3f09bbd..1bc8da8bacd 100644 --- a/schema-engine/connectors/sql-schema-connector/src/sql_doc_parser.rs +++ b/schema-engine/connectors/sql-schema-connector/src/sql_doc_parser.rs @@ -385,6 +385,15 @@ pub(crate) fn parse_sql_doc<'a>(sql: &'a str, enum_names: &'a [String]) -> Conne Ok(parsed_sql) } +/// Mysql-async poorly parses the sql input to support named parameters, which conflicts with our own syntax for overriding query parameters type and nullability. +/// This function removes all single-line comments from the sql input to avoid conflicts. +pub(crate) fn sanitize_sql(sql: &str) -> String { + sql.lines() + .map(|line| line.trim()) + .filter(|line| !line.starts_with("--")) + .join("\n") +} + #[cfg(test)] mod tests { use super::*; @@ -1196,4 +1205,29 @@ SELECT enum FROM "test_introspect_sql"."model" WHERE id = $1;"#, expected.assert_debug_eq(&res); } + + #[test] + fn sanitize_sql_test_1() { + use expect_test::expect; + + let sql = r#" + -- @description This query returns a user by it's id + -- @param {Int} $1:userId valid user identifier + -- @param {String} $2:parentId valid parent identifier + SELECT enum + FROM + "test_introspect_sql"."model" WHERE id = + $1; + "#; + + let expected = expect![[r#" + + SELECT enum + FROM + "test_introspect_sql"."model" WHERE id = + $1; + "#]]; + + expected.assert_eq(&sanitize_sql(sql)); + } } diff --git a/schema-engine/connectors/sql-schema-connector/src/sql_schema_differ/index.rs b/schema-engine/connectors/sql-schema-connector/src/sql_schema_differ/index.rs index 257d79b36df..441960df8c5 100644 --- a/schema-engine/connectors/sql-schema-connector/src/sql_schema_differ/index.rs +++ b/schema-engine/connectors/sql-schema-connector/src/sql_schema_differ/index.rs @@ -1,6 +1,11 @@ use sql_schema_describer::walkers::{IndexWalker, TableWalker}; pub(super) fn index_covers_fk(table: TableWalker<'_>, index: IndexWalker<'_>) -> bool { + // Only normal indexes can cover foreign keys. + if index.index_type() != sql_schema_describer::IndexType::Normal { + return false; + } + table.foreign_keys().any(|fk| { let fk_cols = fk.constrained_columns().map(|col| col.name()); let index_cols = index.column_names(); diff --git a/schema-engine/core/Cargo.toml b/schema-engine/core/Cargo.toml index 6814bf60ed2..6fb22d8a98c 100644 --- a/schema-engine/core/Cargo.toml +++ b/schema-engine/core/Cargo.toml @@ -21,7 +21,7 @@ serde_json.workspace = true tokio.workspace = true tracing.workspace = true tracing-subscriber = "0.3" -tracing-futures = "0.2" +tracing-futures.workspace = true url.workspace = true [build-dependencies] diff --git a/schema-engine/core/src/lib.rs b/schema-engine/core/src/lib.rs index b367ab0bfff..3c0a2bf6d6a 100644 --- a/schema-engine/core/src/lib.rs +++ b/schema-engine/core/src/lib.rs @@ -41,7 +41,7 @@ fn connector_for_connection_string( preview_features: BitFlags, ) -> CoreResult> { match connection_string.split(':').next() { - Some("postgres") | Some("postgresql") => { + Some("postgres") | Some("postgresql") | Some("prisma+postgres") => { let params = ConnectorParams { connection_string, preview_features, diff --git a/schema-engine/datamodel-renderer/Cargo.toml b/schema-engine/datamodel-renderer/Cargo.toml index ad1b0435d66..b74352589e1 100644 --- a/schema-engine/datamodel-renderer/Cargo.toml +++ b/schema-engine/datamodel-renderer/Cargo.toml @@ -6,7 +6,7 @@ edition = "2021" [dependencies] once_cell = "1.15.0" psl.workspace = true -regex = "1.6.0" +regex.workspace = true base64 = "0.13.1" [dev-dependencies] diff --git a/schema-engine/mongodb-schema-describer/Cargo.toml b/schema-engine/mongodb-schema-describer/Cargo.toml index d04c5ab9fb3..e8885e99082 100644 --- a/schema-engine/mongodb-schema-describer/Cargo.toml +++ b/schema-engine/mongodb-schema-describer/Cargo.toml @@ -8,5 +8,5 @@ edition = "2021" [dependencies] mongodb.workspace = true bson.workspace = true -futures = "0.3" +futures.workspace = true serde.workspace = true diff --git a/schema-engine/sql-introspection-tests/Cargo.toml b/schema-engine/sql-introspection-tests/Cargo.toml index 3d45c178f09..d0891b0bbaa 100644 --- a/schema-engine/sql-introspection-tests/Cargo.toml +++ b/schema-engine/sql-introspection-tests/Cargo.toml @@ -18,7 +18,7 @@ itertools.workspace = true enumflags2.workspace = true connection-string.workspace = true pretty_assertions = "1" -tracing-futures = "0.2" +tracing-futures.workspace = true tokio.workspace = true tracing.workspace = true indoc.workspace = true diff --git a/schema-engine/sql-migration-tests/Cargo.toml b/schema-engine/sql-migration-tests/Cargo.toml index 0c744699a26..6a345a36548 100644 --- a/schema-engine/sql-migration-tests/Cargo.toml +++ b/schema-engine/sql-migration-tests/Cargo.toml @@ -30,7 +30,7 @@ serde_json.workspace = true tempfile = "3.1.0" tokio.workspace = true tracing.workspace = true -tracing-futures = "0.2" +tracing-futures.workspace = true url.workspace = true quaint = { workspace = true, features = ["all-native"] } diff --git a/schema-engine/sql-migration-tests/src/assertions.rs b/schema-engine/sql-migration-tests/src/assertions.rs index e32554a451b..6a2598edfb3 100644 --- a/schema-engine/sql-migration-tests/src/assertions.rs +++ b/schema-engine/sql-migration-tests/src/assertions.rs @@ -191,13 +191,11 @@ impl SchemaAssertion { } fn print_context(&self) { - match &self.context { - Some(context) => println!("Test failure with context <{}>", context.red()), - None => {} + if let Some(context) = &self.context { + println!("Test failure with context <{}>", context.red()) } - match &self.description { - Some(description) => println!("{}: {}", "Description".bold(), description.italic()), - None => {} + if let Some(description) = &self.description { + println!("{}: {}", "Description".bold(), description.italic()) } } @@ -325,13 +323,11 @@ pub struct TableAssertion<'a> { impl<'a> TableAssertion<'a> { fn print_context(&self) { - match &self.context { - Some(context) => println!("Test failure with context <{}>", context.red()), - None => {} + if let Some(context) = &self.context { + println!("Test failure with context <{}>", context.red()) } - match &self.description { - Some(description) => println!("{}: {}", "Description".bold(), description.italic()), - None => {} + if let Some(description) = &self.description { + println!("{}: {}", "Description".bold(), description.italic()) } } diff --git a/schema-engine/sql-migration-tests/src/commands/schema_push.rs b/schema-engine/sql-migration-tests/src/commands/schema_push.rs index f7442b3a72c..f20121f7b3a 100644 --- a/schema-engine/sql-migration-tests/src/commands/schema_push.rs +++ b/schema-engine/sql-migration-tests/src/commands/schema_push.rs @@ -102,13 +102,11 @@ impl SchemaPushAssertion { } pub fn print_context(&self) { - match &self.context { - Some(context) => println!("Test failure with context <{}>", context.red()), - None => {} + if let Some(context) = &self.context { + println!("Test failure with context <{}>", context.red()) } - match &self.description { - Some(description) => println!("{}: {}", "Description".bold(), description.italic()), - None => {} + if let Some(description) = &self.description { + println!("{}: {}", "Description".bold(), description.italic()) } } diff --git a/schema-engine/sql-migration-tests/tests/migrations/diff.rs b/schema-engine/sql-migration-tests/tests/migrations/diff.rs index 0eadac39657..b9225bd7a22 100644 --- a/schema-engine/sql-migration-tests/tests/migrations/diff.rs +++ b/schema-engine/sql-migration-tests/tests/migrations/diff.rs @@ -7,6 +7,95 @@ use schema_core::{ use sql_migration_tests::{test_api::*, utils::to_schema_containers}; use std::sync::Arc; +#[test_connector(tags(Sqlite, Mysql, Postgres, CockroachDb, Mssql))] +fn from_unique_index_to_without(mut api: TestApi) { + let tempdir = tempfile::tempdir().unwrap(); + let host = Arc::new(TestConnectorHost::default()); + + api.connector.set_host(host.clone()); + + let from_schema = api.datamodel_with_provider( + r#" + model Post { + id Int @id + title String + author User? @relation(fields: [authorId], references: [id]) + authorId Int? @unique + // ^^^^^^^ this will be removed later + } + + model User { + id Int @id + name String? + posts Post[] + } + "#, + ); + + let to_schema = api.datamodel_with_provider( + r#" + model Post { + id Int @id + title String + author User? @relation(fields: [authorId], references: [id]) + authorId Int? + } + + model User { + id Int @id + name String? + posts Post[] + } + "#, + ); + + let from_file = write_file_to_tmp(&from_schema, &tempdir, "from"); + let to_file = write_file_to_tmp(&to_schema, &tempdir, "to"); + + api.diff(DiffParams { + exit_code: None, + from: DiffTarget::SchemaDatamodel(SchemasContainer { + files: vec![SchemaContainer { + path: from_file.to_string_lossy().into_owned(), + content: from_schema.to_string(), + }], + }), + shadow_database_url: None, + to: DiffTarget::SchemaDatamodel(SchemasContainer { + files: vec![SchemaContainer { + path: to_file.to_string_lossy().into_owned(), + content: to_schema.to_string(), + }], + }), + script: true, + }) + .unwrap(); + + let expected_printed_messages = if api.is_mysql() { + expect![[r#" + [ + "-- DropIndex\nDROP INDEX `Post_authorId_key` ON `Post`;\n", + ] + "#]] + } else if api.is_sqlite() || api.is_postgres() || api.is_cockroach() { + expect![[r#" + [ + "-- DropIndex\nDROP INDEX \"Post_authorId_key\";\n", + ] + "#]] + } else if api.is_mssql() { + expect![[r#" + [ + "BEGIN TRY\n\nBEGIN TRAN;\n\n-- DropIndex\nDROP INDEX [Post_authorId_key] ON [dbo].[Post];\n\nCOMMIT TRAN;\n\nEND TRY\nBEGIN CATCH\n\nIF @@TRANCOUNT > 0\nBEGIN\n ROLLBACK TRAN;\nEND;\nTHROW\n\nEND CATCH\n", + ] + "#]] + } else { + unreachable!() + }; + + expected_printed_messages.assert_debug_eq(&host.printed_messages.lock().unwrap()); +} + #[test_connector(tags(Sqlite))] fn diffing_postgres_schemas_when_initialized_on_sqlite(mut api: TestApi) { // We should get a postgres diff. diff --git a/schema-engine/sql-migration-tests/tests/migrations/mssql.rs b/schema-engine/sql-migration-tests/tests/migrations/mssql.rs index fc543f99227..12e8996cec9 100644 --- a/schema-engine/sql-migration-tests/tests/migrations/mssql.rs +++ b/schema-engine/sql-migration-tests/tests/migrations/mssql.rs @@ -158,7 +158,7 @@ fn mssql_apply_migrations_error_output(api: TestApi) { .split_terminator(" 0: ") .next() .unwrap() - .trim_end_matches(|c| c == '\n' || c == ' '); + .trim_end_matches(['\n', ' ']); expectation.assert_eq(first_segment) } diff --git a/schema-engine/sql-migration-tests/tests/query_introspection/docs.rs b/schema-engine/sql-migration-tests/tests/query_introspection/docs.rs index 8eac2fa9cc1..0e0c334f959 100644 --- a/schema-engine/sql-migration-tests/tests/query_introspection/docs.rs +++ b/schema-engine/sql-migration-tests/tests/query_introspection/docs.rs @@ -2,13 +2,13 @@ use super::utils::*; use sql_migration_tests::test_api::*; #[test_connector(tags(Postgres))] -fn parses_doc_complex(api: TestApi) { +fn parses_doc_complex_pg(api: TestApi) { api.schema_push(SIMPLE_SCHEMA).send().assert_green(); let expected = expect![[r#" IntrospectSqlQueryOutput { name: "test_1", - source: "\n -- @description some fancy query\n -- @param {Int} $1:myInt some integer\n -- @param {String}$2:myString? some string\n SELECT int FROM model WHERE int = $1 and string = $2;\n ", + source: "\nSELECT int FROM model WHERE int = $1 and string = $2;\n", documentation: Some( "some fancy query", ), @@ -50,6 +50,62 @@ fn parses_doc_complex(api: TestApi) { api.introspect_sql("test_1", sql).send_sync().expect_result(expected) } +#[test_connector(tags(Mysql))] +fn parses_doc_complex_mysql(api: TestApi) { + api.schema_push(SIMPLE_SCHEMA).send().assert_green(); + + let expected = expect![[r#" + IntrospectSqlQueryOutput { + name: "test_1", + source: "\nSELECT `int` FROM `model` WHERE `int` = ? and `string` = ?;\n", + documentation: Some( + "some fancy query", + ), + parameters: [ + IntrospectSqlQueryParameterOutput { + documentation: Some( + "some integer", + ), + name: "myInt", + typ: "int", + nullable: false, + }, + IntrospectSqlQueryParameterOutput { + documentation: Some( + "some string", + ), + name: "myString", + typ: "string", + nullable: true, + }, + ], + result_columns: [ + IntrospectSqlQueryColumnOutput { + name: "int", + typ: "int", + nullable: false, + }, + ], + } + "#]]; + + let sql = r#" + -- @description some fancy query + -- @param {Int} $1:myInt some integer + -- @param {String}$2:myString? some string + SELECT `int` FROM `model` WHERE `int` = ? and `string` = ?; + "#; + + let res = api.introspect_sql("test_1", sql).send_sync(); + + res.expect_result(expected); + + api.query_raw( + &res.output.source, + &[quaint::Value::int32(1), quaint::Value::text("hello")], + ); +} + #[test_connector(tags(Sqlite))] fn parses_doc_no_position(api: TestApi) { api.schema_push(SIMPLE_SCHEMA).send().assert_green(); @@ -57,7 +113,7 @@ fn parses_doc_no_position(api: TestApi) { let expected = expect![[r#" IntrospectSqlQueryOutput { name: "test_1", - source: "\n -- @param {String} :myInt some integer\n SELECT int FROM model WHERE int = :myInt and string = ?;\n ", + source: "\nSELECT int FROM model WHERE int = :myInt and string = ?;\n", documentation: None, parameters: [ IntrospectSqlQueryParameterOutput { @@ -100,7 +156,7 @@ fn parses_doc_no_alias(api: TestApi) { let expected = expect![[r#" IntrospectSqlQueryOutput { name: "test_1", - source: "\n -- @param {String} $2 some string\n SELECT int FROM model WHERE int = $1 and string = $2;\n ", + source: "\nSELECT int FROM model WHERE int = $1 and string = $2;\n", documentation: None, parameters: [ IntrospectSqlQueryParameterOutput { diff --git a/schema-engine/sql-schema-describer/Cargo.toml b/schema-engine/sql-schema-describer/Cargo.toml index 514eac9daec..17b8eae6368 100644 --- a/schema-engine/sql-schema-describer/Cargo.toml +++ b/schema-engine/sql-schema-describer/Cargo.toml @@ -14,11 +14,11 @@ enumflags2.workspace = true indexmap.workspace = true indoc.workspace = true once_cell = "1.3" -regex = "1.2" +regex.workspace = true serde.workspace = true tracing.workspace = true tracing-error = "0.2" -tracing-futures = "0.2" +tracing-futures.workspace = true quaint = { workspace = true, features = [ "all-native", "pooled",