Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add "borrowed-fd" feature for sending BorrowedFd on a socket #15

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ repository = "https://github.com/standard-ai/sendfd"
documentation = "https://docs.rs/sendfd"
readme = "README.mkd"

[features]
borrowed-fd = []

[dependencies]
libc = "0.2"
tokio = { version = "1.0.0", features = [ "net" ], optional = true }
Expand Down
101 changes: 99 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ extern crate libc;
#[cfg(feature = "tokio")]
extern crate tokio;

#[cfg(feature = "borrowed-fd")]
use std::os::fd::BorrowedFd;
use std::os::unix::io::{AsRawFd, RawFd};
use std::os::unix::net;
use std::{alloc, io, mem, ptr};
Expand All @@ -16,6 +18,9 @@ pub mod changelog;
pub trait SendWithFd {
/// Send the bytes and the file descriptors.
fn send_with_fd(&self, bytes: &[u8], fds: &[RawFd]) -> io::Result<usize>;
/// Send the bytes and the file descriptors.
#[cfg(feature = "borrowed-fd")]
fn send_with_borrowed_fd(&self, bytes: &[u8], fds: &[BorrowedFd<'_>]) -> io::Result<usize>;
}

/// An extension trait that enables receiving associated file descriptors along with the data.
Expand Down Expand Up @@ -77,7 +82,7 @@ unsafe fn construct_msghdr_for(

/// A common implementation of `sendmsg` that sends provided bytes with ancillary file descriptors
/// over either a datagram or stream unix socket.
fn send_with_fd(socket: RawFd, bs: &[u8], fds: &[RawFd]) -> io::Result<usize> {
fn send_with_fd<F: AsRawFd>(socket: RawFd, bs: &[u8], fds: &[F]) -> io::Result<usize> {
unsafe {
let mut iov = libc::iovec {
// NB: this casts *const to *mut, and in doing so we trust the OS to be a good citizen
Expand All @@ -99,7 +104,7 @@ fn send_with_fd(socket: RawFd, bs: &[u8], fds: &[RawFd]) -> io::Result<usize> {

let cmsg_data = libc::CMSG_DATA(cmsg_header) as *mut RawFd;
for (i, fd) in fds.iter().enumerate() {
ptr::write_unaligned(cmsg_data.add(i), *fd);
ptr::write_unaligned(cmsg_data.add(i), fd.as_raw_fd());
}
let count = libc::sendmsg(socket, &msghdr as *const _, 0);
if count < 0 {
Expand Down Expand Up @@ -181,6 +186,15 @@ impl SendWithFd for net::UnixStream {
fn send_with_fd(&self, bytes: &[u8], fds: &[RawFd]) -> io::Result<usize> {
send_with_fd(self.as_raw_fd(), bytes, fds)
}

/// Send the bytes and the file descriptors as a stream.
///
/// Neither is guaranteed to be received by the other end in a single chunk and
/// may arrive entirely independently.
#[cfg(feature = "borrowed-fd")]
fn send_with_borrowed_fd(&self, bytes: &[u8], fds: &[BorrowedFd<'_>]) -> io::Result<usize> {
send_with_fd(self.as_raw_fd(), bytes, fds)
}
}

#[cfg(feature = "tokio")]
Expand All @@ -193,6 +207,17 @@ impl SendWithFd for tokio::net::UnixStream {
fn send_with_fd(&self, bytes: &[u8], fds: &[RawFd]) -> io::Result<usize> {
self.try_io(Interest::WRITABLE, || send_with_fd(self.as_raw_fd(), bytes, fds))
}

/// Send the bytes and the file descriptors as a stream.
///
/// Neither is guaranteed to be received by the other end in a single chunk and
/// may arrive entirely independently.
#[cfg(feature = "borrowed-fd")]
fn send_with_borrowed_fd(&self, bytes: &[u8], fds: &[BorrowedFd<'_>]) -> io::Result<usize> {
self.try_io(Interest::WRITABLE, || {
send_with_fd(self.as_raw_fd(), bytes, fds)
})
}
}

#[cfg(feature = "tokio")]
Expand All @@ -206,6 +231,16 @@ impl SendWithFd for tokio::net::unix::WriteHalf<'_> {
let unix_stream: &tokio::net::UnixStream = self.as_ref();
unix_stream.send_with_fd(bytes, fds)
}

/// Send the bytes and the file descriptors as a stream.
///
/// Neither is guaranteed to be received by the other end in a single chunk and
/// may arrive entirely independently.
#[cfg(feature = "borrowed-fd")]
fn send_with_borrowed_fd(&self, bytes: &[u8], fds: &[BorrowedFd<'_>]) -> io::Result<usize> {
let unix_stream: &tokio::net::UnixStream = self.as_ref();
unix_stream.send_with_borrowed_fd(bytes, fds)
}
}

impl SendWithFd for net::UnixDatagram {
Expand All @@ -217,6 +252,16 @@ impl SendWithFd for net::UnixDatagram {
fn send_with_fd(&self, bytes: &[u8], fds: &[RawFd]) -> io::Result<usize> {
send_with_fd(self.as_raw_fd(), bytes, fds)
}

/// Send the bytes and the file descriptors as a single packet.
///
/// It is guaranteed that the bytes and the associated file descriptors will arrive at the same
/// time, however the receiver end may not receive the full message if its buffers are too
/// small.
#[cfg(feature = "borrowed-fd")]
fn send_with_borrowed_fd(&self, bytes: &[u8], fds: &[BorrowedFd<'_>]) -> io::Result<usize> {
send_with_fd(self.as_raw_fd(), bytes, fds)
}
}

#[cfg(feature = "tokio")]
Expand All @@ -230,6 +275,18 @@ impl SendWithFd for tokio::net::UnixDatagram {
fn send_with_fd(&self, bytes: &[u8], fds: &[RawFd]) -> io::Result<usize> {
self.try_io(Interest::WRITABLE, || send_with_fd(self.as_raw_fd(), bytes, fds))
}

/// Send the bytes and the file descriptors as a single packet.
///
/// It is guaranteed that the bytes and the associated file descriptors will arrive at the same
/// time, however the receiver end may not receive the full message if its buffers are too
/// small.
#[cfg(feature = "borrowed-fd")]
fn send_with_borrowed_fd(&self, bytes: &[u8], fds: &[BorrowedFd<'_>]) -> io::Result<usize> {
self.try_io(Interest::WRITABLE, || {
send_with_fd(self.as_raw_fd(), bytes, fds)
})
}
}

impl RecvWithFd for net::UnixStream {
Expand Down Expand Up @@ -441,4 +498,44 @@ mod tests {
panic!("expected an error when sending a junk file descriptor");
}
}

#[cfg(feature = "borrowed-fd")]
#[test]
fn borrowed_fd() {
use std::os::fd::AsFd;

let (l, r) = net::UnixStream::pair().expect("create UnixStream pair");
let sent_bytes = b"hello world!";
let sent_fds = [l.as_fd(), r.as_fd()];
assert_eq!(
l.send_with_borrowed_fd(&sent_bytes[..], &sent_fds[..])
.expect("send should be successful"),
sent_bytes.len()
);
let mut recv_bytes = [0; 128];
let mut recv_fds = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
assert_eq!(
r.recv_with_fd(&mut recv_bytes, &mut recv_fds)
.expect("recv should be successful"),
(sent_bytes.len(), sent_fds.len())
);
assert_eq!(recv_bytes[..sent_bytes.len()], sent_bytes[..]);
for (&sent, &recvd) in sent_fds.iter().zip(&recv_fds[..]) {
// Modify the sent resource and check if the received resource has been modified the
// same way.
let expected_value = Some(std::time::Duration::from_secs(42));
unsafe {
let s = net::UnixStream::from(sent.try_clone_to_owned().unwrap());
s.set_read_timeout(expected_value)
.expect("set read timeout");
std::mem::forget(s);
assert_eq!(
net::UnixStream::from_raw_fd(recvd)
.read_timeout()
.expect("get read timeout"),
expected_value
);
}
}
}
}
Loading