diff --git a/Sources/GRPC/AsyncAwaitSupport/GRPCAsyncServerHandler.swift b/Sources/GRPC/AsyncAwaitSupport/GRPCAsyncServerHandler.swift index d7ba2b552..fb3bf311a 100644 --- a/Sources/GRPC/AsyncAwaitSupport/GRPCAsyncServerHandler.swift +++ b/Sources/GRPC/AsyncAwaitSupport/GRPCAsyncServerHandler.swift @@ -283,6 +283,7 @@ internal final class AsyncServerHandler< callType: callType, remoteAddress: context.remoteAddress, userInfoRef: self.userInfoRef, + closeFuture: context.closeFuture, interceptors: interceptors, onRequestPart: self.receiveInterceptedPart(_:), onResponsePart: self.sendInterceptedPart(_:promise:) diff --git a/Sources/GRPC/CallHandlers/BidirectionalStreamingServerHandler.swift b/Sources/GRPC/CallHandlers/BidirectionalStreamingServerHandler.swift index 11081d7da..9fd5a434f 100644 --- a/Sources/GRPC/CallHandlers/BidirectionalStreamingServerHandler.swift +++ b/Sources/GRPC/CallHandlers/BidirectionalStreamingServerHandler.swift @@ -92,6 +92,7 @@ public final class BidirectionalStreamingServerHandler< callType: .bidirectionalStreaming, remoteAddress: context.remoteAddress, userInfoRef: userInfoRef, + closeFuture: context.closeFuture, interceptors: interceptors, onRequestPart: self.receiveInterceptedPart(_:), onResponsePart: self.sendInterceptedPart(_:promise:) diff --git a/Sources/GRPC/CallHandlers/ClientStreamingServerHandler.swift b/Sources/GRPC/CallHandlers/ClientStreamingServerHandler.swift index f82c8e136..e7929aa2f 100644 --- a/Sources/GRPC/CallHandlers/ClientStreamingServerHandler.swift +++ b/Sources/GRPC/CallHandlers/ClientStreamingServerHandler.swift @@ -93,6 +93,7 @@ public final class ClientStreamingServerHandler< callType: .clientStreaming, remoteAddress: context.remoteAddress, userInfoRef: userInfoRef, + closeFuture: context.closeFuture, interceptors: interceptors, onRequestPart: self.receiveInterceptedPart(_:), onResponsePart: self.sendInterceptedPart(_:promise:) diff --git a/Sources/GRPC/CallHandlers/ServerStreamingServerHandler.swift b/Sources/GRPC/CallHandlers/ServerStreamingServerHandler.swift index 8526e6388..6cb3b4bbf 100644 --- a/Sources/GRPC/CallHandlers/ServerStreamingServerHandler.swift +++ b/Sources/GRPC/CallHandlers/ServerStreamingServerHandler.swift @@ -89,6 +89,7 @@ public final class ServerStreamingServerHandler< callType: .serverStreaming, remoteAddress: context.remoteAddress, userInfoRef: userInfoRef, + closeFuture: context.closeFuture, interceptors: interceptors, onRequestPart: self.receiveInterceptedPart(_:), onResponsePart: self.sendInterceptedPart(_:promise:) diff --git a/Sources/GRPC/CallHandlers/UnaryServerHandler.swift b/Sources/GRPC/CallHandlers/UnaryServerHandler.swift index 6da422eec..f90a2545f 100644 --- a/Sources/GRPC/CallHandlers/UnaryServerHandler.swift +++ b/Sources/GRPC/CallHandlers/UnaryServerHandler.swift @@ -87,6 +87,7 @@ public final class UnaryServerHandler< callType: .unary, remoteAddress: context.remoteAddress, userInfoRef: userInfoRef, + closeFuture: context.closeFuture, interceptors: interceptors, onRequestPart: self.receiveInterceptedPart(_:), onResponsePart: self.sendInterceptedPart(_:promise:) diff --git a/Sources/GRPC/Interceptor/ServerInterceptorContext.swift b/Sources/GRPC/Interceptor/ServerInterceptorContext.swift index 632bee450..5543b417f 100644 --- a/Sources/GRPC/Interceptor/ServerInterceptorContext.swift +++ b/Sources/GRPC/Interceptor/ServerInterceptorContext.swift @@ -54,6 +54,12 @@ public struct ServerInterceptorContext { return self._pipeline.remoteAddress } + /// A future which completes when the call closes. This may be used to register callbacks which + /// free up resources used by the interceptor. + public var closeFuture: EventLoopFuture { + return self._pipeline.closeFuture + } + /// A 'UserInfo' dictionary. /// /// - Important: While `UserInfo` has value-semantics, this property retrieves from, and sets a diff --git a/Sources/GRPC/Interceptor/ServerInterceptorPipeline.swift b/Sources/GRPC/Interceptor/ServerInterceptorPipeline.swift index 82acb3bf5..0000c3fc7 100644 --- a/Sources/GRPC/Interceptor/ServerInterceptorPipeline.swift +++ b/Sources/GRPC/Interceptor/ServerInterceptorPipeline.swift @@ -42,6 +42,11 @@ internal final class ServerInterceptorPipeline { @usableFromInline internal let userInfoRef: Ref + /// A future which completes when the call closes. This may be used to register callbacks which + /// free up resources used by the interceptor. + @usableFromInline + internal let closeFuture: EventLoopFuture + /// Called when a response part has traversed the interceptor pipeline. @usableFromInline internal let _onResponsePart: (GRPCServerResponsePart, EventLoopPromise?) -> Void @@ -99,6 +104,7 @@ internal final class ServerInterceptorPipeline { callType: GRPCCallType, remoteAddress: SocketAddress?, userInfoRef: Ref, + closeFuture: EventLoopFuture, interceptors: [ServerInterceptor], onRequestPart: @escaping (GRPCServerRequestPart) -> Void, onResponsePart: @escaping (GRPCServerResponsePart, EventLoopPromise?) -> Void @@ -109,6 +115,7 @@ internal final class ServerInterceptorPipeline { self.type = callType self.remoteAddress = remoteAddress self.userInfoRef = userInfoRef + self.closeFuture = closeFuture self._onResponsePart = onResponsePart self._onRequestPart = onRequestPart diff --git a/Tests/GRPCTests/InterceptorsTests.swift b/Tests/GRPCTests/InterceptorsTests.swift index d1241e2f6..1d8176abb 100644 --- a/Tests/GRPCTests/InterceptorsTests.swift +++ b/Tests/GRPCTests/InterceptorsTests.swift @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +import Atomics import EchoImplementation import EchoModel import GRPC @@ -28,6 +29,7 @@ class InterceptorsTests: GRPCTestCase { private var server: Server! private var connection: ClientConnection! private var echo: Echo_EchoNIOClient! + private let onCloseCounter = ManagedAtomic(0) override func setUp() { super.setUp() @@ -35,7 +37,7 @@ class InterceptorsTests: GRPCTestCase { self.server = try! Server.insecure(group: self.group) .withServiceProviders([ - EchoProvider(), + EchoProvider(interceptors: CountOnCloseInterceptors(counter: self.onCloseCounter)), HelloWorldProvider(interceptors: HelloWorldServerInterceptorFactory()), ]) .withLogger(self.serverLogger) @@ -64,6 +66,8 @@ class InterceptorsTests: GRPCTestCase { let get = self.echo.get(.with { $0.text = "hello" }) assertThat(try get.response.wait(), .is(.with { $0.text = "hello :teg ohce tfiwS" })) assertThat(try get.status.wait(), .hasCode(.ok)) + + XCTAssertEqual(self.onCloseCounter.load(ordering: .sequentiallyConsistent), 1) } func testCollect() { @@ -73,6 +77,8 @@ class InterceptorsTests: GRPCTestCase { collect.sendEnd(promise: nil) assertThat(try collect.response.wait(), .is(.with { $0.text = "3 4 1 2 :tcelloc ohce tfiwS" })) assertThat(try collect.status.wait(), .hasCode(.ok)) + + XCTAssertEqual(self.onCloseCounter.load(ordering: .sequentiallyConsistent), 1) } func testExpand() { @@ -81,6 +87,8 @@ class InterceptorsTests: GRPCTestCase { assertThat(response, .is(.with { $0.text = "hello :)0( dnapxe ohce tfiwS" })) } assertThat(try expand.status.wait(), .hasCode(.ok)) + + XCTAssertEqual(self.onCloseCounter.load(ordering: .sequentiallyConsistent), 1) } func testUpdate() { @@ -91,6 +99,8 @@ class InterceptorsTests: GRPCTestCase { update.sendMessage(.with { $0.text = "hello" }, promise: nil) update.sendEnd(promise: nil) assertThat(try update.status.wait(), .hasCode(.ok)) + + XCTAssertEqual(self.onCloseCounter.load(ordering: .sequentiallyConsistent), 1) } func testSayHello() { @@ -360,6 +370,54 @@ final class ReversingInterceptors: Echo_EchoClientInterceptorFactoryProtocol { } } +final class CountOnCloseInterceptors: Echo_EchoServerInterceptorFactoryProtocol { + // This interceptor is stateless, let's just share it. + private let interceptors: [ServerInterceptor] + + init(counter: ManagedAtomic) { + self.interceptors = [CountOnCloseServerInterceptor(counter: counter)] + } + + func makeGetInterceptors() -> [ServerInterceptor] { + return self.interceptors + } + + func makeExpandInterceptors() -> [ServerInterceptor] { + return self.interceptors + } + + func makeCollectInterceptors() -> [ServerInterceptor] { + return self.interceptors + } + + func makeUpdateInterceptors() -> [ServerInterceptor] { + return self.interceptors + } +} + +final class CountOnCloseServerInterceptor: ServerInterceptor { + private let counter: ManagedAtomic + + init(counter: ManagedAtomic) { + self.counter = counter + } + + override func receive( + _ part: GRPCServerRequestPart, + context: ServerInterceptorContext + ) { + switch part { + case .metadata: + context.closeFuture.whenComplete { _ in + self.counter.wrappingIncrement(ordering: .sequentiallyConsistent) + } + default: + () + } + context.receive(part) + } +} + private enum MagicKey: UserInfo.Key { typealias Value = String } diff --git a/Tests/GRPCTests/ServerInterceptorPipelineTests.swift b/Tests/GRPCTests/ServerInterceptorPipelineTests.swift index 0e0b75f4a..9143e0339 100644 --- a/Tests/GRPCTests/ServerInterceptorPipelineTests.swift +++ b/Tests/GRPCTests/ServerInterceptorPipelineTests.swift @@ -43,6 +43,7 @@ class ServerInterceptorPipelineTests: GRPCTestCase { callType: callType, remoteAddress: nil, userInfoRef: Ref(UserInfo()), + closeFuture: self.embeddedEventLoop.makeSucceededVoidFuture(), interceptors: interceptors, onRequestPart: onRequestPart, onResponsePart: onResponsePart