Skip to content

Commit

Permalink
Add support for calling server-advertised services
Browse files Browse the repository at this point in the history
  • Loading branch information
achim-k committed Jan 24, 2023
1 parent 52767df commit fc7123f
Show file tree
Hide file tree
Showing 20 changed files with 839 additions and 37 deletions.
11 changes: 9 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ if("$ENV{ROS_VERSION}" STREQUAL "1")
add_library(foxglove_bridge_nodelet
ros1_foxglove_bridge/src/ros1_foxglove_bridge_nodelet.cpp
ros1_foxglove_bridge/src/param_utils.cpp
ros1_foxglove_bridge/src/service_utils.cpp
)
target_include_directories(foxglove_bridge_nodelet
SYSTEM PRIVATE
Expand Down Expand Up @@ -167,7 +168,7 @@ if(ROS_BUILD_TYPE STREQUAL "catkin")
if (CATKIN_ENABLE_TESTING)
message(STATUS "Building tests with catkin")

find_package(catkin REQUIRED COMPONENTS roscpp std_msgs)
find_package(catkin REQUIRED COMPONENTS roscpp std_msgs std_srvs)
if(NOT "$ENV{ROS_DISTRO}" STREQUAL "melodic")
find_package(GTest REQUIRED)
endif()
Expand All @@ -176,6 +177,9 @@ if(ROS_BUILD_TYPE STREQUAL "catkin")
catkin_add_gtest(version_test foxglove_bridge_base/tests/version_test.cpp)
target_link_libraries(version_test foxglove_bridge_base)

catkin_add_gtest(serialization_test foxglove_bridge_base/tests/serialization_test.cpp)
target_link_libraries(serialization_test foxglove_bridge_base)

add_rostest_gtest(smoke_test ros1_foxglove_bridge/tests/smoke.test ros1_foxglove_bridge/tests/smoke_test.cpp)
target_include_directories(smoke_test SYSTEM PRIVATE
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/foxglove_bridge_base/include>
Expand All @@ -195,8 +199,11 @@ elseif(ROS_BUILD_TYPE STREQUAL "ament_cmake")
ament_add_gtest(version_test foxglove_bridge_base/tests/version_test.cpp)
target_link_libraries(version_test foxglove_bridge_base)

ament_add_gtest(serialization_test foxglove_bridge_base/tests/serialization_test.cpp)
target_link_libraries(serialization_test foxglove_bridge_base)

ament_add_gtest(smoke_test ros2_foxglove_bridge/tests/smoke_test.cpp)
ament_target_dependencies(smoke_test rclcpp rclcpp_components std_msgs)
ament_target_dependencies(smoke_test rclcpp rclcpp_components std_msgs std_srvs)
target_link_libraries(smoke_test foxglove_bridge_base)
endif()
endif()
Expand Down
40 changes: 40 additions & 0 deletions foxglove_bridge_base/include/foxglove_bridge/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,22 @@ constexpr char CAPABILITY_CLIENT_PUBLISH[] = "clientPublish";
constexpr char CAPABILITY_TIME[] = "time";
constexpr char CAPABILITY_PARAMETERS[] = "parameters";
constexpr char CAPABILITY_PARAMETERS_SUBSCRIBE[] = "parametersSubscribe";
constexpr char CAPABILITY_SERVICES[] = "services";

using ChannelId = uint32_t;
using ClientChannelId = uint32_t;
using SubscriptionId = uint32_t;
using ServiceId = uint32_t;

enum class BinaryOpcode : uint8_t {
MESSAGE_DATA = 1,
TIME_DATA = 2,
SERVICE_CALL_RESPONSE = 3,
};

enum class ClientBinaryOpcode : uint8_t {
MESSAGE_DATA = 1,
SERVICE_CALL_REQUEST = 2,
};

struct ClientAdvertisement {
Expand All @@ -33,4 +37,40 @@ struct ClientAdvertisement {
std::vector<uint8_t> schema;
};

struct ServiceWithoutId {
std::string name;
std::string type;
std::string requestSchema;
std::string responseSchema;
};

struct Service : ServiceWithoutId {
ServiceId id;

Service() = default;
Service(const ServiceWithoutId& s, const ServiceId& id)
: ServiceWithoutId(s)
, id(id) {}
};

struct ServiceResponse {
ServiceId serviceId;
uint32_t callId;
std::string encoding;
std::vector<uint8_t> data;

size_t size() const {
return 4 + 4 + 4 + encoding.size() + data.size();
}
void read(const uint8_t* data, size_t size);
void write(uint8_t* data) const;

bool operator==(const ServiceResponse& other) const {
return serviceId == other.serviceId && callId == other.callId && encoding == other.encoding &&
data == other.data;
}
};

using ServiceRequest = ServiceResponse;

} // namespace foxglove
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include <nlohmann/json.hpp>

#include "common.hpp"
#include "parameter.hpp"

namespace foxglove {
Expand Down Expand Up @@ -36,5 +37,7 @@ inline void WriteUint32LE(uint8_t* buf, uint32_t val) {

void to_json(nlohmann::json& j, const Parameter& p);
void from_json(const nlohmann::json& j, Parameter& p);
void to_json(nlohmann::json& j, const Service& p);
void from_json(const nlohmann::json& j, Service& p);

} // namespace foxglove
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ std::vector<uint8_t> connectClientAndReceiveMsg(const std::string& uri,
std::future<std::vector<Parameter>> waitForParameters(std::shared_ptr<ClientInterface> client,
const std::string& requestId = std::string());

std::future<ServiceResponse> waitForServiceResponse(std::shared_ptr<ClientInterface> client);

std::future<Service> waitForService(std::shared_ptr<ClientInterface> client,
const std::string& serviceName);

extern template class Client<websocketpp::config::asio_client>;

} // namespace foxglove
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class ClientInterface {
virtual void advertise(const std::vector<ClientAdvertisement>& channels) = 0;
virtual void unadvertise(const std::vector<ClientChannelId>& channelIds) = 0;
virtual void publish(ClientChannelId channelId, const uint8_t* buffer, size_t size) = 0;
virtual void sendServiceRequest(const ServiceRequest& request) = 0;
virtual void getParameters(const std::vector<std::string>& parameterNames,
const std::optional<std::string>& requestId) = 0;
virtual void setParameters(const std::vector<Parameter>& parameters,
Expand Down Expand Up @@ -183,6 +184,13 @@ class Client : public ClientInterface {
sendBinary(payload.data(), payload.size());
}

void sendServiceRequest(const ServiceRequest& request) override {
std::vector<uint8_t> payload(1 + request.size());
payload[0] = uint8_t(ClientBinaryOpcode::SERVICE_CALL_REQUEST);
request.write(payload.data() + 1);
sendBinary(payload.data(), payload.size());
}

void getParameters(const std::vector<std::string>& parameterNames,
const std::optional<std::string>& requestId = std::nullopt) override {
nlohmann::json jsonPayload{{"op", "getParameters"}, {"parameterNames", parameterNames}};
Expand Down
120 changes: 109 additions & 11 deletions foxglove_bridge_base/include/foxglove_bridge/websocket_server.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ using ParameterChangeHandler =
std::function<void(const std::vector<Parameter>&, const std::optional<std::string>&, ConnHandle)>;
using ParameterSubscriptionHandler =
std::function<void(const std::vector<std::string>&, ParameterSubscriptionOperation, ConnHandle)>;
using ServiceRequestHandler = std::function<void(const ServiceRequest&, ConnHandle)>;

class ServerInterface {
public:
Expand All @@ -161,6 +162,8 @@ class ServerInterface {
const std::vector<Parameter>& parameters,
const std::optional<std::string>& requestId) = 0;
virtual void updateParameterValues(const std::vector<Parameter>& parameters) = 0;
virtual std::vector<ServiceId> addServices(const std::vector<ServiceWithoutId>& services) = 0;
virtual void removeServices(const std::vector<ServiceId>& serviceIds) = 0;

virtual void setSubscribeHandler(SubscribeUnsubscribeHandler handler) = 0;
virtual void setUnsubscribeHandler(SubscribeUnsubscribeHandler handler) = 0;
Expand All @@ -170,10 +173,12 @@ class ServerInterface {
virtual void setParameterRequestHandler(ParameterRequestHandler handler) = 0;
virtual void setParameterChangeHandler(ParameterChangeHandler handler) = 0;
virtual void setParameterSubscriptionHandler(ParameterSubscriptionHandler handler) = 0;
virtual void setServiceRequestHandler(ServiceRequestHandler handler) = 0;

virtual void sendMessage(ConnHandle clientHandle, ChannelId chanId, uint64_t timestamp,
std::string_view data) = 0;
virtual void broadcastTime(uint64_t timestamp) = 0;
virtual void sendServiceResponse(ConnHandle clientHandle, const ServiceResponse& response) = 0;

virtual std::optional<Tcp::endpoint> localEndpoint() = 0;
virtual std::string remoteEndpointString(ConnHandle clientHandle) = 0;
Expand Down Expand Up @@ -213,6 +218,8 @@ class Server final : public ServerInterface {
void publishParameterValues(ConnHandle clientHandle, const std::vector<Parameter>& parameters,
const std::optional<std::string>& requestId = std::nullopt) override;
void updateParameterValues(const std::vector<Parameter>& parameters) override;
std::vector<ServiceId> addServices(const std::vector<ServiceWithoutId>& services) override;
void removeServices(const std::vector<ServiceId>& serviceIds) override;

void setSubscribeHandler(SubscribeUnsubscribeHandler handler) override;
void setUnsubscribeHandler(SubscribeUnsubscribeHandler handler) override;
Expand All @@ -222,10 +229,12 @@ class Server final : public ServerInterface {
void setParameterRequestHandler(ParameterRequestHandler handler) override;
void setParameterChangeHandler(ParameterChangeHandler handler) override;
void setParameterSubscriptionHandler(ParameterSubscriptionHandler handler) override;
void setServiceRequestHandler(ServiceRequestHandler handler) override;

void sendMessage(ConnHandle clientHandle, ChannelId chanId, uint64_t timestamp,
std::string_view data) override;
void broadcastTime(uint64_t timestamp) override;
void sendServiceResponse(ConnHandle clientHandle, const ServiceResponse& response) override;

std::optional<Tcp::endpoint> localEndpoint() override;
std::string remoteEndpointString(ConnHandle clientHandle) override;
Expand Down Expand Up @@ -262,6 +271,8 @@ class Server final : public ServerInterface {
_clientChannels;
std::map<ConnHandle, std::unordered_set<std::string>, std::owner_less<>>
_clientParamSubscriptions;
ServiceId _nextServiceId = 0;
std::unordered_map<ServiceId, ServiceWithoutId> _services;
SubscribeUnsubscribeHandler _subscribeHandler;
SubscribeUnsubscribeHandler _unsubscribeHandler;
ClientAdvertiseHandler _clientAdvertiseHandler;
Expand All @@ -270,6 +281,7 @@ class Server final : public ServerInterface {
ParameterRequestHandler _parameterRequestHandler;
ParameterChangeHandler _parameterChangeHandler;
ParameterSubscriptionHandler _parameterSubscriptionHandler;
ServiceRequestHandler _serviceRequestHandler;
std::shared_mutex _clientsChannelMutex;
std::mutex _clientParamSubscriptionsMutex;

Expand Down Expand Up @@ -375,7 +387,7 @@ inline void Server<ServerConfiguration>::handleConnectionOpened(ConnHandle hdl)
})
.dump());

json channels;
std::vector<Channel> channels;
for (const auto& [id, channel] : _channels) {
(void)id;
channels.push_back(channel);
Expand All @@ -384,6 +396,15 @@ inline void Server<ServerConfiguration>::handleConnectionOpened(ConnHandle hdl)
{"op", "advertise"},
{"channels", std::move(channels)},
});

std::vector<Service> services;
for (const auto& [id, service] : _services) {
services.push_back(Service(service, id));
}
sendJson(hdl, {
{"op", "advertiseServices"},
{"services", std::move(services)},
});
}

template <typename ServerConfiguration>
Expand Down Expand Up @@ -483,6 +504,11 @@ inline void Server<ServerConfiguration>::setParameterSubscriptionHandler(
_parameterSubscriptionHandler = std::move(handler);
}

template <typename ServerConfiguration>
inline void Server<ServerConfiguration>::setServiceRequestHandler(ServiceRequestHandler handler) {
_serviceRequestHandler = std::move(handler);
}

template <typename ServerConfiguration>
inline void Server<ServerConfiguration>::stop() {
if (_server.stopped()) {
Expand Down Expand Up @@ -883,16 +909,6 @@ inline void Server<ServerConfiguration>::handleBinaryMessage(ConnHandle hdl, con
return;
}

std::unique_lock<std::shared_mutex> lock(_clientsChannelMutex);

auto clientPublicationsIt = _clientChannels.find(hdl);
if (clientPublicationsIt == _clientChannels.end()) {
sendStatus(hdl, StatusLevel::Error, "Client has no advertised channels");
return;
}

auto& clientPublications = clientPublicationsIt->second;

const auto op = static_cast<ClientBinaryOpcode>(msg[0]);
switch (op) {
case ClientBinaryOpcode::MESSAGE_DATA: {
Expand All @@ -901,6 +917,15 @@ inline void Server<ServerConfiguration>::handleBinaryMessage(ConnHandle hdl, con
return;
}
const ClientChannelId channelId = *reinterpret_cast<const ClientChannelId*>(msg + 1);
std::unique_lock<std::shared_mutex> lock(_clientsChannelMutex);

auto clientPublicationsIt = _clientChannels.find(hdl);
if (clientPublicationsIt == _clientChannels.end()) {
sendStatus(hdl, StatusLevel::Error, "Client has no advertised channels");
return;
}

auto& clientPublications = clientPublicationsIt->second;
const auto& channelIt = clientPublications.find(channelId);
if (channelIt == clientPublications.end()) {
sendStatus(hdl, StatusLevel::Error,
Expand All @@ -920,6 +945,25 @@ inline void Server<ServerConfiguration>::handleBinaryMessage(ConnHandle hdl, con
_clientMessageHandler(clientMessage, hdl);
}
} break;
case ClientBinaryOpcode::SERVICE_CALL_REQUEST: {
ServiceRequest request;
if (length < request.size()) {
sendStatus(hdl, StatusLevel::Error,
"Invalid service call request length " + std::to_string(length));
return;
}

request.read(msg + 1, length - 1);
if (_services.find(request.serviceId) == _services.end()) {
sendStatus(hdl, StatusLevel::Error,
"Service " + std::to_string(request.serviceId) + " is not advertised");
return;
}

if (_serviceRequestHandler) {
_serviceRequestHandler(request, hdl);
}
} break;
default: {
sendStatus(hdl, StatusLevel::Error,
"Unrecognized client opcode " + std::to_string(uint8_t(op)));
Expand Down Expand Up @@ -1001,6 +1045,51 @@ inline void Server<ServerConfiguration>::updateParameterValues(
}
}

template <typename ServerConfiguration>
inline std::vector<ServiceId> Server<ServerConfiguration>::addServices(
const std::vector<ServiceWithoutId>& services) {
if (services.empty()) {
return {};
}

std::vector<ServiceId> serviceIds;
json newServices;
for (const auto& service : services) {
const ServiceId serviceId = ++_nextServiceId;
_services.emplace(serviceId, service);
serviceIds.push_back(serviceId);
newServices.push_back(Service(service, serviceId));
}

const auto msg = json{{"op", "advertiseServices"}, {"services", std::move(newServices)}}.dump();
for (const auto& [hdl, clientInfo] : _clients) {
(void)clientInfo;
sendJsonRaw(hdl, msg);
}

return serviceIds;
}

template <typename ServerConfiguration>
inline void Server<ServerConfiguration>::removeServices(const std::vector<ServiceId>& serviceIds) {
std::vector<ServiceId> removedServices;
for (const auto& serviceId : serviceIds) {
if (const auto it = _services.find(serviceId); it != _services.end()) {
_services.erase(it);
removedServices.push_back(serviceId);
}
}

if (!removedServices.empty()) {
const auto msg =
json{{"op", "unadvertiseServices"}, {"serviceIds", std::move(removedServices)}}.dump();
for (const auto& [hdl, clientInfo] : _clients) {
(void)clientInfo;
sendJsonRaw(hdl, msg);
}
}
}

template <typename ServerConfiguration>
inline void Server<ServerConfiguration>::sendMessage(ConnHandle clientHandle, ChannelId chanId,
uint64_t timestamp, std::string_view data) {
Expand Down Expand Up @@ -1059,6 +1148,15 @@ inline void Server<ServerConfiguration>::broadcastTime(uint64_t timestamp) {
}
}

template <typename ServerConfiguration>
inline void Server<ServerConfiguration>::sendServiceResponse(ConnHandle clientHandle,
const ServiceResponse& response) {
std::vector<uint8_t> payload(1 + response.size());
payload[0] = uint8_t(BinaryOpcode::SERVICE_CALL_RESPONSE);
response.write(payload.data() + 1);
sendBinary(clientHandle, payload);
}

template <typename ServerConfiguration>
inline std::optional<asio::ip::tcp::endpoint> Server<ServerConfiguration>::localEndpoint() {
std::error_code ec;
Expand Down
Loading

0 comments on commit fc7123f

Please sign in to comment.