diff --git a/Sources/GRPC/Interceptor/ClientInterceptorPipeline.swift b/Sources/GRPC/Interceptor/ClientInterceptorPipeline.swift index d54def98d..1f653c17d 100644 --- a/Sources/GRPC/Interceptor/ClientInterceptorPipeline.swift +++ b/Sources/GRPC/Interceptor/ClientInterceptorPipeline.swift @@ -79,16 +79,17 @@ internal final class ClientInterceptorPipeline { internal let _errorDelegate: ClientErrorDelegate? @usableFromInline - internal let _onError: (Error) -> Void + internal private(set) var _onError: ((Error) -> Void)? @usableFromInline - internal let _onCancel: (EventLoopPromise?) -> Void + internal private(set) var _onCancel: ((EventLoopPromise?) -> Void)? @usableFromInline - internal let _onRequestPart: (GRPCClientRequestPart, EventLoopPromise?) -> Void + internal private(set) var _onRequestPart: + ((GRPCClientRequestPart, EventLoopPromise?) -> Void)? @usableFromInline - internal let _onResponsePart: (GRPCClientResponsePart) -> Void + internal private(set) var _onResponsePart: ((GRPCClientResponsePart) -> Void)? /// The index after the last user interceptor context index. (i.e. `_userContexts.endIndex`). @usableFromInline @@ -217,9 +218,13 @@ internal final class ClientInterceptorPipeline { case self._tailIndex: if part.isEnd { + // Update our state before handling the response part. + self._isOpen = false + self._onResponsePart?(part) self.close() + } else { + self._onResponsePart?(part) } - self._onResponsePart(part) default: self._userContexts[index].invokeReceive(part) @@ -275,9 +280,8 @@ internal final class ClientInterceptorPipeline { /// Handles a caught error which has traversed the interceptor pipeline. @usableFromInline internal func _errorCaught(_ error: Error) { - // We're about to complete, close the pipeline. - self.close() - + // We're about to call out to an error handler: update our state first. + self._isOpen = false var unwrappedError: Error // Unwrap the error, if possible. @@ -295,7 +299,10 @@ internal final class ClientInterceptorPipeline { } // Emit the unwrapped error. - self._onError(unwrappedError) + self._onError?(unwrappedError) + + // Close the pipeline. + self.close() } /// Writes a request message into the interceptor pipeline. @@ -351,7 +358,7 @@ internal final class ClientInterceptorPipeline { ) { switch index { case self._headIndex: - self._onRequestPart(part, promise) + self._onRequestPart?(part, promise) case self._tailIndex: self._invokeSend( @@ -407,7 +414,7 @@ internal final class ClientInterceptorPipeline { ) { switch index { case self._headIndex: - self._onCancel(promise) + self._onCancel?(promise) case self._tailIndex: self._invokeCancel( @@ -425,7 +432,7 @@ internal final class ClientInterceptorPipeline { extension ClientInterceptorPipeline { /// Closes the pipeline. This should be called once, by the tail interceptor, to indicate that - /// the RPC has completed. + /// the RPC has completed. If this is not called, we will leak. /// - Important: This *must* to be called from the `eventLoop`. @inlinable internal func close() { @@ -437,7 +444,14 @@ extension ClientInterceptorPipeline { self._scheduledClose = nil // Cancel the transport. - self._onCancel(nil) + self._onCancel?(nil) + + // `ClientTransport` holds a reference to us and references to itself via these callbacks. Break + // these references now by replacing the callbacks. + self._onError = nil + self._onCancel = nil + self._onRequestPart = nil + self._onResponsePart = nil } /// Sets up a deadline for the pipeline. diff --git a/Sources/GRPC/Interceptor/ClientTransport.swift b/Sources/GRPC/Interceptor/ClientTransport.swift index 82eaba221..b15c33367 100644 --- a/Sources/GRPC/Interceptor/ClientTransport.swift +++ b/Sources/GRPC/Interceptor/ClientTransport.swift @@ -85,8 +85,8 @@ internal final class ClientTransport { // trailers here and only forward them when we receive the status. private var trailers: HPACKHeaders? - /// The interceptor pipeline connected to this transport. This must be set to `nil` when removed - /// from the `ChannelPipeline` in order to break reference cycles. + /// The interceptor pipeline connected to this transport. The pipeline also holds references + /// to `self` which are dropped when the interceptor pipeline is closed. @usableFromInline internal var _pipeline: ClientInterceptorPipeline? @@ -118,6 +118,7 @@ internal final class ClientTransport { self.logger = logger self.serializer = serializer self.deserializer = deserializer + // The references to self held by the pipeline are dropped when it is closed. self._pipeline = ClientInterceptorPipeline( eventLoop: eventLoop, details: details, @@ -241,7 +242,8 @@ extension ClientTransport { if self.state.cancel() { let error = GRPCError.RPCCancelledByClient() - self.forwardErrorToInterceptors(error) + let status = error.makeGRPCStatus() + self.forwardToInterceptors(.end(status, [:])) self.failBufferedWrites(with: error) self.channel?.close(mode: .all, promise: nil) self.channelPromise?.fail(error) @@ -363,11 +365,9 @@ extension ClientTransport { private func dropReferences() { if self.callEventLoop.inEventLoop { self.channel = nil - self._pipeline = nil } else { self.callEventLoop.execute { self.channel = nil - self._pipeline = nil } } } diff --git a/Tests/GRPCTests/ClientCallTests.swift b/Tests/GRPCTests/ClientCallTests.swift index dbce89775..fce9d3f15 100644 --- a/Tests/GRPCTests/ClientCallTests.swift +++ b/Tests/GRPCTests/ClientCallTests.swift @@ -197,11 +197,7 @@ class ClientCallTests: GRPCTestCase { // Cancellation should succeed. assertThat(try get.cancel().wait(), .doesNotThrow()) - // The status promise will fail. - assertThat( - try promise.futureResult.wait(), - .throws(.instanceOf(GRPCError.RPCCancelledByClient.self)) - ) + assertThat(try promise.futureResult.wait(), .hasCode(.cancelled)) // Cancellation should now fail, we've already cancelled. assertThat(try get.cancel().wait(), .throws(.instanceOf(GRPCError.AlreadyComplete.self))) diff --git a/Tests/GRPCTests/ClientCancellingTests.swift b/Tests/GRPCTests/ClientCancellingTests.swift index 5724b0046..a12d9b828 100644 --- a/Tests/GRPCTests/ClientCancellingTests.swift +++ b/Tests/GRPCTests/ClientCancellingTests.swift @@ -27,7 +27,7 @@ class ClientCancellingTests: EchoTestCaseBase { call.cancel(promise: nil) call.response.whenFailure { error in - XCTAssertTrue(error is GRPCError.RPCCancelledByClient) + XCTAssertEqual((error as? GRPCStatus)?.code, .cancelled) responseReceived.fulfill() } @@ -47,7 +47,7 @@ class ClientCancellingTests: EchoTestCaseBase { call.cancel(promise: nil) call.response.whenFailure { error in - XCTAssertTrue(error is GRPCError.RPCCancelledByClient) + XCTAssertEqual((error as? GRPCStatus)?.code, .cancelled) responseReceived.fulfill() } diff --git a/Tests/GRPCTests/EchoHelpers/Interceptors/EchoInterceptorFactories.swift b/Tests/GRPCTests/EchoHelpers/Interceptors/EchoInterceptorFactories.swift new file mode 100644 index 000000000..203cab499 --- /dev/null +++ b/Tests/GRPCTests/EchoHelpers/Interceptors/EchoInterceptorFactories.swift @@ -0,0 +1,87 @@ +/* + * Copyright 2021, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import EchoModel +import GRPC + +// MARK: - Client + +internal final class EchoClientInterceptors: Echo_EchoClientInterceptorFactoryProtocol { + internal typealias Factory = () -> ClientInterceptor + private var factories: [Factory] = [] + + internal init(_ factories: Factory...) { + self.factories = factories + } + + internal func register(_ factory: @escaping Factory) { + self.factories.append(factory) + } + + private func makeInterceptors() -> [ClientInterceptor] { + return self.factories.map { $0() } + } + + func makeGetInterceptors() -> [ClientInterceptor] { + return self.makeInterceptors() + } + + func makeExpandInterceptors() -> [ClientInterceptor] { + return self.makeInterceptors() + } + + func makeCollectInterceptors() -> [ClientInterceptor] { + return self.makeInterceptors() + } + + func makeUpdateInterceptors() -> [ClientInterceptor] { + return self.makeInterceptors() + } +} + +// MARK: - Server + +internal final class EchoServerInterceptors: Echo_EchoServerInterceptorFactoryProtocol { + internal typealias Factory = () -> ServerInterceptor + private var factories: [Factory] = [] + + internal init(_ factories: Factory...) { + self.factories = factories + } + + internal func register(_ factory: @escaping Factory) { + self.factories.append(factory) + } + + private func makeInterceptors() -> [ServerInterceptor] { + return self.factories.map { $0() } + } + + func makeGetInterceptors() -> [ServerInterceptor] { + return self.makeInterceptors() + } + + func makeExpandInterceptors() -> [ServerInterceptor] { + return self.makeInterceptors() + } + + func makeCollectInterceptors() -> [ServerInterceptor] { + return self.makeInterceptors() + } + + func makeUpdateInterceptors() -> [ServerInterceptor] { + return self.makeInterceptors() + } +} diff --git a/Tests/GRPCTests/InterceptedRPCCancellationTests.swift b/Tests/GRPCTests/InterceptedRPCCancellationTests.swift new file mode 100644 index 000000000..b82ff293e --- /dev/null +++ b/Tests/GRPCTests/InterceptedRPCCancellationTests.swift @@ -0,0 +1,202 @@ +/* + * Copyright 2021, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import EchoImplementation +import EchoModel +@testable import GRPC +import NIOCore +import NIOPosix +import protocol SwiftProtobuf.Message +import XCTest + +final class InterceptedRPCCancellationTests: GRPCTestCase { + func testCancellationWithinInterceptedRPC() throws { + // This test validates that when using interceptors to replay an RPC that the lifecycle of + // the interceptor pipeline is correctly managed. That is, the transport maintains a reference + // to the pipeline for as long as the call is alive (rather than dropping the reference when + // the RPC ends). + let group = MultiThreadedEventLoopGroup(numberOfThreads: System.coreCount) + defer { + XCTAssertNoThrow(try group.syncShutdownGracefully()) + } + + // Interceptor checks that a "magic" header is present. + let serverInterceptors = EchoServerInterceptors(MagicRequiredServerInterceptor.init) + let server = try Server.insecure(group: group) + .withLogger(self.serverLogger) + .withServiceProviders([EchoProvider(interceptors: serverInterceptors)]) + .bind(host: "127.0.0.1", port: 0) + .wait() + defer { + XCTAssertNoThrow(try server.close().wait()) + } + + let connection = ClientConnection.insecure(group: group) + .withBackgroundActivityLogger(self.clientLogger) + .connect(host: "127.0.0.1", port: server.channel.localAddress!.port!) + defer { + XCTAssertNoThrow(try connection.close().wait()) + } + + let clientInterceptors = EchoClientInterceptors() + // Retries an RPC with a "magic" header if it fails with the permission denied status code. + clientInterceptors.register { + MagicAddingClientInterceptor(channel: connection) + } + + let echo = Echo_EchoClient(channel: connection, interceptors: clientInterceptors) + + let receivedFirstResponse = connection.eventLoop.makePromise(of: Void.self) + let update = echo.update { _ in + receivedFirstResponse.succeed(()) + } + + XCTAssertNoThrow(try update.sendMessage(.with { $0.text = "ping" }).wait()) + // Wait for the pong: it means the second RPC is up and running and the first should have + // completed. + XCTAssertNoThrow(try receivedFirstResponse.futureResult.wait()) + XCTAssertNoThrow(try update.cancel().wait()) + + let status = try update.status.wait() + XCTAssertEqual(status.code, .cancelled) + } +} + +final class MagicRequiredServerInterceptor< + Request: Message, + Response: Message +>: ServerInterceptor { + override func receive( + _ part: GRPCServerRequestPart, + context: ServerInterceptorContext + ) { + switch part { + case let .metadata(metadata): + if metadata.contains(name: "magic") { + context.log.debug("metadata contains magic; accepting rpc") + context.receive(part) + } else { + context.log.debug("metadata does not contains magic; rejecting rpc") + let status = GRPCStatus(code: .permissionDenied, message: nil) + context.send(.end(status, [:]), promise: nil) + } + case .message, .end: + context.receive(part) + } + } +} + +final class MagicAddingClientInterceptor< + Request: Message, + Response: Message +>: ClientInterceptor { + private let channel: GRPCChannel + private var requestParts = CircularBuffer>() + private var retry: Call? + + init(channel: GRPCChannel) { + self.channel = channel + } + + override func cancel( + promise: EventLoopPromise?, + context: ClientInterceptorContext + ) { + if let retry = self.retry { + context.log.debug("cancelling retry RPC") + retry.cancel(promise: promise) + } else { + context.cancel(promise: promise) + } + } + + override func send( + _ part: GRPCClientRequestPart, + promise: EventLoopPromise?, + context: ClientInterceptorContext + ) { + if let retry = self.retry { + context.log.debug("retrying part \(part)") + retry.send(part, promise: promise) + } else { + switch part { + case .metadata: + // Replace the metadata with the magic words. + self.requestParts.append(.metadata(["magic": "it's real!"])) + case .message, .end: + self.requestParts.append(part) + } + context.send(part, promise: promise) + } + } + + override func receive( + _ part: GRPCClientResponsePart, + context: ClientInterceptorContext + ) { + switch part { + case .metadata, .message: + XCTFail("Unexpected response part \(part)") + context.receive(part) + + case let .end(status, _): + guard status.code == .permissionDenied else { + XCTFail("Unexpected status code \(status)") + context.receive(part) + return + } + + XCTAssertNil(self.retry) + + context.log.debug("initial rpc failed, retrying") + + self.retry = self.channel.makeCall( + path: context.path, + type: context.type, + callOptions: CallOptions(logger: context.logger), + interceptors: [] + ) + + self.retry!.invoke(onError: { + context.log.debug("intercepting error from retried rpc") + context.errorCaught($0) + }) { responsePart in + context.log.debug("intercepting response part from retried rpc") + context.receive(responsePart) + } + + while let requestPart = self.requestParts.popFirst() { + context.log.debug("replaying \(requestPart) on new rpc") + self.retry!.send(requestPart, promise: nil) + } + } + } +} + +// MARK: - GRPC Logger + +// Our tests also check the "Source" of a logger is "GRPC". That assertion fails when we log from +// tests so we'll use our internal logger instead. +extension ClientInterceptorContext { + var log: GRPCLogger { + return GRPCLogger(wrapping: self.logger) + } +} + +extension ServerInterceptorContext { + var log: GRPCLogger { + return GRPCLogger(wrapping: self.logger) + } +}