From e0b311e806cd4c65515aa6f48d91c83f76508d1f Mon Sep 17 00:00:00 2001 From: David Nadoba Date: Sun, 10 Jul 2022 18:31:47 +0200 Subject: [PATCH] Use `swift-atomics` instead of `NIOAtomics` `NIOAtomics` was deprecated in https://github.com/apple/swift-nio/pull/2204 in favor of `swift-atomics` https://github.com/apple/swift-atomics --- Package.swift | 3 ++ .../HTTPConnectionPool+Manager.swift | 7 +++-- .../HTTPClientTestUtils.swift | 25 ++++++++------- .../HTTPClientTests.swift | 31 ++++++++++--------- .../HTTPConnectionPool+StateTestUtils.swift | 5 +-- 5 files changed, 39 insertions(+), 32 deletions(-) diff --git a/Package.swift b/Package.swift index 20484832d..a9ea2ba7f 100644 --- a/Package.swift +++ b/Package.swift @@ -27,6 +27,7 @@ let package = Package( .package(url: "https://github.com/apple/swift-nio-extras.git", from: "1.10.0"), .package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.11.4"), .package(url: "https://github.com/apple/swift-log.git", from: "1.4.0"), + .package(url: "https://github.com/apple/swift-atomics.git", from: "1.0.2"), ], targets: [ .target(name: "CAsyncHTTPClient"), @@ -46,6 +47,7 @@ let package = Package( .product(name: "NIOSOCKS", package: "swift-nio-extras"), .product(name: "NIOTransportServices", package: "swift-nio-transport-services"), .product(name: "Logging", package: "swift-log"), + .product(name: "Atomics", package: "swift-atomics"), ] ), .testTarget( @@ -61,6 +63,7 @@ let package = Package( .product(name: "NIOHTTP2", package: "swift-nio-http2"), .product(name: "NIOSOCKS", package: "swift-nio-extras"), .product(name: "Logging", package: "swift-log"), + .product(name: "Atomics", package: "swift-atomics"), ], resources: [ .copy("Resources/self_signed_cert.pem"), diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Manager.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Manager.swift index 1a1760908..8500c59da 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Manager.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Manager.swift @@ -12,6 +12,7 @@ // //===----------------------------------------------------------------------===// +import Atomics import Logging import NIOConcurrencyHelpers import NIOCore @@ -165,14 +166,14 @@ extension HTTPConnectionPool.Connection.ID { static var globalGenerator = Generator() struct Generator { - private let atomic: NIOAtomic + private let atomic: ManagedAtomic init() { - self.atomic = .makeAtomic(value: 0) + self.atomic = .init(0) } func next() -> Int { - return self.atomic.add(1) + return self.atomic.loadThenWrappingIncrement(ordering: .relaxed) } } } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift b/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift index c99facc3f..3e9cb8ccc 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift @@ -12,6 +12,7 @@ // //===----------------------------------------------------------------------===// +import Atomics import AsyncHTTPClient import Foundation import Logging @@ -351,7 +352,7 @@ internal final class HTTPBin where private let mode: Mode private let sslContext: NIOSSLContext? private var serverChannel: Channel! - private let isShutdown: NIOAtomic = .makeAtomic(value: false) + private let isShutdown = ManagedAtomic(false) private let handlerFactory: (Int) -> (RequestHandler) init( @@ -376,7 +377,7 @@ internal final class HTTPBin where self.activeConnCounterHandler = ConnectionsCountHandler() - let connectionIDAtomic = NIOAtomic.makeAtomic(value: 0) + let connectionIDAtomic = ManagedAtomic(0) self.serverChannel = try! ServerBootstrap(group: self.group) .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) @@ -384,7 +385,7 @@ internal final class HTTPBin where channel.pipeline.addHandler(self.activeConnCounterHandler) }.childChannelInitializer { channel in do { - let connectionID = connectionIDAtomic.add(1) + let connectionID = connectionIDAtomic.loadThenWrappingIncrement(ordering: .relaxed) if case .refuse = mode { throw HTTPBinError.refusedConnection @@ -572,12 +573,12 @@ internal final class HTTPBin where } func shutdown() throws { - self.isShutdown.store(true) + self.isShutdown.store(true, ordering: .relaxed) try self.group.syncShutdownGracefully() } deinit { - assert(self.isShutdown.load(), "HTTPBin not shutdown before deinit") + assert(self.isShutdown.load(ordering: .relaxed), "HTTPBin not shutdown before deinit") } } @@ -946,24 +947,24 @@ internal final class HTTPBinHandler: ChannelInboundHandler { final class ConnectionsCountHandler: ChannelInboundHandler { typealias InboundIn = Channel - private let activeConns = NIOAtomic.makeAtomic(value: 0) - private let createdConns = NIOAtomic.makeAtomic(value: 0) + private let activeConns = ManagedAtomic(0) + private let createdConns = ManagedAtomic(0) var createdConnections: Int { - self.createdConns.load() + self.createdConns.load(ordering: .relaxed) } var currentlyActiveConnections: Int { - self.activeConns.load() + self.activeConns.load(ordering: .relaxed) } func channelRead(context: ChannelHandlerContext, data: NIOAny) { let channel = self.unwrapInboundIn(data) - _ = self.activeConns.add(1) - _ = self.createdConns.add(1) + _ = self.activeConns.loadThenWrappingIncrement(ordering: .relaxed) + _ = self.createdConns.loadThenWrappingIncrement(ordering: .relaxed) channel.closeFuture.whenComplete { _ in - _ = self.activeConns.sub(1) + _ = self.activeConns.loadThenWrappingDecrement(ordering: .relaxed) } context.fireChannelRead(data) diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift index 02c60d177..29ccf3453 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift @@ -12,6 +12,7 @@ // //===----------------------------------------------------------------------===// +import Atomics /* NOT @testable */ import AsyncHTTPClient // Tests that need @testable go into HTTPClientInternalTests.swift #if canImport(Network) import Network @@ -1790,16 +1791,16 @@ class HTTPClientTests: XCTestCase { typealias InboundIn = HTTPServerRequestPart typealias OutboundOut = HTTPServerResponsePart - let requestNumber: NIOAtomic - let connectionNumber: NIOAtomic + let requestNumber: ManagedAtomic + let connectionNumber: ManagedAtomic - init(requestNumber: NIOAtomic, connectionNumber: NIOAtomic) { + init(requestNumber: ManagedAtomic, connectionNumber: ManagedAtomic) { self.requestNumber = requestNumber self.connectionNumber = connectionNumber } func channelActive(context: ChannelHandlerContext) { - _ = self.connectionNumber.add(1) + _ = self.connectionNumber.loadThenWrappingIncrement(ordering: .relaxed) } func channelRead(context: ChannelHandlerContext, data: NIOAny) { @@ -1809,7 +1810,7 @@ class HTTPClientTests: XCTestCase { case .head, .body: () case .end: - let last = self.requestNumber.add(1) + let last = self.requestNumber.loadThenWrappingIncrement(ordering: .relaxed) switch last { case 0, 2: context.write(self.wrapOutboundOut(.head(.init(version: .init(major: 1, minor: 1), status: .ok))), @@ -1824,8 +1825,8 @@ class HTTPClientTests: XCTestCase { } } - let requestNumber = NIOAtomic.makeAtomic(value: 0) - let connectionNumber = NIOAtomic.makeAtomic(value: 0) + let requestNumber = ManagedAtomic(0) + let connectionNumber = ManagedAtomic(0) let sharedStateServerHandler = ServerThatAcceptsThenRejects(requestNumber: requestNumber, connectionNumber: connectionNumber) var maybeServer: Channel? @@ -1854,19 +1855,19 @@ class HTTPClientTests: XCTestCase { XCTAssertNoThrow(try client.syncShutdown()) } - XCTAssertEqual(0, sharedStateServerHandler.connectionNumber.load()) - XCTAssertEqual(0, sharedStateServerHandler.requestNumber.load()) + XCTAssertEqual(0, sharedStateServerHandler.connectionNumber.load(ordering: .relaxed)) + XCTAssertEqual(0, sharedStateServerHandler.requestNumber.load(ordering: .relaxed)) XCTAssertEqual(.ok, try client.get(url: url).wait().status) - XCTAssertEqual(1, sharedStateServerHandler.connectionNumber.load()) - XCTAssertEqual(1, sharedStateServerHandler.requestNumber.load()) + XCTAssertEqual(1, sharedStateServerHandler.connectionNumber.load(ordering: .relaxed)) + XCTAssertEqual(1, sharedStateServerHandler.requestNumber.load(ordering: .relaxed)) XCTAssertThrowsError(try client.get(url: url).wait().status) { error in XCTAssertEqual(.remoteConnectionClosed, error as? HTTPClientError) } - XCTAssertEqual(1, sharedStateServerHandler.connectionNumber.load()) - XCTAssertEqual(2, sharedStateServerHandler.requestNumber.load()) + XCTAssertEqual(1, sharedStateServerHandler.connectionNumber.load(ordering: .relaxed)) + XCTAssertEqual(2, sharedStateServerHandler.requestNumber.load(ordering: .relaxed)) XCTAssertEqual(.ok, try client.get(url: url).wait().status) - XCTAssertEqual(2, sharedStateServerHandler.connectionNumber.load()) - XCTAssertEqual(3, sharedStateServerHandler.requestNumber.load()) + XCTAssertEqual(2, sharedStateServerHandler.connectionNumber.load(ordering: .relaxed)) + XCTAssertEqual(3, sharedStateServerHandler.requestNumber.load(ordering: .relaxed)) } func testPoolClosesIdleConnections() { diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+StateTestUtils.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+StateTestUtils.swift index 0ffdeebd8..5f6208a81 100644 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+StateTestUtils.swift +++ b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+StateTestUtils.swift @@ -12,6 +12,7 @@ // //===----------------------------------------------------------------------===// +import Atomics @testable import AsyncHTTPClient import Dispatch import NIOConcurrencyHelpers @@ -21,14 +22,14 @@ import NIOEmbedded /// An `EventLoopGroup` of `EmbeddedEventLoop`s. final class EmbeddedEventLoopGroup: EventLoopGroup { private let loops: [EmbeddedEventLoop] - private let index = NIOAtomic.makeAtomic(value: 0) + private let index = ManagedAtomic(0) internal init(loops: Int) { self.loops = (0.. EventLoop { - let index: Int = self.index.add(1) + let index: Int = self.index.loadThenWrappingIncrement(ordering: .relaxed) return self.loops[index % self.loops.count] }