diff --git a/src/tateyama/endpoint/common/worker_common.h b/src/tateyama/endpoint/common/worker_common.h index 932c6612..c3d47cde 100644 --- a/src/tateyama/endpoint/common/worker_common.h +++ b/src/tateyama/endpoint/common/worker_common.h @@ -43,7 +43,24 @@ namespace tateyama::endpoint::common { class worker_common { -public: +protected: + enum class shutdown_consequence : std::uint32_t { + /** + * @brief no shutdown, continue normal processing. + */ + keep_working = 0U, + + /** + * @brief postpone the shutdown. + */ + postpone, + + /** + * @brief shutdown immediately. + */ + immediate, + }; + enum class connection_type : std::uint32_t { /** * @brief undefined type. @@ -61,6 +78,7 @@ class worker_common { stream, }; +public: worker_common(connection_type con, std::size_t session_id, std::string_view conn_info, std::shared_ptr session) : connection_type_(con), session_id_(session_id), @@ -117,7 +135,7 @@ class worker_common { // for session management const std::shared_ptr session_; // NOLINT - bool cancel_requested_to_all_responses{}; // NOLINT + bool cancel_requested_to_all_responses_{}; // NOLINT bool handshake(tateyama::api::server::request* req, tateyama::api::server::response* res) { if (req->service_id() != tateyama::framework::service_id_endpoint_broker) { @@ -251,30 +269,44 @@ class worker_common { responses_.erase(slot); } - bool handle_shutdown() { - if (auto type = session_context_->shutdown_request(); type != tateyama::session::shutdown_request_type::nothing) { - if (type == tateyama::session::shutdown_request_type::forceful && !cancel_requested_to_all_responses) { - foreach_response([](tateyama::endpoint::common::response& e){e.cancel();}); - cancel_requested_to_all_responses = true; + shutdown_consequence handle_shutdown() { + switch (session_context_->shutdown_request()) { + case tateyama::session::shutdown_request_type::graceful: + if (has_incomplete_response()) { + return shutdown_consequence::postpone; + } + return shutdown_consequence::immediate; + case tateyama::session::shutdown_request_type::forceful: + if (!cancel_requested_to_all_responses_) { + foreach_response([this](tateyama::endpoint::common::response& e){ + e.cancel(); + notify_client(&e, tateyama::proto::diagnostics::Code::OPERATION_CANCELED, ""); + }); + cancel_requested_to_all_responses_ = true; } - return true; + return shutdown_consequence::immediate; + default: + return shutdown_consequence::keep_working; } - return false; } bool foreach_response(const std::function& func) { - bool rv{false}; - std::lock_guard lock(mtx_responses_); - for (auto it{responses_.begin()}, end{responses_.end()}; it != end; ) { - if (auto r = it->second.lock(); r) { - func(*r); - rv = true; - ++it; - } else { - it = responses_.erase(it); + std::vector> targets{}; + { + std::lock_guard lock(mtx_responses_); + for (auto it{responses_.begin()}, end{responses_.end()}; it != end; ) { + if (auto r = it->second.lock(); r) { + targets.emplace_back(r); + ++it; + } else { + it = responses_.erase(it); + } } } - return rv; + for (auto &&e: targets) { + func(*e); + } + return !targets.empty(); } bool shutdown_request(tateyama::session::shutdown_request_type type) noexcept { diff --git a/src/tateyama/endpoint/ipc/bootstrap/server_wires_impl.h b/src/tateyama/endpoint/ipc/bootstrap/server_wires_impl.h index 347e8651..4887fe8f 100644 --- a/src/tateyama/endpoint/ipc/bootstrap/server_wires_impl.h +++ b/src/tateyama/endpoint/ipc/bootstrap/server_wires_impl.h @@ -468,8 +468,8 @@ class server_wire_container_impl : public server_wire_container std::lock_guard lock(mtx_); wire_->write(bip_buffer_, from, header); } - void set_shutdown() { - wire_->set_shutdown(); + void notify_shutdown() override { + wire_->notify_shutdown(); } // for client @@ -555,9 +555,9 @@ class server_wire_container_impl : public server_wire_container return garbage_collector_impl_.get(); } - void notify() { + void notify_shutdown() { request_wire_.notify(); - response_wire_.set_shutdown(); + response_wire_.notify_shutdown(); } [[nodiscard]] bool terminate_requested() const { diff --git a/src/tateyama/endpoint/ipc/bootstrap/worker.cpp b/src/tateyama/endpoint/ipc/bootstrap/worker.cpp index d99bd0d8..0003dd9b 100644 --- a/src/tateyama/endpoint/ipc/bootstrap/worker.cpp +++ b/src/tateyama/endpoint/ipc/bootstrap/worker.cpp @@ -29,6 +29,10 @@ void Worker::run() { while(true) { auto hdr = request_wire_container_->peep(); if (hdr.get_length() == 0 && hdr.get_idx() == tateyama::common::wire::message_header::null_request) { + if (handle_shutdown() == shutdown_consequence::immediate) { + VLOG_LP(log_trace) << "received and completed shutdown request: session_id = " << std::to_string(session_id_); + return; + } if (request_wire_container_->terminate_requested()) { VLOG_LP(log_trace) << "received shutdown request: session_id = " << std::to_string(session_id_); return; @@ -52,7 +56,8 @@ void Worker::run() { try { auto h = request_wire_container_->peep(); if (h.get_length() == 0 && h.get_idx() == tateyama::common::wire::message_header::null_request) { - if (handle_shutdown()) { + if (handle_shutdown() == shutdown_consequence::immediate) { + wire_->get_response_wire().notify_shutdown(); VLOG_LP(log_trace) << "received and completed shutdown request: session_id = " << std::to_string(session_id_); break; } @@ -65,19 +70,26 @@ void Worker::run() { auto request = std::make_shared(*wire_, h, database_info_, session_info_); std::size_t index = h.get_idx(); auto response = std::make_shared(wire_, h.get_idx(), [this, index](){remove_response(index);}); + bool break_while{false}; if (request->service_id() != tateyama::framework::service_id_endpoint_broker) { - if (!handle_shutdown()) { + switch (handle_shutdown()) { + case shutdown_consequence::keep_working: register_response(index, static_cast>(response)); if (!service_(static_cast>(request), static_cast>(std::move(response)))) { VLOG_LP(log_info) << "terminate worker because service returns an error"; - break; + break_while = true; } - } else { + break; + case shutdown_consequence::postpone: notify_client(response.get(), tateyama::proto::diagnostics::SESSION_CLOSED, ""); - if (!has_incomplete_response()) { - break; - } + break; + case shutdown_consequence::immediate: + notify_client(response.get(), tateyama::proto::diagnostics::SESSION_CLOSED, ""); + wire_->get_response_wire().notify_shutdown(); + break_while = true; + VLOG_LP(log_trace) << "received and completed shutdown request: session_id = " << std::to_string(session_id_); + break; } } else { if (!endpoint_service(static_cast>(request), @@ -87,6 +99,9 @@ void Worker::run() { break; } } + if (break_while) { + break; + } request->dispose(); request = nullptr; wire_->get_garbage_collector()->dump(); @@ -106,7 +121,7 @@ bool Worker::terminate(tateyama::session::shutdown_request_type type) { VLOG_LP(log_trace) << "send terminate request: session_id = " << std::to_string(session_id_); auto rv = shutdown_request(type); - wire_->notify(); + wire_->notify_shutdown(); return rv; } diff --git a/src/tateyama/endpoint/ipc/server_wires.h b/src/tateyama/endpoint/ipc/server_wires.h index a5d9641c..22b8d376 100644 --- a/src/tateyama/endpoint/ipc/server_wires.h +++ b/src/tateyama/endpoint/ipc/server_wires.h @@ -47,6 +47,7 @@ class server_wire_container response_wire_container& operator = (response_wire_container&&) = default; virtual void write(const char*, tateyama::common::wire::response_header) = 0; + virtual void notify_shutdown() = 0; }; class resultset_wire_container; using resultset_wire_deleter_type = void(*)(resultset_wire_container*); diff --git a/src/tateyama/endpoint/ipc/wire.h b/src/tateyama/endpoint/ipc/wire.h index 36782615..56c952d1 100644 --- a/src/tateyama/endpoint/ipc/wire.h +++ b/src/tateyama/endpoint/ipc/wire.h @@ -372,7 +372,7 @@ inline static std::int64_t n_cap(std::int64_t timeout) { // for request class unidirectional_message_wire : public simple_wire { - constexpr static std::size_t watch_interval = 5; + constexpr static std::size_t watch_interval = 2; public: unidirectional_message_wire(boost::interprocess::managed_shared_memory* managed_shm_ptr, std::size_t capacity) : simple_wire(managed_shm_ptr, capacity) {} @@ -497,7 +497,7 @@ class unidirectional_response_wire : public simple_wire { [[nodiscard]] response_header::msg_type get_type() const { return header_received_.get_type(); } - void set_shutdown() noexcept { + void notify_shutdown() noexcept { shutdown_.store(true); } @@ -509,7 +509,7 @@ class unidirectional_response_wire : public simple_wire { c_empty_.notify_one(); } } - [[nodiscard]] bool get_shutdown() const noexcept { + [[nodiscard]] bool check_shutdown() const noexcept { return shutdown_.load(); } diff --git a/src/tateyama/endpoint/stream/bootstrap/stream_worker.cpp b/src/tateyama/endpoint/stream/bootstrap/stream_worker.cpp index 4caad148..c9af7f05 100644 --- a/src/tateyama/endpoint/stream/bootstrap/stream_worker.cpp +++ b/src/tateyama/endpoint/stream/bootstrap/stream_worker.cpp @@ -53,6 +53,7 @@ void stream_worker::run() session_stream_->close(); return; } + session_stream_->change_slot_size(max_result_sets_); break; } @@ -77,24 +78,28 @@ void stream_worker::run() std::string payload{}; switch (session_stream_->await(slot, payload)) { - case tateyama::endpoint::stream::stream_socket::await_result::payload: { auto request = std::make_shared(*session_stream_, payload, database_info_, session_info_); auto response = std::make_shared(session_stream_, slot, [this, slot](){remove_response(slot);}); + bool break_while{false}; if (request->service_id() != tateyama::framework::service_id_endpoint_broker) { - if (!handle_shutdown()) { + switch (handle_shutdown()) { + case shutdown_consequence::keep_working: register_response(slot, static_cast>(response)); if(!service_(static_cast>(request), static_cast>(std::move(response)))) { VLOG_LP(log_info) << "terminate worker because service returns an error"; - break; + break_while = true; } - } else { + break; + case shutdown_consequence::postpone: notify_client(response.get(), tateyama::proto::diagnostics::SESSION_CLOSED, ""); - if (!has_incomplete_response()) { - break; - } + break; + case shutdown_consequence::immediate: + notify_client(response.get(), tateyama::proto::diagnostics::SESSION_CLOSED, ""); + break_while = true; + break; } } else { if (!endpoint_service(static_cast>(request), @@ -105,25 +110,26 @@ void stream_worker::run() } } request = nullptr; - continue; + if (!break_while) { + continue; + } + break; } case tateyama::endpoint::stream::stream_socket::await_result::timeout: - { - if (handle_shutdown()) { + if (handle_shutdown() == shutdown_consequence::immediate) { VLOG_LP(log_trace) << "received and completed shutdown request: session_id = " << std::to_string(session_id_); break; } continue; - } - default: - session_stream_->close(); + default: // some error break; } - break; } + session_stream_->close(); + #ifdef ENABLE_ALTIMETER tateyama::endpoint::altimeter::session_end(database_info_, session_info_); #endif diff --git a/src/tateyama/endpoint/stream/stream.h b/src/tateyama/endpoint/stream/stream.h index 8ea29577..a7e3962a 100644 --- a/src/tateyama/endpoint/stream/stream.h +++ b/src/tateyama/endpoint/stream/stream.h @@ -193,7 +193,7 @@ class stream_socket private: int socket_; static constexpr std::size_t N_FDS = 1; - static constexpr int TIMEOUT_MS = 1000; // 1000(mS) + static constexpr int TIMEOUT_MS = 2000; // 2000(mS) struct pollfd fds_[N_FDS]{}; // NOLINT bool session_closed_{false}; diff --git a/src/tateyama/session/resource/bridge.cpp b/src/tateyama/session/resource/bridge.cpp index c697ff54..a0484b4a 100644 --- a/src/tateyama/session/resource/bridge.cpp +++ b/src/tateyama/session/resource/bridge.cpp @@ -108,7 +108,7 @@ std::optional bridge::get(std:: return tateyama::proto::session::diagnostic::ErrorCode::SESSION_NOT_FOUND; } -std::optional bridge::shutdown(std::string_view session_specifier, [[maybe_unused]] shutdown_request_type type) { +std::optional bridge::shutdown(std::string_view session_specifier, shutdown_request_type type) { session_context::numeric_id_type numeric_id{}; try { auto opt = find_only_one_session(session_specifier, numeric_id); diff --git a/test/tateyama/endpoint/ipc/ipc_client.cpp b/test/tateyama/endpoint/ipc/ipc_client.cpp index 26c0cfc5..25a5560b 100644 --- a/test/tateyama/endpoint/ipc/ipc_client.cpp +++ b/test/tateyama/endpoint/ipc/ipc_client.cpp @@ -1,5 +1,5 @@ /* - * Copyright 2018-2023 Project Tsurugi. + * Copyright 2018-2024 Project Tsurugi. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,8 +15,6 @@ */ #include "ipc_client.h" -#include - namespace tateyama::endpoint::ipc { ipc_client::ipc_client(std::shared_ptr const &cfg, tateyama::proto::endpoint::request::Handshake& hs) @@ -45,7 +43,7 @@ ipc_client::ipc_client(std::string_view name, std::size_t session_id, tateyama:: handshake(); } -static constexpr tsubakuro::common::wire::message_header::index_type ipc_test_index = 1234; +static constexpr tsubakuro::common::wire::message_header::index_type ipc_test_index = 135; void ipc_client::send(const std::size_t tag, const std::string &message) { request_header_content hdr { session_id_, tag }; @@ -78,13 +76,11 @@ bool parse_response_header(std::string_view input, parse_response_result &result } void ipc_client::receive(std::string &message) { - receive(message, static_cast(0), false); -} -void ipc_client::receive(std::string &message, tateyama::proto::framework::response::Header::PayloadType type) { - receive(message, type, true); + tateyama::proto::framework::response::Header::PayloadType type{}; + receive(message, type); } -void ipc_client::receive(std::string &message, tateyama::proto::framework::response::Header::PayloadType type, bool do_check) { +void ipc_client::receive(std::string &message, tateyama::proto::framework::response::Header::PayloadType &type) { tsubakuro::common::wire::response_header header; int ntry = 0; bool ok = false; @@ -94,6 +90,9 @@ void ipc_client::receive(std::string &message, tateyama::proto::framework::respo header = response_wire_->await(); ok = true; } catch (const std::runtime_error &ex) { + if (response_wire_->check_shutdown()) { + throw std::runtime_error("server shutdown"); + } std::cout << ex.what() << std::endl; ntry++; if (ntry >= 100) { @@ -109,12 +108,10 @@ void ipc_client::receive(std::string &message, tateyama::proto::framework::respo // parse_response_result result; parse_response_header(r_msg, result); - if (do_check) { - EXPECT_EQ(type, result.payload_type_); - } // ASSERT_TRUE(parse_response_header(r_msg, result)); // EXPECT_EQ(session_id_, result.session_id_); message = result.payload_; + type = result.payload_type_; } resultset_wires_container* ipc_client::create_resultset_wires() { diff --git a/test/tateyama/endpoint/ipc/ipc_client.h b/test/tateyama/endpoint/ipc/ipc_client.h index 27287686..48cc94c2 100644 --- a/test/tateyama/endpoint/ipc/ipc_client.h +++ b/test/tateyama/endpoint/ipc/ipc_client.h @@ -40,7 +40,7 @@ class ipc_client { } void send(const std::size_t tag, const std::string &message); void receive(std::string &message); - void receive(std::string &message, tateyama::proto::framework::response::Header::PayloadType type); + void receive(std::string &message, tateyama::proto::framework::response::Header::PayloadType& type); resultset_wires_container* create_resultset_wires(); void dispose_resultset_wires(resultset_wires_container *rwc); diff --git a/test/tateyama/endpoint/ipc/ipc_session_test.cpp b/test/tateyama/endpoint/ipc/ipc_session_test.cpp index b9653f1d..99ea59da 100644 --- a/test/tateyama/endpoint/ipc/ipc_session_test.cpp +++ b/test/tateyama/endpoint/ipc/ipc_session_test.cpp @@ -18,12 +18,13 @@ #include "tateyama/endpoint/header_utils.h" #include #include "tateyama/status/resource/database_info_impl.h" +#include #include "ipc_client.h" #include namespace tateyama::server { -class ipc_listener_for_test { +class ipc_listener_for_session_test { public: static void run(tateyama::endpoint::ipc::bootstrap::Worker& worker) { worker.invoke([&]{worker.run();}); @@ -38,7 +39,7 @@ namespace tateyama::endpoint::ipc { static constexpr std::size_t my_session_id = 123; -static constexpr std::string_view database_name = "ipc_session_test"; +static constexpr std::string_view database_name = "ipc_sessionsession_test"; static constexpr std::size_t datachannel_buffer_size = 64 * 1024; static constexpr std::string_view request_test_message = "abcdefgh"; static constexpr std::string_view response_test_message = "opqrstuvwxyz"; @@ -55,12 +56,18 @@ class session_service : public tateyama::framework::routing_service { req_ = req; res_ = res; // do not respond to the request message in this test + requested_ = true; return true; } + bool requested() const { + return requested_; + } + private: std::shared_ptr req_{}; std::shared_ptr res_{}; + bool requested_{}; }; class ipc_session_test : public ::testing::Test { @@ -72,49 +79,91 @@ class ipc_session_test : public ::testing::Test { session_name += "-"; session_name += std::to_string(my_session_id); auto wire = std::make_shared(session_name, "dummy_mutex_file_name", datachannel_buffer_size, 16); - worker_ = std::make_unique(service_, my_session_id, wire, database_info_, nullptr); - tateyama::server::ipc_listener_for_test::run(*worker_); + session_bridge_ = std::make_shared(); + worker_ = std::make_unique(service_, my_session_id, wire, database_info_, session_bridge_); + tateyama::server::ipc_listener_for_session_test::run(*worker_); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + client_ = std::make_unique(database_name, my_session_id); } virtual void TearDown() { - worker_->terminate(); - tateyama::server::ipc_listener_for_test::wait(*worker_); + tateyama::server::ipc_listener_for_session_test::wait(*worker_); rv_ = system("if [ -f /dev/shm/ipc_session_test ]; then rm -f /dev/shm/ipc_session_test; fi"); } int rv_; -public: +protected: tateyama::status_info::resource::database_info_impl database_info_{database_name}; session_service service_{}; - -private: std::unique_ptr worker_{}; + std::shared_ptr session_bridge_{}; + std::unique_ptr client_{}; }; -TEST_F(ipc_session_test, cancel) { - std::this_thread::sleep_for(std::chrono::milliseconds(500)); - +TEST_F(ipc_session_test, cancel_request) { // client part (send request) - auto client = std::make_unique(database_name, my_session_id); - client->send(0, std::string(request_test_message)); // we do not care service_id nor request message here + client_->send(0, std::string(request_test_message)); // we do not care service_id nor request message here // client part (send cancel) tateyama::proto::endpoint::request::Cancel cancel{}; tateyama::proto::endpoint::request::Request endpoint_request{}; endpoint_request.set_allocated_cancel(&cancel); - client->send(tateyama::framework::service_id_endpoint_broker, endpoint_request.SerializeAsString()); + client_->send(tateyama::framework::service_id_endpoint_broker, endpoint_request.SerializeAsString()); endpoint_request.release_cancel(); // client part (receive) std::string res{}; - client->receive(res, tateyama::proto::framework::response::Header::SERVER_DIAGNOSTICS); + tateyama::proto::framework::response::Header::PayloadType type{}; + client_->receive(res, type); + EXPECT_EQ(type, tateyama::proto::framework::response::Header::SERVER_DIAGNOSTICS); tateyama::proto::diagnostics::Record response{}; if(!response.ParseFromString(res)) { FAIL(); } EXPECT_EQ(response.code(), tateyama::proto::diagnostics::Code::OPERATION_CANCELED); + + worker_->terminate(); +} + +TEST_F(ipc_session_test, shutdown_after_request) { + // client part (send request) + client_->send(0, std::string(request_test_message)); // we do not care service_id nor request message here + while (!service_.requested()) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + + // shutdown request + session_bridge_->shutdown(std::string(":") + std::to_string(my_session_id), session::shutdown_request_type::forceful); + + // client part (receive) + std::string res{}; + tateyama::proto::framework::response::Header::PayloadType type{}; + client_->receive(res, type); + EXPECT_EQ(type, tateyama::proto::framework::response::Header::SERVER_DIAGNOSTICS); + tateyama::proto::diagnostics::Record response{}; + if(!response.ParseFromString(res)) { + FAIL(); + } + EXPECT_EQ(response.code(), tateyama::proto::diagnostics::Code::OPERATION_CANCELED); +} + +TEST_F(ipc_session_test, shutdown_before_request) { + // shutdown request + session_bridge_->shutdown(std::string(":") + std::to_string(my_session_id), session::shutdown_request_type::forceful); + + // ensure shutdown request has been processed by the worker + std::this_thread::sleep_for(std::chrono::milliseconds(2500)); + + // client part (send request) + client_->send(0, std::string(request_test_message)); // we do not care service_id nor request message here + + // client part (receive) + std::string res{}; + tateyama::proto::framework::response::Header::PayloadType type{}; + + EXPECT_THROW(client_->receive(res, type), std::runtime_error); } } // namespace tateyama::endpoint::ipc diff --git a/test/tateyama/endpoint/stream/stream_client.cpp b/test/tateyama/endpoint/stream/stream_client.cpp index 4507a00c..dea874dc 100644 --- a/test/tateyama/endpoint/stream/stream_client.cpp +++ b/test/tateyama/endpoint/stream/stream_client.cpp @@ -13,13 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include #include // for inet_addr() #include // for bezero() #include "stream_client.h" -#include - namespace tateyama::api::endpoint::stream { stream_client::stream_client(tateyama::proto::endpoint::request::Handshake& hs) : endpoint_handshake_(hs) @@ -43,7 +42,7 @@ stream_client::stream_client(tateyama::proto::endpoint::request::Handshake& hs) handshake(); } -void +bool stream_client::send(const std::uint8_t type, const std::uint16_t slot, std::string_view message) { std::uint8_t header[7]; // NOLINT @@ -56,13 +55,18 @@ stream_client::send(const std::uint8_t type, const std::uint16_t slot, std::stri header[4] = (length >> 8) & 0xff; header[5] = (length >> 16) & 0xff; header[6] = (length >> 24) & 0xff; - ::send(sockfd_, header, 7, 0); + if (::send(sockfd_, header, 7, MSG_NOSIGNAL) < 0) { + return false; + } if (length > 0) { - ::send(sockfd_, message.data(), length, 0); + if (::send(sockfd_, message.data(), length, MSG_NOSIGNAL) < 0) { + return false; + } } + return true; } -void stream_client::send(const std::size_t tag, std::string_view message) { +bool stream_client::send(const std::size_t tag, std::string_view message) { ::tateyama::proto::framework::request::Header hdr{}; hdr.set_service_id(tag); std::stringstream ss{}; @@ -73,7 +77,7 @@ void stream_client::send(const std::size_t tag, std::string_view message) { throw std::runtime_error("payload serialize error"); } auto request_message = ss.str(); - send(REQUEST_SESSION_PAYLOAD, 1, request_message); + return send(REQUEST_SESSION_PAYLOAD, 1, request_message); } struct parse_response_result { @@ -95,42 +99,44 @@ static bool parse_response_header(std::string_view input, parse_response_result } void -stream_client::receive(std::string& message) { - receive(message, static_cast(0), false); -} -void -stream_client::receive(std::string &message, tateyama::proto::framework::response::Header::PayloadType type) { - receive(message, type, true); +stream_client::receive(std::string &message) { + tateyama::proto::framework::response::Header::PayloadType type{}; + receive(message, type); } + void -stream_client::receive(std::string& message, tateyama::proto::framework::response::Header::PayloadType type, bool do_check) { +stream_client::receive(std::string& message, tateyama::proto::framework::response::Header::PayloadType& type) { std::uint8_t data[4]; // NOLINT - recv(sockfd_, &type_, 1, 0); - - recv(sockfd_, data, 2, 0); + if (::recv(sockfd_, &type_, 1, 0) < 0) { + throw std::runtime_error("error in recv()"); + } + if (::recv(sockfd_, data, 2, 0) < 0) { + throw std::runtime_error("error in recv()"); + } slot_ = data[0] | (data[1] << 8); if (type_ == RESPONSE_RESULT_SET_PAYLOAD) { - ::recv(sockfd_, &writer_, 1, 0); + if (::recv(sockfd_, &writer_, 1, 0) < 0) { + throw std::runtime_error("error in recv()"); + } } - ::recv(sockfd_, data, 4, 0); + if (::recv(sockfd_, data, 4, 0) < 0) { + throw std::runtime_error("error in recv()"); + } std::size_t length = data[0] | (data[1] << 8) | (data[2] << 16) | (data[3] << 24); std::string r_msg; if (length > 0) { r_msg.resize(length); - ::recv(sockfd_, r_msg.data(), length, 0); - + if (::recv(sockfd_, r_msg.data(), length, 0) < 0) { + throw std::runtime_error("error in recv()"); + } parse_response_result result; if (parse_response_header(r_msg, result)) { - if (do_check) { - EXPECT_EQ(type, result.payload_type_); - } message = result.payload_; - } else { - FAIL(); + type = result.payload_type_; } } else { r_msg.clear(); @@ -144,11 +150,19 @@ void stream_client::handshake() { endpoint_handshake_.set_allocated_wire_information(&wire_information); tateyama::proto::endpoint::request::Request endpoint_request{}; endpoint_request.set_allocated_handshake(&endpoint_handshake_); - send(tateyama::framework::service_id_endpoint_broker, endpoint_request.SerializeAsString()); + try { + send(tateyama::framework::service_id_endpoint_broker, endpoint_request.SerializeAsString()); + } catch (std::exception &ex) { + std::cout << ex.what() << std::endl; + } endpoint_request.release_handshake(); endpoint_handshake_.release_wire_information(); - receive(handshake_response_); + try { + receive(handshake_response_); + } catch (std::exception &ex) { + std::cout << ex.what() << std::endl; + } } diff --git a/test/tateyama/endpoint/stream/stream_client.h b/test/tateyama/endpoint/stream/stream_client.h index 7575a600..3ec6de81 100644 --- a/test/tateyama/endpoint/stream/stream_client.h +++ b/test/tateyama/endpoint/stream/stream_client.h @@ -56,11 +56,11 @@ class stream_client { ::close(sockfd_); } - void send(const std::uint8_t type, const std::uint16_t slot, std::string_view message); - void send(const std::size_t tag, std::string_view message); + bool send(const std::uint8_t type, const std::uint16_t slot, std::string_view message); + bool send(const std::size_t tag, std::string_view message); void receive(std::string &message); void receive() { receive(response_); } - void receive(std::string &message, tateyama::proto::framework::response::Header::PayloadType type); + void receive(std::string &message, tateyama::proto::framework::response::Header::PayloadType& type); void close() { ::close(sockfd_); } diff --git a/test/tateyama/endpoint/stream/stream_session_test.cpp b/test/tateyama/endpoint/stream/stream_session_test.cpp index f1fd2ab3..7c3881df 100644 --- a/test/tateyama/endpoint/stream/stream_session_test.cpp +++ b/test/tateyama/endpoint/stream/stream_session_test.cpp @@ -23,17 +23,18 @@ #include -static constexpr std::size_t my_session_id_ = 123; +static constexpr std::size_t my_session_id = 123; static constexpr std::string_view request_test_message = "abcdefgh"; static constexpr std::string_view response_test_message = "opqrstuvwxyz"; namespace tateyama::endpoint::stream { -class service_for_test : public tateyama::framework::routing_service { +class service_for_session_test : public tateyama::framework::routing_service { public: bool setup(tateyama::framework::environment&) { return true; } bool start(tateyama::framework::environment&) { return true; } bool shutdown(tateyama::framework::environment&) { return true; } + std::string_view label() const noexcept { return __func__; } id_type id() const noexcept { return 100; } // dummy bool operator ()(std::shared_ptr req, @@ -41,17 +42,28 @@ class service_for_test : public tateyama::framework::routing_service { req_ = req; res_ = res; // do not respond to the request message in this test + requested_ = true; return true; } + bool requested() const { + return requested_; + } + private: std::shared_ptr req_{}; std::shared_ptr res_{}; + bool requested_{}; }; -class stream_listener_for_test { +class stream_listener_for_session_test { public: - stream_listener_for_test(service_for_test& service) : service_(service) { + stream_listener_for_session_test(service_for_session_test& service, std::shared_ptr session_bridge) : + service_(service), + session_bridge_(session_bridge) { + } + ~stream_listener_for_session_test() { + connection_socket_.close(); } void operator()() { while (true) { @@ -59,7 +71,7 @@ class stream_listener_for_test { stream = connection_socket_.accept(); if (stream != nullptr) { - worker_ = std::make_unique(service_, my_session_id_, std::move(stream), database_info_, false); + worker_ = std::make_unique(service_, my_session_id, std::move(stream), database_info_, false, session_bridge_); worker_->invoke([&]{worker_->run();}); } else { // connect via pipe (request_terminate) break; @@ -75,10 +87,11 @@ class stream_listener_for_test { } private: - service_for_test& service_; + service_for_session_test& service_; + std::shared_ptr session_bridge_; connection_socket connection_socket_{tateyama::api::endpoint::stream::stream_client::PORT_FOR_TEST}; std::unique_ptr worker_{}; - tateyama::status_info::resource::database_info_impl database_info_{"stream_session_test"}; + tateyama::status_info::resource::database_info_impl database_info_{"stream_info_test"}; }; } @@ -87,47 +100,102 @@ namespace tateyama::api::endpoint::stream { class stream_session_test : public ::testing::Test { virtual void SetUp() { - thread_ = std::thread(std::ref(listener_)); + session_bridge_ = std::make_shared(); + listener_ = std::make_unique(service_, session_bridge_); + thread_ = std::thread(std::ref(*listener_)); + client_ = std::make_unique(); + + std::this_thread::sleep_for(std::chrono::milliseconds(500)); } virtual void TearDown() { + // terminate session + client_->close(); + + // terminate listener + listener_->wait_worker_termination(); + listener_->terminate(); + thread_.join(); } public: - tateyama::endpoint::stream::service_for_test service_{}; - tateyama::endpoint::stream::stream_listener_for_test listener_{service_}; + std::shared_ptr session_bridge_{}; + std::unique_ptr listener_{}; + tateyama::endpoint::stream::service_for_session_test service_{}; std::thread thread_{}; + std::unique_ptr client_{}; }; TEST_F(stream_session_test, cancel) { - std::this_thread::sleep_for(std::chrono::milliseconds(500)); - try { // client part (send request) - auto client = std::make_unique(); - client->send(0, request_test_message); // we do not care service_id nor request message here + EXPECT_TRUE(client_->send(0, request_test_message)); // we do not care service_id nor request message here // client part (send cancel) tateyama::proto::endpoint::request::Cancel cancel{}; tateyama::proto::endpoint::request::Request endpoint_request{}; endpoint_request.set_allocated_cancel(&cancel); - client->send(tateyama::framework::service_id_endpoint_broker, endpoint_request.SerializeAsString()); + EXPECT_TRUE(client_->send(tateyama::framework::service_id_endpoint_broker, endpoint_request.SerializeAsString())); endpoint_request.release_cancel(); // client part (receive) std::string res{}; - client->receive(res, tateyama::proto::framework::response::Header::SERVER_DIAGNOSTICS); + tateyama::proto::framework::response::Header::PayloadType type{}; + client_->receive(res, type); + EXPECT_EQ(type, tateyama::proto::framework::response::Header::SERVER_DIAGNOSTICS); + tateyama::proto::diagnostics::Record response{}; + if(!response.ParseFromString(res)) { + FAIL(); + } + EXPECT_EQ(response.code(), tateyama::proto::diagnostics::Code::OPERATION_CANCELED); + } catch (std::runtime_error &ex) { + std::cout << ex.what() << std::endl; + FAIL(); + } +} + +TEST_F(stream_session_test, shutdown_after_request) { + try { + // client part (send request) + EXPECT_TRUE(client_->send(0, std::string(request_test_message))); // we do not care service_id nor request message here + while (!service_.requested()) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + + // shutdown request + if (auto rv = session_bridge_->shutdown(std::string(":") + std::to_string(my_session_id), session::shutdown_request_type::forceful); rv) { + FAIL(); + } + + // client part (receive) + std::string res{}; + tateyama::proto::framework::response::Header::PayloadType type{}; + client_->receive(res, type); + EXPECT_EQ(type, tateyama::proto::framework::response::Header::SERVER_DIAGNOSTICS); tateyama::proto::diagnostics::Record response{}; if(!response.ParseFromString(res)) { FAIL(); } EXPECT_EQ(response.code(), tateyama::proto::diagnostics::Code::OPERATION_CANCELED); + } catch (std::runtime_error &ex) { + std::cout << ex.what() << std::endl; + FAIL(); + } +} - // terminate session and listener - client->close(); - listener_.wait_worker_termination(); - listener_.terminate(); +TEST_F(stream_session_test, shutdown_before_request) { + try { + // shutdown request + if (auto rv = session_bridge_->shutdown(std::string(":") + std::to_string(my_session_id), session::shutdown_request_type::forceful); rv) { + FAIL(); + } + + // ensure shutdown request has been processed by the worker + std::this_thread::sleep_for(std::chrono::milliseconds(2500)); + + // client part (send request) + EXPECT_FALSE(client_->send(0, std::string(request_test_message))); } catch (std::runtime_error &ex) { std::cout << ex.what() << std::endl; FAIL(); diff --git a/test/tsubakuro/common/wire/udf_wires.h b/test/tsubakuro/common/wire/udf_wires.h index 6d963412..ff06c9a7 100644 --- a/test/tsubakuro/common/wire/udf_wires.h +++ b/test/tsubakuro/common/wire/udf_wires.h @@ -88,15 +88,15 @@ class session_wire_container public: wire_container() = default; wire_container(unidirectional_message_wire* wire, char* bip_buffer) : wire_(wire), bip_buffer_(bip_buffer) {}; - message_header peep(bool wait = false) { - return wire_->peep(bip_buffer_, wait); + message_header peep() { + return wire_->peep(bip_buffer_); } void write(const signed char* from, std::size_t length, message_header::index_type index) { const char *ptr = reinterpret_cast(from); wire_->write(bip_buffer_, ptr, message_header(index, length)); } void disconnect() { - wire_->write(bip_buffer_, nullptr, message_header(message_header::not_use, 0)); + wire_->write(bip_buffer_, nullptr, message_header(message_header::null_request, 0)); } private: @@ -123,6 +123,9 @@ class session_wire_container void read(signed char* to) { wire_->read(reinterpret_cast(to), bip_buffer_); } + bool check_shutdown() { + return wire_->check_shutdown(); + } void close() { wire_->close(); } diff --git a/test/tsubakuro/common/wire/wire.h b/test/tsubakuro/common/wire/wire.h index 02bef9f3..22cc22d8 100644 --- a/test/tsubakuro/common/wire/wire.h +++ b/test/tsubakuro/common/wire/wire.h @@ -42,7 +42,7 @@ class message_header { public: using length_type = std::uint32_t; using index_type = std::uint16_t; - static constexpr index_type not_use = 0xffff; + static constexpr index_type null_request = 0xffff; static constexpr std::size_t size = sizeof(length_type) + sizeof(index_type); @@ -372,55 +372,79 @@ inline static std::int64_t n_cap(std::int64_t timeout) { // for request class unidirectional_message_wire : public simple_wire { + constexpr static std::size_t watch_interval = 2; public: unidirectional_message_wire(boost::interprocess::managed_shared_memory* managed_shm_ptr, std::size_t capacity) : simple_wire(managed_shm_ptr, capacity) {} /** - * @brief peep the current header. + * @brief wait a request message arives and peep the current header. + * @returnm the essage_header if request message has been received, + * otherwise, say timeout or termination requested, dummy request message whose length is 0 and index is message_header::null_request. */ - message_header peep(const char* base, bool wait_flag = false) { + message_header peep(const char* base) { while (true) { - if(stored() >= message_header::size || shutdown_requested_.load()) { - break; + if(stored() >= message_header::size) { + copy_header(base); + return header_received_; } - if (wait_flag) { - boost::interprocess::scoped_lock lock(m_mutex_); - wait_for_read_ = true; - std::atomic_thread_fence(std::memory_order_acq_rel); - c_empty_.wait(lock, [this](){ return (stored() >= message_header::size) || shutdown_requested_.load(); }); + if (termination_requested_.load() || onetime_notification_.load()) { + onetime_notification_.store(false); + return {message_header::null_request, 0}; + } + boost::interprocess::scoped_lock lock(m_mutex_); + wait_for_read_ = true; + std::atomic_thread_fence(std::memory_order_acq_rel); + if (!c_empty_.timed_wait(lock, + boost::get_system_time() + boost::posix_time::microseconds(u_cap(u_round(watch_interval * 1000 * 1000))), + [this](){ return (stored() >= message_header::size) || termination_requested_.load() || onetime_notification_.load(); })) { wait_for_read_ = false; - } else { - if (stored() < message_header::size) { return {}; } + header_received_ = message_header(message_header::null_request, 0); + return header_received_; } + wait_for_read_ = false; } - if (!shutdown_requested_.load()) { - copy_header(base); - } else { - header_received_ = message_header(message_header::not_use, 0); + } + + /** + * @brief wake up the worker immediately. + */ + void notify() { + onetime_notification_.store(true); + std::atomic_thread_fence(std::memory_order_acq_rel); + if (wait_for_read_) { + boost::interprocess::scoped_lock lock(m_mutex_); + c_empty_.notify_one(); } - return header_received_; } /** - * @brief wake up the worker thread waiting for request arrival, supposed to be used in server shutdown. + * @brief wake up the worker thread waiting for request arrival, supposed to be used in server termination. */ void terminate() { - shutdown_requested_.store(true); + termination_requested_.store(true); std::atomic_thread_fence(std::memory_order_acq_rel); if (wait_for_read_) { boost::interprocess::scoped_lock lock(m_mutex_); c_empty_.notify_one(); } } + /** + * @brief check if an termination request has been made + * @retrun true if terminate request has been made + */ + [[nodiscard]] bool terminate_requested() { + return termination_requested_.load(); + } private: - std::atomic_bool shutdown_requested_{}; + std::atomic_bool termination_requested_{}; + std::atomic_bool onetime_notification_{}; }; // for response class unidirectional_response_wire : public simple_wire { - constexpr static std::size_t watch_interval = 5; + constexpr static std::size_t watch_interval = 2; public: unidirectional_response_wire(boost::interprocess::managed_shared_memory* managed_shm_ptr, std::size_t capacity) : simple_wire(managed_shm_ptr, capacity) {} @@ -473,6 +497,9 @@ class unidirectional_response_wire : public simple_wire { [[nodiscard]] response_header::msg_type get_type() const { return header_received_.get_type(); } + void notify_shutdown() noexcept { + shutdown_.store(true); + } void close() { closed_.store(true); @@ -482,15 +509,19 @@ class unidirectional_response_wire : public simple_wire { c_empty_.notify_one(); } } + [[nodiscard]] bool check_shutdown() const noexcept { + return shutdown_.load(); + } private: std::atomic_bool closed_{}; + std::atomic_bool shutdown_{}; }; // for resultset class unidirectional_simple_wires { - constexpr static std::size_t watch_interval = 5; + constexpr static std::size_t watch_interval = 2; public: class unidirectional_simple_wire : public simple_wire { @@ -890,18 +921,22 @@ class status_provider { status_provider(boost::interprocess::managed_shared_memory* managed_shm_ptr, std::string_view file) : mutex_file_(file, managed_shm_ptr->get_segment_manager()) { } - [[nodiscard]] bool is_alive() { - int fd = open(mutex_file_.c_str(), O_WRONLY); // NOLINT + [[nodiscard]] std::string is_alive() { + int fd = open(mutex_file_.c_str(), O_RDONLY); // NOLINT if (fd < 0) { - return false; + std::stringstream ss{}; + ss << "cannot open the lock file (" << mutex_file_.c_str() << ")"; + return ss.str(); } if (flock(fd, LOCK_EX | LOCK_NB) == 0) { // NOLINT flock(fd, LOCK_UN); close(fd); - return false; + std::stringstream ss{}; + ss << "the lock file (" << mutex_file_.c_str() << ") is not locked, possibly due to server crash"; + return ss.str(); } close(fd); - return true; + return {}; } private: @@ -916,6 +951,7 @@ class connection_queue constexpr static const char* name = "connection_queue"; class index_queue { + constexpr static std::size_t watch_interval = 2; using long_allocator = boost::interprocess::allocator; public: @@ -946,10 +982,12 @@ class connection_queue } } } - void wait(std::atomic_bool& terminate) { + [[nodiscard]] bool wait(std::atomic_bool& terminate) { boost::interprocess::scoped_lock lock(mutex_); std::atomic_thread_fence(std::memory_order_acq_rel); - condition_.wait(lock, [this, &terminate](){ return (pushed_.load() > poped_.load()) || terminate.load(); }); + return condition_.timed_wait(lock, + boost::get_system_time() + boost::posix_time::microseconds(u_cap(u_round(watch_interval * 1000 * 1000))), + [this, &terminate](){ return (pushed_.load() > poped_.load()) || terminate.load(); }); } [[nodiscard]] std::size_t pop() { return queue_.at(index(poped_.fetch_add(1))); @@ -1058,10 +1096,11 @@ class connection_queue bool check(std::size_t rid) { return v_requested_.at(rid).check(); } - std::size_t listen() { - q_requested_.wait(terminate_); - return ++session_id_; + if (q_requested_.wait(terminate_)) { + return ++session_id_; + } + return 0; } std::size_t accept(std::size_t session_id) { std::size_t sid = q_requested_.pop();