Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement S/G IO for batched sends and eliminate another frame copy #2874

Merged
merged 4 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions src/crypto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@
* The resulting ciphertext and the GCM tag are written into the tagged_cipher buffer.
*/
int
gcm_t::encrypt(const std::string_view &plaintext, std::uint8_t *tagged_cipher, aes_t *iv) {
gcm_t::encrypt(const std::string_view &plaintext, std::uint8_t *tag, std::uint8_t *ciphertext, aes_t *iv) {

Check warning on line 188 in src/crypto.cpp

View check run for this annotation

Codecov / codecov/patch

src/crypto.cpp#L188

Added line #L188 was not covered by tests
if (!encrypt_ctx && init_encrypt_gcm(encrypt_ctx, &key, iv, padding)) {
return -1;
}
Expand All @@ -196,18 +196,15 @@
return -1;
}

auto tag = tagged_cipher;
auto cipher = tag + tag_size;

int update_outlen, final_outlen;

// Encrypt into the caller's buffer
if (EVP_EncryptUpdate(encrypt_ctx.get(), cipher, &update_outlen, (const std::uint8_t *) plaintext.data(), plaintext.size()) != 1) {
if (EVP_EncryptUpdate(encrypt_ctx.get(), ciphertext, &update_outlen, (const std::uint8_t *) plaintext.data(), plaintext.size()) != 1) {
return -1;
}

// GCM encryption won't ever fill ciphertext here but we have to call it anyway
if (EVP_EncryptFinal_ex(encrypt_ctx.get(), cipher + update_outlen, &final_outlen) != 1) {
if (EVP_EncryptFinal_ex(encrypt_ctx.get(), ciphertext + update_outlen, &final_outlen) != 1) {
return -1;
}

Expand All @@ -218,6 +215,12 @@
return update_outlen + final_outlen;
}

int
gcm_t::encrypt(const std::string_view &plaintext, std::uint8_t *tagged_cipher, aes_t *iv) {

Check warning on line 219 in src/crypto.cpp

View check run for this annotation

Codecov / codecov/patch

src/crypto.cpp#L219

Added line #L219 was not covered by tests
// This overload handles the common case of [GCM tag][cipher text] buffer layout
return encrypt(plaintext, tagged_cipher, tagged_cipher + tag_size, iv);

Check warning on line 221 in src/crypto.cpp

View check run for this annotation

Codecov / codecov/patch

src/crypto.cpp#L221

Added line #L221 was not covered by tests
}

int
ecb_t::decrypt(const std::string_view &cipher, std::vector<std::uint8_t> &plaintext) {
auto fg = util::fail_guard([this]() {
Expand Down
11 changes: 11 additions & 0 deletions src/crypto.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,17 @@ namespace crypto {

gcm_t(const crypto::aes_t &key, bool padding = true);

/**
* @brief Encrypts the plaintext using AES GCM mode.
* @param plaintext The plaintext data to be encrypted.
* @param tag The buffer where the GCM tag will be written.
* @param ciphertext The buffer where the resulting ciphertext will be written.
* @param iv The initialization vector to be used for the encryption.
* @return The total length of the ciphertext and GCM tag. Returns -1 in case of an error.
*/
int
encrypt(const std::string_view &plaintext, std::uint8_t *tag, std::uint8_t *ciphertext, aes_t *iv);

/**
* @brief Encrypts the plaintext using AES GCM mode.
* length of cipher must be at least: round_to_pkcs7_padded(plaintext.size()) + crypto::cipher::tag_size
Expand Down
49 changes: 47 additions & 2 deletions src/platform/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -606,15 +606,60 @@
void
restart();

struct batched_send_info_t {
struct buffer_descriptor_t {
const char *buffer;
size_t block_size;
size_t size;

// Constructors required for emplace_back() prior to C++20
buffer_descriptor_t(const char *buffer, size_t size):
buffer(buffer), size(size) {}
buffer_descriptor_t():
buffer(nullptr), size(0) {}

Check warning on line 617 in src/platform/common.h

View check run for this annotation

Codecov / codecov/patch

src/platform/common.h#L614-L617

Added lines #L614 - L617 were not covered by tests
};

struct batched_send_info_t {
// Optional headers to be prepended to each packet
const char *headers;
size_t header_size;

// One or more data buffers to use for the payloads
//
// NB: Data buffers must be aligned to payload size!
std::vector<buffer_descriptor_t> &payload_buffers;
size_t payload_size;

// The offset (in header+payload message blocks) in the header and payload
// buffers to begin sending messages from
size_t block_offset;

// The number of header+payload message blocks to send
size_t block_count;

std::uintptr_t native_socket;
boost::asio::ip::address &target_address;
uint16_t target_port;
boost::asio::ip::address &source_address;

/**
* @brief Returns a payload buffer descriptor for the given payload offset.
* @param offset The offset in the total payload data (bytes).
* @return Buffer descriptor describing the region at the given offset.
*/
buffer_descriptor_t
buffer_for_payload_offset(ptrdiff_t offset) {

Check warning on line 649 in src/platform/common.h

View check run for this annotation

Codecov / codecov/patch

src/platform/common.h#L649

Added line #L649 was not covered by tests
for (const auto &desc : payload_buffers) {
if (offset < desc.size) {
return {
desc.buffer + offset,

Check warning on line 653 in src/platform/common.h

View check run for this annotation

Codecov / codecov/patch

src/platform/common.h#L652-L653

Added lines #L652 - L653 were not covered by tests
desc.size - offset,
};

Check warning on line 655 in src/platform/common.h

View check run for this annotation

Codecov / codecov/patch

src/platform/common.h#L655

Added line #L655 was not covered by tests
}
else {
offset -= desc.size;

Check warning on line 658 in src/platform/common.h

View check run for this annotation

Codecov / codecov/patch

src/platform/common.h#L658

Added line #L658 was not covered by tests
}
}
return {};
}
};
bool
send_batch(batched_send_info_t &send_info);
Expand Down
70 changes: 52 additions & 18 deletions src/platform/linux/misc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -433,30 +433,56 @@
memcpy(CMSG_DATA(pktinfo_cm), &pktInfo, sizeof(pktInfo));
}

auto const max_iovs_per_msg = send_info.payload_buffers.size() + (send_info.headers ? 1 : 0);

#ifdef UDP_SEGMENT
{
struct iovec iov = {};

msg.msg_iov = &iov;
msg.msg_iovlen = 1;

// UDP GSO on Linux currently only supports sending 64K or 64 segments at a time
size_t seg_index = 0;
const size_t seg_max = 65536 / 1500;
struct iovec iovs[(send_info.headers ? std::min(seg_max, send_info.block_count) : 1) * max_iovs_per_msg] = {};
auto msg_size = send_info.header_size + send_info.payload_size;

Check warning on line 444 in src/platform/linux/misc.cpp

View check run for this annotation

Codecov / codecov/patch

src/platform/linux/misc.cpp#L444

Added line #L444 was not covered by tests
while (seg_index < send_info.block_count) {
iov.iov_base = (void *) &send_info.buffer[seg_index * send_info.block_size];
iov.iov_len = send_info.block_size * std::min(send_info.block_count - seg_index, seg_max);
int iovlen = 0;

Check warning on line 446 in src/platform/linux/misc.cpp

View check run for this annotation

Codecov / codecov/patch

src/platform/linux/misc.cpp#L446

Added line #L446 was not covered by tests
auto segs_in_batch = std::min(send_info.block_count - seg_index, seg_max);
if (send_info.headers) {
// Interleave iovs for headers and payloads
for (auto i = 0; i < segs_in_batch; i++) {
iovs[iovlen].iov_base = (void *) &send_info.headers[(send_info.block_offset + seg_index + i) * send_info.header_size];
iovs[iovlen].iov_len = send_info.header_size;
iovlen++;
auto payload_desc = send_info.buffer_for_payload_offset((send_info.block_offset + seg_index + i) * send_info.payload_size);
iovs[iovlen].iov_base = (void *) payload_desc.buffer;
iovs[iovlen].iov_len = send_info.payload_size;
iovlen++;

Check warning on line 457 in src/platform/linux/misc.cpp

View check run for this annotation

Codecov / codecov/patch

src/platform/linux/misc.cpp#L451-L457

Added lines #L451 - L457 were not covered by tests
}
}
else {
// Translate buffer descriptors into iovs
auto payload_offset = (send_info.block_offset + seg_index) * send_info.payload_size;
auto payload_length = payload_offset + (segs_in_batch * send_info.payload_size);

Check warning on line 463 in src/platform/linux/misc.cpp

View check run for this annotation

Codecov / codecov/patch

src/platform/linux/misc.cpp#L462-L463

Added lines #L462 - L463 were not covered by tests
while (payload_offset < payload_length) {
auto payload_desc = send_info.buffer_for_payload_offset(payload_offset);
iovs[iovlen].iov_base = (void *) payload_desc.buffer;

Check warning on line 466 in src/platform/linux/misc.cpp

View check run for this annotation

Codecov / codecov/patch

src/platform/linux/misc.cpp#L465-L466

Added lines #L465 - L466 were not covered by tests
iovs[iovlen].iov_len = std::min(payload_desc.size, payload_length - payload_offset);
payload_offset += iovs[iovlen].iov_len;
iovlen++;

Check warning on line 469 in src/platform/linux/misc.cpp

View check run for this annotation

Codecov / codecov/patch

src/platform/linux/misc.cpp#L468-L469

Added lines #L468 - L469 were not covered by tests
}
}

msg.msg_iov = iovs;
msg.msg_iovlen = iovlen;

Check warning on line 474 in src/platform/linux/misc.cpp

View check run for this annotation

Codecov / codecov/patch

src/platform/linux/misc.cpp#L473-L474

Added lines #L473 - L474 were not covered by tests

// We should not use GSO if the data is <= one full block size
if (iov.iov_len > send_info.block_size) {
if (segs_in_batch > 1) {
msg.msg_controllen = cmbuflen + CMSG_SPACE(sizeof(uint16_t));

// Enable GSO to perform segmentation of our buffer for us
auto cm = CMSG_NXTHDR(&msg, pktinfo_cm);
cm->cmsg_level = SOL_UDP;
cm->cmsg_type = UDP_SEGMENT;
cm->cmsg_len = CMSG_LEN(sizeof(uint16_t));
*((uint16_t *) CMSG_DATA(cm)) = send_info.block_size;
*((uint16_t *) CMSG_DATA(cm)) = msg_size;

Check warning on line 485 in src/platform/linux/misc.cpp

View check run for this annotation

Codecov / codecov/patch

src/platform/linux/misc.cpp#L485

Added line #L485 was not covered by tests
}
else {
msg.msg_controllen = cmbuflen;
Expand All @@ -483,10 +509,11 @@
continue;
}

BOOST_LOG(verbose) << "sendmsg() failed: "sv << errno;
break;
}

seg_index += bytes_sent / send_info.block_size;
seg_index += bytes_sent / msg_size;

Check warning on line 516 in src/platform/linux/misc.cpp

View check run for this annotation

Codecov / codecov/patch

src/platform/linux/misc.cpp#L516

Added line #L516 was not covered by tests
}

// If we sent something, return the status and don't fall back to the non-GSO path.
Expand All @@ -498,18 +525,25 @@

{
// If GSO is not supported, use sendmmsg() instead.
struct mmsghdr msgs[send_info.block_count];
struct iovec iovs[send_info.block_count];
struct mmsghdr msgs[send_info.block_count] = {};
struct iovec iovs[send_info.block_count * (send_info.headers ? 2 : 1)] = {};
int iov_idx = 0;
for (size_t i = 0; i < send_info.block_count; i++) {
iovs[i] = {};
iovs[i].iov_base = (void *) &send_info.buffer[i * send_info.block_size];
iovs[i].iov_len = send_info.block_size;
msgs[i].msg_hdr.msg_iov = &iovs[iov_idx];
msgs[i].msg_hdr.msg_iovlen = send_info.headers ? 2 : 1;

Check warning on line 533 in src/platform/linux/misc.cpp

View check run for this annotation

Codecov / codecov/patch

src/platform/linux/misc.cpp#L532-L533

Added lines #L532 - L533 were not covered by tests

if (send_info.headers) {
iovs[iov_idx].iov_base = (void *) &send_info.headers[(send_info.block_offset + i) * send_info.header_size];
iovs[iov_idx].iov_len = send_info.header_size;
iov_idx++;

Check warning on line 538 in src/platform/linux/misc.cpp

View check run for this annotation

Codecov / codecov/patch

src/platform/linux/misc.cpp#L536-L538

Added lines #L536 - L538 were not covered by tests
}
auto payload_desc = send_info.buffer_for_payload_offset((send_info.block_offset + i) * send_info.payload_size);
iovs[iov_idx].iov_base = (void *) payload_desc.buffer;
iovs[iov_idx].iov_len = send_info.payload_size;
iov_idx++;

Check warning on line 543 in src/platform/linux/misc.cpp

View check run for this annotation

Codecov / codecov/patch

src/platform/linux/misc.cpp#L540-L543

Added lines #L540 - L543 were not covered by tests

msgs[i] = {};
msgs[i].msg_hdr.msg_name = msg.msg_name;
msgs[i].msg_hdr.msg_namelen = msg.msg_namelen;
msgs[i].msg_hdr.msg_iov = &iovs[i];
msgs[i].msg_hdr.msg_iovlen = 1;
msgs[i].msg_hdr.msg_control = cmbuf.buf;
msgs[i].msg_hdr.msg_controllen = cmbuflen;
}
Expand Down
37 changes: 31 additions & 6 deletions src/platform/windows/misc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1452,12 +1452,37 @@
msg.namelen = sizeof(taddr_v4);
}

WSABUF buf;
buf.buf = (char *) send_info.buffer;
buf.len = send_info.block_size * send_info.block_count;
auto const max_bufs_per_msg = send_info.payload_buffers.size() + (send_info.headers ? 1 : 0);

msg.lpBuffers = &buf;
msg.dwBufferCount = 1;
WSABUF bufs[(send_info.headers ? send_info.block_count : 1) * max_bufs_per_msg];
DWORD bufcount = 0;

Check warning on line 1458 in src/platform/windows/misc.cpp

View check run for this annotation

Codecov / codecov/patch

src/platform/windows/misc.cpp#L1458

Added line #L1458 was not covered by tests
if (send_info.headers) {
// Interleave buffers for headers and payloads
for (auto i = 0; i < send_info.block_count; i++) {
bufs[bufcount].buf = (char *) &send_info.headers[(send_info.block_offset + i) * send_info.header_size];
bufs[bufcount].len = send_info.header_size;
bufcount++;
auto payload_desc = send_info.buffer_for_payload_offset((send_info.block_offset + i) * send_info.payload_size);
bufs[bufcount].buf = (char *) payload_desc.buffer;
bufs[bufcount].len = send_info.payload_size;
bufcount++;

Check warning on line 1468 in src/platform/windows/misc.cpp

View check run for this annotation

Codecov / codecov/patch

src/platform/windows/misc.cpp#L1462-L1468

Added lines #L1462 - L1468 were not covered by tests
}
}
else {
// Translate buffer descriptors into WSABUFs
auto payload_offset = send_info.block_offset * send_info.payload_size;
auto payload_length = payload_offset + (send_info.block_count * send_info.payload_size);

Check warning on line 1474 in src/platform/windows/misc.cpp

View check run for this annotation

Codecov / codecov/patch

src/platform/windows/misc.cpp#L1473-L1474

Added lines #L1473 - L1474 were not covered by tests
while (payload_offset < payload_length) {
auto payload_desc = send_info.buffer_for_payload_offset(payload_offset);
bufs[bufcount].buf = (char *) payload_desc.buffer;
bufs[bufcount].len = std::min(payload_desc.size, payload_length - payload_offset);
payload_offset += bufs[bufcount].len;
bufcount++;

Check warning on line 1480 in src/platform/windows/misc.cpp

View check run for this annotation

Codecov / codecov/patch

src/platform/windows/misc.cpp#L1476-L1480

Added lines #L1476 - L1480 were not covered by tests
}
}

msg.lpBuffers = bufs;
msg.dwBufferCount = bufcount;

Check warning on line 1485 in src/platform/windows/misc.cpp

View check run for this annotation

Codecov / codecov/patch

src/platform/windows/misc.cpp#L1484-L1485

Added lines #L1484 - L1485 were not covered by tests
msg.dwFlags = 0;

// At most, one DWORD option and one PKTINFO option
Expand Down Expand Up @@ -1505,7 +1530,7 @@
cm->cmsg_level = IPPROTO_UDP;
cm->cmsg_type = UDP_SEND_MSG_SIZE;
cm->cmsg_len = WSA_CMSG_LEN(sizeof(DWORD));
*((DWORD *) WSA_CMSG_DATA(cm)) = send_info.block_size;
*((DWORD *) WSA_CMSG_DATA(cm)) = send_info.header_size + send_info.payload_size;

Check warning on line 1533 in src/platform/windows/misc.cpp

View check run for this annotation

Codecov / codecov/patch

src/platform/windows/misc.cpp#L1533

Added line #L1533 was not covered by tests
}

msg.Control.len = cmbuflen;
Expand Down
Loading
Loading