diff --git a/source/extensions/common/wasm/context.cc b/source/extensions/common/wasm/context.cc index 80e477d701250..da994c4a04249 100644 --- a/source/extensions/common/wasm/context.cc +++ b/source/extensions/common/wasm/context.cc @@ -1562,12 +1562,15 @@ constexpr absl::string_view FailStreamResponseDetails = "wasm_fail_stream"; void Context::failStream(WasmStreamType stream_type) { switch (stream_type) { case WasmStreamType::Request: - if (decoder_callbacks_ && !local_reply_sent_) { - decoder_callbacks_->sendLocalReply(Envoy::Http::Code::ServiceUnavailable, "", nullptr, - Grpc::Status::WellKnownGrpcStatus::Unavailable, - FailStreamResponseDetails); - local_reply_sent_ = true; - } + // TODO(krinkinmu): currently, if Wasm crashes during the request processing + // we can't send local reply back to the client. The problem is that this + // local reply will be send through the filter chain and Wasm plugin will be + // called to process the response. However, because Wasm plugin crashed it will + // just ask Envoy to pause iteration and the local reply will never get sent. + // + // For now we are just closing the connection, because it's strictly better than + // stuck connection, even if it's not the best behavior we could think of. + closeStream(stream_type); break; case WasmStreamType::Response: if (encoder_callbacks_ && !local_reply_sent_) { @@ -1595,12 +1598,6 @@ void Context::failStream(WasmStreamType stream_type) { WasmResult Context::sendLocalResponse(uint32_t response_code, std::string_view body_text, Pairs additional_headers, uint32_t grpc_status, std::string_view details) { - // This flag is used to avoid calling sendLocalReply() twice, even if wasm code has this - // logic. We can't reuse "local_reply_sent_" here because it can't avoid calling nested - // sendLocalReply() during encodeHeaders(). - if (local_reply_hold_) { - return WasmResult::BadArgument; - } // "additional_headers" is a collection of string_views. These will no longer // be valid when "modify_headers" is finally called below, so we must // make copies of all the headers. @@ -1625,11 +1622,6 @@ WasmResult Context::sendLocalResponse(uint32_t response_code, std::string_view b modify_headers = std::move(modify_headers), grpc_status, details = StringUtil::replaceAllEmptySpace( absl::string_view(details.data(), details.size()))] { - // When the wasm vm fails, failStream() is called if the plugin is fail-closed, we need - // this flag to avoid calling sendLocalReply() twice. - if (local_reply_sent_) { - return; - } // C++, Rust and other SDKs use -1 (InvalidCode) as the default value if gRPC code is not set, // which should be mapped to nullopt in Envoy to prevent it from sending a grpc-status trailer // at all. @@ -1640,10 +1632,8 @@ WasmResult Context::sendLocalResponse(uint32_t response_code, std::string_view b } decoder_callbacks_->sendLocalReply(static_cast(response_code), body_text, modify_headers, grpc_status_code, details); - local_reply_sent_ = true; }); } - local_reply_hold_ = true; return WasmResult::Ok; } diff --git a/source/extensions/common/wasm/context.h b/source/extensions/common/wasm/context.h index 77dbba0e37727..58c57e6bbd2eb 100644 --- a/source/extensions/common/wasm/context.h +++ b/source/extensions/common/wasm/context.h @@ -431,7 +431,6 @@ class Context : public proxy_wasm::ContextBase, bool buffering_response_body_ = false; bool end_of_stream_ = false; bool local_reply_sent_ = false; - bool local_reply_hold_ = false; ProtobufWkt::Struct temporary_metadata_; // MB: must be a node-type map as we take persistent references to the entries. diff --git a/test/extensions/common/wasm/BUILD b/test/extensions/common/wasm/BUILD index 8b8531b46d46c..1bf7bc4659c55 100644 --- a/test/extensions/common/wasm/BUILD +++ b/test/extensions/common/wasm/BUILD @@ -55,6 +55,7 @@ envoy_cc_test( "//test/extensions/common/wasm/test_data:test_context_cpp_plugin", "//test/extensions/common/wasm/test_data:test_cpp_plugin", "//test/extensions/common/wasm/test_data:test_restriction_cpp_plugin", + "//test/mocks/local_reply:local_reply_mocks", "//test/mocks/server:server_mocks", "//test/test_common:environment_lib", "//test/test_common:registry_lib", diff --git a/test/extensions/common/wasm/test_data/test_context_cpp.cc b/test/extensions/common/wasm/test_data/test_context_cpp.cc index 0a18d36275d70..142c82bf1b127 100644 --- a/test/extensions/common/wasm/test_data/test_context_cpp.cc +++ b/test/extensions/common/wasm/test_data/test_context_cpp.cc @@ -94,22 +94,60 @@ FilterDataStatus DupReplyContext::onRequestBody(size_t, bool) { return FilterDataStatus::Continue; } -class PanicReplyContext : public Context { +class LocalReplyInRequestAndResponseContext : public Context { public: - explicit PanicReplyContext(uint32_t id, RootContext* root) : Context(id, root) {} + explicit LocalReplyInRequestAndResponseContext(uint32_t id, RootContext* root) : Context(id, root) {} + FilterHeadersStatus onRequestHeaders(uint32_t, bool) override; + FilterHeadersStatus onResponseHeaders(uint32_t, bool) override; +private: + EnvoyRootContext* root() { return static_cast(Context::root()); } +}; + +FilterHeadersStatus LocalReplyInRequestAndResponseContext::onRequestHeaders(uint32_t, bool) { + sendLocalResponse(200, "ok", "body", {}); + return FilterHeadersStatus::Continue; +} + +FilterHeadersStatus LocalReplyInRequestAndResponseContext::onResponseHeaders(uint32_t, bool) { + sendLocalResponse(200, "ok", "body", {}); + return FilterHeadersStatus::Continue; +} + +class PanicInRequestContext : public Context { +public: + explicit PanicInRequestContext(uint32_t id, RootContext* root) : Context(id, root) {} FilterDataStatus onRequestBody(size_t body_buffer_length, bool end_of_stream) override; private: EnvoyRootContext* root() { return static_cast(Context::root()); } }; -FilterDataStatus PanicReplyContext::onRequestBody(size_t, bool) { +FilterDataStatus PanicInRequestContext::onRequestBody(size_t, bool) { sendLocalResponse(200, "not send", "body", {}); - int* badptr = nullptr; - *badptr = 0; // NOLINT(clang-analyzer-core.NullDereference) + abort(); return FilterDataStatus::Continue; } +class PanicInResponseContext : public Context { +public: + explicit PanicInResponseContext(uint32_t id, RootContext* root) : Context(id, root) {} + FilterHeadersStatus onResponseHeaders(uint32_t, bool) override; + FilterHeadersStatus onRequestHeaders(uint32_t, bool) override; + +private: + EnvoyRootContext* root() { return static_cast(Context::root()); } +}; + +FilterHeadersStatus PanicInResponseContext::onRequestHeaders(uint32_t, bool) { + sendLocalResponse(200, "ok", "body", {}); + return FilterHeadersStatus::Continue; +} + +FilterHeadersStatus PanicInResponseContext::onResponseHeaders(uint32_t, bool) { + abort(); + return FilterHeadersStatus::Continue; +} + class InvalidGrpcStatusReplyContext : public Context { public: explicit InvalidGrpcStatusReplyContext(uint32_t id, RootContext* root) : Context(id, root) {} @@ -127,9 +165,15 @@ FilterDataStatus InvalidGrpcStatusReplyContext::onRequestBody(size_t size, bool) static RegisterContextFactory register_DupReplyContext(CONTEXT_FACTORY(DupReplyContext), ROOT_FACTORY(EnvoyRootContext), "send local reply twice"); -static RegisterContextFactory register_PanicReplyContext(CONTEXT_FACTORY(PanicReplyContext), +static RegisterContextFactory register_LocalReplyInRequestAndResponseContext(CONTEXT_FACTORY(LocalReplyInRequestAndResponseContext), + ROOT_FACTORY(EnvoyRootContext), + "local reply in request and response"); +static RegisterContextFactory register_PanicInRequestContext(CONTEXT_FACTORY(PanicInRequestContext), + ROOT_FACTORY(EnvoyRootContext), + "panic during request processing"); +static RegisterContextFactory register_PanicInResponseContext(CONTEXT_FACTORY(PanicInResponseContext), ROOT_FACTORY(EnvoyRootContext), - "panic after sending local reply"); + "panic during response processing"); static RegisterContextFactory register_InvalidGrpcStatusReplyContext(CONTEXT_FACTORY(InvalidGrpcStatusReplyContext), ROOT_FACTORY(EnvoyRootContext), diff --git a/test/extensions/common/wasm/wasm_test.cc b/test/extensions/common/wasm/wasm_test.cc index e12e68eccf7e4..155e28ac7f65c 100644 --- a/test/extensions/common/wasm/wasm_test.cc +++ b/test/extensions/common/wasm/wasm_test.cc @@ -1,11 +1,15 @@ +#include "envoy/http/filter.h" +#include "envoy/http/filter_factory.h" #include "envoy/server/lifecycle_notifier.h" #include "source/common/common/hex.h" #include "source/common/event/dispatcher_impl.h" +#include "source/common/http/filter_manager.h" #include "source/common/stats/isolated_store_impl.h" #include "source/extensions/common/wasm/wasm.h" #include "test/extensions/common/wasm/wasm_runtime.h" +#include "test/mocks/local_reply/mocks.h" #include "test/mocks/server/mocks.h" #include "test/mocks/stats/mocks.h" #include "test/mocks/upstream/mocks.h" @@ -1310,7 +1314,6 @@ class WasmCommonContextTest : public Common::Wasm::WasmHttpFilterTestBase< return new TestContext(wasm, plugin); }); } - void setupContext() { setupFilterBase(); } TestContext& rootContext() { return *static_cast(root_context_); } @@ -1392,8 +1395,8 @@ TEST_P(WasmCommonContextTest, EmptyContext) { root_context_->validateConfiguration("", plugin_); } -// test that we don't send the local reply twice, even though it's specified in the wasm code -TEST_P(WasmCommonContextTest, DuplicateLocalReply) { +// test that in case -1 is send from wasm it propagate nullopt +TEST_P(WasmCommonContextTest, ProcessInvalidGRPCStatusCodeAsEmptyInLocalReply) { std::string code; if (std::get<0>(GetParam()) != "null") { code = TestEnvironment::readFileToStringForTest(TestEnvironment::substitute(absl::StrCat( @@ -1404,92 +1407,200 @@ TEST_P(WasmCommonContextTest, DuplicateLocalReply) { } EXPECT_FALSE(code.empty()); - setup(code, "context", "send local reply twice"); + setup(code, "context", "send local reply grpc"); setupContext(); EXPECT_CALL(decoder_callbacks_, encodeHeaders_(_, _)) .WillOnce([this](Http::ResponseHeaderMap&, bool) { context().onResponseHeaders(0, false); }); - EXPECT_CALL(decoder_callbacks_, - sendLocalReply(Envoy::Http::Code::OK, testing::Eq("body"), _, _, testing::Eq("ok"))); + EXPECT_CALL(decoder_callbacks_, sendLocalReply(Envoy::Http::Code::OK, testing::Eq("body"), _, + testing::Eq(absl::nullopt), testing::Eq("ok"))); // Create in-VM context. context().onCreate(); EXPECT_EQ(proxy_wasm::FilterDataStatus::StopIterationNoBuffer, context().onRequestBody(0, false)); } -// test that we don't send the local reply twice when the wasm code panics -TEST_P(WasmCommonContextTest, LocalReplyWhenPanic) { +// test that in case valid grpc status is send from wasm it propagate as it is +TEST_P(WasmCommonContextTest, ProcessValidGRPCStatusCodeAsEmptyInLocalReply) { std::string code; if (std::get<0>(GetParam()) != "null") { code = TestEnvironment::readFileToStringForTest(TestEnvironment::substitute(absl::StrCat( "{{ test_rundir }}/test/extensions/common/wasm/test_data/test_context_cpp.wasm"))); } else { - // no need test the Null VM plugin. - return; + // The name of the Null VM plugin. + code = "CommonWasmTestContextCpp"; } EXPECT_FALSE(code.empty()); - setup(code, "context", "panic after sending local reply"); + setup(code, "context", "send local reply grpc"); setupContext(); - // In the case of VM failure, failStream is called, so we need to make sure that we don't send the - // local reply twice. + EXPECT_CALL(decoder_callbacks_, encodeHeaders_(_, _)) + .WillOnce([this](Http::ResponseHeaderMap&, bool) { context().onResponseHeaders(0, false); }); EXPECT_CALL(decoder_callbacks_, - sendLocalReply(Envoy::Http::Code::ServiceUnavailable, testing::Eq(""), _, - testing::Eq(Grpc::Status::WellKnownGrpcStatus::Unavailable), - testing::Eq("wasm_fail_stream"))); + sendLocalReply(Envoy::Http::Code::OK, testing::Eq("body"), _, + testing::Eq(Grpc::Status::WellKnownGrpcStatus::PermissionDenied), + testing::Eq("ok"))); // Create in-VM context. context().onCreate(); - EXPECT_EQ(proxy_wasm::FilterDataStatus::StopIterationNoBuffer, context().onRequestBody(0, false)); + EXPECT_EQ(proxy_wasm::FilterDataStatus::StopIterationNoBuffer, context().onRequestBody(1, false)); } -// test that in case -1 is send from wasm it propagate nullopt -TEST_P(WasmCommonContextTest, ProcessInvalidGRPCStatusCodeAsEmptyInLocalReply) { +class WasmLocalReplyTest : public WasmCommonContextTest { +public: + WasmLocalReplyTest() = default; + + void setup(const std::string& code, std::string vm_configuration, std::string root_id = "") { + WasmCommonContextTest::setup(code, vm_configuration, root_id); + filter_manager_ = std::make_unique( + filter_manager_callbacks_, dispatcher_, connection_, 0, nullptr, true, 10000, + filter_factory_, local_reply_, protocol_, time_source_, filter_state_, overload_manager_); + request_headers_ = Http::RequestHeaderMapPtr{ + new Http::TestRequestHeaderMapImpl{{":path", "/"}, {":method", "GET"}}}; + request_data_ = Envoy::Buffer::OwnedImpl("body"); + } + + Http::StreamFilterSharedPtr filter() { return context_; } + + Http::FilterFactoryCb createWasmFilter() { + return [this](Http::FilterChainFactoryCallbacks& callbacks) { + callbacks.addStreamFilter(filter()); + }; + } + + void setupContext() { + WasmCommonContextTest::setupContext(); + ON_CALL(filter_factory_, createFilterChain(_)) + .WillByDefault(Invoke([this](Http::FilterChainManager& manager) -> bool { + auto factory = createWasmFilter(); + manager.applyFilterFactoryCb({}, factory); + return true; + })); + ON_CALL(filter_manager_callbacks_, requestHeaders()) + .WillByDefault(Return(makeOptRef(*request_headers_))); + filter_manager_->createFilterChain(); + filter_manager_->requestHeadersInitialized(); + } + + std::unique_ptr filter_manager_; + NiceMock filter_manager_callbacks_; + NiceMock dispatcher_; + NiceMock connection_; + NiceMock filter_factory_; + NiceMock local_reply_; + Http::Protocol protocol_{Http::Protocol::Http2}; + NiceMock time_source_; + StreamInfo::FilterStateSharedPtr filter_state_ = + std::make_shared(StreamInfo::FilterState::LifeSpan::Connection); + NiceMock overload_manager_; + Http::RequestHeaderMapPtr request_headers_; + Envoy::Buffer::OwnedImpl request_data_; +}; + +INSTANTIATE_TEST_SUITE_P(Runtimes, WasmLocalReplyTest, + Envoy::Extensions::Common::Wasm::runtime_and_cpp_values); + +TEST_P(WasmLocalReplyTest, DuplicateLocalReply) { std::string code; if (std::get<0>(GetParam()) != "null") { code = TestEnvironment::readFileToStringForTest(TestEnvironment::substitute(absl::StrCat( "{{ test_rundir }}/test/extensions/common/wasm/test_data/test_context_cpp.wasm"))); } else { - // The name of the Null VM plugin. - code = "CommonWasmTestContextCpp"; + // Skip the Null plugin + return; } EXPECT_FALSE(code.empty()); - setup(code, "context", "send local reply grpc"); + setup(code, "context", "send local reply twice"); setupContext(); - EXPECT_CALL(decoder_callbacks_, encodeHeaders_(_, _)) - .WillOnce([this](Http::ResponseHeaderMap&, bool) { context().onResponseHeaders(0, false); }); - EXPECT_CALL(decoder_callbacks_, sendLocalReply(Envoy::Http::Code::OK, testing::Eq("body"), _, - testing::Eq(absl::nullopt), testing::Eq("ok"))); - // Create in-VM context. - context().onCreate(); - EXPECT_EQ(proxy_wasm::FilterDataStatus::StopIterationNoBuffer, context().onRequestBody(0, false)); + // Even if sendLocalReply is called multiple times it should only generate a single + // response to the client, so encodeHeaders should only be called once + EXPECT_CALL(filter_manager_callbacks_, encodeHeaders(_, _)); + EXPECT_CALL(filter_manager_callbacks_, endStream()); + filter_manager_->decodeHeaders(*request_headers_, false); + filter_manager_->decodeData(request_data_, false); + filter_manager_->destroyFilters(); } -// test that in case valid grpc status is send from wasm it propagate as it is -TEST_P(WasmCommonContextTest, ProcessValidGRPCStatusCodeAsEmptyInLocalReply) { +TEST_P(WasmLocalReplyTest, LocalReplyInRequestAndResponse) { std::string code; if (std::get<0>(GetParam()) != "null") { code = TestEnvironment::readFileToStringForTest(TestEnvironment::substitute(absl::StrCat( "{{ test_rundir }}/test/extensions/common/wasm/test_data/test_context_cpp.wasm"))); } else { - // The name of the Null VM plugin. code = "CommonWasmTestContextCpp"; } EXPECT_FALSE(code.empty()); - setup(code, "context", "send local reply grpc"); + setup(code, "context", "local reply in request and response"); setupContext(); - EXPECT_CALL(decoder_callbacks_, encodeHeaders_(_, _)) - .WillOnce([this](Http::ResponseHeaderMap&, bool) { context().onResponseHeaders(0, false); }); - EXPECT_CALL(decoder_callbacks_, - sendLocalReply(Envoy::Http::Code::OK, testing::Eq("body"), _, - testing::Eq(Grpc::Status::WellKnownGrpcStatus::PermissionDenied), - testing::Eq("ok"))); - // Create in-VM context. - context().onCreate(); - EXPECT_EQ(proxy_wasm::FilterDataStatus::StopIterationNoBuffer, context().onRequestBody(1, false)); + EXPECT_CALL(filter_manager_callbacks_, encodeHeaders(_, _)); + EXPECT_CALL(filter_manager_callbacks_, endStream()); + filter_manager_->decodeHeaders(*request_headers_, false); + filter_manager_->decodeData(request_data_, false); + filter_manager_->destroyFilters(); +} + +TEST_P(WasmLocalReplyTest, PanicDuringRequest) { + std::string code; + if (std::get<0>(GetParam()) != "null") { + code = TestEnvironment::readFileToStringForTest(TestEnvironment::substitute(absl::StrCat( + "{{ test_rundir }}/test/extensions/common/wasm/test_data/test_context_cpp.wasm"))); + } else { + // Let's not cause crashes in Null VM + return; + } + EXPECT_FALSE(code.empty()); + + setup(code, "context", "panic during request processing"); + setupContext(); + + // When Wasm VM crashes during request processing it just resets the connection. + // + // NOTE: We cannot currently send a local reply in this case. If we did, this local + // reply would go through the filter chain and Wasm plugin would be called to process + // it. + // + // However, because the Wasm VM crashed, it can't meaningfully process response and + // the filter chain will be paused and connection will get stuck. That is unless Wasm + // is configured in the fail open mode. + // + // So instead of getting the connection stuck, we just reset the stream and that's + // what this test tries to confirm. + EXPECT_CALL(filter_manager_callbacks_, resetStream(_, _)); + + filter_manager_->decodeHeaders(*request_headers_, false); + filter_manager_->decodeData(request_data_, false); + filter_manager_->destroyFilters(); +} + +TEST_P(WasmLocalReplyTest, PanicDuringResponse) { + std::string code; + if (std::get<0>(GetParam()) != "null") { + code = TestEnvironment::readFileToStringForTest(TestEnvironment::substitute(absl::StrCat( + "{{ test_rundir }}/test/extensions/common/wasm/test_data/test_context_cpp.wasm"))); + } else { + // Let's not cause crashes in Null VM + return; + } + EXPECT_FALSE(code.empty()); + + setup(code, "context", "panic during response processing"); + setupContext(); + + // Unlike the case above, when we panic during the response processing, we actually + // end up sending the locally generated reply to the client. That's because when + // we generate local reply during response processing, Envoy does not send it through + // the filter chain and instead sends it directly. It's because Envoy does not want + // to corrupt the state of the plugins by calling them with the local reply after + // they already started processing the response. + EXPECT_CALL(filter_manager_callbacks_, encodeHeaders(_, _)); + EXPECT_CALL(filter_manager_callbacks_, endStream()); + + filter_manager_->decodeHeaders(*request_headers_, false); + filter_manager_->decodeData(request_data_, false); + filter_manager_->destroyFilters(); } } // namespace Wasm diff --git a/test/test_common/wasm_base.h b/test/test_common/wasm_base.h index 3e2b34393c06a..d68b38f3c013d 100644 --- a/test/test_common/wasm_base.h +++ b/test/test_common/wasm_base.h @@ -142,12 +142,12 @@ template class WasmHttpFilterTestBase : public W auto wasm = WasmTestBase::wasm_ ? WasmTestBase::wasm_->wasm().get() : nullptr; int root_context_id = wasm ? wasm->getRootContext(WasmTestBase::plugin_, false)->id() : 0; context_ = - std::make_unique(wasm, root_context_id, WasmTestBase::plugin_handle_); + std::make_shared(wasm, root_context_id, WasmTestBase::plugin_handle_); context_->setDecoderFilterCallbacks(decoder_callbacks_); context_->setEncoderFilterCallbacks(encoder_callbacks_); } - std::unique_ptr context_; + std::shared_ptr context_; NiceMock decoder_callbacks_; NiceMock encoder_callbacks_; NiceMock request_stream_info_; @@ -160,12 +160,12 @@ class WasmNetworkFilterTestBase : public WasmTestBase { auto wasm = WasmTestBase::wasm_ ? WasmTestBase::wasm_->wasm().get() : nullptr; int root_context_id = wasm ? wasm->getRootContext(WasmTestBase::plugin_, false)->id() : 0; context_ = - std::make_unique(wasm, root_context_id, WasmTestBase::plugin_handle_); + std::make_shared(wasm, root_context_id, WasmTestBase::plugin_handle_); context_->initializeReadFilterCallbacks(read_filter_callbacks_); context_->initializeWriteFilterCallbacks(write_filter_callbacks_); } - std::unique_ptr context_; + std::shared_ptr context_; NiceMock read_filter_callbacks_; NiceMock write_filter_callbacks_; };