From 03010c784c0bc6199130983381044f7b19fb6be6 Mon Sep 17 00:00:00 2001 From: George Barnett Date: Thu, 16 Sep 2021 12:38:20 +0100 Subject: [PATCH] Extend lifetime of client interceptor pipeline (#1265) Motivation: A client call (i.e. the object the user holds) may live longer than the transport associated with it (roughly speaking, the http/2 stream channel). An example of this is when interceptors are use to retry and RPC and redirect responses back to the original call. However, the interceptor pipeline is held by the transport and is currently set to nil when the transport is removed from the channel. This means events invoked from the call object (such as cancellation) which go via the transport (holding the interceptor pipeline) are incorrectly failed. Modifications: - Have the client interceptor pipeline break the ref cycle between the transport and itself when the interceptor pipeline closes rather than when the transport is closed - Emit a cancellation status rater than error on cancellation - Update the ordering of when close is called in the interceptor pipeline. - Add and update tests Result: "sub"-RPCs may be cancelled. --- .../ClientInterceptorPipeline.swift | 40 ++-- .../GRPC/Interceptor/ClientTransport.swift | 10 +- Tests/GRPCTests/ClientCallTests.swift | 6 +- Tests/GRPCTests/ClientCancellingTests.swift | 4 +- .../EchoInterceptorFactories.swift | 87 ++++++++ .../InterceptedRPCCancellationTests.swift | 202 ++++++++++++++++++ 6 files changed, 324 insertions(+), 25 deletions(-) create mode 100644 Tests/GRPCTests/EchoHelpers/Interceptors/EchoInterceptorFactories.swift create mode 100644 Tests/GRPCTests/InterceptedRPCCancellationTests.swift diff --git a/Sources/GRPC/Interceptor/ClientInterceptorPipeline.swift b/Sources/GRPC/Interceptor/ClientInterceptorPipeline.swift index d54def98d..1f653c17d 100644 --- a/Sources/GRPC/Interceptor/ClientInterceptorPipeline.swift +++ b/Sources/GRPC/Interceptor/ClientInterceptorPipeline.swift @@ -79,16 +79,17 @@ internal final class ClientInterceptorPipeline { internal let _errorDelegate: ClientErrorDelegate? @usableFromInline - internal let _onError: (Error) -> Void + internal private(set) var _onError: ((Error) -> Void)? @usableFromInline - internal let _onCancel: (EventLoopPromise?) -> Void + internal private(set) var _onCancel: ((EventLoopPromise?) -> Void)? @usableFromInline - internal let _onRequestPart: (GRPCClientRequestPart, EventLoopPromise?) -> Void + internal private(set) var _onRequestPart: + ((GRPCClientRequestPart, EventLoopPromise?) -> Void)? @usableFromInline - internal let _onResponsePart: (GRPCClientResponsePart) -> Void + internal private(set) var _onResponsePart: ((GRPCClientResponsePart) -> Void)? /// The index after the last user interceptor context index. (i.e. `_userContexts.endIndex`). @usableFromInline @@ -217,9 +218,13 @@ internal final class ClientInterceptorPipeline { case self._tailIndex: if part.isEnd { + // Update our state before handling the response part. + self._isOpen = false + self._onResponsePart?(part) self.close() + } else { + self._onResponsePart?(part) } - self._onResponsePart(part) default: self._userContexts[index].invokeReceive(part) @@ -275,9 +280,8 @@ internal final class ClientInterceptorPipeline { /// Handles a caught error which has traversed the interceptor pipeline. @usableFromInline internal func _errorCaught(_ error: Error) { - // We're about to complete, close the pipeline. - self.close() - + // We're about to call out to an error handler: update our state first. + self._isOpen = false var unwrappedError: Error // Unwrap the error, if possible. @@ -295,7 +299,10 @@ internal final class ClientInterceptorPipeline { } // Emit the unwrapped error. - self._onError(unwrappedError) + self._onError?(unwrappedError) + + // Close the pipeline. + self.close() } /// Writes a request message into the interceptor pipeline. @@ -351,7 +358,7 @@ internal final class ClientInterceptorPipeline { ) { switch index { case self._headIndex: - self._onRequestPart(part, promise) + self._onRequestPart?(part, promise) case self._tailIndex: self._invokeSend( @@ -407,7 +414,7 @@ internal final class ClientInterceptorPipeline { ) { switch index { case self._headIndex: - self._onCancel(promise) + self._onCancel?(promise) case self._tailIndex: self._invokeCancel( @@ -425,7 +432,7 @@ internal final class ClientInterceptorPipeline { extension ClientInterceptorPipeline { /// Closes the pipeline. This should be called once, by the tail interceptor, to indicate that - /// the RPC has completed. + /// the RPC has completed. If this is not called, we will leak. /// - Important: This *must* to be called from the `eventLoop`. @inlinable internal func close() { @@ -437,7 +444,14 @@ extension ClientInterceptorPipeline { self._scheduledClose = nil // Cancel the transport. - self._onCancel(nil) + self._onCancel?(nil) + + // `ClientTransport` holds a reference to us and references to itself via these callbacks. Break + // these references now by replacing the callbacks. + self._onError = nil + self._onCancel = nil + self._onRequestPart = nil + self._onResponsePart = nil } /// Sets up a deadline for the pipeline. diff --git a/Sources/GRPC/Interceptor/ClientTransport.swift b/Sources/GRPC/Interceptor/ClientTransport.swift index 82eaba221..b15c33367 100644 --- a/Sources/GRPC/Interceptor/ClientTransport.swift +++ b/Sources/GRPC/Interceptor/ClientTransport.swift @@ -85,8 +85,8 @@ internal final class ClientTransport { // trailers here and only forward them when we receive the status. private var trailers: HPACKHeaders? - /// The interceptor pipeline connected to this transport. This must be set to `nil` when removed - /// from the `ChannelPipeline` in order to break reference cycles. + /// The interceptor pipeline connected to this transport. The pipeline also holds references + /// to `self` which are dropped when the interceptor pipeline is closed. @usableFromInline internal var _pipeline: ClientInterceptorPipeline? @@ -118,6 +118,7 @@ internal final class ClientTransport { self.logger = logger self.serializer = serializer self.deserializer = deserializer + // The references to self held by the pipeline are dropped when it is closed. self._pipeline = ClientInterceptorPipeline( eventLoop: eventLoop, details: details, @@ -241,7 +242,8 @@ extension ClientTransport { if self.state.cancel() { let error = GRPCError.RPCCancelledByClient() - self.forwardErrorToInterceptors(error) + let status = error.makeGRPCStatus() + self.forwardToInterceptors(.end(status, [:])) self.failBufferedWrites(with: error) self.channel?.close(mode: .all, promise: nil) self.channelPromise?.fail(error) @@ -363,11 +365,9 @@ extension ClientTransport { private func dropReferences() { if self.callEventLoop.inEventLoop { self.channel = nil - self._pipeline = nil } else { self.callEventLoop.execute { self.channel = nil - self._pipeline = nil } } } diff --git a/Tests/GRPCTests/ClientCallTests.swift b/Tests/GRPCTests/ClientCallTests.swift index dbce89775..fce9d3f15 100644 --- a/Tests/GRPCTests/ClientCallTests.swift +++ b/Tests/GRPCTests/ClientCallTests.swift @@ -197,11 +197,7 @@ class ClientCallTests: GRPCTestCase { // Cancellation should succeed. assertThat(try get.cancel().wait(), .doesNotThrow()) - // The status promise will fail. - assertThat( - try promise.futureResult.wait(), - .throws(.instanceOf(GRPCError.RPCCancelledByClient.self)) - ) + assertThat(try promise.futureResult.wait(), .hasCode(.cancelled)) // Cancellation should now fail, we've already cancelled. assertThat(try get.cancel().wait(), .throws(.instanceOf(GRPCError.AlreadyComplete.self))) diff --git a/Tests/GRPCTests/ClientCancellingTests.swift b/Tests/GRPCTests/ClientCancellingTests.swift index 5724b0046..a12d9b828 100644 --- a/Tests/GRPCTests/ClientCancellingTests.swift +++ b/Tests/GRPCTests/ClientCancellingTests.swift @@ -27,7 +27,7 @@ class ClientCancellingTests: EchoTestCaseBase { call.cancel(promise: nil) call.response.whenFailure { error in - XCTAssertTrue(error is GRPCError.RPCCancelledByClient) + XCTAssertEqual((error as? GRPCStatus)?.code, .cancelled) responseReceived.fulfill() } @@ -47,7 +47,7 @@ class ClientCancellingTests: EchoTestCaseBase { call.cancel(promise: nil) call.response.whenFailure { error in - XCTAssertTrue(error is GRPCError.RPCCancelledByClient) + XCTAssertEqual((error as? GRPCStatus)?.code, .cancelled) responseReceived.fulfill() } diff --git a/Tests/GRPCTests/EchoHelpers/Interceptors/EchoInterceptorFactories.swift b/Tests/GRPCTests/EchoHelpers/Interceptors/EchoInterceptorFactories.swift new file mode 100644 index 000000000..203cab499 --- /dev/null +++ b/Tests/GRPCTests/EchoHelpers/Interceptors/EchoInterceptorFactories.swift @@ -0,0 +1,87 @@ +/* + * Copyright 2021, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import EchoModel +import GRPC + +// MARK: - Client + +internal final class EchoClientInterceptors: Echo_EchoClientInterceptorFactoryProtocol { + internal typealias Factory = () -> ClientInterceptor + private var factories: [Factory] = [] + + internal init(_ factories: Factory...) { + self.factories = factories + } + + internal func register(_ factory: @escaping Factory) { + self.factories.append(factory) + } + + private func makeInterceptors() -> [ClientInterceptor] { + return self.factories.map { $0() } + } + + func makeGetInterceptors() -> [ClientInterceptor] { + return self.makeInterceptors() + } + + func makeExpandInterceptors() -> [ClientInterceptor] { + return self.makeInterceptors() + } + + func makeCollectInterceptors() -> [ClientInterceptor] { + return self.makeInterceptors() + } + + func makeUpdateInterceptors() -> [ClientInterceptor] { + return self.makeInterceptors() + } +} + +// MARK: - Server + +internal final class EchoServerInterceptors: Echo_EchoServerInterceptorFactoryProtocol { + internal typealias Factory = () -> ServerInterceptor + private var factories: [Factory] = [] + + internal init(_ factories: Factory...) { + self.factories = factories + } + + internal func register(_ factory: @escaping Factory) { + self.factories.append(factory) + } + + private func makeInterceptors() -> [ServerInterceptor] { + return self.factories.map { $0() } + } + + func makeGetInterceptors() -> [ServerInterceptor] { + return self.makeInterceptors() + } + + func makeExpandInterceptors() -> [ServerInterceptor] { + return self.makeInterceptors() + } + + func makeCollectInterceptors() -> [ServerInterceptor] { + return self.makeInterceptors() + } + + func makeUpdateInterceptors() -> [ServerInterceptor] { + return self.makeInterceptors() + } +} diff --git a/Tests/GRPCTests/InterceptedRPCCancellationTests.swift b/Tests/GRPCTests/InterceptedRPCCancellationTests.swift new file mode 100644 index 000000000..b82ff293e --- /dev/null +++ b/Tests/GRPCTests/InterceptedRPCCancellationTests.swift @@ -0,0 +1,202 @@ +/* + * Copyright 2021, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import EchoImplementation +import EchoModel +@testable import GRPC +import NIOCore +import NIOPosix +import protocol SwiftProtobuf.Message +import XCTest + +final class InterceptedRPCCancellationTests: GRPCTestCase { + func testCancellationWithinInterceptedRPC() throws { + // This test validates that when using interceptors to replay an RPC that the lifecycle of + // the interceptor pipeline is correctly managed. That is, the transport maintains a reference + // to the pipeline for as long as the call is alive (rather than dropping the reference when + // the RPC ends). + let group = MultiThreadedEventLoopGroup(numberOfThreads: System.coreCount) + defer { + XCTAssertNoThrow(try group.syncShutdownGracefully()) + } + + // Interceptor checks that a "magic" header is present. + let serverInterceptors = EchoServerInterceptors(MagicRequiredServerInterceptor.init) + let server = try Server.insecure(group: group) + .withLogger(self.serverLogger) + .withServiceProviders([EchoProvider(interceptors: serverInterceptors)]) + .bind(host: "127.0.0.1", port: 0) + .wait() + defer { + XCTAssertNoThrow(try server.close().wait()) + } + + let connection = ClientConnection.insecure(group: group) + .withBackgroundActivityLogger(self.clientLogger) + .connect(host: "127.0.0.1", port: server.channel.localAddress!.port!) + defer { + XCTAssertNoThrow(try connection.close().wait()) + } + + let clientInterceptors = EchoClientInterceptors() + // Retries an RPC with a "magic" header if it fails with the permission denied status code. + clientInterceptors.register { + MagicAddingClientInterceptor(channel: connection) + } + + let echo = Echo_EchoClient(channel: connection, interceptors: clientInterceptors) + + let receivedFirstResponse = connection.eventLoop.makePromise(of: Void.self) + let update = echo.update { _ in + receivedFirstResponse.succeed(()) + } + + XCTAssertNoThrow(try update.sendMessage(.with { $0.text = "ping" }).wait()) + // Wait for the pong: it means the second RPC is up and running and the first should have + // completed. + XCTAssertNoThrow(try receivedFirstResponse.futureResult.wait()) + XCTAssertNoThrow(try update.cancel().wait()) + + let status = try update.status.wait() + XCTAssertEqual(status.code, .cancelled) + } +} + +final class MagicRequiredServerInterceptor< + Request: Message, + Response: Message +>: ServerInterceptor { + override func receive( + _ part: GRPCServerRequestPart, + context: ServerInterceptorContext + ) { + switch part { + case let .metadata(metadata): + if metadata.contains(name: "magic") { + context.log.debug("metadata contains magic; accepting rpc") + context.receive(part) + } else { + context.log.debug("metadata does not contains magic; rejecting rpc") + let status = GRPCStatus(code: .permissionDenied, message: nil) + context.send(.end(status, [:]), promise: nil) + } + case .message, .end: + context.receive(part) + } + } +} + +final class MagicAddingClientInterceptor< + Request: Message, + Response: Message +>: ClientInterceptor { + private let channel: GRPCChannel + private var requestParts = CircularBuffer>() + private var retry: Call? + + init(channel: GRPCChannel) { + self.channel = channel + } + + override func cancel( + promise: EventLoopPromise?, + context: ClientInterceptorContext + ) { + if let retry = self.retry { + context.log.debug("cancelling retry RPC") + retry.cancel(promise: promise) + } else { + context.cancel(promise: promise) + } + } + + override func send( + _ part: GRPCClientRequestPart, + promise: EventLoopPromise?, + context: ClientInterceptorContext + ) { + if let retry = self.retry { + context.log.debug("retrying part \(part)") + retry.send(part, promise: promise) + } else { + switch part { + case .metadata: + // Replace the metadata with the magic words. + self.requestParts.append(.metadata(["magic": "it's real!"])) + case .message, .end: + self.requestParts.append(part) + } + context.send(part, promise: promise) + } + } + + override func receive( + _ part: GRPCClientResponsePart, + context: ClientInterceptorContext + ) { + switch part { + case .metadata, .message: + XCTFail("Unexpected response part \(part)") + context.receive(part) + + case let .end(status, _): + guard status.code == .permissionDenied else { + XCTFail("Unexpected status code \(status)") + context.receive(part) + return + } + + XCTAssertNil(self.retry) + + context.log.debug("initial rpc failed, retrying") + + self.retry = self.channel.makeCall( + path: context.path, + type: context.type, + callOptions: CallOptions(logger: context.logger), + interceptors: [] + ) + + self.retry!.invoke(onError: { + context.log.debug("intercepting error from retried rpc") + context.errorCaught($0) + }) { responsePart in + context.log.debug("intercepting response part from retried rpc") + context.receive(responsePart) + } + + while let requestPart = self.requestParts.popFirst() { + context.log.debug("replaying \(requestPart) on new rpc") + self.retry!.send(requestPart, promise: nil) + } + } + } +} + +// MARK: - GRPC Logger + +// Our tests also check the "Source" of a logger is "GRPC". That assertion fails when we log from +// tests so we'll use our internal logger instead. +extension ClientInterceptorContext { + var log: GRPCLogger { + return GRPCLogger(wrapping: self.logger) + } +} + +extension ServerInterceptorContext { + var log: GRPCLogger { + return GRPCLogger(wrapping: self.logger) + } +}