Skip to content

Commit

Permalink
implement tateyama::api::server::response::check_cancel()
Browse files Browse the repository at this point in the history
  • Loading branch information
t-horikawa committed Feb 20, 2024
1 parent e118bd2 commit 8ae7f10
Show file tree
Hide file tree
Showing 14 changed files with 134 additions and 22 deletions.
19 changes: 18 additions & 1 deletion include/tateyama/api/server/response.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,17 @@ class response {
virtual void session_id(std::size_t id) = 0;

/**
* @brief report error with diagnostics information
* @brief report error with diagnostics information.
* @param record the diagnostic record to report
* @details report an error with diagnostics information for client. When this function is called, no more
* body_head() or body() is expected to be called.
* @attention this function is not thread-safe and should be called from single thread at a time.
* @attention After calling this for cancelling the current job, the job must not use the related resources.
* This includes the below:
*
* - request object
* - response object
* - resources underlying session context
*/
virtual void error(proto::diagnostics::Record const& record) = 0;

Expand Down Expand Up @@ -105,6 +111,17 @@ class response {
*/
virtual status release_channel(data_channel& ch) = 0;

/**
* @brief returns whether or not cancellation was requested to the corresponding job.
* @details To cancel the job, first you must shutdown the operation of this job, and then call error().
* At this time, `OPERATION_CANCELED` is recommended as the diagnostic code for cancelling the job.
* Or, to avoid inappropriate conditions, you can omit the cancel request and complete the job.
* @return true if the job is calling for a cancel
* @return false otherwise
* @see error()
*/
[[nodiscard]] virtual bool check_cancel() const = 0;

};

}
40 changes: 40 additions & 0 deletions src/tateyama/endpoint/common/response.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright 2018-2023 Project Tsurugi.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <memory>

#include <tateyama/api/server/response.h>

namespace tateyama::endpoint::common {
/**
* @brief response object for common_endpoint
*/
class response : public tateyama::api::server::response {
public:
[[nodiscard]] bool check_cancel() const override {
return cancel_response_ != nullptr;
}

void set_cancel(const std::shared_ptr<tateyama::api::server::response>& cancel_response) noexcept {
cancel_response_ = cancel_response;
}

private:
std::shared_ptr<tateyama::api::server::response> cancel_response_{};
};

} // tateyama::common::wire
31 changes: 29 additions & 2 deletions src/tateyama/endpoint/common/worker_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include <tateyama/proto/endpoint/response.pb.h>
#include <tateyama/proto/diagnostics.pb.h>

#include "response.h"
#include "tateyama/endpoint/common/session_info_impl.h"

