Skip to content

Commit

Permalink
datagram tests
Browse files Browse the repository at this point in the history
  • Loading branch information
swhitty committed Nov 23, 2024
1 parent 5653d61 commit dc1cc41
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 29 deletions.
54 changes: 29 additions & 25 deletions FlyingSocks/Sources/AsyncSocket.swift
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@
// SOFTWARE.
//

#if canImport(FoundationEssentials)
import FoundationEssentials
#else
import Foundation
#endif

public protocol AsyncSocketPool: Sendable {

Expand Down Expand Up @@ -63,19 +67,35 @@ public extension AsyncSocketPool where Self == SocketPool<Poll> {
public struct AsyncSocket: Sendable {

public struct Message: Sendable {
public let peerAddress: any SocketAddress
public let bytes: [UInt8]
public let interfaceIndex: UInt32?
public let localAddress: (any SocketAddress)?
public var peerAddress: any SocketAddress
public var payload: Data
public var interfaceIndex: UInt32?
public var localAddress: (any SocketAddress)?

public init(
peerAddress: any SocketAddress,
payload: Data,
interfaceIndex: UInt32? = nil,
localAddress: (any SocketAddress)? = nil
) {
self.peerAddress = peerAddress
self.payload = payload
self.interfaceIndex = interfaceIndex
self.localAddress = localAddress
}

@available(*, deprecated, renamed: "payload")
public var bytes: [UInt8] { Array(payload) }

@available(*, deprecated, renamed: "init(peerAddress:payload:)")
public init(
peerAddress: any SocketAddress,
bytes: [UInt8],
interfaceIndex: UInt32? = nil,
localAddress: (any SocketAddress)? = nil
) {
self.peerAddress = peerAddress
self.bytes = bytes
self.payload = Data(bytes)
self.interfaceIndex = interfaceIndex
self.localAddress = localAddress
}
Expand Down Expand Up @@ -228,11 +248,12 @@ public struct AsyncSocket: Sendable {

#if !canImport(WinSDK)
public func send(
message: [UInt8],
message: some Sequence<UInt8>,
to peerAddress: some SocketAddress,
interfaceIndex: UInt32? = nil,
from localAddress: (some SocketAddress)? = nil
from localAddress: (any SocketAddress)? = nil
) async throws {
let message = Array(message)
let sent = try await pool.loopUntilReady(for: .write, on: socket) {
try socket.send(message: message, to: peerAddress, interfaceIndex: interfaceIndex, from: localAddress)
}
Expand All @@ -241,29 +262,12 @@ public struct AsyncSocket: Sendable {
}
}

public func send(
message: Data,
to peerAddress: some SocketAddress,
interfaceIndex: UInt32? = nil,
from localAddress: (some SocketAddress)? = nil
) async throws {
try await send(message: Array(message), to: peerAddress, interfaceIndex: interfaceIndex, from: localAddress)
}

public func send(message: Message) async throws {
let localAddress: AnySocketAddress?

if let unwrappedLocalAddress = message.localAddress {
localAddress = AnySocketAddress(unwrappedLocalAddress)
} else {
localAddress = nil
}

try await send(
message: message.bytes,
to: AnySocketAddress(message.peerAddress),
interfaceIndex: message.interfaceIndex,
from: localAddress
from: message.localAddress.map { AnySocketAddress($0) }
)
}
#endif
Expand Down
6 changes: 3 additions & 3 deletions FlyingSocks/Sources/Socket.swift
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ public struct Socket: Sendable, Hashable {
message: [UInt8],
to peerAddress: some SocketAddress,
interfaceIndex: UInt32? = nil,
from localAddress: (some SocketAddress)? = nil
from localAddress: (any SocketAddress)? = nil
) throws -> Int {
try message.withUnsafeBytes { buffer in
try send(
Expand All @@ -417,7 +417,7 @@ public struct Socket: Sendable, Hashable {
flags: Int32,
to peerAddress: some SocketAddress,
interfaceIndex: UInt32? = nil,
from localAddress: (some SocketAddress)? = nil
from localAddress: (any SocketAddress)? = nil
) throws -> Int {
var iov = iovec()
var msg = msghdr()
Expand Down Expand Up @@ -685,7 +685,7 @@ fileprivate extension Socket {
static func withPacketInfoControl<T>(
family: sa_family_t,
interfaceIndex: UInt32?,
address: (some SocketAddress)?,
address: (any SocketAddress)?,
_ body: (UnsafePointer<cmsghdr>?, ControlMessageHeaderLengthType) -> T
) -> T {
switch Int32(family) {
Expand Down
25 changes: 24 additions & 1 deletion FlyingSocks/Tests/AsyncSocketTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ struct AsyncSocketTests {
#endif

@Test
func messageSequence_receives_messages() async throws {
func messageSequence_sendsData_receivesMessage() async throws {
let (socket, port) = try await AsyncSocket.makeLoopbackDatagram()
var messages = socket.messages

Expand All @@ -267,6 +267,22 @@ struct AsyncSocketTests {
try await received?.payloadString == "Fish 🐡"
)
}

@Test
func messageSequence_sendsMessage_receivesMessage() async throws {
let (socket, port) = try await AsyncSocket.makeLoopbackDatagram()
var messages = socket.messages

async let received = messages.next()

let client = try await AsyncSocket.makeLoopbackDatagram().0
let message = AsyncSocket.Message(peerAddress: .loopback(port: port), payload: "Chips 🍟")
try await client.send(message: message)

#expect(
try await received?.payloadString == "Chips 🍟"
)
}
}

extension AsyncSocket {
Expand Down Expand Up @@ -347,6 +363,13 @@ private extension AsyncSocket.Message {
return text
}
}

init(peerAddress: some SocketAddress, payload: String) {
self.init(
peerAddress: peerAddress,
bytes: Array(payload.data(using: .utf8)!)
)
}
}

struct DisconnectedPool: AsyncSocketPool {
Expand Down

0 comments on commit dc1cc41

Please sign in to comment.