diff --git a/FlyingSocks/Sources/AsyncSocket.swift b/FlyingSocks/Sources/AsyncSocket.swift index 4afe2209..339a3625 100644 --- a/FlyingSocks/Sources/AsyncSocket.swift +++ b/FlyingSocks/Sources/AsyncSocket.swift @@ -62,6 +62,25 @@ public extension AsyncSocketPool where Self == SocketPool { public struct AsyncSocket: Sendable { + public struct Message: Sendable { + public let peerAddress: sockaddr_storage + public let bytes: [UInt8] + public let interfaceIndex: UInt32? + public let localAddress: sockaddr_storage? + + public init( + peerAddress: sockaddr_storage, + bytes: [UInt8], + interfaceIndex: UInt32? = nil, + localAddress: sockaddr_storage? = nil + ) { + self.peerAddress = peerAddress + self.bytes = bytes + self.interfaceIndex = interfaceIndex + self.localAddress = localAddress + } + } + public let socket: Socket let pool: any AsyncSocketPool @@ -143,6 +162,23 @@ public struct AsyncSocket: Sendable { } while true } +#if !canImport(WinSDK) + public func receive(atMost length: Int) async throws -> Message { + try Task.checkCancellation() + + repeat { + do { + let (peerAddress, bytes, interfaceIndex, localAddress) = try socket.receive(length: length) + return Message(peerAddress: peerAddress, bytes: bytes, interfaceIndex: interfaceIndex, localAddress: localAddress) + } catch SocketError.blocked { + try await pool.suspendSocket(socket, untilReadyFor: .read) + } catch { + throw error + } + } while true + } +#endif + /// Reads bytes from the socket up to by not over/ /// - Parameter bytes: The max number of bytes to read /// - Returns: an array of the read bytes capped to the number of bytes provided. @@ -190,6 +226,31 @@ public struct AsyncSocket: Sendable { try await send(Array(data), to: address) } +#if !canImport(WinSDK) + public func send( + message: [UInt8], + to peerAddress: some SocketAddress, + interfaceIndex: UInt32? = nil, + from localAddress: (some SocketAddress)? = nil + ) async throws { + let sent = try await pool.loopUntilReady(for: .write, on: socket) { + try socket.send(message: message, to: peerAddress, interfaceIndex: interfaceIndex, from: localAddress) + } + guard sent == message.count else { + throw SocketError.disconnected + } + } + + public func send( + message: Data, + to peerAddress: some SocketAddress, + interfaceIndex: UInt32? = nil, + from localAddress: (some SocketAddress)? = nil + ) async throws { + try await send(message: Array(message), to: peerAddress, interfaceIndex: interfaceIndex, from: localAddress) + } +#endif + public func close() throws { try socket.close() } @@ -275,7 +336,8 @@ public struct AsyncSocketSequence: AsyncSequence, AsyncIteratorProtocol, Sendabl public struct AsyncSocketMessageSequence: AsyncSequence, AsyncIteratorProtocol, Sendable { public static let DefaultMaxMessageLength: Int = 1500 - public typealias Element = (sockaddr_storage, [UInt8]) + // Windows has a different recvmsg() API signature which is presently unsupported + public typealias Element = AsyncSocket.Message private let socket: AsyncSocket private let maxMessageLength: Int @@ -288,7 +350,15 @@ public struct AsyncSocketMessageSequence: AsyncSequence, AsyncIteratorProtocol, } public mutating func next() async throws -> Element? { - return try await socket.receive(atMost: maxMessageLength) +#if !canImport(WinSDK) + try await socket.receive(atMost: maxMessageLength) +#else + let peerAddress: sockaddr_storage + let bytes: [UInt8] + + (peerAddress, bytes) = try await socket.receive(atMost: maxMessageLength) + return AsyncSocket.Message(peerAddress: peerAddress, bytes: bytes) +#endif } } diff --git a/FlyingSocks/Sources/Socket+Android.swift b/FlyingSocks/Sources/Socket+Android.swift index 4574267d..9dab334b 100644 --- a/FlyingSocks/Sources/Socket+Android.swift +++ b/FlyingSocks/Sources/Socket+Android.swift @@ -37,6 +37,10 @@ let EPOLLET: UInt32 = 1 << 31; public extension Socket { typealias FileDescriptorType = Int32 + typealias IovLengthType = UInt + typealias ControlMessageHeaderLengthType = Int + typealias IPv4InterfaceIndexType = Int32 + typealias IPv6InterfaceIndexType = Int32 } extension Socket.FileDescriptor { @@ -47,6 +51,10 @@ extension Socket { static let stream = Int32(SOCK_STREAM) static let datagram = Int32(SOCK_DGRAM) static let in_addr_any = Android.in_addr(s_addr: Android.in_addr_t(0)) + static let ipproto_ip = Int32(IPPROTO_IP) + static let ipproto_ipv6 = Int32(IPPROTO_IPV6) + static let ip_pktinfo = Int32(IP_PKTINFO) + static let ipv6_pktinfo = Int32(IPV6_PKTINFO) static func makeAddressINET(port: UInt16) -> Android.sockaddr_in { Android.sockaddr_in( @@ -184,6 +192,14 @@ extension Socket { static func sendto(_ fd: FileDescriptorType, _ buffer: UnsafeRawPointer!, _ nbyte: Int, _ flags: Int32, _ destaddr: UnsafePointer!, _ destlen: socklen_t) -> Int { Android.sendto(fd, buffer, nbyte, flags, destaddr, destlen) } + + static func recvmsg(_ fd: FileDescriptorType, _ message: UnsafeMutablePointer, _ flags: Int32) -> Int { + Android.recvmsg(fd, message, flags) + } + + static func sendmsg(_ fd: FileDescriptorType, _ message: UnsafePointer, _ flags: Int32) -> Int { + Android.sendmsg(fd, message, flags) + } } #endif diff --git a/FlyingSocks/Sources/Socket+Darwin.swift b/FlyingSocks/Sources/Socket+Darwin.swift index f4829f4b..d0b26e52 100644 --- a/FlyingSocks/Sources/Socket+Darwin.swift +++ b/FlyingSocks/Sources/Socket+Darwin.swift @@ -34,6 +34,10 @@ import Darwin public extension Socket { typealias FileDescriptorType = Int32 + typealias IovLengthType = Int + typealias ControlMessageHeaderLengthType = UInt32 + typealias IPv4InterfaceIndexType = UInt32 + typealias IPv6InterfaceIndexType = UInt32 } extension Socket.FileDescriptor { @@ -44,6 +48,10 @@ extension Socket { static let stream = Int32(SOCK_STREAM) static let datagram = Int32(SOCK_DGRAM) static let in_addr_any = Darwin.in_addr(s_addr: Darwin.in_addr_t(0)) + static let ipproto_ip = Int32(IPPROTO_IP) + static let ipproto_ipv6 = Int32(IPPROTO_IPV6) + static let ip_pktinfo = Int32(IP_PKTINFO) + static let ipv6_pktinfo = Int32(50) // __APPLE_USE_RFC_2292 static func makeAddressINET(port: UInt16) -> Darwin.sockaddr_in { Darwin.sockaddr_in( @@ -185,6 +193,14 @@ extension Socket { static func sendto(_ fd: FileDescriptorType, _ buffer: UnsafeRawPointer!, _ nbyte: Int, _ flags: Int32, _ destaddr: UnsafePointer!, _ destlen: socklen_t) -> Int { Darwin.sendto(fd, buffer, nbyte, flags, destaddr, destlen) } + + static func recvmsg(_ fd: FileDescriptorType, _ message: UnsafeMutablePointer, _ flags: Int32) -> Int { + Darwin.recvmsg(fd, message, flags) + } + + static func sendmsg(_ fd: FileDescriptorType, _ message: UnsafePointer, _ flags: Int32) -> Int { + Darwin.sendmsg(fd, message, flags) + } } #endif diff --git a/FlyingSocks/Sources/Socket+Glibc.swift b/FlyingSocks/Sources/Socket+Glibc.swift index cc4aeafc..2cec8ca1 100644 --- a/FlyingSocks/Sources/Socket+Glibc.swift +++ b/FlyingSocks/Sources/Socket+Glibc.swift @@ -34,6 +34,10 @@ import Glibc public extension Socket { typealias FileDescriptorType = Int32 + typealias IovLengthType = Int + typealias ControlMessageHeaderLengthType = Int + typealias IPv4InterfaceIndexType = Int32 + typealias IPv6InterfaceIndexType = UInt32 } extension Socket.FileDescriptor { @@ -44,6 +48,10 @@ extension Socket { static let stream = Int32(SOCK_STREAM.rawValue) static let datagram = Int32(SOCK_DGRAM.rawValue) static let in_addr_any = Glibc.in_addr(s_addr: Glibc.in_addr_t(0)) + static let ipproto_ip = Int32(IPPROTO_IP) + static let ipproto_ipv6 = Int32(IPPROTO_IPV6) + static let ip_pktinfo = Int32(IP_PKTINFO) + static let ipv6_pktinfo = Int32(IPV6_PKTINFO) static func makeAddressINET(port: UInt16) -> Glibc.sockaddr_in { Glibc.sockaddr_in( @@ -181,6 +189,19 @@ extension Socket { static func sendto(_ fd: FileDescriptorType, _ buffer: UnsafeRawPointer!, _ nbyte: Int, _ flags: Int32, _ destaddr: UnsafePointer!, _ destlen: socklen_t) -> Int { Glibc.sendto(fd, buffer, nbyte, flags, destaddr, destlen) } + + static func recvmsg(_ fd: FileDescriptorType, _ message: UnsafeMutablePointer, _ flags: Int32) -> Int { + Glibc.recvmsg(fd, message, flags) + } + + static func sendmsg(_ fd: FileDescriptorType, _ message: UnsafePointer, _ flags: Int32) -> Int { + Glibc.sendmsg(fd, message, flags) + } +} + +struct in6_pktinfo { + var ipi6_addr: in6_addr + var ipi6_ifindex: CUnsignedInt } #endif diff --git a/FlyingSocks/Sources/Socket+Musl.swift b/FlyingSocks/Sources/Socket+Musl.swift index 2de01f32..4d823717 100644 --- a/FlyingSocks/Sources/Socket+Musl.swift +++ b/FlyingSocks/Sources/Socket+Musl.swift @@ -34,6 +34,10 @@ import Musl public extension Socket { typealias FileDescriptorType = Int32 + typealias IovLengthType = Int + typealias ControlMessageHeaderLengthType = UInt32 + typealias IPv4InterfaceIndexType = Int32 + typealias IPv6InterfaceIndexType = UInt32 } extension Socket.FileDescriptor { @@ -44,6 +48,10 @@ extension Socket { static let stream = Int32(SOCK_STREAM) static let datagram = Int32(SOCK_DGRAM) static let in_addr_any = Musl.in_addr(s_addr: Musl.in_addr_t(0)) + static let ipproto_ip = Int32(IPPROTO_IP) + static let ipproto_ipv6 = Int32(IPPROTO_IPV6) + static let ip_pktinfo = Int32(IP_PKTINFO) + static let ipv6_pktinfo = Int32(IPV6_PKTINFO) static func makeAddressINET(port: UInt16) -> Musl.sockaddr_in { Musl.sockaddr_in( @@ -181,6 +189,14 @@ extension Socket { static func sendto(_ fd: FileDescriptorType, _ buffer: UnsafeRawPointer!, _ nbyte: Int, _ flags: Int32, _ destaddr: UnsafePointer!, _ destlen: socklen_t) -> Int { Musl.sendto(fd, buffer, nbyte, flags, destaddr, destlen) } + + static func recvmsg(_ fd: FileDescriptorType, _ message: UnsafeMutablePointer, _ flags: Int32) -> Int { + Musl.recvmsg(fd, message, flags) + } + + static func sendmsg(_ fd: FileDescriptorType, _ message: UnsafePointer, _ flags: Int32) -> Int { + Musl.sendmsg(fd, message, flags) + } } #endif diff --git a/FlyingSocks/Sources/Socket+WinSock2.swift b/FlyingSocks/Sources/Socket+WinSock2.swift index 177670d3..d2b87397 100755 --- a/FlyingSocks/Sources/Socket+WinSock2.swift +++ b/FlyingSocks/Sources/Socket+WinSock2.swift @@ -44,6 +44,10 @@ public typealias sa_family_t = UInt8 public extension Socket { typealias FileDescriptorType = UInt64 + typealias IovLengthType = UInt + typealias ControlMessageHeaderLengthType = DWORD + typealias IPv4InterfaceIndexType = ULONG + typealias IPv6InterfaceIndexType = ULONG } extension Socket.FileDescriptor { @@ -54,6 +58,10 @@ extension Socket { static let stream = Int32(SOCK_STREAM) static let datagram = Int32(SOCK_DGRAM) static let in_addr_any = WinSDK.in_addr() + static let ipproto_ip = Int32(IPPROTO_IP) + static let ipproto_ipv6 = Int32(IPPROTO_IPV6) + static let ip_pktinfo = Int32(IP_PKTINFO) + static let ipv6_pktinfo = Int32(IPV6_PKTINFO) static func makeAddressINET(port: UInt16) -> WinSDK.sockaddr_in { WinSDK.sockaddr_in( @@ -193,6 +201,14 @@ extension Socket { static func sendto(_ fd: FileDescriptorType, _ buffer: UnsafeRawPointer!, _ nbyte: Int, _ flags: Int32, _ destaddr: UnsafePointer!, _ destlen: socklen_t) -> Int { WinSDK.sendto(fd, buffer, nbyte, flags, destaddr, destlen) } + + static func recvmsg(_ fd: FileDescriptorType, _ message: UnsafeMutablePointer, _ flags: Int32) -> Int { + WinSDK.recvmsg(fd, message, flags) + } + + static func sendmsg(_ fd: FileDescriptorType, _ message: UnsafePointer, _ flags: Int32) -> Int { + WinSDK.sendmsg(fd, message, flags) + } } #endif diff --git a/FlyingSocks/Sources/Socket.swift b/FlyingSocks/Sources/Socket.swift index e75b2a23..72e0f785 100644 --- a/FlyingSocks/Sources/Socket.swift +++ b/FlyingSocks/Sources/Socket.swift @@ -79,6 +79,9 @@ public struct Socket: Sendable, Hashable { throw SocketError.makeFailed("CreateSocket") } self.file = descriptor + if type == SocketType.datagram.rawValue { + try setPktInfo(domain: domain) + } } public init(domain: Int32, type: SocketType) throws { @@ -101,6 +104,29 @@ public struct Socket: Sendable, Hashable { } } + // enable return of ip_pktinfo/ipv6_pktinfo on recvmsg() + private func setPktInfo(domain: Int32) throws { + var enable = Int32(1) + let level: Int32 + let name: Int32 + + switch domain { + case AF_INET: + level = Socket.ipproto_ip + name = Self.ip_pktinfo + case AF_INET6: + level = Socket.ipproto_ipv6 + name = Self.ipv6_pktinfo + default: + return + } + + let result = Socket.setsockopt(file.rawValue, level, name, &enable, socklen_t(MemoryLayout.size)) + guard result >= 0 else { + throw SocketError.makeFailed("SetPktInfoOption") + } + } + public func setValue(_ value: O.Value, for option: O) throws { var value = option.makeSocketValue(from: value) let result = withUnsafeBytes(of: &value) { @@ -253,7 +279,7 @@ public struct Socket: Sendable, Hashable { } } guard count > 0 else { - if errno == EWOULDBLOCK || errno == EAGAIN { + if errno == EWOULDBLOCK { throw SocketError.blocked } else if errno == EBADF || count == 0 { throw SocketError.disconnected @@ -264,6 +290,73 @@ public struct Socket: Sendable, Hashable { return (addr, count) } +#if !canImport(WinSDK) + public func receive(length: Int) throws -> (sockaddr_storage, [UInt8], UInt32?, sockaddr_storage?) { + var peerAddress: sockaddr_storage? + var interfaceIndex: UInt32? + var localAddress: sockaddr_storage? + + let bytes = try [UInt8](unsafeUninitializedCapacity: length) { buffer, count in + (peerAddress, count, interfaceIndex, localAddress) = try receive(into: buffer.baseAddress!, length: length, flags: 0) + } + + return (peerAddress!, bytes, interfaceIndex, localAddress) + } + + private static let ControlMsgBufferSize = MemoryLayout.size + max(MemoryLayout.size, MemoryLayout.size) + + private func receive( + into buffer: UnsafeMutablePointer, + length: Int, + flags: Int32 + ) throws -> (sockaddr_storage, Int, UInt32?, sockaddr_storage?) { + var iov = iovec() + var msg = msghdr() + var peerAddress = sockaddr_storage() + var localAddress: sockaddr_storage? + var interfaceIndex: UInt32? + var controlMsgBuffer = [UInt8](repeating: 0, count: Socket.ControlMsgBufferSize) + + iov.iov_base = UnsafeMutableRawPointer(buffer) + iov.iov_len = IovLengthType(length) + + let count = withUnsafeMutablePointer(to: &iov) { iov in + msg.msg_iov = iov + msg.msg_iovlen = 1 + msg.msg_namelen = socklen_t(MemoryLayout.size) + + return withUnsafeMutablePointer(to: &peerAddress) { peerAddress in + msg.msg_name = UnsafeMutableRawPointer(peerAddress) + + return controlMsgBuffer.withUnsafeMutableBytes { controlMsgBuffer in + msg.msg_control = UnsafeMutableRawPointer(controlMsgBuffer.baseAddress) + msg.msg_controllen = ControlMessageHeaderLengthType(controlMsgBuffer.count) + + let count = Socket.recvmsg(file.rawValue, &msg, flags) + + if count > 0, msg.msg_controllen != 0 { + (interfaceIndex, localAddress) = Socket.getPacketInfoControl(msghdr: msg) + } + + return count + } + } + } + + guard count > 0 else { + if errno == EWOULDBLOCK { + throw SocketError.blocked + } else if errno == EBADF || count == 0 { + throw SocketError.disconnected + } else { + throw SocketError.makeFailed("RecvMsg") + } + } + + return (peerAddress, count, interfaceIndex, localAddress) + } +#endif + public func write(_ data: Data, from index: Data.Index = 0) throws -> Data.Index { precondition(index >= 0) guard index < data.endIndex else { return data.endIndex } @@ -310,6 +403,75 @@ public struct Socket: Sendable, Hashable { return sent } +#if !canImport(WinSDK) + public func send( + message: [UInt8], + to peerAddress: some SocketAddress, + interfaceIndex: UInt32? = nil, + from localAddress: (some SocketAddress)? = nil + ) throws -> Int { + try message.withUnsafeBytes { buffer in + try send( + buffer.baseAddress!, + length: buffer.count, + flags: 0, + to: peerAddress, + interfaceIndex: interfaceIndex, + from: localAddress + ) + } + } + + private func send( + _ pointer: UnsafeRawPointer, + length: Int, + flags: Int32, + to peerAddress: some SocketAddress, + interfaceIndex: UInt32? = nil, + from localAddress: (some SocketAddress)? = nil + ) throws -> Int { + var iov = iovec() + var msg = msghdr() + let family = peerAddress.family + + iov.iov_base = UnsafeMutableRawPointer(mutating: pointer) + iov.iov_len = IovLengthType(length) + + let sent = withUnsafeMutablePointer(to: &iov) { iov in + var peerAddress = peerAddress + + msg.msg_iov = iov + msg.msg_iovlen = 1 + msg.msg_namelen = peerAddress.size + + return withUnsafeMutablePointer(to: &peerAddress) { peerAddress in + msg.msg_name = UnsafeMutableRawPointer(peerAddress) + + return Socket.withPacketInfoControl( + family: family, + interfaceIndex: interfaceIndex, + address: localAddress) { control, controllen in + if let control { + msg.msg_control = UnsafeMutableRawPointer(mutating: control) + msg.msg_controllen = controllen + } + return Socket.sendmsg(file.rawValue, &msg, flags) + } + } + } + + guard sent >= 0 else { + if errno == EWOULDBLOCK { + throw SocketError.blocked + } else { + throw SocketError.makeFailed("SendMsg") + } + } + + return sent + } +#endif + public func close() throws { if Socket.close(file.rawValue) == -1 { throw SocketError.makeFailed("Close") @@ -422,8 +584,8 @@ public extension SocketOption where Self == Int32SocketOption { package extension Socket { - static func makePair(flags: Flags? = nil) throws -> (Socket, Socket) { - let (file1, file2) = Socket.socketpair(AF_UNIX, Socket.stream, 0) + static func makePair(flags: Flags? = nil, type: SocketType = .stream) throws -> (Socket, Socket) { + let (file1, file2) = Socket.socketpair(AF_UNIX, type.rawValue, 0) guard file1 > -1, file2 > -1 else { throw SocketError.makeFailed("SocketPair") } @@ -441,3 +603,130 @@ package extension Socket { try Socket.makePair(flags: .nonBlocking, type: type) } } + +#if !canImport(WinSDK) +fileprivate extension Socket { + // https://github.com/swiftlang/swift-evolution/blob/main/proposals/0138-unsaferawbufferpointer.md + private static func withControlMessage( + control: UnsafeRawPointer, + controllen: ControlMessageHeaderLengthType, + _ body: (cmsghdr, UnsafeRawBufferPointer) -> () + ) { + let controlBuffer = UnsafeRawBufferPointer(start: control, count: Int(controllen)) + var cmsgHeaderIndex = 0 + + while true { + let cmsgDataIndex = cmsgHeaderIndex + MemoryLayout.stride + + if cmsgDataIndex > controllen { + break + } + + let header = controlBuffer.load(fromByteOffset: cmsgHeaderIndex, as: cmsghdr.self) + if Int(header.cmsg_len) < MemoryLayout.stride { + break + } + + cmsgHeaderIndex = cmsgDataIndex + cmsgHeaderIndex += Int(header.cmsg_len) - MemoryLayout.stride + if cmsgHeaderIndex > controlBuffer.count { + break + } + body(header, UnsafeRawBufferPointer(rebasing: controlBuffer[cmsgDataIndex...alignment - 1 + cmsgHeaderIndex &= ~(MemoryLayout.alignment - 1) + } + } + + static func getPacketInfoControl( + msghdr: msghdr + ) -> (UInt32?, sockaddr_storage?) { + var interfaceIndex: UInt32? + var localAddress = sockaddr_storage() + + withControlMessage(control: msghdr.msg_control, controllen: msghdr.msg_controllen) { cmsghdr, cmsgdata in + switch cmsghdr.cmsg_level { + case Socket.ipproto_ip: + guard cmsghdr.cmsg_type == Socket.ip_pktinfo else { break } + cmsgdata.baseAddress!.withMemoryRebound(to: in_pktinfo.self, capacity: 1) { pktinfo in + var sin = sockaddr_in() + sin.sin_addr = pktinfo.pointee.ipi_addr + interfaceIndex = UInt32(pktinfo.pointee.ipi_ifindex) + localAddress = sin.makeStorage() + } + case Socket.ipproto_ipv6: + guard cmsghdr.cmsg_type == Socket.ipv6_pktinfo else { break } + cmsgdata.baseAddress!.withMemoryRebound(to: in6_pktinfo.self, capacity: 1) { pktinfo in + var sin6 = sockaddr_in6() + sin6.sin6_addr = pktinfo.pointee.ipi6_addr + interfaceIndex = UInt32(pktinfo.pointee.ipi6_ifindex) + localAddress = sin6.makeStorage() + } + default: + break + } + } + + return (interfaceIndex, interfaceIndex != nil ? localAddress : nil) + } + + static func withPacketInfoControl( + family: sa_family_t, + interfaceIndex: UInt32?, + address: (some SocketAddress)?, + _ body: (UnsafePointer?, ControlMessageHeaderLengthType) -> T + ) -> T { + switch Int32(family) { + case AF_INET: + let buffer = ManagedBuffer.create(minimumCapacity: 1) { buffer in + buffer.withUnsafeMutablePointers { header, element in + header.pointee.cmsg_len = ControlMessageHeaderLengthType(MemoryLayout.size + MemoryLayout.size) + header.pointee.cmsg_level = SOL_SOCKET + header.pointee.cmsg_type = Socket.ipproto_ip + element.pointee.ipi_ifindex = IPv4InterfaceIndexType(interfaceIndex ?? 0) + if let address { + var address = address + withUnsafePointer(to: &address) { + $0.withMemoryRebound(to: sockaddr_in.self, capacity: 1) { + element.pointee.ipi_addr = $0.pointee.sin_addr + } + } + } else { + element.pointee.ipi_addr.s_addr = 0 + } + + return header.pointee + } + } + + return buffer.withUnsafeMutablePointerToHeader { body($0, ControlMessageHeaderLengthType($0.pointee.cmsg_len)) } + case AF_INET6: + let buffer = ManagedBuffer.create(minimumCapacity: 1) { buffer in + buffer.withUnsafeMutablePointers { header, element in + header.pointee.cmsg_len = ControlMessageHeaderLengthType(MemoryLayout.size + MemoryLayout.size) + header.pointee.cmsg_level = SOL_SOCKET + header.pointee.cmsg_type = Socket.ipproto_ipv6 + element.pointee.ipi6_ifindex = IPv6InterfaceIndexType(interfaceIndex ?? 0) + if let address { + var address = address + withUnsafePointer(to: &address) { + $0.withMemoryRebound(to: sockaddr_in6.self, capacity: 1) { + element.pointee.ipi6_addr = $0.pointee.sin6_addr + } + } + } else { + element.pointee.ipi6_addr = in6_addr() + } + + return header.pointee + } + } + + return buffer.withUnsafeMutablePointerToHeader { body($0, ControlMessageHeaderLengthType($0.pointee.cmsg_len)) } + default: + return body(nil, 0) + } + } +} +#endif diff --git a/FlyingSocks/Sources/SocketAddress.swift b/FlyingSocks/Sources/SocketAddress.swift index e43ae962..2dafdf04 100644 --- a/FlyingSocks/Sources/SocketAddress.swift +++ b/FlyingSocks/Sources/SocketAddress.swift @@ -67,6 +67,18 @@ extension SocketAddress { 0 } } + + public func makeStorage() -> sockaddr_storage { + var storage = sockaddr_storage() + + withUnsafeMutablePointer(to: &storage) { + $0.withMemoryRebound(to: Self.self, capacity: 1) { + $0.pointee = self + } + } + + return storage + } } public extension SocketAddress where Self == sockaddr_in { diff --git a/FlyingSocks/Tests/AsyncSocketTests.swift b/FlyingSocks/Tests/AsyncSocketTests.swift index 0d5e0ea7..969e1b18 100644 --- a/FlyingSocks/Tests/AsyncSocketTests.swift +++ b/FlyingSocks/Tests/AsyncSocketTests.swift @@ -207,6 +207,26 @@ struct AsyncSocketTests { try s2.close() try? Socket.unlink(addr) } + +#if !canImport(WinSDK) + @Test + func datagramSocketReceivesMessage_WhenAvailable() async throws { + let (s1, s2, addr) = try await AsyncSocket.makeDatagramPair() + + async let d2: AsyncSocket.Message = s2.receive(atMost: 100) +#if canImport(Darwin) + try await s1.write("Swift".data(using: .utf8)!) +#else + try await s1.send(message: "Swift".data(using: .utf8)!, to: addr, from: addr) +#endif + let v2 = try await d2 + #expect(String(data: Data(v2.bytes), encoding: .utf8) == "Swift") + + try s1.close() + try s2.close() + try? Socket.unlink(addr) + } +#endif } extension AsyncSocket { diff --git a/FlyingSocks/Tests/SocketAddressTests.swift b/FlyingSocks/Tests/SocketAddressTests.swift index 43711be5..e45a1c6a 100644 --- a/FlyingSocks/Tests/SocketAddressTests.swift +++ b/FlyingSocks/Tests/SocketAddressTests.swift @@ -323,24 +323,3 @@ struct SocketAddressTests { } } } - - -private extension SocketAddress { - - func makeStorage() -> sockaddr_storage { - var storage = sockaddr_storage() - var addr = self - let addrSize = MemoryLayout.size - let storageSize = MemoryLayout.size - - withUnsafePointer(to: &addr) { addrPtr in - let addrRawPtr = UnsafeRawPointer(addrPtr) - withUnsafeMutablePointer(to: &storage) { storagePtr in - let storageRawPtr = UnsafeMutableRawPointer(storagePtr) - let copySize = min(addrSize, storageSize) - storageRawPtr.copyMemory(from: addrRawPtr, byteCount: copySize) - } - } - return storage - } -}