From 4eae52837ee9fe47b91c718f3a2a37beb9b3f119 Mon Sep 17 00:00:00 2001 From: George Barnett Date: Tue, 8 Oct 2024 14:18:11 +0100 Subject: [PATCH 1/2] Add support for RPC cancellation Motivation: grpc-swift has support for RPC cancellation via a cancellation handler. It should be supported here as well. Modifications: - Modify the server stream handler so that it holds the cancellation handle and triggers cancellation appropriately. - Modify the server stream handling to set a cancellation handle on the appropriate stream handler. - Update tests Result: HTTP/2 server transport respects stream cancellation --- Package.swift | 4 +- .../Internal/NIOChannelPipeline+GRPC.swift | 3 +- .../Server/CommonHTTP2ServerTransport.swift | 36 ++- .../Server/GRPCServerStreamHandler.swift | 49 ++++ .../Connection/Utilities/ConnectionTest.swift | 3 +- .../Connection/Utilities/TestServer.swift | 3 +- .../Server/GRPCServerStreamHandlerTests.swift | 242 +++++++++--------- .../ControlClient.swift | 19 +- .../ControlMessages.swift | 8 + .../ControlService.swift | 35 +++ .../HTTP2TransportTests.swift | 102 +++++--- 11 files changed, 327 insertions(+), 177 deletions(-) diff --git a/Package.swift b/Package.swift index f5817fb..145e2c0 100644 --- a/Package.swift +++ b/Package.swift @@ -35,7 +35,7 @@ let products: [Product] = [ let dependencies: [Package.Dependency] = [ .package( url: "https://github.com/grpc/grpc-swift.git", - exact: "2.0.0-alpha.1" + branch: "main" ), .package( url: "https://github.com/apple/swift-nio.git", @@ -43,7 +43,7 @@ let dependencies: [Package.Dependency] = [ ), .package( url: "https://github.com/apple/swift-nio-http2.git", - from: "1.32.0" + from: "1.34.1" ), .package( url: "https://github.com/apple/swift-nio-transport-services.git", diff --git a/Sources/GRPCNIOTransportCore/Internal/NIOChannelPipeline+GRPC.swift b/Sources/GRPCNIOTransportCore/Internal/NIOChannelPipeline+GRPC.swift index ba993be..1254147 100644 --- a/Sources/GRPCNIOTransportCore/Internal/NIOChannelPipeline+GRPC.swift +++ b/Sources/GRPCNIOTransportCore/Internal/NIOChannelPipeline+GRPC.swift @@ -85,7 +85,8 @@ extension ChannelPipeline.SynchronousOperations { scheme: scheme, acceptedEncodings: compressionConfig.enabledAlgorithms, maxPayloadSize: rpcConfig.maxRequestPayloadSize, - methodDescriptorPromise: methodDescriptorPromise + methodDescriptorPromise: methodDescriptorPromise, + eventLoop: streamChannel.eventLoop ) try streamChannel.pipeline.syncOperations.addHandler(streamHandler) diff --git a/Sources/GRPCNIOTransportCore/Server/CommonHTTP2ServerTransport.swift b/Sources/GRPCNIOTransportCore/Server/CommonHTTP2ServerTransport.swift index 762c3e5..b9e9a51 100644 --- a/Sources/GRPCNIOTransportCore/Server/CommonHTTP2ServerTransport.swift +++ b/Sources/GRPCNIOTransportCore/Server/CommonHTTP2ServerTransport.swift @@ -241,19 +241,35 @@ package final class CommonHTTP2ServerTransport< return } - let rpcStream = RPCStream( - descriptor: descriptor, - inbound: RPCAsyncSequence(wrapping: inbound), - outbound: RPCWriter.Closable( - wrapping: ServerConnection.Stream.Outbound( - responseWriter: outbound, - http2Stream: stream + await withServerContextRPCCancellationHandle { handle in + stream.channel.eventLoop.execute { + // Sync is safe: this is on the right event loop. + let sync = stream.channel.pipeline.syncOperations + + // Looking up the handler can fail if the channel is already closed, in which case + // cancel the handle directly. + do { + let handler = try sync.handler(type: GRPCServerStreamHandler.self) + handler.setCancellationHandle(handle) + } catch { + handle.cancel() + } + } + + let rpcStream = RPCStream( + descriptor: descriptor, + inbound: RPCAsyncSequence(wrapping: inbound), + outbound: RPCWriter.Closable( + wrapping: ServerConnection.Stream.Outbound( + responseWriter: outbound, + http2Stream: stream + ) ) ) - ) - let context = ServerContext(descriptor: descriptor) - await streamHandler(rpcStream, context) + let context = ServerContext(descriptor: descriptor, cancellation: handle) + await streamHandler(rpcStream, context) + } } } diff --git a/Sources/GRPCNIOTransportCore/Server/GRPCServerStreamHandler.swift b/Sources/GRPCNIOTransportCore/Server/GRPCServerStreamHandler.swift index 78d6144..22c7eef 100644 --- a/Sources/GRPCNIOTransportCore/Server/GRPCServerStreamHandler.swift +++ b/Sources/GRPCNIOTransportCore/Server/GRPCServerStreamHandler.swift @@ -26,9 +26,11 @@ package final class GRPCServerStreamHandler: ChannelDuplexHandler, RemovableChan package typealias OutboundOut = HTTP2Frame.FramePayload private var stateMachine: GRPCStreamStateMachine + private let eventLoop: any EventLoop private var isReading = false private var flushPending = false + private var isCancelled = false // We buffer the final status + trailers to avoid reordering issues (i.e., // if there are messages still not written into the channel because flush has @@ -38,6 +40,8 @@ package final class GRPCServerStreamHandler: ChannelDuplexHandler, RemovableChan private let methodDescriptorPromise: EventLoopPromise + private var cancellationHandle: Optional + // Existential errors unconditionally allocate, avoid this per-use allocation by doing it // statically. private static let handlerRemovedBeforeDescriptorResolved: any Error = RPCError( @@ -50,6 +54,8 @@ package final class GRPCServerStreamHandler: ChannelDuplexHandler, RemovableChan acceptedEncodings: CompressionAlgorithmSet, maxPayloadSize: Int, methodDescriptorPromise: EventLoopPromise, + eventLoop: any EventLoop, + cancellationHandler: ServerContext.RPCCancellationHandle? = nil, skipStateMachineAssertions: Bool = false ) { self.stateMachine = .init( @@ -58,12 +64,54 @@ package final class GRPCServerStreamHandler: ChannelDuplexHandler, RemovableChan skipAssertions: skipStateMachineAssertions ) self.methodDescriptorPromise = methodDescriptorPromise + self.cancellationHandle = cancellationHandler + self.eventLoop = eventLoop + } + + package func setCancellationHandle(_ handle: ServerContext.RPCCancellationHandle) { + if self.eventLoop.inEventLoop { + self.syncSetCancellationHandle(handle) + } else { + let loopBoundSelf = NIOLoopBound(self, eventLoop: self.eventLoop) + self.eventLoop.execute { + loopBoundSelf.value.syncSetCancellationHandle(handle) + } + } + } + + private func syncSetCancellationHandle(_ handle: ServerContext.RPCCancellationHandle) { + assert(self.cancellationHandle == nil, "\(#function) must only be called once") + + if self.isCancelled { + handle.cancel() + } else { + self.cancellationHandle = handle + } + } + + private func cancelRPC() { + if let handle = self.cancellationHandle.take() { + handle.cancel() + } else { + self.isCancelled = true + } } } // - MARK: ChannelInboundHandler extension GRPCServerStreamHandler { + package func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) { + switch event { + case is ChannelShouldQuiesceEvent: + self.cancelRPC() + default: + () + } + + context.fireUserInboundEventTriggered(event) + } + package func channelRead(context: ChannelHandlerContext, data: NIOAny) { self.isReading = true let frame = self.unwrapInboundIn(data) @@ -186,6 +234,7 @@ extension GRPCServerStreamHandler { ) { switch self.stateMachine.unexpectedInboundClose(reason: reason) { case .fireError_serverOnly(let wrappedError): + self.cancelRPC() context.fireErrorCaught(wrappedError) case .doNothing: () diff --git a/Tests/GRPCNIOTransportCoreTests/Client/Connection/Utilities/ConnectionTest.swift b/Tests/GRPCNIOTransportCoreTests/Client/Connection/Utilities/ConnectionTest.swift index 1cf707a..6653f06 100644 --- a/Tests/GRPCNIOTransportCoreTests/Client/Connection/Utilities/ConnectionTest.swift +++ b/Tests/GRPCNIOTransportCoreTests/Client/Connection/Utilities/ConnectionTest.swift @@ -117,7 +117,8 @@ extension ConnectionTest { scheme: .http, acceptedEncodings: .none, maxPayloadSize: .max, - methodDescriptorPromise: channel.eventLoop.makePromise(of: MethodDescriptor.self) + methodDescriptorPromise: channel.eventLoop.makePromise(of: MethodDescriptor.self), + eventLoop: stream.eventLoop ) return stream.eventLoop.makeCompletedFuture { diff --git a/Tests/GRPCNIOTransportCoreTests/Client/Connection/Utilities/TestServer.swift b/Tests/GRPCNIOTransportCoreTests/Client/Connection/Utilities/TestServer.swift index f5b04af..01ec6ad 100644 --- a/Tests/GRPCNIOTransportCoreTests/Client/Connection/Utilities/TestServer.swift +++ b/Tests/GRPCNIOTransportCoreTests/Client/Connection/Utilities/TestServer.swift @@ -74,7 +74,8 @@ final class TestServer: Sendable { scheme: .http, acceptedEncodings: .all, maxPayloadSize: .max, - methodDescriptorPromise: channel.eventLoop.makePromise(of: MethodDescriptor.self) + methodDescriptorPromise: channel.eventLoop.makePromise(of: MethodDescriptor.self), + eventLoop: stream.eventLoop ) try stream.pipeline.syncOperations.addHandlers(handler) diff --git a/Tests/GRPCNIOTransportCoreTests/Server/GRPCServerStreamHandlerTests.swift b/Tests/GRPCNIOTransportCoreTests/Server/GRPCServerStreamHandlerTests.swift index 2b805d3..f6828d7 100644 --- a/Tests/GRPCNIOTransportCoreTests/Server/GRPCServerStreamHandlerTests.swift +++ b/Tests/GRPCNIOTransportCoreTests/Server/GRPCServerStreamHandlerTests.swift @@ -19,19 +19,33 @@ import NIOCore import NIOEmbedded import NIOHPACK import NIOHTTP2 +import Testing import XCTest @testable import GRPCNIOTransportCore final class GRPCServerStreamHandlerTests: XCTestCase { + private func makeServerStreamHandler( + channel: any Channel, + scheme: Scheme = .http, + acceptedEncodings: CompressionAlgorithmSet = [], + maxPayloadSize: Int = .max, + descriptorPromise: EventLoopPromise? = nil, + disableAssertions: Bool = false + ) -> GRPCServerStreamHandler { + return GRPCServerStreamHandler( + scheme: scheme, + acceptedEncodings: acceptedEncodings, + maxPayloadSize: maxPayloadSize, + methodDescriptorPromise: descriptorPromise ?? channel.eventLoop.makePromise(), + eventLoop: channel.eventLoop, + skipStateMachineAssertions: disableAssertions + ) + } + func testH2FramesAreIgnored() throws { let channel = EmbeddedChannel() - let handler = GRPCServerStreamHandler( - scheme: .http, - acceptedEncodings: [], - maxPayloadSize: 1, - methodDescriptorPromise: channel.eventLoop.makePromise(of: MethodDescriptor.self) - ) + let handler = self.makeServerStreamHandler(channel: channel) try channel.pipeline.syncOperations.addHandler(handler) let framesToBeIgnored: [HTTP2Frame.FramePayload] = [ @@ -56,12 +70,7 @@ final class GRPCServerStreamHandlerTests: XCTestCase { func testClientInitialMetadataWithoutContentTypeResultsInRejectedRPC() throws { let channel = EmbeddedChannel() - let handler = GRPCServerStreamHandler( - scheme: .http, - acceptedEncodings: [], - maxPayloadSize: 1, - methodDescriptorPromise: channel.eventLoop.makePromise(of: MethodDescriptor.self) - ) + let handler = self.makeServerStreamHandler(channel: channel) try channel.pipeline.syncOperations.addHandler(handler) // Receive client's initial metadata without content-type @@ -86,12 +95,7 @@ final class GRPCServerStreamHandlerTests: XCTestCase { func testClientInitialMetadataWithoutMethodResultsInRejectedRPC() throws { let channel = EmbeddedChannel() - let handler = GRPCServerStreamHandler( - scheme: .http, - acceptedEncodings: [], - maxPayloadSize: 1, - methodDescriptorPromise: channel.eventLoop.makePromise(of: MethodDescriptor.self) - ) + let handler = self.makeServerStreamHandler(channel: channel) try channel.pipeline.syncOperations.addHandler(handler) // Receive client's initial metadata without :method @@ -125,12 +129,7 @@ final class GRPCServerStreamHandlerTests: XCTestCase { func testClientInitialMetadataWithoutSchemeResultsInRejectedRPC() throws { let channel = EmbeddedChannel() - let handler = GRPCServerStreamHandler( - scheme: .http, - acceptedEncodings: [], - maxPayloadSize: 1, - methodDescriptorPromise: channel.eventLoop.makePromise(of: MethodDescriptor.self) - ) + let handler = self.makeServerStreamHandler(channel: channel) try channel.pipeline.syncOperations.addHandler(handler) // Receive client's initial metadata without :scheme @@ -164,12 +163,7 @@ final class GRPCServerStreamHandlerTests: XCTestCase { func testClientInitialMetadataWithoutPathResultsInRejectedRPC() throws { let channel = EmbeddedChannel() - let handler = GRPCServerStreamHandler( - scheme: .http, - acceptedEncodings: [], - maxPayloadSize: 1, - methodDescriptorPromise: channel.eventLoop.makePromise(of: MethodDescriptor.self) - ) + let handler = self.makeServerStreamHandler(channel: channel) try channel.pipeline.syncOperations.addHandler(handler) // Receive client's initial metadata without :path @@ -202,12 +196,7 @@ final class GRPCServerStreamHandlerTests: XCTestCase { func testNotAcceptedEncodingResultsInRejectedRPC() throws { let channel = EmbeddedChannel() - let handler = GRPCServerStreamHandler( - scheme: .http, - acceptedEncodings: [], - maxPayloadSize: 100, - methodDescriptorPromise: channel.eventLoop.makePromise(of: MethodDescriptor.self) - ) + let handler = self.makeServerStreamHandler(channel: channel) try channel.pipeline.syncOperations.addHandler(handler) // Receive client's initial metadata @@ -244,12 +233,7 @@ final class GRPCServerStreamHandlerTests: XCTestCase { func testOverMaximumPayloadSize() throws { let channel = EmbeddedChannel() - let handler = GRPCServerStreamHandler( - scheme: .http, - acceptedEncodings: [], - maxPayloadSize: 1, - methodDescriptorPromise: channel.eventLoop.makePromise(of: MethodDescriptor.self) - ) + let handler = self.makeServerStreamHandler(channel: channel, maxPayloadSize: 1) try channel.pipeline.syncOperations.addHandler(handler) // Receive client's initial metadata @@ -313,13 +297,7 @@ final class GRPCServerStreamHandlerTests: XCTestCase { func testClientEndsStream() throws { let channel = EmbeddedChannel() - let handler = GRPCServerStreamHandler( - scheme: .http, - acceptedEncodings: [], - maxPayloadSize: 1, - methodDescriptorPromise: channel.eventLoop.makePromise(of: MethodDescriptor.self), - skipStateMachineAssertions: true - ) + let handler = self.makeServerStreamHandler(channel: channel, disableAssertions: true) try channel.pipeline.syncOperations.addHandler(handler) // Receive client's initial metadata with end stream set @@ -379,13 +357,7 @@ final class GRPCServerStreamHandlerTests: XCTestCase { func testNormalFlow() throws { let channel = EmbeddedChannel() - let handler = GRPCServerStreamHandler( - scheme: .http, - acceptedEncodings: [], - maxPayloadSize: 42, - methodDescriptorPromise: channel.eventLoop.makePromise(of: MethodDescriptor.self), - skipStateMachineAssertions: true - ) + let handler = self.makeServerStreamHandler(channel: channel, disableAssertions: true) try channel.pipeline.syncOperations.addHandler(handler) // Receive client's initial metadata @@ -489,12 +461,7 @@ final class GRPCServerStreamHandlerTests: XCTestCase { func testReceiveMessageSplitAcrossMultipleBuffers() throws { let channel = EmbeddedChannel() - let handler = GRPCServerStreamHandler( - scheme: .http, - acceptedEncodings: [], - maxPayloadSize: 100, - methodDescriptorPromise: channel.eventLoop.makePromise(of: MethodDescriptor.self) - ) + let handler = self.makeServerStreamHandler(channel: channel) try channel.pipeline.syncOperations.addHandler(handler) // Receive client's initial metadata @@ -585,12 +552,7 @@ final class GRPCServerStreamHandlerTests: XCTestCase { func testReceiveMultipleHeaders() throws { let channel = EmbeddedChannel() - let handler = GRPCServerStreamHandler( - scheme: .http, - acceptedEncodings: [], - maxPayloadSize: 100, - methodDescriptorPromise: channel.eventLoop.makePromise(of: MethodDescriptor.self) - ) + let handler = self.makeServerStreamHandler(channel: channel) try channel.pipeline.syncOperations.addHandler(handler) // Receive client's initial metadata let clientInitialMetadata: HPACKHeaders = [ @@ -625,12 +587,7 @@ final class GRPCServerStreamHandlerTests: XCTestCase { func testSendMultipleMessagesInSingleBuffer() throws { let channel = EmbeddedChannel() - let handler = GRPCServerStreamHandler( - scheme: .http, - acceptedEncodings: [], - maxPayloadSize: 100, - methodDescriptorPromise: channel.eventLoop.makePromise(of: MethodDescriptor.self) - ) + let handler = self.makeServerStreamHandler(channel: channel) try channel.pipeline.syncOperations.addHandler(handler) // Receive client's initial metadata @@ -703,12 +660,7 @@ final class GRPCServerStreamHandlerTests: XCTestCase { func testMessageAndStatusAreNotReordered() throws { let channel = EmbeddedChannel() - let handler = GRPCServerStreamHandler( - scheme: .http, - acceptedEncodings: [], - maxPayloadSize: 100, - methodDescriptorPromise: channel.eventLoop.makePromise(of: MethodDescriptor.self) - ) + let handler = self.makeServerStreamHandler(channel: channel) try channel.pipeline.syncOperations.addHandler(handler) // Receive client's initial metadata @@ -785,13 +737,7 @@ final class GRPCServerStreamHandlerTests: XCTestCase { func testMethodDescriptorPromiseSucceeds() throws { let channel = EmbeddedChannel() let promise = channel.eventLoop.makePromise(of: MethodDescriptor.self) - let handler = GRPCServerStreamHandler( - scheme: .http, - acceptedEncodings: [], - maxPayloadSize: 100, - methodDescriptorPromise: promise, - skipStateMachineAssertions: true - ) + let handler = self.makeServerStreamHandler(channel: channel, descriptorPromise: promise) try channel.pipeline.syncOperations.addHandler(handler) // Receive client's initial metadata @@ -824,15 +770,8 @@ final class GRPCServerStreamHandlerTests: XCTestCase { func testMethodDescriptorPromiseIsFailedWhenHandlerRemoved() throws { let channel = EmbeddedChannel() let promise = channel.eventLoop.makePromise(of: MethodDescriptor.self) - let handler = GRPCServerStreamHandler( - scheme: .http, - acceptedEncodings: [], - maxPayloadSize: 100, - methodDescriptorPromise: promise, - skipStateMachineAssertions: true - ) + let handler = self.makeServerStreamHandler(channel: channel, descriptorPromise: promise) try channel.pipeline.syncOperations.addHandler(handler) - try channel.pipeline.syncOperations.removeHandler(handler).wait() XCTAssertThrowsError( @@ -847,13 +786,7 @@ final class GRPCServerStreamHandlerTests: XCTestCase { func testMethodDescriptorPromiseIsFailedIfRPCRejected() throws { let channel = EmbeddedChannel() let promise = channel.eventLoop.makePromise(of: MethodDescriptor.self) - let handler = GRPCServerStreamHandler( - scheme: .http, - acceptedEncodings: [], - maxPayloadSize: 100, - methodDescriptorPromise: promise, - skipStateMachineAssertions: true - ) + let handler = self.makeServerStreamHandler(channel: channel, descriptorPromise: promise) try channel.pipeline.syncOperations.addHandler(handler) // Receive client's initial metadata @@ -882,12 +815,10 @@ final class GRPCServerStreamHandlerTests: XCTestCase { func testUnexpectedStreamClose_ErrorFired() throws { let channel = EmbeddedChannel() let promise = channel.eventLoop.makePromise(of: MethodDescriptor.self) - let handler = GRPCServerStreamHandler( - scheme: .http, - acceptedEncodings: [], - maxPayloadSize: 100, - methodDescriptorPromise: promise, - skipStateMachineAssertions: true + let handler = self.makeServerStreamHandler( + channel: channel, + descriptorPromise: promise, + disableAssertions: true ) try channel.pipeline.syncOperations.addHandler(handler) @@ -937,12 +868,10 @@ final class GRPCServerStreamHandlerTests: XCTestCase { func testUnexpectedStreamClose_ChannelInactive() throws { let channel = EmbeddedChannel() let promise = channel.eventLoop.makePromise(of: MethodDescriptor.self) - let handler = GRPCServerStreamHandler( - scheme: .http, - acceptedEncodings: [], - maxPayloadSize: 100, - methodDescriptorPromise: promise, - skipStateMachineAssertions: true + let handler = self.makeServerStreamHandler( + channel: channel, + descriptorPromise: promise, + disableAssertions: true ) try channel.pipeline.syncOperations.addHandler(handler) @@ -992,12 +921,10 @@ final class GRPCServerStreamHandlerTests: XCTestCase { func testUnexpectedStreamClose_ResetStreamFrame() throws { let channel = EmbeddedChannel() let promise = channel.eventLoop.makePromise(of: MethodDescriptor.self) - let handler = GRPCServerStreamHandler( - scheme: .http, - acceptedEncodings: [], - maxPayloadSize: 100, - methodDescriptorPromise: promise, - skipStateMachineAssertions: true + let handler = self.makeServerStreamHandler( + channel: channel, + descriptorPromise: promise, + disableAssertions: true ) try channel.pipeline.syncOperations.addHandler(handler) @@ -1043,6 +970,81 @@ final class GRPCServerStreamHandlerTests: XCTestCase { XCTAssertEqual(error.message, "Invalid state") } } + +} + +struct ServerStreamHandlerTests { + private func makeServerStreamHandler( + channel: any Channel, + scheme: Scheme = .http, + acceptedEncodings: CompressionAlgorithmSet = [], + maxPayloadSize: Int = .max, + descriptorPromise: EventLoopPromise? = nil, + disableAssertions: Bool = false + ) -> GRPCServerStreamHandler { + return GRPCServerStreamHandler( + scheme: scheme, + acceptedEncodings: acceptedEncodings, + maxPayloadSize: maxPayloadSize, + methodDescriptorPromise: descriptorPromise ?? channel.eventLoop.makePromise(), + eventLoop: channel.eventLoop, + skipStateMachineAssertions: disableAssertions + ) + } + + @Test("ChannelShouldQuiesceEvent is buffered and turns into RPC cancellation") + func shouldQuiesceEventIsBufferedBeforeHandleIsSet() async throws { + let channel = EmbeddedChannel() + let handler = self.makeServerStreamHandler(channel: channel) + try channel.pipeline.syncOperations.addHandler(handler) + channel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent()) + + await withServerContextRPCCancellationHandle { handle in + handler.setCancellationHandle(handle) + #expect(handle.isCancelled) + } + + // Throwing is fine: the channel is closed abruptly, errors are expected. + _ = try? channel.finish() + } + + @Test("ChannelShouldQuiesceEvent turns into RPC cancellation") + func shouldQuiesceEventTriggersCancellation() async throws { + let channel = EmbeddedChannel() + let handler = self.makeServerStreamHandler(channel: channel) + try channel.pipeline.syncOperations.addHandler(handler) + + await withServerContextRPCCancellationHandle { handle in + handler.setCancellationHandle(handle) + #expect(!handle.isCancelled) + channel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent()) + #expect(handle.isCancelled) + } + + // Throwing is fine: the channel is closed abruptly, errors are expected. + _ = try? channel.finish() + } + + @Test("RST_STREAM turns into RPC cancellation") + func rstStreamTriggersCancellation() async throws { + let channel = EmbeddedChannel() + let handler = self.makeServerStreamHandler(channel: channel) + try channel.pipeline.syncOperations.addHandler(handler) + + await withServerContextRPCCancellationHandle { handle in + handler.setCancellationHandle(handle) + #expect(!handle.isCancelled) + + let rstStream: HTTP2Frame.FramePayload = .rstStream(.cancel) + channel.pipeline.fireChannelRead(NIOAny(rstStream)) + + #expect(handle.isCancelled) + } + + // Throwing is fine: the channel is closed abruptly, errors are expected. + _ = try? channel.finish() + } + } extension EmbeddedChannel { diff --git a/Tests/GRPCNIOTransportHTTP2Tests/ControlClient.swift b/Tests/GRPCNIOTransportHTTP2Tests/ControlClient.swift index 9516b8e..39748d4 100644 --- a/Tests/GRPCNIOTransportHTTP2Tests/ControlClient.swift +++ b/Tests/GRPCNIOTransportHTTP2Tests/ControlClient.swift @@ -17,7 +17,7 @@ import GRPCCore internal struct ControlClient { - private let client: GRPCCore.GRPCClient + internal let client: GRPCCore.GRPCClient internal init(wrapping client: GRPCCore.GRPCClient) { self.client = client @@ -88,4 +88,21 @@ internal struct ControlClient { handler: body ) } + + internal func waitForCancellation( + request: GRPCCore.ClientRequest, + options: GRPCCore.CallOptions = .defaults, + _ body: @Sendable @escaping ( + _ response: GRPCCore.StreamingClientResponse + ) async throws -> R + ) async throws -> R where R: Sendable { + try await self.client.serverStreaming( + request: request, + descriptor: MethodDescriptor(service: "Control", method: "WaitForCancellation"), + serializer: JSONSerializer(), + deserializer: JSONDeserializer(), + options: options, + handler: body + ) + } } diff --git a/Tests/GRPCNIOTransportHTTP2Tests/ControlMessages.swift b/Tests/GRPCNIOTransportHTTP2Tests/ControlMessages.swift index a935f82..cb5cfbd 100644 --- a/Tests/GRPCNIOTransportHTTP2Tests/ControlMessages.swift +++ b/Tests/GRPCNIOTransportHTTP2Tests/ControlMessages.swift @@ -120,6 +120,14 @@ struct ControlOutput: Codable { var payload: Data } +enum CancellationKind: Codable { + case awaitCancelled + case withCancellationHandler +} + +struct Empty: Codable { +} + struct JSONSerializer: MessageSerializer { private let encoder = JSONEncoder() diff --git a/Tests/GRPCNIOTransportHTTP2Tests/ControlService.swift b/Tests/GRPCNIOTransportHTTP2Tests/ControlService.swift index eee1296..cefae8e 100644 --- a/Tests/GRPCNIOTransportHTTP2Tests/ControlService.swift +++ b/Tests/GRPCNIOTransportHTTP2Tests/ControlService.swift @@ -51,10 +51,45 @@ struct ControlService: RegistrableRPCService { return try await self.handle(request: request) } ) + router.registerHandler( + forMethod: MethodDescriptor(service: "Control", method: "WaitForCancellation"), + deserializer: JSONDeserializer(), + serializer: JSONSerializer(), + handler: { request, context in + return try await self.waitForCancellation( + request: ServerRequest(stream: request), + context: context + ) + } + ) } } extension ControlService { + private func waitForCancellation( + request: ServerRequest, + context: ServerContext + ) async throws -> StreamingServerResponse { + switch request.message { + case .awaitCancelled: + return StreamingServerResponse { _ in + try await context.cancellation.cancelled + return [:] + } + + case .withCancellationHandler: + let signal = AsyncStream.makeStream(of: Void.self) + return StreamingServerResponse { _ in + await withRPCCancellationHandler { + for await _ in signal.stream {} + return [:] + } onCancelRPC: { + signal.continuation.finish() + } + } + } + } + private func handle( request: StreamingServerRequest ) async throws -> StreamingServerResponse { diff --git a/Tests/GRPCNIOTransportHTTP2Tests/HTTP2TransportTests.swift b/Tests/GRPCNIOTransportHTTP2Tests/HTTP2TransportTests.swift index 6d268ed..342b5a3 100644 --- a/Tests/GRPCNIOTransportHTTP2Tests/HTTP2TransportTests.swift +++ b/Tests/GRPCNIOTransportHTTP2Tests/HTTP2TransportTests.swift @@ -51,7 +51,7 @@ final class HTTP2TransportTests: XCTestCase { clientCompression: CompressionAlgorithm = .none, clientEnabledCompression: CompressionAlgorithmSet = .none, serverCompression: CompressionAlgorithmSet = .none, - _ execute: (ControlClient, Transport) async throws -> Void + _ execute: (ControlClient, GRPCServer, Transport) async throws -> Void ) async throws { for pair in transport { try await withThrowingTaskGroup(of: Void.self) { group in @@ -87,7 +87,7 @@ final class HTTP2TransportTests: XCTestCase { do { let control = ControlClient(wrapping: client) - try await execute(control, pair) + try await execute(control, server, pair) } catch { XCTFail("Unexpected error: '\(error)' (\(pair))") } @@ -227,7 +227,7 @@ final class HTTP2TransportTests: XCTestCase { func testUnaryOK() async throws { // Client sends one message, server sends back metadata, a single message, and an ok status with // trailing metadata. - try await self.forEachTransportPair { control, pair in + try await self.forEachTransportPair { control, _, pair in let input = ControlInput.with { $0.echoMetadataInHeaders = true $0.echoMetadataInTrailers = true @@ -257,7 +257,7 @@ final class HTTP2TransportTests: XCTestCase { func testUnaryNotOK() async throws { // Client sends one message, server sends back metadata, a single message, and a non-ok status // with trailing metadata. - try await self.forEachTransportPair { control, pair in + try await self.forEachTransportPair { control, _, pair in let input = ControlInput.with { $0.echoMetadataInTrailers = true $0.numberOfMessages = 1 @@ -291,7 +291,7 @@ final class HTTP2TransportTests: XCTestCase { func testUnaryRejected() async throws { // Client sends one message, server sends non-ok status with trailing metadata. - try await self.forEachTransportPair { control, pair in + try await self.forEachTransportPair { control, _, pair in let metadata: Metadata = ["test-key": "test-value"] let request = ClientRequest( message: .trailersOnly(code: .aborted, message: "\(#function)", echoMetadata: true), @@ -317,7 +317,7 @@ final class HTTP2TransportTests: XCTestCase { } func testClientStreamingOK() async throws { - try await self.forEachTransportPair { control, pair in + try await self.forEachTransportPair { control, _, pair in let metadata: Metadata = ["test-key": "test-value"] let request = StreamingClientRequest( of: ControlInput.self, @@ -348,7 +348,7 @@ final class HTTP2TransportTests: XCTestCase { } func testClientStreamingNotOK() async throws { - try await self.forEachTransportPair { control, pair in + try await self.forEachTransportPair { control, _, pair in let metadata: Metadata = ["test-key": "test-value"] let request = StreamingClientRequest( of: ControlInput.self, @@ -385,7 +385,7 @@ final class HTTP2TransportTests: XCTestCase { func testClientStreamingRejected() async throws { // Client sends one message, server sends non-ok status with trailing metadata. - try await self.forEachTransportPair { control, pair in + try await self.forEachTransportPair { control, _, pair in let metadata: Metadata = ["test-key": "test-value"] let request = StreamingClientRequest( of: ControlInput.self, @@ -419,7 +419,7 @@ final class HTTP2TransportTests: XCTestCase { } func testServerStreamingOK() async throws { - try await self.forEachTransportPair { control, pair in + try await self.forEachTransportPair { control, _, pair in let metadata: Metadata = ["test-key": "test-value"] let input = ControlInput.with { $0.echoMetadataInHeaders = true @@ -458,7 +458,7 @@ final class HTTP2TransportTests: XCTestCase { } func testServerStreamingEmptyOK() async throws { - try await self.forEachTransportPair { control, pair in + try await self.forEachTransportPair { control, _, pair in let metadata: Metadata = ["test-key": "test-value"] // Echo back metadata, but don't send any messages. let input = ControlInput.with { @@ -489,7 +489,7 @@ final class HTTP2TransportTests: XCTestCase { } func testServerStreamingNotOK() async throws { - try await self.forEachTransportPair { control, pair in + try await self.forEachTransportPair { control, _, pair in let metadata: Metadata = ["test-key": "test-value"] let input = ControlInput.with { $0.echoMetadataInHeaders = true @@ -539,7 +539,7 @@ final class HTTP2TransportTests: XCTestCase { } func testServerStreamingEmptyNotOK() async throws { - try await self.forEachTransportPair { control, pair in + try await self.forEachTransportPair { control, _, pair in let metadata: Metadata = ["test-key": "test-value"] let input = ControlInput.with { $0.echoMetadataInHeaders = true @@ -575,7 +575,7 @@ final class HTTP2TransportTests: XCTestCase { } func testServerStreamingRejected() async throws { - try await self.forEachTransportPair { control, pair in + try await self.forEachTransportPair { control, _, pair in let metadata: Metadata = ["test-key": "test-value"] let request = ClientRequest( message: .trailersOnly(code: .aborted, message: "\(#function)", echoMetadata: true), @@ -596,7 +596,7 @@ final class HTTP2TransportTests: XCTestCase { } func testBidiStreamingOK() async throws { - try await self.forEachTransportPair { control, pair in + try await self.forEachTransportPair { control, _, pair in let metadata: Metadata = ["test-key": "test-value"] let request = StreamingClientRequest( of: ControlInput.self, @@ -636,7 +636,7 @@ final class HTTP2TransportTests: XCTestCase { } func testBidiStreamingEmptyOK() async throws { - try await self.forEachTransportPair { control, pair in + try await self.forEachTransportPair { control, _, pair in let request = StreamingClientRequest(of: ControlInput.self) { _ in } try await control.bidiStream(request: request) { response in switch response.accepted { @@ -659,7 +659,7 @@ final class HTTP2TransportTests: XCTestCase { } func testBidiStreamingNotOK() async throws { - try await self.forEachTransportPair { control, pair in + try await self.forEachTransportPair { control, _, pair in let metadata: Metadata = ["test-key": "test-value"] let request = StreamingClientRequest( of: ControlInput.self, @@ -707,7 +707,7 @@ final class HTTP2TransportTests: XCTestCase { } func testBidiStreamingRejected() async throws { - try await self.forEachTransportPair { control, pair in + try await self.forEachTransportPair { control, _, pair in let metadata: Metadata = ["test-key": "test-value"] let request = StreamingClientRequest( of: ControlInput.self, @@ -738,7 +738,7 @@ final class HTTP2TransportTests: XCTestCase { // MARK: - Not Implemented func testUnaryNotImplemented() async throws { - try await self.forEachTransportPair(enableControlService: false) { control, pair in + try await self.forEachTransportPair(enableControlService: false) { control, _, pair in let request = ClientRequest(message: ControlInput()) try await control.unary(request: request) { response in XCTAssertThrowsError(ofType: RPCError.self, try response.message) { error in @@ -749,7 +749,7 @@ final class HTTP2TransportTests: XCTestCase { } func testClientStreamingNotImplemented() async throws { - try await self.forEachTransportPair(enableControlService: false) { control, pair in + try await self.forEachTransportPair(enableControlService: false) { control, _, pair in let request = StreamingClientRequest(of: ControlInput.self) { _ in } try await control.clientStream(request: request) { response in XCTAssertThrowsError(ofType: RPCError.self, try response.message) { error in @@ -760,7 +760,7 @@ final class HTTP2TransportTests: XCTestCase { } func testServerStreamingNotImplemented() async throws { - try await self.forEachTransportPair(enableControlService: false) { control, pair in + try await self.forEachTransportPair(enableControlService: false) { control, _, pair in let request = ClientRequest(message: ControlInput()) try await control.serverStream(request: request) { response in XCTAssertThrowsError(ofType: RPCError.self, try response.accepted.get()) { error in @@ -771,7 +771,7 @@ final class HTTP2TransportTests: XCTestCase { } func testBidiStreamingNotImplemented() async throws { - try await self.forEachTransportPair(enableControlService: false) { control, pair in + try await self.forEachTransportPair(enableControlService: false) { control, _, pair in let request = StreamingClientRequest(of: ControlInput.self) { _ in } try await control.bidiStream(request: request) { response in XCTAssertThrowsError(ofType: RPCError.self, try response.accepted.get()) { error in @@ -980,7 +980,7 @@ final class HTTP2TransportTests: XCTestCase { clientCompression: .deflate, clientEnabledCompression: .deflate, serverCompression: .deflate - ) { control, pair in + ) { control, _, pair in try await self.testUnaryCompression( client: .deflate, server: .deflate, @@ -995,7 +995,7 @@ final class HTTP2TransportTests: XCTestCase { clientCompression: .gzip, clientEnabledCompression: .gzip, serverCompression: .gzip - ) { control, pair in + ) { control, _, pair in try await self.testUnaryCompression( client: .gzip, server: .gzip, @@ -1010,7 +1010,7 @@ final class HTTP2TransportTests: XCTestCase { clientCompression: .deflate, clientEnabledCompression: .deflate, serverCompression: .deflate - ) { control, pair in + ) { control, _, pair in try await self.testClientStreamingCompression( client: .deflate, server: .deflate, @@ -1025,7 +1025,7 @@ final class HTTP2TransportTests: XCTestCase { clientCompression: .gzip, clientEnabledCompression: .gzip, serverCompression: .gzip - ) { control, pair in + ) { control, _, pair in try await self.testClientStreamingCompression( client: .gzip, server: .gzip, @@ -1040,7 +1040,7 @@ final class HTTP2TransportTests: XCTestCase { clientCompression: .deflate, clientEnabledCompression: .deflate, serverCompression: .deflate - ) { control, pair in + ) { control, _, pair in try await self.testServerStreamingCompression( client: .deflate, server: .deflate, @@ -1055,7 +1055,7 @@ final class HTTP2TransportTests: XCTestCase { clientCompression: .gzip, clientEnabledCompression: .gzip, serverCompression: .gzip - ) { control, pair in + ) { control, _, pair in try await self.testServerStreamingCompression( client: .gzip, server: .gzip, @@ -1070,7 +1070,7 @@ final class HTTP2TransportTests: XCTestCase { clientCompression: .deflate, clientEnabledCompression: .deflate, serverCompression: .deflate - ) { control, pair in + ) { control, _, pair in try await self.testBidiStreamingCompression( client: .deflate, server: .deflate, @@ -1085,7 +1085,7 @@ final class HTTP2TransportTests: XCTestCase { clientCompression: .gzip, clientEnabledCompression: .gzip, serverCompression: .gzip - ) { control, pair in + ) { control, _, pair in try await self.testBidiStreamingCompression( client: .gzip, server: .gzip, @@ -1099,7 +1099,7 @@ final class HTTP2TransportTests: XCTestCase { try await self.forEachTransportPair( clientEnabledCompression: .all, serverCompression: .gzip - ) { control, pair in + ) { control, _, pair in let message = ControlInput.with { $0.numberOfMessages = 1 $0.payloadParameters = .with { @@ -1129,7 +1129,7 @@ final class HTTP2TransportTests: XCTestCase { try await self.forEachTransportPair( clientEnabledCompression: .all, serverCompression: .gzip - ) { control, pair in + ) { control, _, pair in let request = StreamingClientRequest(of: ControlInput.self) { writer in try await writer.write(.noOp) } @@ -1154,7 +1154,7 @@ final class HTTP2TransportTests: XCTestCase { try await self.forEachTransportPair( clientEnabledCompression: .all, serverCompression: .gzip - ) { control, pair in + ) { control, _, pair in let message = ControlInput.with { $0.numberOfMessages = 1 $0.payloadParameters = .with { @@ -1184,7 +1184,7 @@ final class HTTP2TransportTests: XCTestCase { try await self.forEachTransportPair( clientEnabledCompression: .all, serverCompression: .gzip - ) { control, pair in + ) { control, _, pair in let request = StreamingClientRequest(of: ControlInput.self) { writer in try await writer.write(.noOp) } @@ -1206,7 +1206,7 @@ final class HTTP2TransportTests: XCTestCase { } func testUnaryTimeoutPropagatedToServer() async throws { - try await self.forEachTransportPair { control, pair in + try await self.forEachTransportPair { control, _, pair in let message = ControlInput.with { $0.echoMetadataInHeaders = true $0.numberOfMessages = 1 @@ -1223,7 +1223,7 @@ final class HTTP2TransportTests: XCTestCase { } func testClientStreamingTimeoutPropagatedToServer() async throws { - try await self.forEachTransportPair { control, pair in + try await self.forEachTransportPair { control, _, pair in let request = StreamingClientRequest(of: ControlInput.self) { writer in let message = ControlInput.with { $0.echoMetadataInHeaders = true @@ -1242,7 +1242,7 @@ final class HTTP2TransportTests: XCTestCase { } func testServerStreamingTimeoutPropagatedToServer() async throws { - try await self.forEachTransportPair { control, pair in + try await self.forEachTransportPair { control, _, pair in let message = ControlInput.with { $0.echoMetadataInHeaders = true $0.numberOfMessages = 1 @@ -1259,7 +1259,7 @@ final class HTTP2TransportTests: XCTestCase { } func testBidiStreamingTimeoutPropagatedToServer() async throws { - try await self.forEachTransportPair { control, pair in + try await self.forEachTransportPair { control, _, pair in let request = StreamingClientRequest(of: ControlInput.self) { writer in try await writer.write(.echoMetadata) } @@ -1372,7 +1372,7 @@ final class HTTP2TransportTests: XCTestCase { } func testUnaryScheme() async throws { - try await self.forEachTransportPair { control, pair in + try await self.forEachTransportPair { control, _, pair in let input = ControlInput.with { $0.echoMetadataInHeaders = true $0.numberOfMessages = 1 @@ -1385,7 +1385,7 @@ final class HTTP2TransportTests: XCTestCase { } func testServerStreamingScheme() async throws { - try await self.forEachTransportPair { control, pair in + try await self.forEachTransportPair { control, _, pair in let input = ControlInput.with { $0.echoMetadataInHeaders = true $0.numberOfMessages = 1 @@ -1398,7 +1398,7 @@ final class HTTP2TransportTests: XCTestCase { } func testClientStreamingScheme() async throws { - try await self.forEachTransportPair { control, pair in + try await self.forEachTransportPair { control, _, pair in let request = StreamingClientRequest(of: ControlInput.self) { writer in let input = ControlInput.with { $0.echoMetadataInHeaders = true @@ -1413,7 +1413,7 @@ final class HTTP2TransportTests: XCTestCase { } func testBidiStreamingScheme() async throws { - try await self.forEachTransportPair { control, pair in + try await self.forEachTransportPair { control, _, pair in let request = StreamingClientRequest(of: ControlInput.self) { writer in let input = ControlInput.with { $0.echoMetadataInHeaders = true @@ -1426,6 +1426,26 @@ final class HTTP2TransportTests: XCTestCase { } } } + + func testServerCancellation() async throws { + for kind in [CancellationKind.awaitCancelled, .withCancellationHandler] { + try await self.forEachTransportPair { control, server, pair in + let request = ClientRequest(message: kind) + try await control.waitForCancellation(request: request) { response in + // Shutdown the client so that it doesn't attempt to reconnect when the server closes. + control.client.beginGracefulShutdown() + + // Shutdown the server to cancel the RPC. + server.beginGracefulShutdown() + + // The RPC should complete without any error or response. + let responses = try await response.messages.reduce(into: []) { $0.append($1) } + XCTAssert(responses.isEmpty) + } + + } + } + } } extension [HTTP2TransportTests.Transport] { From d9ed15496f598841f4179ab93ddd82b75bd5c16c Mon Sep 17 00:00:00 2001 From: George Barnett Date: Tue, 15 Oct 2024 11:23:10 +0100 Subject: [PATCH 2/2] Dont execute the RPC if the stream is already closed --- .../Server/CommonHTTP2ServerTransport.swift | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Sources/GRPCNIOTransportCore/Server/CommonHTTP2ServerTransport.swift b/Sources/GRPCNIOTransportCore/Server/CommonHTTP2ServerTransport.swift index b9e9a51..db06969 100644 --- a/Sources/GRPCNIOTransportCore/Server/CommonHTTP2ServerTransport.swift +++ b/Sources/GRPCNIOTransportCore/Server/CommonHTTP2ServerTransport.swift @@ -246,13 +246,13 @@ package final class CommonHTTP2ServerTransport< // Sync is safe: this is on the right event loop. let sync = stream.channel.pipeline.syncOperations - // Looking up the handler can fail if the channel is already closed, in which case - // cancel the handle directly. do { let handler = try sync.handler(type: GRPCServerStreamHandler.self) handler.setCancellationHandle(handle) } catch { - handle.cancel() + // Looking up the handler can fail if the channel is already closed, in which case + // don't execute the RPC, just return early. + return } }