Skip to content

Commit

Permalink
UDP: support for multihoming with unbound sockets
Browse files Browse the repository at this point in the history
  • Loading branch information
lhoward committed Nov 14, 2024
1 parent 0df6cf7 commit 6fb4b23
Show file tree
Hide file tree
Showing 10 changed files with 481 additions and 26 deletions.
74 changes: 72 additions & 2 deletions FlyingSocks/Sources/AsyncSocket.swift
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,25 @@ public extension AsyncSocketPool where Self == SocketPool<Poll> {

public struct AsyncSocket: Sendable {

public struct Message: Sendable {
public let peerAddress: sockaddr_storage
public let bytes: [UInt8]
public let interfaceIndex: UInt32?
public let localAddress: sockaddr_storage?

public init(
peerAddress: sockaddr_storage,
bytes: [UInt8],
interfaceIndex: UInt32? = nil,
localAddress: sockaddr_storage? = nil
) {
self.peerAddress = peerAddress
self.bytes = bytes
self.interfaceIndex = interfaceIndex
self.localAddress = localAddress
}
}

public let socket: Socket
let pool: any AsyncSocketPool

Expand Down Expand Up @@ -143,6 +162,23 @@ public struct AsyncSocket: Sendable {
} while true
}

Check warning on line 163 in FlyingSocks/Sources/AsyncSocket.swift

View check run for this annotation

Codecov / codecov/patch

FlyingSocks/Sources/AsyncSocket.swift#L158-L163

Added lines #L158 - L163 were not covered by tests

#if !canImport(WinSDK)
public func receive(atMost length: Int) async throws -> Message {
try Task.checkCancellation()

repeat {
do {
let (peerAddress, bytes, interfaceIndex, localAddress) = try socket.receive(length: length)
return Message(peerAddress: peerAddress, bytes: bytes, interfaceIndex: interfaceIndex, localAddress: localAddress)
} catch SocketError.blocked {
try await pool.suspendSocket(socket, untilReadyFor: .read)
} catch {
throw error
}
} while true
}

Check warning on line 179 in FlyingSocks/Sources/AsyncSocket.swift

View check run for this annotation

Codecov / codecov/patch

FlyingSocks/Sources/AsyncSocket.swift#L174-L179

Added lines #L174 - L179 were not covered by tests
#endif

/// Reads bytes from the socket up to by not over/
/// - Parameter bytes: The max number of bytes to read
/// - Returns: an array of the read bytes capped to the number of bytes provided.
Expand Down Expand Up @@ -190,6 +226,31 @@ public struct AsyncSocket: Sendable {
try await send(Array(data), to: address)
}

Check warning on line 227 in FlyingSocks/Sources/AsyncSocket.swift

View check run for this annotation

Codecov / codecov/patch

FlyingSocks/Sources/AsyncSocket.swift#L225-L227

Added lines #L225 - L227 were not covered by tests

#if !canImport(WinSDK)
public func send(
message: [UInt8],
to peerAddress: some SocketAddress,
interfaceIndex: UInt32? = nil,
from localAddress: (some SocketAddress)? = nil
) async throws {
let sent = try await pool.loopUntilReady(for: .write, on: socket) {
try socket.send(message: message, to: peerAddress, interfaceIndex: interfaceIndex, from: localAddress)
}
guard sent == message.count else {
throw SocketError.disconnected
}
}

Check warning on line 242 in FlyingSocks/Sources/AsyncSocket.swift

View check run for this annotation

Codecov / codecov/patch

FlyingSocks/Sources/AsyncSocket.swift#L235-L242

Added lines #L235 - L242 were not covered by tests

public func send(
message: Data,
to peerAddress: some SocketAddress,
interfaceIndex: UInt32? = nil,
from localAddress: (some SocketAddress)? = nil
) async throws {
try await send(message: Array(message), to: peerAddress, interfaceIndex: interfaceIndex, from: localAddress)
}

Check warning on line 251 in FlyingSocks/Sources/AsyncSocket.swift

View check run for this annotation

Codecov / codecov/patch

FlyingSocks/Sources/AsyncSocket.swift#L249-L251

Added lines #L249 - L251 were not covered by tests
#endif

public func close() throws {
try socket.close()
}
Expand Down Expand Up @@ -275,7 +336,8 @@ public struct AsyncSocketSequence: AsyncSequence, AsyncIteratorProtocol, Sendabl
public struct AsyncSocketMessageSequence: AsyncSequence, AsyncIteratorProtocol, Sendable {
public static let DefaultMaxMessageLength: Int = 1500

public typealias Element = (sockaddr_storage, [UInt8])
// Windows has a different recvmsg() API signature which is presently unsupported
public typealias Element = AsyncSocket.Message

private let socket: AsyncSocket
private let maxMessageLength: Int
Expand All @@ -288,7 +350,15 @@ public struct AsyncSocketMessageSequence: AsyncSequence, AsyncIteratorProtocol,
}

Check warning on line 350 in FlyingSocks/Sources/AsyncSocket.swift

View check run for this annotation

Codecov / codecov/patch

FlyingSocks/Sources/AsyncSocket.swift#L347-L350

Added lines #L347 - L350 were not covered by tests

public mutating func next() async throws -> Element? {

Check warning on line 352 in FlyingSocks/Sources/AsyncSocket.swift

View check run for this annotation

Codecov / codecov/patch

FlyingSocks/Sources/AsyncSocket.swift#L352

Added line #L352 was not covered by tests
return try await socket.receive(atMost: maxMessageLength)
#if !canImport(WinSDK)
try await socket.receive(atMost: maxMessageLength)

Check warning on line 354 in FlyingSocks/Sources/AsyncSocket.swift

View check run for this annotation

Codecov / codecov/patch

FlyingSocks/Sources/AsyncSocket.swift#L354

Added line #L354 was not covered by tests
#else
let peerAddress: sockaddr_storage
let bytes: [UInt8]

(peerAddress, bytes) = try await socket.receive(atMost: maxMessageLength)
return AsyncSocket.Message(peerAddress: peerAddress, bytes: bytes)
#endif
}

Check warning on line 362 in FlyingSocks/Sources/AsyncSocket.swift

View check run for this annotation

Codecov / codecov/patch

FlyingSocks/Sources/AsyncSocket.swift#L362

Added line #L362 was not covered by tests
}

