diff --git a/source/extensions/common/wasm/context.cc b/source/extensions/common/wasm/context.cc index 27389da0394d..e617007300a9 100644 --- a/source/extensions/common/wasm/context.cc +++ b/source/extensions/common/wasm/context.cc @@ -1598,12 +1598,6 @@ void Context::failStream(WasmStreamType stream_type) { WasmResult Context::sendLocalResponse(uint32_t response_code, std::string_view body_text, Pairs additional_headers, uint32_t grpc_status, std::string_view details) { - // This flag is used to avoid calling sendLocalReply() twice, even if wasm code has this - // logic. We can't reuse "local_reply_sent_" here because it can't avoid calling nested - // sendLocalReply() during encodeHeaders(). - if (local_reply_hold_) { - return WasmResult::BadArgument; - } // "additional_headers" is a collection of string_views. These will no longer // be valid when "modify_headers" is finally called below, so we must // make copies of all the headers. @@ -1628,11 +1622,6 @@ WasmResult Context::sendLocalResponse(uint32_t response_code, std::string_view b modify_headers = std::move(modify_headers), grpc_status, details = StringUtil::replaceAllEmptySpace( absl::string_view(details.data(), details.size()))] { - // When the wasm vm fails, failStream() is called if the plugin is fail-closed, we need - // this flag to avoid calling sendLocalReply() twice. - if (local_reply_sent_) { - return; - } // C++, Rust and other SDKs use -1 (InvalidCode) as the default value if gRPC code is not set, // which should be mapped to nullopt in Envoy to prevent it from sending a grpc-status trailer // at all. @@ -1643,10 +1632,8 @@ WasmResult Context::sendLocalResponse(uint32_t response_code, std::string_view b } decoder_callbacks_->sendLocalReply(static_cast(response_code), body_text, modify_headers, grpc_status_code, details); - local_reply_sent_ = true; }); } - local_reply_hold_ = true; return WasmResult::Ok; } diff --git a/source/extensions/common/wasm/context.h b/source/extensions/common/wasm/context.h index fc9aaa25a61a..9c0923f3e5ef 100644 --- a/source/extensions/common/wasm/context.h +++ b/source/extensions/common/wasm/context.h @@ -439,7 +439,6 @@ class Context : public proxy_wasm::ContextBase, bool buffering_response_body_ = false; bool end_of_stream_ = false; bool local_reply_sent_ = false; - bool local_reply_hold_ = false; ProtobufWkt::Struct temporary_metadata_; // MB: must be a node-type map as we take persistent references to the entries. diff --git a/test/extensions/common/wasm/BUILD b/test/extensions/common/wasm/BUILD index af4f0146e916..89e632d33558 100644 --- a/test/extensions/common/wasm/BUILD +++ b/test/extensions/common/wasm/BUILD @@ -56,6 +56,7 @@ envoy_cc_test( "//test/extensions/common/wasm/test_data:test_context_cpp_plugin", "//test/extensions/common/wasm/test_data:test_cpp_plugin", "//test/extensions/common/wasm/test_data:test_restriction_cpp_plugin", + "//test/mocks/local_reply:local_reply_mocks", "//test/mocks/server:server_mocks", "//test/test_common:environment_lib", "//test/test_common:registry_lib", diff --git a/test/extensions/common/wasm/test_data/test_context_cpp.cc b/test/extensions/common/wasm/test_data/test_context_cpp.cc index 78bacb86d6be..7f53b8592b5c 100644 --- a/test/extensions/common/wasm/test_data/test_context_cpp.cc +++ b/test/extensions/common/wasm/test_data/test_context_cpp.cc @@ -94,22 +94,59 @@ FilterDataStatus DupReplyContext::onRequestBody(size_t, bool) { return FilterDataStatus::Continue; } -class PanicReplyContext : public Context { +class LocalReplyInRequestAndResponseContext : public Context { public: - explicit PanicReplyContext(uint32_t id, RootContext* root) : Context(id, root) {} + explicit LocalReplyInRequestAndResponseContext(uint32_t id, RootContext* root) : Context(id, root) {} + FilterHeadersStatus onRequestHeaders(uint32_t, bool) override; + FilterHeadersStatus onResponseHeaders(uint32_t, bool) override; +private: + EnvoyRootContext* root() { return static_cast(Context::root()); } +}; + +FilterHeadersStatus LocalReplyInRequestAndResponseContext::onRequestHeaders(uint32_t, bool) { + sendLocalResponse(200, "ok", "body", {}); + return FilterHeadersStatus::Continue; +} + +FilterHeadersStatus LocalReplyInRequestAndResponseContext::onResponseHeaders(uint32_t, bool) { + sendLocalResponse(200, "ok", "body", {}); + return FilterHeadersStatus::Continue; +} + +class PanicInRequestContext : public Context { +public: + explicit PanicInRequestContext(uint32_t id, RootContext* root) : Context(id, root) {} FilterDataStatus onRequestBody(size_t body_buffer_length, bool end_of_stream) override; private: EnvoyRootContext* root() { return static_cast(Context::root()); } }; -FilterDataStatus PanicReplyContext::onRequestBody(size_t, bool) { - sendLocalResponse(200, "not send", "body", {}); - int* badptr = nullptr; - *badptr = 0; // NOLINT(clang-analyzer-core.NullDereference) +FilterDataStatus PanicInRequestContext::onRequestBody(size_t, bool) { + abort(); return FilterDataStatus::Continue; } +class PanicInResponseContext : public Context { +public: + explicit PanicInResponseContext(uint32_t id, RootContext* root) : Context(id, root) {} + FilterHeadersStatus onResponseHeaders(uint32_t, bool) override; + FilterHeadersStatus onRequestHeaders(uint32_t, bool) override; + +private: + EnvoyRootContext* root() { return static_cast(Context::root()); } +}; + +FilterHeadersStatus PanicInResponseContext::onRequestHeaders(uint32_t, bool) { + sendLocalResponse(200, "ok", "body", {}); + return FilterHeadersStatus::Continue; +} + +FilterHeadersStatus PanicInResponseContext::onResponseHeaders(uint32_t, bool) { + abort(); + return FilterHeadersStatus::Continue; +} + class InvalidGrpcStatusReplyContext : public Context { public: explicit InvalidGrpcStatusReplyContext(uint32_t id, RootContext* root) : Context(id, root) {} @@ -128,9 +165,15 @@ FilterDataStatus InvalidGrpcStatusReplyContext::onRequestBody(size_t size, bool) static RegisterContextFactory register_DupReplyContext(CONTEXT_FACTORY(DupReplyContext), ROOT_FACTORY(EnvoyRootContext), "send local reply twice"); -static RegisterContextFactory register_PanicReplyContext(CONTEXT_FACTORY(PanicReplyContext), +static RegisterContextFactory register_LocalReplyInRequestAndResponseContext(CONTEXT_FACTORY(LocalReplyInRequestAndResponseContext), + ROOT_FACTORY(EnvoyRootContext), + "local reply in request and response"); +static RegisterContextFactory register_PanicInRequestContext(CONTEXT_FACTORY(PanicInRequestContext), + ROOT_FACTORY(EnvoyRootContext), + "panic during request processing"); +static RegisterContextFactory register_PanicInResponseContext(CONTEXT_FACTORY(PanicInResponseContext), ROOT_FACTORY(EnvoyRootContext), - "panic after sending local reply"); + "panic during response processing"); static RegisterContextFactory register_InvalidGrpcStatusReplyContext(CONTEXT_FACTORY(InvalidGrpcStatusReplyContext), diff --git a/test/extensions/common/wasm/wasm_test.cc b/test/extensions/common/wasm/wasm_test.cc index 6ed8bafb0ad8..b52f09993a3f 100644 --- a/test/extensions/common/wasm/wasm_test.cc +++ b/test/extensions/common/wasm/wasm_test.cc @@ -1,12 +1,16 @@ +#include "envoy/http/filter.h" +#include "envoy/http/filter_factory.h" #include "envoy/server/lifecycle_notifier.h" #include "source/common/common/base64.h" #include "source/common/common/hex.h" #include "source/common/event/dispatcher_impl.h" +#include "source/common/http/filter_manager.h" #include "source/common/stats/isolated_store_impl.h" #include "source/extensions/common/wasm/wasm.h" #include "test/extensions/common/wasm/wasm_runtime.h" +#include "test/mocks/local_reply/mocks.h" #include "test/mocks/server/mocks.h" #include "test/mocks/stats/mocks.h" #include "test/mocks/upstream/mocks.h" @@ -1313,7 +1317,6 @@ class WasmCommonContextTest : public Common::Wasm::WasmHttpFilterTestBase< return new TestContext(wasm, plugin); }); } - void setupContext() { setupFilterBase(); } TestContext& rootContext() { return *static_cast(root_context_); } @@ -1395,30 +1398,6 @@ TEST_P(WasmCommonContextTest, EmptyContext) { root_context_->validateConfiguration("", plugin_); } -// test that we don't send the local reply twice, even though it's specified in the wasm code -TEST_P(WasmCommonContextTest, DuplicateLocalReply) { - std::string code; - if (std::get<0>(GetParam()) != "null") { - code = TestEnvironment::readFileToStringForTest(TestEnvironment::substitute(absl::StrCat( - "{{ test_rundir }}/test/extensions/common/wasm/test_data/test_context_cpp.wasm"))); - } else { - // The name of the Null VM plugin. - code = "CommonWasmTestContextCpp"; - } - EXPECT_FALSE(code.empty()); - - setup(code, "context", "send local reply twice"); - setupContext(); - EXPECT_CALL(decoder_callbacks_, encodeHeaders_(_, _)) - .WillOnce([this](Http::ResponseHeaderMap&, bool) { context().onResponseHeaders(0, false); }); - EXPECT_CALL(decoder_callbacks_, - sendLocalReply(Envoy::Http::Code::OK, testing::Eq("body"), _, _, testing::Eq("ok"))); - - // Create in-VM context. - context().onCreate(); - EXPECT_EQ(proxy_wasm::FilterDataStatus::StopIterationNoBuffer, context().onRequestBody(0, false)); -} - // test that we don't send the local reply twice when the wasm code panics TEST_P(WasmCommonContextTest, LocalReplyWhenPanic) { std::string code; @@ -1426,12 +1405,12 @@ TEST_P(WasmCommonContextTest, LocalReplyWhenPanic) { code = TestEnvironment::readFileToStringForTest(TestEnvironment::substitute(absl::StrCat( "{{ test_rundir }}/test/extensions/common/wasm/test_data/test_context_cpp.wasm"))); } else { - // no need test the Null VM plugin. + // Let's not cause crashes in Null VM return; } EXPECT_FALSE(code.empty()); - setup(code, "context", "panic after sending local reply"); + setup(code, "context", "panic during request processing"); setupContext(); // In the case of VM failure, failStream is called, so we need to make sure that we don't send the // local reply twice. @@ -1495,6 +1474,125 @@ TEST_P(WasmCommonContextTest, ProcessValidGRPCStatusCodeAsEmptyInLocalReply) { EXPECT_EQ(proxy_wasm::FilterDataStatus::StopIterationNoBuffer, context().onRequestBody(1, false)); } +class WasmLocalReplyTest : public WasmCommonContextTest { +public: + WasmLocalReplyTest() = default; + + void setup(const std::string& code, std::string vm_configuration, std::string root_id = "") { + WasmCommonContextTest::setup(code, vm_configuration, root_id); + filter_manager_ = std::make_unique( + filter_manager_callbacks_, dispatcher_, connection_, 0, nullptr, true, 10000, + filter_factory_, local_reply_, protocol_, time_source_, filter_state_, overload_manager_); + request_headers_ = Http::RequestHeaderMapPtr{ + new Http::TestRequestHeaderMapImpl{{":path", "/"}, {":method", "GET"}}}; + request_data_ = Envoy::Buffer::OwnedImpl("body"); + } + + Http::StreamFilterSharedPtr filter() { return context_; } + + Http::FilterFactoryCb createWasmFilter() { + return [this](Http::FilterChainFactoryCallbacks& callbacks) { + callbacks.addStreamFilter(filter()); + }; + } + + void setupContext() { + WasmCommonContextTest::setupContext(); + ON_CALL(filter_factory_, createFilterChain(_)) + .WillByDefault(Invoke([this](Http::FilterChainManager& manager) -> bool { + auto factory = createWasmFilter(); + manager.applyFilterFactoryCb({}, factory); + return true; + })); + ON_CALL(filter_manager_callbacks_, requestHeaders()) + .WillByDefault(Return(makeOptRef(*request_headers_))); + filter_manager_->createFilterChain(); + filter_manager_->requestHeadersInitialized(); + } + + std::unique_ptr filter_manager_; + NiceMock filter_manager_callbacks_; + NiceMock dispatcher_; + NiceMock connection_; + NiceMock filter_factory_; + NiceMock local_reply_; + Http::Protocol protocol_{Http::Protocol::Http2}; + NiceMock time_source_; + StreamInfo::FilterStateSharedPtr filter_state_ = + std::make_shared(StreamInfo::FilterState::LifeSpan::Connection); + NiceMock overload_manager_; + Http::RequestHeaderMapPtr request_headers_; + Envoy::Buffer::OwnedImpl request_data_; +}; + +INSTANTIATE_TEST_SUITE_P(Runtimes, WasmLocalReplyTest, + Envoy::Extensions::Common::Wasm::runtime_and_cpp_values); + +TEST_P(WasmLocalReplyTest, DuplicateLocalReply) { + std::string code; + if (std::get<0>(GetParam()) != "null") { + code = TestEnvironment::readFileToStringForTest(TestEnvironment::substitute(absl::StrCat( + "{{ test_rundir }}/test/extensions/common/wasm/test_data/test_context_cpp.wasm"))); + } else { + // The name of the Null VM plugin. + code = "CommonWasmTestContextCpp"; + } + EXPECT_FALSE(code.empty()); + + setup(code, "context", "send local reply twice"); + setupContext(); + + // Even if sendLocalReply is called multiple times it should only generate a single + // response to the client, so encodeHeaders should only be called once + EXPECT_CALL(filter_manager_callbacks_, encodeHeaders(_, _)); + EXPECT_CALL(filter_manager_callbacks_, endStream()); + filter_manager_->decodeHeaders(*request_headers_, false); + filter_manager_->decodeData(request_data_, true); + filter_manager_->destroyFilters(); +} + +TEST_P(WasmLocalReplyTest, LocalReplyInRequestAndResponse) { + std::string code; + if (std::get<0>(GetParam()) != "null") { + code = TestEnvironment::readFileToStringForTest(TestEnvironment::substitute(absl::StrCat( + "{{ test_rundir }}/test/extensions/common/wasm/test_data/test_context_cpp.wasm"))); + } else { + code = "CommonWasmTestContextCpp"; + } + EXPECT_FALSE(code.empty()); + + setup(code, "context", "local reply in request and response"); + setupContext(); + + EXPECT_CALL(filter_manager_callbacks_, encodeHeaders(_, _)); + EXPECT_CALL(filter_manager_callbacks_, endStream()); + filter_manager_->decodeHeaders(*request_headers_, false); + filter_manager_->decodeData(request_data_, true); + filter_manager_->destroyFilters(); +} + +TEST_P(WasmLocalReplyTest, PanicDuringResponse) { + std::string code; + if (std::get<0>(GetParam()) != "null") { + code = TestEnvironment::readFileToStringForTest(TestEnvironment::substitute(absl::StrCat( + "{{ test_rundir }}/test/extensions/common/wasm/test_data/test_context_cpp.wasm"))); + } else { + // Let's not cause crashes in Null VM + return; + } + EXPECT_FALSE(code.empty()); + + setup(code, "context", "panic during response processing"); + setupContext(); + + EXPECT_CALL(filter_manager_callbacks_, encodeHeaders(_, _)); + EXPECT_CALL(filter_manager_callbacks_, endStream()); + + filter_manager_->decodeHeaders(*request_headers_, false); + filter_manager_->decodeData(request_data_, true); + filter_manager_->destroyFilters(); +} + class PluginConfigTest : public testing::TestWithParam> { public: PluginConfigTest() = default; @@ -1548,7 +1646,7 @@ TEST_P(PluginConfigTest, FailReloadPolicy) { const std::string plugin_config_yaml = fmt::format( R"EOF( name: "{}_test_wasm_reload" -root_id: "panic after sending local reply" +root_id: "panic during request processing" failure_policy: FAIL_RELOAD vm_config: runtime: "envoy.wasm.runtime.{}" @@ -1667,7 +1765,7 @@ TEST_P(PluginConfigTest, FailClosedPolicy) { const std::string plugin_config_yaml = fmt::format( R"EOF( name: "{}_test_wasm_reload" -root_id: "panic after sending local reply" +root_id: "panic during request processing" failure_policy: FAIL_CLOSED vm_config: runtime: "envoy.wasm.runtime.{}" @@ -1746,7 +1844,7 @@ TEST_P(PluginConfigTest, FailUnspecifiedPolicy) { const std::string plugin_config_yaml = fmt::format( R"EOF( name: "{}_test_wasm_reload" -root_id: "panic after sending local reply" +root_id: "panic during request processing" vm_config: runtime: "envoy.wasm.runtime.{}" configuration: @@ -1824,7 +1922,7 @@ TEST_P(PluginConfigTest, FailOpenPolicy) { const std::string plugin_config_yaml = fmt::format( R"EOF( name: "{}_test_wasm_reload" -root_id: "panic after sending local reply" +root_id: "panic during request processing" failure_policy: FAIL_OPEN vm_config: runtime: "envoy.wasm.runtime.{}" diff --git a/test/test_common/wasm_base.h b/test/test_common/wasm_base.h index 4d211dd9c532..a79ae2155a4a 100644 --- a/test/test_common/wasm_base.h +++ b/test/test_common/wasm_base.h @@ -142,12 +142,12 @@ template class WasmHttpFilterTestBase : public W auto wasm = WasmTestBase::wasm_ ? WasmTestBase::wasm_->wasm().get() : nullptr; int root_context_id = wasm ? wasm->getRootContext(WasmTestBase::plugin_, false)->id() : 0; context_ = - std::make_unique(wasm, root_context_id, WasmTestBase::plugin_handle_); + std::make_shared(wasm, root_context_id, WasmTestBase::plugin_handle_); context_->setDecoderFilterCallbacks(decoder_callbacks_); context_->setEncoderFilterCallbacks(encoder_callbacks_); } - std::unique_ptr context_; + std::shared_ptr context_; NiceMock decoder_callbacks_; NiceMock encoder_callbacks_; NiceMock request_stream_info_; @@ -160,12 +160,12 @@ class WasmNetworkFilterTestBase : public WasmTestBase { auto wasm = WasmTestBase::wasm_ ? WasmTestBase::wasm_->wasm().get() : nullptr; int root_context_id = wasm ? wasm->getRootContext(WasmTestBase::plugin_, false)->id() : 0; context_ = - std::make_unique(wasm, root_context_id, WasmTestBase::plugin_handle_); + std::make_shared(wasm, root_context_id, WasmTestBase::plugin_handle_); context_->initializeReadFilterCallbacks(read_filter_callbacks_); context_->initializeWriteFilterCallbacks(write_filter_callbacks_); } - std::unique_ptr context_; + std::shared_ptr context_; NiceMock read_filter_callbacks_; NiceMock write_filter_callbacks_; };