namespace tateyama::endpoint::common {
Expand Down Expand Up @@ -179,7 +180,33 @@ class worker_common {
record.release_message();
}

void register_response(std::size_t slot, const std::shared_ptr<tateyama::api::server::response>& response) noexcept {
bool endpoint_service([[maybe_unused]] const std::shared_ptr<tateyama::api::server::request>& req,
const std::shared_ptr<tateyama::api::server::response>& res,
std::size_t slot) {
auto data = req->payload();
tateyama::proto::endpoint::request::Request rq{};
if(! rq.ParseFromArray(data.data(), static_cast<int>(data.size()))) {
std::string error_message{"request parse error"};
LOG(INFO) << error_message;
notify_client(res.get(), tateyama::proto::diagnostics::Code::INVALID_REQUEST, error_message);
return false;
}
if(rq.command_case() != tateyama::proto::endpoint::request::Request::kCancel) {
std::stringstream ss;
ss << "bad request (cancel in endpoint): " << rq.command_case();
LOG(INFO) << ss.str();
notify_client(res.get(), tateyama::proto::diagnostics::Code::INVALID_REQUEST, ss.str());
return false;
}
if (auto itr = responses_.find(slot); itr != responses_.end()) {
if (auto ptr = itr->second.lock(); ptr) {
ptr->set_cancel(res);
}
}
return true;
}

void register_response(std::size_t slot, const std::shared_ptr<tateyama::endpoint::common::response>& response) noexcept {
responses_.emplace(slot, response);
}
void remove_response(std::size_t slot) noexcept {
Expand All @@ -189,7 +216,7 @@ class worker_common {
private:
tateyama::session::session_variable_set session_variable_set_;
const std::shared_ptr<tateyama::session::resource::session_context_impl> session_context_;
std::map<std::size_t, std::weak_ptr<tateyama::api::server::response>> responses_{};
std::map<std::size_t, std::weak_ptr<tateyama::endpoint::common::response>> responses_{};

std::string_view connection_label(connection_type con) {
switch (con) {
Expand Down
19 changes: 14 additions & 5 deletions src/tateyama/endpoint/ipc/bootstrap/worker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,20 @@ void Worker::run()
auto request = std::make_shared<ipc_request>(*wire_, h, database_info_, session_info_);
std::size_t index = h.get_idx();
auto response = std::make_shared<ipc_response>(wire_, h.get_idx(), [this, index](){remove_response(index);});
register_response(index, response);
if (!service_(static_cast<std::shared_ptr<tateyama::api::server::request>>(request),
static_cast<std::shared_ptr<tateyama::api::server::response>>(std::move(response)))) {
VLOG_LP(log_info) << "terminate worker because service returns an error";
break;
if (request->service_id() != tateyama::framework::service_id_endpoint_broker) {
register_response(index, static_cast<std::shared_ptr<tateyama::endpoint::common::response>>(response));
if (!service_(static_cast<std::shared_ptr<tateyama::api::server::request>>(request),
static_cast<std::shared_ptr<tateyama::api::server::response>>(std::move(response)))) {
VLOG_LP(log_info) << "terminate worker because service returns an error";
break;
}
} else {
if (!endpoint_service(static_cast<std::shared_ptr<tateyama::api::server::request>>(request),
static_cast<std::shared_ptr<tateyama::api::server::response>>(std::move(response)),
index)) {
VLOG_LP(log_info) << "terminate worker because endpoint service returns an error";
break;
}
}
request->dispose();
request = nullptr;
Expand Down
5 changes: 2 additions & 3 deletions src/tateyama/endpoint/ipc/ipc_response.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@
#include <atomic>
#include <functional>

#include <tateyama/api/server/response.h>

#include "tateyama/endpoint/common/response.h"
#include "tateyama/endpoint/common/pointer_comp.h"
#include "server_wires.h"
#include "ipc_request.h"
Expand Down Expand Up @@ -78,7 +77,7 @@ class ipc_data_channel : public tateyama::api::server::data_channel {
/**
* @brief response object for ipc_endpoint
*/
class ipc_response : public tateyama::api::server::response {
class ipc_response : public tateyama::endpoint::common::response {
friend ipc_data_channel;

public:
Expand Down
4 changes: 2 additions & 2 deletions src/tateyama/endpoint/loopback/loopback_response.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
#include <map>
#include <mutex>

#include <tateyama/api/server/response.h>
#include <tateyama/endpoint/common/response.h>

#include "loopback_data_writer.h"
#include "loopback_data_channel.h"

namespace tateyama::endpoint::loopback {

class loopback_response: public tateyama::api::server::response {
class loopback_response: public tateyama::endpoint::common::response {
public:
/**
* @see tateyama::server::response::session_id()
Expand Down
18 changes: 14 additions & 4 deletions src/tateyama/endpoint/stream/bootstrap/stream_worker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,20 @@ void stream_worker::run()

auto request = std::make_shared<stream_request>(*session_stream_, payload, database_info_, session_info_);
auto response = std::make_shared<stream_response>(session_stream_, slot);
if(!service_(static_cast<std::shared_ptr<tateyama::api::server::request>>(request),
static_cast<std::shared_ptr<tateyama::api::server::response>>(std::move(response)))) {
VLOG_LP(log_info) << "terminate worker because service returns an error";
break;
if (request->service_id() != tateyama::framework::service_id_endpoint_broker) {
register_response(slot, static_cast<std::shared_ptr<tateyama::endpoint::common::response>>(response));
if(!service_(static_cast<std::shared_ptr<tateyama::api::server::request>>(request),
static_cast<std::shared_ptr<tateyama::api::server::response>>(std::move(response)))) {
VLOG_LP(log_info) << "terminate worker because service returns an error";
break;
}
} else {
if (!endpoint_service(static_cast<std::shared_ptr<tateyama::api::server::request>>(request),
static_cast<std::shared_ptr<tateyama::api::server::response>>(std::move(response)),
slot)) {
VLOG_LP(log_info) << "terminate worker because endpoint service returns an error";
break;
}
}
request = nullptr;
}
Expand Down
1 change: 0 additions & 1 deletion src/tateyama/endpoint/stream/stream_response.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ tateyama::status stream_response::release_channel(tateyama::api::server::data_ch
return tateyama::status::unknown;
}


// class stream_data_channel
tateyama::status stream_data_channel::acquire(std::shared_ptr<tateyama::api::server::writer>& wrt) {
try {
Expand Down
5 changes: 2 additions & 3 deletions src/tateyama/endpoint/stream/stream_response.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@
#include <condition_variable>
#include <atomic>

#include <tateyama/api/server/response.h>

#include <tateyama/endpoint/common/response.h>
#include "tateyama/endpoint/common/pointer_comp.h"
#include "stream.h"

Expand Down Expand Up @@ -74,7 +73,7 @@ class stream_data_channel : public tateyama::api::server::data_channel {
/**
* @brief response object for stream_endpoint
*/
class stream_response : public tateyama::api::server::response {
class stream_response : public tateyama::endpoint::common::response {
friend stream_data_channel;

public:
Expand Down
10 changes: 9 additions & 1 deletion src/tateyama/proto/endpoint/request.proto
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@ message Request {
oneof command {
// handshake operation.
Handshake handshake = 11;

// cancel operation.
Cancel cancel = 12;
}
reserved 12 to 99;
reserved 13 to 99;
}

// handshake operation.
Expand Down Expand Up @@ -73,3 +76,8 @@ message WireInformation {
uint64 maximum_concurrent_result_sets = 1;
}
}

// cancel operation.
message Cancel {
// no special properties.
}
1 change: 1 addition & 0 deletions test/tateyama/datastore/datastore_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ class datastore_test :
void error(proto::diagnostics::Record const& record) override {}
status acquire_channel(std::string_view name, std::shared_ptr<api::server::data_channel>& ch) override { return status::ok; }
status release_channel(api::server::data_channel& ch) override { return status::ok; }
bool check_cancel() const override { return false; }

std::size_t session_id_{};
std::string body_{};
Expand Down
1 change: 1 addition & 0 deletions test/tateyama/framework/router_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ class router_test :
}
status acquire_channel(std::string_view name, std::shared_ptr<api::server::data_channel>& ch) override { return status::ok; }
status release_channel(api::server::data_channel& ch) override { return status::ok; }
bool check_cancel() const override { return false; }

std::size_t session_id_{};
std::string body_{};
Expand Down
1 change: 1 addition & 0 deletions test/tateyama/session/session_gc_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class session_gc_test :
void error(proto::diagnostics::Record const& record) override {}
status acquire_channel(std::string_view name, std::shared_ptr<api::server::data_channel>& ch) override { return status::ok; }
status release_channel(api::server::data_channel& ch) override { return status::ok; }
bool check_cancel() const override { return false; }

std::size_t session_id_{};
std::string body_{};
Expand Down
1 change: 1 addition & 0 deletions test/tateyama/session/session_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class session_test :
void error(proto::diagnostics::Record const& record) override {}
status acquire_channel(std::string_view name, std::shared_ptr<api::server::data_channel>& ch) override { return status::ok; }
status release_channel(api::server::data_channel& ch) override { return status::ok; }
bool check_cancel() const override { return false; }

std::size_t session_id_{};
std::string body_{};
Expand Down

0 comments on commit 8ae7f10

Please sign in to comment.