Expand Down
16 changes: 16 additions & 0 deletions FlyingSocks/Sources/Socket+Android.swift
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ let EPOLLET: UInt32 = 1 << 31;

public extension Socket {
typealias FileDescriptorType = Int32
typealias IovLengthType = UInt
typealias ControlMessageHeaderLengthType = Int
typealias IPv4InterfaceIndexType = Int32
typealias IPv6InterfaceIndexType = Int32
}

extension Socket.FileDescriptor {
Expand All @@ -47,6 +51,10 @@ extension Socket {
static let stream = Int32(SOCK_STREAM)
static let datagram = Int32(SOCK_DGRAM)
static let in_addr_any = Android.in_addr(s_addr: Android.in_addr_t(0))
static let ipproto_ip = Int32(IPPROTO_IP)
static let ipproto_ipv6 = Int32(IPPROTO_IPV6)
static let ip_pktinfo = Int32(IP_PKTINFO)
static let ipv6_pktinfo = Int32(IPV6_PKTINFO)

static func makeAddressINET(port: UInt16) -> Android.sockaddr_in {
Android.sockaddr_in(
Expand Down Expand Up @@ -184,6 +192,14 @@ extension Socket {
static func sendto(_ fd: FileDescriptorType, _ buffer: UnsafeRawPointer!, _ nbyte: Int, _ flags: Int32, _ destaddr: UnsafePointer<sockaddr>!, _ destlen: socklen_t) -> Int {
Android.sendto(fd, buffer, nbyte, flags, destaddr, destlen)
}

static func recvmsg(_ fd: FileDescriptorType, _ message: UnsafeMutablePointer<msghdr>, _ flags: Int32) -> Int {
Android.recvmsg(fd, message, flags)
}

static func sendmsg(_ fd: FileDescriptorType, _ message: UnsafePointer<msghdr>, _ flags: Int32) -> Int {
Android.sendmsg(fd, message, flags)
}
}

#endif
16 changes: 16 additions & 0 deletions FlyingSocks/Sources/Socket+Darwin.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ import Darwin

public extension Socket {
typealias FileDescriptorType = Int32
typealias IovLengthType = Int
typealias ControlMessageHeaderLengthType = UInt32
typealias IPv4InterfaceIndexType = UInt32
typealias IPv6InterfaceIndexType = UInt32
}

extension Socket.FileDescriptor {
Expand All @@ -44,6 +48,10 @@ extension Socket {
static let stream = Int32(SOCK_STREAM)
static let datagram = Int32(SOCK_DGRAM)
static let in_addr_any = Darwin.in_addr(s_addr: Darwin.in_addr_t(0))
static let ipproto_ip = Int32(IPPROTO_IP)
static let ipproto_ipv6 = Int32(IPPROTO_IPV6)
static let ip_pktinfo = Int32(IP_PKTINFO)
static let ipv6_pktinfo = Int32(50) // __APPLE_USE_RFC_2292

static func makeAddressINET(port: UInt16) -> Darwin.sockaddr_in {
Darwin.sockaddr_in(
Expand Down Expand Up @@ -185,6 +193,14 @@ extension Socket {
static func sendto(_ fd: FileDescriptorType, _ buffer: UnsafeRawPointer!, _ nbyte: Int, _ flags: Int32, _ destaddr: UnsafePointer<sockaddr>!, _ destlen: socklen_t) -> Int {
Darwin.sendto(fd, buffer, nbyte, flags, destaddr, destlen)
}

Check warning on line 195 in FlyingSocks/Sources/Socket+Darwin.swift

View check run for this annotation

Codecov / codecov/patch

FlyingSocks/Sources/Socket+Darwin.swift#L193-L195

Added lines #L193 - L195 were not covered by tests

static func recvmsg(_ fd: FileDescriptorType, _ message: UnsafeMutablePointer<msghdr>, _ flags: Int32) -> Int {
Darwin.recvmsg(fd, message, flags)
}

static func sendmsg(_ fd: FileDescriptorType, _ message: UnsafePointer<msghdr>, _ flags: Int32) -> Int {
Darwin.sendmsg(fd, message, flags)
}

Check warning on line 203 in FlyingSocks/Sources/Socket+Darwin.swift

View check run for this annotation

Codecov / codecov/patch

FlyingSocks/Sources/Socket+Darwin.swift#L201-L203

Added lines #L201 - L203 were not covered by tests
}

#endif
21 changes: 21 additions & 0 deletions FlyingSocks/Sources/Socket+Glibc.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ import Glibc

public extension Socket {
typealias FileDescriptorType = Int32
typealias IovLengthType = Int
typealias ControlMessageHeaderLengthType = Int
typealias IPv4InterfaceIndexType = Int32
typealias IPv6InterfaceIndexType = UInt32
}

extension Socket.FileDescriptor {
Expand All @@ -44,6 +48,10 @@ extension Socket {
static let stream = Int32(SOCK_STREAM.rawValue)
static let datagram = Int32(SOCK_DGRAM.rawValue)
static let in_addr_any = Glibc.in_addr(s_addr: Glibc.in_addr_t(0))
static let ipproto_ip = Int32(IPPROTO_IP)
static let ipproto_ipv6 = Int32(IPPROTO_IPV6)
static let ip_pktinfo = Int32(IP_PKTINFO)
static let ipv6_pktinfo = Int32(IPV6_PKTINFO)

static func makeAddressINET(port: UInt16) -> Glibc.sockaddr_in {
Glibc.sockaddr_in(
Expand Down Expand Up @@ -181,6 +189,19 @@ extension Socket {
static func sendto(_ fd: FileDescriptorType, _ buffer: UnsafeRawPointer!, _ nbyte: Int, _ flags: Int32, _ destaddr: UnsafePointer<sockaddr>!, _ destlen: socklen_t) -> Int {
Glibc.sendto(fd, buffer, nbyte, flags, destaddr, destlen)
}

static func recvmsg(_ fd: FileDescriptorType, _ message: UnsafeMutablePointer<msghdr>, _ flags: Int32) -> Int {
Glibc.recvmsg(fd, message, flags)
}

static func sendmsg(_ fd: FileDescriptorType, _ message: UnsafePointer<msghdr>, _ flags: Int32) -> Int {
Glibc.sendmsg(fd, message, flags)
}
}

struct in6_pktinfo {
var ipi6_addr: in6_addr
var ipi6_ifindex: CUnsignedInt
}

#endif
16 changes: 16 additions & 0 deletions FlyingSocks/Sources/Socket+Musl.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ import Musl

public extension Socket {
typealias FileDescriptorType = Int32
typealias IovLengthType = Int
typealias ControlMessageHeaderLengthType = UInt32
typealias IPv4InterfaceIndexType = Int32
typealias IPv6InterfaceIndexType = UInt32
}

extension Socket.FileDescriptor {
Expand All @@ -44,6 +48,10 @@ extension Socket {
static let stream = Int32(SOCK_STREAM)
static let datagram = Int32(SOCK_DGRAM)
static let in_addr_any = Musl.in_addr(s_addr: Musl.in_addr_t(0))
static let ipproto_ip = Int32(IPPROTO_IP)
static let ipproto_ipv6 = Int32(IPPROTO_IPV6)
static let ip_pktinfo = Int32(IP_PKTINFO)
static let ipv6_pktinfo = Int32(IPV6_PKTINFO)

static func makeAddressINET(port: UInt16) -> Musl.sockaddr_in {
Musl.sockaddr_in(
Expand Down Expand Up @@ -181,6 +189,14 @@ extension Socket {
static func sendto(_ fd: FileDescriptorType, _ buffer: UnsafeRawPointer!, _ nbyte: Int, _ flags: Int32, _ destaddr: UnsafePointer<sockaddr>!, _ destlen: socklen_t) -> Int {
Musl.sendto(fd, buffer, nbyte, flags, destaddr, destlen)
}

static func recvmsg(_ fd: FileDescriptorType, _ message: UnsafeMutablePointer<msghdr>, _ flags: Int32) -> Int {
Musl.recvmsg(fd, message, flags)
}

static func sendmsg(_ fd: FileDescriptorType, _ message: UnsafePointer<msghdr>, _ flags: Int32) -> Int {
Musl.sendmsg(fd, message, flags)
}
}

#endif
16 changes: 16 additions & 0 deletions FlyingSocks/Sources/Socket+WinSock2.swift
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ public typealias sa_family_t = UInt8

public extension Socket {
typealias FileDescriptorType = UInt64
typealias IovLengthType = UInt
typealias ControlMessageHeaderLengthType = DWORD
typealias IPv4InterfaceIndexType = ULONG
typealias IPv6InterfaceIndexType = ULONG
}

extension Socket.FileDescriptor {
Expand All @@ -54,6 +58,10 @@ extension Socket {
static let stream = Int32(SOCK_STREAM)
static let datagram = Int32(SOCK_DGRAM)
static let in_addr_any = WinSDK.in_addr()
static let ipproto_ip = Int32(IPPROTO_IP)
static let ipproto_ipv6 = Int32(IPPROTO_IPV6)
static let ip_pktinfo = Int32(IP_PKTINFO)
static let ipv6_pktinfo = Int32(IPV6_PKTINFO)

static func makeAddressINET(port: UInt16) -> WinSDK.sockaddr_in {
WinSDK.sockaddr_in(
Expand Down Expand Up @@ -193,6 +201,14 @@ extension Socket {
static func sendto(_ fd: FileDescriptorType, _ buffer: UnsafeRawPointer!, _ nbyte: Int, _ flags: Int32, _ destaddr: UnsafePointer<sockaddr>!, _ destlen: socklen_t) -> Int {
WinSDK.sendto(fd, buffer, nbyte, flags, destaddr, destlen)
}

static func recvmsg(_ fd: FileDescriptorType, _ message: UnsafeMutablePointer<msghdr>, _ flags: Int32) -> Int {
WinSDK.recvmsg(fd, message, flags)
}

static func sendmsg(_ fd: FileDescriptorType, _ message: UnsafePointer<msghdr>, _ flags: Int32) -> Int {
WinSDK.sendmsg(fd, message, flags)
}
}

#endif
Loading

0 comments on commit 6fb4b23

Please sign in to comment.