Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multithreaded server #71

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions protocol/protocol.json
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,14 @@
}
],
"notifications": [
{
"method": "$/cancelRequest",
"messageDirection": "clientToServer",
"params": {
"kind": "reference",
"name": "CancelRequestParams"
}
},
{
"method": "closeDocument",
"messageDirection": "clientToServer",
Expand Down Expand Up @@ -688,6 +696,18 @@
}
]
},
{
"name": "CancelRequestParams",
"properties": [
{
"name": "id",
"type": {
"kind": "base",
"name": "string"
}
}
]
},
{
"name": "InitializeParams",
"properties": []
Expand Down
2 changes: 2 additions & 0 deletions snail/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ target_sources(common
path.cpp
string_compare.cpp
system.cpp
thread.cpp
thread_pool.cpp
trim.cpp
wildcard.cpp
ms_xca_decompression.cpp
Expand Down
24 changes: 24 additions & 0 deletions snail/common/thread.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@

#include <snail/common/thread.hpp>

#if defined(_WIN32)
# define NOMINMAX
# define WIN32_LEAN_AND_MEAN
# include <Windows.h>
# include <processthreadsapi.h>
# include <utf8/cpp17.h>
#else
# include <pthread.h>
#endif

void snail::common::set_thread_name(const std::string& name)
{
#if defined(_WIN32)
const auto name_u16 = utf8::utf8to16(name);
SetThreadDescription(GetCurrentThread(), reinterpret_cast<const wchar_t*>(name_u16.data()));
#elif defined(__linux__)
::pthread_setname_np(::pthread_self(), name.c_str());
#elif defined(__APPLE__)
::pthread_setname_np(name.c_str());
#endif
}
9 changes: 9 additions & 0 deletions snail/common/thread.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#pragma once

#include <string>

namespace snail::common {

void set_thread_name(const std::string& name);

} // namespace snail::common
83 changes: 83 additions & 0 deletions snail/common/thread_pool.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#include <snail/common/thread_pool.hpp>

#include <format>

#include <snail/common/thread.hpp>

using namespace snail::common;

thread_pool::thread_pool(std::size_t max_size) :
max_size_(max_size),
stopped_(false)
{}

thread_pool::~thread_pool()
{
stop();
}

void thread_pool::submit(task_type task)
{
{
std::unique_lock<std::mutex> lock(mutex_);

if(stopped_) return; // do not submit jobs to thread pool that was stopped

task_queue_.push(std::move(task));
}

condition_variable_.notify_one();

Check warning on line 29 in snail/common/thread_pool.cpp

View check run for this annotation

Codecov / codecov/patch

snail/common/thread_pool.cpp#L29

Added line #L29 was not covered by tests

if(workers_.size() < max_size_)
{
std::unique_lock<std::mutex> lock(mutex_);

// If the task has not been picked up, try to start a new thread.
if(!task_queue_.empty()) spawn_thread();
}
}

void thread_pool::stop()
{
{
std::unique_lock<std::mutex> lock(mutex_);
stopped_ = true;
max_size_ = 0;
}
condition_variable_.notify_all();
for(auto& worker : workers_)
{
worker.join();
}
workers_.clear();
}

void thread_pool::spawn_thread()
{
const auto id = workers_.size() + 1;

workers_.emplace_back(
[this, id]
{
set_thread_name(std::format("Worker {}", id));
while(true)
{
std::function<void()> task;

{
std::unique_lock<std::mutex> lock(mutex_);

// TODO: only wait for a given time and shut down this
// thread if we have not received any work.
condition_variable_.wait(lock, [this]
{ return stopped_ || !task_queue_.empty(); });
if(stopped_ && task_queue_.empty()) return;

task = std::move(task_queue_.front());
task_queue_.pop();
}

task();

Check warning on line 80 in snail/common/thread_pool.cpp

View check run for this annotation

Codecov / codecov/patch

snail/common/thread_pool.cpp#L80

Added line #L80 was not covered by tests
}
});
}
39 changes: 39 additions & 0 deletions snail/common/thread_pool.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#pragma once

#include <condition_variable>
#include <cstdint>
#include <functional>
#include <mutex>
#include <queue>
#include <thread>
#include <vector>

namespace snail::common {

class thread_pool
{
public:
using task_type = std::function<void()>;

explicit thread_pool(std::size_t max_size);
~thread_pool();

void submit(task_type task);

void stop();

private:
std::size_t max_size_;
std::vector<std::thread> workers_;

std::queue<task_type> task_queue_;

std::mutex mutex_;
std::condition_variable condition_variable_;

bool stopped_;

void spawn_thread();
};

} // namespace snail::common
6 changes: 3 additions & 3 deletions snail/jsonrpc/jsonrpc_v2_protocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

using namespace snail::jsonrpc;

request v2_protocol::load_request(std::string_view content)
request v2_protocol::load_request(std::string_view content) const
{
nlohmann::json data;
try
Expand Down Expand Up @@ -60,7 +60,7 @@ request v2_protocol::load_request(std::string_view content)
.id = std::move(id_data)};
}

std::string v2_protocol::dump_response(const jsonrpc::response& response)
std::string v2_protocol::dump_response(const jsonrpc::response& response) const
{
if(response.id)
{
Expand All @@ -70,7 +70,7 @@ std::string v2_protocol::dump_response(const jsonrpc::response& response)
return std::format(R"({{"jsonrpc":"2.0","result":{}}})", response.result.dump());
}

std::string v2_protocol::dump_error(const rpc_error& error, const nlohmann::json* id)
std::string v2_protocol::dump_error(const rpc_error& error, const nlohmann::json* id) const
{
const auto json_error = nlohmann::json{
{"code", error.code()},
Expand Down
6 changes: 3 additions & 3 deletions snail/jsonrpc/jsonrpc_v2_protocol.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ namespace snail::jsonrpc {
class v2_protocol : public protocol
{
public:
[[nodiscard]] virtual request load_request(std::string_view content) override;
[[nodiscard]] virtual request load_request(std::string_view content) const override;

[[nodiscard]] virtual std::string dump_response(const jsonrpc::response& response) override;
[[nodiscard]] virtual std::string dump_response(const jsonrpc::response& response) const override;

[[nodiscard]] virtual std::string dump_error(const rpc_error& error, const nlohmann::json* id) override;
[[nodiscard]] virtual std::string dump_error(const rpc_error& error, const nlohmann::json* id) const override;
};

} // namespace snail::jsonrpc
5 changes: 4 additions & 1 deletion snail/jsonrpc/message_handler.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include <functional>
#include <optional>
#include <string>

Expand All @@ -8,9 +9,11 @@ namespace snail::jsonrpc {
class message_handler
{
public:
using respond_callback = std::function<void(std::string)>;

virtual ~message_handler() = default;

virtual std::optional<std::string> handle(std::string_view data) = 0;
virtual void handle(std::string data, respond_callback respond) = 0;
};

} // namespace snail::jsonrpc
6 changes: 3 additions & 3 deletions snail/jsonrpc/protocol.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ class protocol
public:
virtual ~protocol() = default;

[[nodiscard]] virtual request load_request(std::string_view content) = 0;
[[nodiscard]] virtual request load_request(std::string_view content) const = 0;

[[nodiscard]] virtual std::string dump_response(const jsonrpc::response& response) = 0;
[[nodiscard]] virtual std::string dump_response(const jsonrpc::response& response) const = 0;

[[nodiscard]] virtual std::string dump_error(const rpc_error& error, const nlohmann::json* id) = 0;
[[nodiscard]] virtual std::string dump_error(const rpc_error& error, const nlohmann::json* id) const = 0;
};

} // namespace snail::jsonrpc
96 changes: 60 additions & 36 deletions snail/jsonrpc/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,59 +16,83 @@

server::~server() = default;

[[noreturn]] void server::serve_forever()
{
connection_->serve_forever(*this);
}

void server::serve_next()
{
connection_->serve_next(*this);
}

void server::register_method(std::string name, std::function<std::optional<nlohmann::json>(const nlohmann::json&)> handler)
{
methods_.emplace(std::move(name), std::move(handler));
}

std::optional<std::string> server::handle(std::string_view data)
void server::handle(std::string data, respond_callback respond)
{
const nlohmann::json* error_id = nullptr;
jsonrpc::request request;
try
{
auto request = protocol_->load_request(data);

if(request.id.has_value())
{
error_id = &request.id.value();
}

auto response = handle_request(request);

if(!response) return std::nullopt;

return protocol_->dump_response(*response);
request = protocol_->load_request(data);
}
catch(rpc_error& e)
{
return protocol_->dump_error(e, error_id);
respond(protocol_->dump_error(e, nullptr));
return;

Check warning on line 34 in snail/jsonrpc/server.cpp

View check run for this annotation

Codecov / codecov/patch

snail/jsonrpc/server.cpp#L33-L34

Added lines #L33 - L34 were not covered by tests
}
catch(std::exception& e)
{
return protocol_->dump_error(internal_error(e.what()), error_id);
respond(protocol_->dump_error(internal_error(e.what()), nullptr));
return;

Check warning on line 39 in snail/jsonrpc/server.cpp

View check run for this annotation

Codecov / codecov/patch

snail/jsonrpc/server.cpp#L38-L39

Added lines #L38 - L39 were not covered by tests
}
}

std::optional<response> server::handle_request(const jsonrpc::request& request)
{
const auto& method = methods_.find(request.method);
if(method == methods_.end()) throw unknown_method_error(std::format("Unknown method: '{}'", request.method.c_str()).c_str());
if(request.id)
{
const auto& handler = request_handlers_.find(request.method);
if(handler == request_handlers_.end())
{
respond(protocol_->dump_error(unknown_method_error(std::format("Unknown request method: '{}'", request.method.c_str()).c_str()), &(*request.id)));
return;
}

auto respond_wrapper = [this, respond = respond, id = *request.id](nlohmann::json result)
{
auto response = jsonrpc::response{
.result = std::move(result),
.id = id};

auto response_str = protocol_->dump_response(response);

auto result = method->second(request.params);
{
std::lock_guard<std::mutex> response_guard(response_mutex_);
respond(std::move(response_str));
}
};

if(!result) return std::nullopt;
auto error_wrapper = [this, respond = respond, id = *request.id](std::string message)
{
auto error_str = protocol_->dump_error(internal_error(message.c_str()), &id);

{
std::lock_guard<std::mutex> response_guard(response_mutex_);
respond(std::move(error_str));
}
};

return response{
.result = std::move(*result),
.id = request.id};
handler->second(request.params, respond_wrapper, error_wrapper);
}
else
{
const auto& handler = notification_handlers_.find(request.method);
if(handler == notification_handlers_.end())
{
respond(protocol_->dump_error(unknown_method_error(std::format("Unknown notification method: '{}'", request.method.c_str()).c_str()), nullptr));
return;
}

auto error_wrapper = [this, respond = std::move(respond)](std::string message)
{
auto error_str = protocol_->dump_error(internal_error(message.c_str()), nullptr);

Check warning on line 88 in snail/jsonrpc/server.cpp

View check run for this annotation

Codecov / codecov/patch

snail/jsonrpc/server.cpp#L88

Added line #L88 was not covered by tests

{
std::lock_guard<std::mutex> response_guard(response_mutex_);
respond(std::move(error_str));

Check warning on line 92 in snail/jsonrpc/server.cpp

View check run for this annotation

Codecov / codecov/patch

snail/jsonrpc/server.cpp#L91-L92

Added lines #L91 - L92 were not covered by tests
}
};

Check warning on line 94 in snail/jsonrpc/server.cpp

View check run for this annotation

Codecov / codecov/patch

snail/jsonrpc/server.cpp#L94

Added line #L94 was not covered by tests

handler->second(request.params, error_wrapper);
}
}
Loading
Loading