From 5835ceae2bb36772ac9120b81aba1c171f702e3e Mon Sep 17 00:00:00 2001 From: Luke Howard Date: Mon, 11 Nov 2024 12:20:14 +1100 Subject: [PATCH] Support for datagram (UDP, message) sockets Fixes: #128 --- FlyingSocks/Sources/AsyncSocket.swift | 47 ++++++++++++++++++++ FlyingSocks/Sources/Socket+Android.swift | 9 ++++ FlyingSocks/Sources/Socket+Darwin.swift | 9 ++++ FlyingSocks/Sources/Socket+Glibc.swift | 9 ++++ FlyingSocks/Sources/Socket+Musl.swift | 9 ++++ FlyingSocks/Sources/Socket+WinSock2.swift | 9 ++++ FlyingSocks/Sources/Socket.swift | 52 +++++++++++++++++++++++ 7 files changed, 144 insertions(+) diff --git a/FlyingSocks/Sources/AsyncSocket.swift b/FlyingSocks/Sources/AsyncSocket.swift index 045f070b..3e6a86d9 100644 --- a/FlyingSocks/Sources/AsyncSocket.swift +++ b/FlyingSocks/Sources/AsyncSocket.swift @@ -129,6 +129,20 @@ public struct AsyncSocket: Sendable { return buffer } + public func receive(atMost length: Int = 4096) async throws -> (sockaddr_storage, [UInt8]) { + try Task.checkCancellation() + + repeat { + do { + return try socket.receive(length: length) + } catch SocketError.blocked { + try await pool.suspendSocket(socket, untilReadyFor: .read) + } catch { + throw error + } + } while true + } + /// Reads bytes from the socket up to by not over/ /// - Parameter bytes: The max number of bytes to read /// - Returns: an array of the read bytes capped to the number of bytes provided. @@ -163,6 +177,19 @@ public struct AsyncSocket: Sendable { } } + public func send(_ data: [UInt8], to address: some SocketAddress) async throws { + let sent = try await pool.loopUntilReady(for: .write, on: socket) { + try socket.send(data, to: address) + } + guard sent == data.count else { + throw SocketError.disconnected + } + } + + public func send(_ data: Data, to address: some SocketAddress) async throws { + try await send(Array(data), to: address) + } + public func close() throws { try socket.close() } @@ -174,6 +201,10 @@ public struct AsyncSocket: Sendable { public var sockets: AsyncSocketSequence { AsyncSocketSequence(socket: self) } + + public var messages: AsyncSocketMessageSequence { + AsyncSocketMessageSequence(socket: self) + } } package extension AsyncSocket { @@ -237,6 +268,22 @@ public struct AsyncSocketSequence: AsyncSequence, AsyncIteratorProtocol, Sendabl } } +public struct AsyncSocketMessageSequence: AsyncSequence, AsyncIteratorProtocol, Sendable { + public typealias Element = (sockaddr_storage, [UInt8]) + + let socket: AsyncSocket + + public func makeAsyncIterator() -> AsyncSocketMessageSequence { self } + + public mutating func next() async throws -> Element? { + return try await socket.receive() + } + + public func nextBuffer(suggested count: Int) async throws -> Element? { + try await socket.receive(atMost: count) + } +} + private actor ClientPoolLoader { static let shared = ClientPoolLoader() diff --git a/FlyingSocks/Sources/Socket+Android.swift b/FlyingSocks/Sources/Socket+Android.swift index fb154a8e..4574267d 100644 --- a/FlyingSocks/Sources/Socket+Android.swift +++ b/FlyingSocks/Sources/Socket+Android.swift @@ -45,6 +45,7 @@ extension Socket.FileDescriptor { extension Socket { static let stream = Int32(SOCK_STREAM) + static let datagram = Int32(SOCK_DGRAM) static let in_addr_any = Android.in_addr(s_addr: Android.in_addr_t(0)) static func makeAddressINET(port: UInt16) -> Android.sockaddr_in { @@ -175,6 +176,14 @@ extension Socket { static func pollfd(fd: FileDescriptorType, events: Int16, revents: Int16) -> Android.pollfd { Android.pollfd(fd: fd, events: events, revents: revents) } + + static func recvfrom(_ fd: FileDescriptorType, _ buffer: UnsafeMutableRawPointer!, _ nbyte: Int, _ flags: Int32, _ addr: UnsafeMutablePointer!, _ len: UnsafeMutablePointer!) -> Int { + Android.recvfrom(fd, buffer, nbyte, flags, addr, len) + } + + static func sendto(_ fd: FileDescriptorType, _ buffer: UnsafeRawPointer!, _ nbyte: Int, _ flags: Int32, _ destaddr: UnsafePointer!, _ destlen: socklen_t) -> Int { + Android.sendto(fd, buffer, nbyte, flags, destaddr, destlen) + } } #endif diff --git a/FlyingSocks/Sources/Socket+Darwin.swift b/FlyingSocks/Sources/Socket+Darwin.swift index efa3a643..f4829f4b 100644 --- a/FlyingSocks/Sources/Socket+Darwin.swift +++ b/FlyingSocks/Sources/Socket+Darwin.swift @@ -42,6 +42,7 @@ extension Socket.FileDescriptor { extension Socket { static let stream = Int32(SOCK_STREAM) + static let datagram = Int32(SOCK_DGRAM) static let in_addr_any = Darwin.in_addr(s_addr: Darwin.in_addr_t(0)) static func makeAddressINET(port: UInt16) -> Darwin.sockaddr_in { @@ -176,6 +177,14 @@ extension Socket { static func pollfd(fd: FileDescriptorType, events: Int16, revents: Int16) -> Darwin.pollfd { Darwin.pollfd(fd: fd, events: events, revents: revents) } + + static func recvfrom(_ fd: FileDescriptorType, _ buffer: UnsafeMutableRawPointer!, _ nbyte: Int, _ flags: Int32, _ addr: UnsafeMutablePointer!, _ len: UnsafeMutablePointer!) -> Int { + Darwin.recvfrom(fd, buffer, nbyte, flags, addr, len) + } + + static func sendto(_ fd: FileDescriptorType, _ buffer: UnsafeRawPointer!, _ nbyte: Int, _ flags: Int32, _ destaddr: UnsafePointer!, _ destlen: socklen_t) -> Int { + Darwin.sendto(fd, buffer, nbyte, flags, destaddr, destlen) + } } #endif diff --git a/FlyingSocks/Sources/Socket+Glibc.swift b/FlyingSocks/Sources/Socket+Glibc.swift index 9880ef8c..cc4aeafc 100644 --- a/FlyingSocks/Sources/Socket+Glibc.swift +++ b/FlyingSocks/Sources/Socket+Glibc.swift @@ -42,6 +42,7 @@ extension Socket.FileDescriptor { extension Socket { static let stream = Int32(SOCK_STREAM.rawValue) + static let datagram = Int32(SOCK_DGRAM.rawValue) static let in_addr_any = Glibc.in_addr(s_addr: Glibc.in_addr_t(0)) static func makeAddressINET(port: UInt16) -> Glibc.sockaddr_in { @@ -172,6 +173,14 @@ extension Socket { static func pollfd(fd: FileDescriptorType, events: Int16, revents: Int16) -> Glibc.pollfd { Glibc.pollfd(fd: fd, events: events, revents: revents) } + + static func recvfrom(_ fd: FileDescriptorType, _ buffer: UnsafeMutableRawPointer!, _ nbyte: Int, _ flags: Int32, _ addr: UnsafeMutablePointer!, _ len: UnsafeMutablePointer!) -> Int { + Glibc.recvfrom(fd, buffer, nbyte, flags, addr, len) + } + + static func sendto(_ fd: FileDescriptorType, _ buffer: UnsafeRawPointer!, _ nbyte: Int, _ flags: Int32, _ destaddr: UnsafePointer!, _ destlen: socklen_t) -> Int { + Glibc.sendto(fd, buffer, nbyte, flags, destaddr, destlen) + } } #endif diff --git a/FlyingSocks/Sources/Socket+Musl.swift b/FlyingSocks/Sources/Socket+Musl.swift index 5f285fd4..2de01f32 100644 --- a/FlyingSocks/Sources/Socket+Musl.swift +++ b/FlyingSocks/Sources/Socket+Musl.swift @@ -42,6 +42,7 @@ extension Socket.FileDescriptor { extension Socket { static let stream = Int32(SOCK_STREAM) + static let datagram = Int32(SOCK_DGRAM) static let in_addr_any = Musl.in_addr(s_addr: Musl.in_addr_t(0)) static func makeAddressINET(port: UInt16) -> Musl.sockaddr_in { @@ -172,6 +173,14 @@ extension Socket { static func pollfd(fd: FileDescriptorType, events: Int16, revents: Int16) -> Musl.pollfd { Musl.pollfd(fd: fd, events: events, revents: revents) } + + static func recvfrom(_ fd: FileDescriptorType, _ buffer: UnsafeMutableRawPointer!, _ nbyte: Int, _ flags: Int32, _ addr: UnsafeMutablePointer!, _ len: UnsafeMutablePointer!) -> Int { + Musl.recvfrom(fd, buffer, nbyte, flags, addr, len) + } + + static func sendto(_ fd: FileDescriptorType, _ buffer: UnsafeRawPointer!, _ nbyte: Int, _ flags: Int32, _ destaddr: UnsafePointer!, _ destlen: socklen_t) -> Int { + Musl.sendto(fd, buffer, nbyte, flags, destaddr, destlen) + } } #endif diff --git a/FlyingSocks/Sources/Socket+WinSock2.swift b/FlyingSocks/Sources/Socket+WinSock2.swift index 0c2c086f..177670d3 100755 --- a/FlyingSocks/Sources/Socket+WinSock2.swift +++ b/FlyingSocks/Sources/Socket+WinSock2.swift @@ -52,6 +52,7 @@ extension Socket.FileDescriptor { extension Socket { static let stream = Int32(SOCK_STREAM) + static let datagram = Int32(SOCK_DGRAM) static let in_addr_any = WinSDK.in_addr() static func makeAddressINET(port: UInt16) -> WinSDK.sockaddr_in { @@ -184,6 +185,14 @@ extension Socket { static func pollfd(fd: FileDescriptorType, events: Int16, revents: Int16) -> WinSDK.WSAPOLLFD { WinSDK.WSAPOLLFD(fd: fd, events: events, revents: revents) } + + static func recvfrom(_ fd: FileDescriptorType, _ buffer: UnsafeMutableRawPointer!, _ nbyte: Int, _ flags: Int32, _ addr: UnsafeMutablePointer!, _ len: UnsafeMutablePointer!) -> Int { + WinSDK.recvfrom(fd, buffer, nbyte, flags, addr, len) + } + + static func sendto(_ fd: FileDescriptorType, _ buffer: UnsafeRawPointer!, _ nbyte: Int, _ flags: Int32, _ destaddr: UnsafePointer!, _ destlen: socklen_t) -> Int { + WinSDK.sendto(fd, buffer, nbyte, flags, destaddr, destlen) + } } #endif diff --git a/FlyingSocks/Sources/Socket.swift b/FlyingSocks/Sources/Socket.swift index 7de00526..cf029c0d 100644 --- a/FlyingSocks/Sources/Socket.swift +++ b/FlyingSocks/Sources/Socket.swift @@ -214,6 +214,35 @@ public struct Socket: Sendable, Hashable { return count } + public func receive(length: Int) throws -> (sockaddr_storage, [UInt8]) { + var address: sockaddr_storage? + let bytes = try [UInt8](unsafeUninitializedCapacity: length) { buffer, count in + (address, count) = try receive(into: buffer.baseAddress!, length: length) + } + + return (address!, bytes) + } + + private func receive(into buffer: UnsafeMutablePointer, length: Int) throws -> (sockaddr_storage, Int) { + var addr = sockaddr_storage() + var size = socklen_t(MemoryLayout.size) + let count = withUnsafeMutablePointer(to: &addr) { + $0.withMemoryRebound(to: sockaddr.self, capacity: 1) { + Socket.recvfrom(file.rawValue, buffer, length, 0, $0, &size) + } + } + guard count > 0 else { + if errno == EWOULDBLOCK { + throw SocketError.blocked + } else if errno == EBADF || count == 0 { + throw SocketError.disconnected + } else { + throw SocketError.makeFailed("RecvFrom") + } + } + return (addr, count) + } + public func write(_ data: Data, from index: Data.Index = 0) throws -> Data.Index { precondition(index >= 0) guard index < data.endIndex else { return data.endIndex } @@ -237,6 +266,29 @@ public struct Socket: Sendable, Hashable { return sent } + public func send(_ bytes: [UInt8], to address: some SocketAddress) throws -> Int { + try bytes.withUnsafeBytes { buffer in + try send(buffer.baseAddress!, length: bytes.count, to: address) + } + } + + private func send(_ pointer: UnsafeRawPointer, length: Int, to address: A) throws -> Int { + var addr = address + let sent = withUnsafePointer(to: &addr) { + $0.withMemoryRebound(to: sockaddr.self, capacity: 1) { + Socket.sendto(file.rawValue, pointer, length, 0, $0, socklen_t(MemoryLayout.size)) + } + } + guard sent >= 0 || errno == EISCONN else { + if errno == EINPROGRESS { + throw SocketError.blocked + } else { + throw SocketError.makeFailed("SendTo") + } + } + return sent + } + public func close() throws { if Socket.close(file.rawValue) == -1 { throw SocketError.makeFailed("Close")