diff --git a/Package.resolved b/Package.resolved index d29152f96..0fcf0992c 100644 --- a/Package.resolved +++ b/Package.resolved @@ -51,8 +51,8 @@ "repositoryURL": "https://github.com/apple/swift-protobuf.git", "state": { "branch": null, - "revision": "da75a93ac017534e0028e83c0e4fc4610d2acf48", - "version": "1.7.0" + "revision": "7790acf0a81d08429cb20375bf42a8c7d279c5fe", + "version": "1.8.0" } } ] diff --git a/Sources/GRPC/CallHandlers/_BaseCallHandler.swift b/Sources/GRPC/CallHandlers/_BaseCallHandler.swift index eb7dd8e6b..395bd537c 100644 --- a/Sources/GRPC/CallHandlers/_BaseCallHandler.swift +++ b/Sources/GRPC/CallHandlers/_BaseCallHandler.swift @@ -25,7 +25,10 @@ import Logging /// - Important: This is **NOT** part of the public API. public class _BaseCallHandler: GRPCCallHandler { public func makeGRPCServerCodec() -> ChannelHandler { - return HTTP1ToGRPCServerCodec(logger: self.logger) + return HTTP1ToGRPCServerCodec( + encoding: self.callHandlerContext.encoding, + logger: self.logger + ) } /// Called when the request head has been received. diff --git a/Sources/GRPC/ClientCalls/BaseClientCall.swift b/Sources/GRPC/ClientCalls/BaseClientCall.swift index df0df35e4..8ceaf3eac 100644 --- a/Sources/GRPC/ClientCalls/BaseClientCall.swift +++ b/Sources/GRPC/ClientCalls/BaseClientCall.swift @@ -58,6 +58,7 @@ public class BaseClientCall: Client internal let multiplexer: EventLoopFuture // Note: documentation is inherited from the `ClientCall` protocol. + public let options: CallOptions public let subchannel: EventLoopFuture public let initialMetadata: EventLoopFuture public let trailingMetadata: EventLoopFuture @@ -75,12 +76,14 @@ public class BaseClientCall: Client eventLoop: EventLoop, multiplexer: EventLoopFuture, callType: GRPCCallType, + callOptions: CallOptions, responseHandler: GRPCClientResponseChannelHandler, requestHandler: _ClientRequestChannelHandler, logger: Logger ) { self.logger = logger self.multiplexer = multiplexer + self.options = callOptions let streamPromise = eventLoop.makePromise(of: Channel.self) @@ -135,7 +138,6 @@ extension _GRPCRequestHead { path: String, host: String, requestID: String, - encoding: ClientConnection.Configuration.MessageEncoding, options: CallOptions ) { var customMetadata = options.customMetadata @@ -143,13 +145,6 @@ extension _GRPCRequestHead { customMetadata.add(name: requestIDHeader, value: requestID) } - var encoding = encoding - // Compression is disabled at the RPC level; remove outbound (request) encoding. This will stop - // any 'grpc-encoding' header being sent to the peer. - if options.disableCompression { - encoding.outbound = nil - } - self = _GRPCRequestHead( method: options.cacheable ? "GET" : "POST", scheme: scheme, @@ -157,7 +152,7 @@ extension _GRPCRequestHead { host: host, timeout: options.timeout, customMetadata: customMetadata, - encoding: encoding + encoding: options.messageEncoding ) } } diff --git a/Sources/GRPC/ClientCalls/BidirectionalStreamingCall.swift b/Sources/GRPC/ClientCalls/BidirectionalStreamingCall.swift index 0c09cbb49..69bd44511 100644 --- a/Sources/GRPC/ClientCalls/BidirectionalStreamingCall.swift +++ b/Sources/GRPC/ClientCalls/BidirectionalStreamingCall.swift @@ -60,7 +60,6 @@ public final class BidirectionalStreamingCall { get } @@ -62,10 +65,10 @@ public protocol StreamingRequestClientCall: ClientCall { /// /// - Parameters: /// - message: The message to send. - /// - disableCompression: Whether compression should be disabled for this message. Ignored if - /// compression was not enabled for the connection or RPC. + /// - compression: Whether compression should be used for this message. Ignored if compression + /// was not enabled for the RPC. /// - Returns: A future which will be fullfilled when the message has been sent. - func sendMessage(_ message: RequestPayload, disableCompression: Bool) -> EventLoopFuture + func sendMessage(_ message: RequestPayload, compression: Compression) -> EventLoopFuture /// Sends a message to the service. /// @@ -73,10 +76,10 @@ public protocol StreamingRequestClientCall: ClientCall { /// /// - Parameters: /// - message: The message to send. - /// - disableCompression: Whether compression should be disabled for this message. Ignored if - /// compression was not enabled for the connection or RPC. + /// - compression: Whether compression should be used for this message. Ignored if compression + /// was not enabled for the RPC. /// - promise: A promise to be fulfilled when the message has been sent. - func sendMessage(_ message: RequestPayload, disableCompression: Bool, promise: EventLoopPromise?) + func sendMessage(_ message: RequestPayload, compression: Compression, promise: EventLoopPromise?) /// Sends a sequence of messages to the service. /// @@ -84,9 +87,9 @@ public protocol StreamingRequestClientCall: ClientCall { /// /// - Parameters: /// - messages: The sequence of messages to send. - /// - disableCompression: Whether compression should be disabled for these messages. Ignored if - /// compression was not enabled for the connection or RPC. - func sendMessages(_ messages: S, disableCompression: Bool) -> EventLoopFuture where S.Element == RequestPayload + /// - compression: Whether compression should be used for this message. Ignored if compression + /// was not enabled for the RPC. + func sendMessages(_ messages: S, compression: Compression) -> EventLoopFuture where S.Element == RequestPayload /// Sends a sequence of messages to the service. /// @@ -94,10 +97,10 @@ public protocol StreamingRequestClientCall: ClientCall { /// /// - Parameters: /// - messages: The sequence of messages to send. - /// - disableCompression: Whether compression should be disabled for these messages. Ignored if - /// compression was not enabled for the connection or RPC. + /// - compression: Whether compression should be used for this message. Ignored if compression + /// was not enabled for the RPC. /// - promise: A promise to be fulfilled when all messages have been sent successfully. - func sendMessages(_ messages: S, disableCompression: Bool, promise: EventLoopPromise?) where S.Element == RequestPayload + func sendMessages(_ messages: S, compression: Compression, promise: EventLoopPromise?) where S.Element == RequestPayload /// Returns a future which can be used as a message queue. /// @@ -137,30 +140,42 @@ public protocol UnaryResponseClientCall: ClientCall { extension StreamingRequestClientCall { public func sendMessage( _ message: RequestPayload, - disableCompression: Bool = false + compression: Compression = .deferToCallDefault ) -> EventLoopFuture { return self.subchannel.flatMap { channel in - return channel.writeAndFlush(_GRPCClientRequestPart.message(.init(message, disableCompression: disableCompression))) + let context = _MessageContext( + message, + compressed: compression.isEnabled(enabledOnCall: self.options.messageEncoding.enabledForRequests) + ) + return channel.writeAndFlush(_GRPCClientRequestPart.message(context)) } } public func sendMessage( _ message: RequestPayload, - disableCompression: Bool = false, + compression: Compression = .deferToCallDefault, promise: EventLoopPromise? ) { self.subchannel.whenSuccess { channel in - channel.writeAndFlush(_GRPCClientRequestPart.message(.init(message, disableCompression: disableCompression)), promise: promise) + let context = _MessageContext( + message, + compressed: compression.isEnabled(enabledOnCall: self.options.messageEncoding.enabledForRequests) + ) + channel.writeAndFlush(_GRPCClientRequestPart.message(context), promise: promise) } } public func sendMessages( _ messages: S, - disableCompression: Bool = false + compression: Compression = .deferToCallDefault ) -> EventLoopFuture where S.Element == RequestPayload { return self.subchannel.flatMap { channel -> EventLoopFuture in - let writeFutures = messages.map { message in - channel.write(_GRPCClientRequestPart.message(.init(message, disableCompression: disableCompression))) + let writeFutures = messages.map { message -> EventLoopFuture in + let context = _MessageContext( + message, + compressed: compression.isEnabled(enabledOnCall: self.options.messageEncoding.enabledForRequests) + ) + return channel.write(_GRPCClientRequestPart.message(context)) } channel.flush() return EventLoopFuture.andAllSucceed(writeFutures, on: channel.eventLoop) @@ -169,7 +184,7 @@ extension StreamingRequestClientCall { public func sendMessages( _ messages: S, - disableCompression: Bool = false, + compression: Compression = .deferToCallDefault, promise: EventLoopPromise? ) where S.Element == RequestPayload { if let promise = promise { @@ -177,7 +192,11 @@ extension StreamingRequestClientCall { } else { self.subchannel.whenSuccess { channel in for message in messages { - channel.write(_GRPCClientRequestPart.message(.init(message, disableCompression: disableCompression)), promise: nil) + let context = _MessageContext( + message, + compressed: compression.isEnabled(enabledOnCall: self.options.messageEncoding.enabledForRequests) + ) + channel.write(_GRPCClientRequestPart.message(context), promise: nil) } channel.flush() } diff --git a/Sources/GRPC/ClientCalls/ClientStreamingCall.swift b/Sources/GRPC/ClientCalls/ClientStreamingCall.swift index fb23870d0..379497e57 100644 --- a/Sources/GRPC/ClientCalls/ClientStreamingCall.swift +++ b/Sources/GRPC/ClientCalls/ClientStreamingCall.swift @@ -64,7 +64,6 @@ public final class ClientStreamingCall( requestHead: requestHead, - request: .init(request) + request: .init(request, compressed: callOptions.messageEncoding.enabledForRequests) ) super.init( eventLoop: connection.eventLoop, multiplexer: connection.multiplexer, callType: .serverStreaming, + callOptions: callOptions, responseHandler: responseHandler, requestHandler: requestHandler, logger: logger diff --git a/Sources/GRPC/ClientCalls/UnaryCall.swift b/Sources/GRPC/ClientCalls/UnaryCall.swift index 5808e82d6..4d079b54d 100644 --- a/Sources/GRPC/ClientCalls/UnaryCall.swift +++ b/Sources/GRPC/ClientCalls/UnaryCall.swift @@ -61,19 +61,19 @@ public final class UnaryCall( requestHead: requestHead, - request: .init(request) + request: .init(request, compressed: callOptions.messageEncoding.enabledForRequests) ) super.init( eventLoop: connection.channel.eventLoop, multiplexer: connection.multiplexer, callType: .unary, + callOptions: callOptions, responseHandler: responseHandler, requestHandler: requestHandler, logger: logger diff --git a/Sources/GRPC/ClientConnection.swift b/Sources/GRPC/ClientConnection.swift index f4faf19b4..f4fd4b80a 100644 --- a/Sources/GRPC/ClientConnection.swift +++ b/Sources/GRPC/ClientConnection.swift @@ -476,20 +476,6 @@ extension ClientConnection { /// be `nil`. public var connectionBackoff: ConnectionBackoff? - /// The compression used for requests, and the compression algorithms to advertise as acceptable - /// for the remote peer to use for encoding responses. - /// - /// If compression is enabled for a connection it may be disabled for requests on any RPC by - /// setting `CallOptions.disableCompression` to `true`. - /// - /// Compression may also be disabled at the message-level for streaming requests (i.e. client - /// streaming and bidirectional streaming RPCs) by setting `disableCompression` to `true` in - /// `sendMessage(_:disableCompression)`, `sendMessage(_:disableCompression:promise)`, - /// `sendMessages(_:disableCompression)` or `sendMessages(_:disableCompression:promise)`. - /// - /// Note that disabling compression has no effect if compression is disabled on the connection. - public var messageEncoding: MessageEncoding - /// The HTTP protocol used for this connection. public var httpProtocol: HTTP2ToHTTP1ClientCodec.HTTPProtocol { return self.tls == nil ? .http : .https @@ -511,8 +497,7 @@ extension ClientConnection { errorDelegate: ClientErrorDelegate? = LoggingClientErrorDelegate(), connectivityStateDelegate: ConnectivityStateDelegate? = nil, tls: Configuration.TLS? = nil, - connectionBackoff: ConnectionBackoff? = ConnectionBackoff(), - messageEncoding: MessageEncoding = .none + connectionBackoff: ConnectionBackoff? = ConnectionBackoff() ) { self.target = target self.eventLoopGroup = eventLoopGroup @@ -520,7 +505,6 @@ extension ClientConnection { self.connectivityStateDelegate = connectivityStateDelegate self.tls = tls self.connectionBackoff = connectionBackoff - self.messageEncoding = messageEncoding } } } diff --git a/Sources/GRPC/ClientOptions.swift b/Sources/GRPC/ClientOptions.swift index 8958c9bb2..73a90c743 100644 --- a/Sources/GRPC/ClientOptions.swift +++ b/Sources/GRPC/ClientOptions.swift @@ -26,6 +26,18 @@ public struct CallOptions { /// The call timeout. public var timeout: GRPCTimeout + /// The compression used for requests, and the compression algorithms to advertise as acceptable + /// for the remote peer to use for encoding responses. + /// + /// Compression may also be disabled at the message-level for streaming requests (i.e. client + /// streaming and bidirectional streaming RPCs) by setting `compression` to `.disabled` in + /// `sendMessage(_:compression)`, `sendMessage(_:compression:promise)`, + /// `sendMessages(_:compression)` or `sendMessages(_:compression:promise)`. + /// + /// Note that enabling `compression` via the `sendMessage` or `sendMessages` methods only applies + /// if encoding has been specified in these options. + public var messageEncoding: MessageEncoding + /// Whether the call is cacheable. public var cacheable: Bool @@ -46,24 +58,20 @@ public struct CallOptions { /// messages associated with the call. public var requestIDHeader: String? - /// Disables request compression on this call. Ignored if compression is disabled at the - /// connection level. - public var disableCompression: Bool - public init( customMetadata: HPACKHeaders = HPACKHeaders(), timeout: GRPCTimeout = GRPCTimeout.infinite, - cacheable: Bool = false, + messageEncoding: MessageEncoding = .none, requestIDProvider: RequestIDProvider = .autogenerated, requestIDHeader: String? = nil, - disableCompression: Bool = false + cacheable: Bool = false ) { self.customMetadata = customMetadata self.timeout = timeout - self.cacheable = false + self.messageEncoding = messageEncoding self.requestIDProvider = requestIDProvider self.requestIDHeader = requestIDHeader - self.disableCompression = disableCompression + self.cacheable = false } /// How Request IDs should be provided. diff --git a/Sources/GRPC/Compression/CompressionAlgorithm.swift b/Sources/GRPC/Compression/CompressionAlgorithm.swift index 39130111d..25664108f 100644 --- a/Sources/GRPC/Compression/CompressionAlgorithm.swift +++ b/Sources/GRPC/Compression/CompressionAlgorithm.swift @@ -24,14 +24,15 @@ public struct CompressionAlgorithm: Equatable { public static let deflate = CompressionAlgorithm(.deflate) public static let gzip = CompressionAlgorithm(.gzip) - public static let all = Algorithm.allCases.map(CompressionAlgorithm.init) + // The order here is important: most compression to least. + public static let all: [CompressionAlgorithm] = [.gzip, .deflate, .identity] /// The name of the compression algorithm. public var name: String { return self.algorithm.rawValue } - internal enum Algorithm: String, CaseIterable { + internal enum Algorithm: String { case identity case deflate case gzip diff --git a/Sources/GRPC/Compression/MessageEncoding.swift b/Sources/GRPC/Compression/MessageEncoding.swift index 0e1b55fc2..8d01200ca 100644 --- a/Sources/GRPC/Compression/MessageEncoding.swift +++ b/Sources/GRPC/Compression/MessageEncoding.swift @@ -14,11 +14,39 @@ * limitations under the License. */ -extension ClientConnection.Configuration { + +/// Whether compression should be enabled for the message. +public enum Compression { + /// Enable compression. Note that this will be ignored if compression has not been enabled or is + /// not supported on the call. + case enabled + + /// Disable compression. + case disabled + + /// Defer to the call (the `CallOptions` for the client, and the context for the server) to + /// determine whether compression should be used for the message. + case deferToCallDefault +} + +extension Compression { + func isEnabled(enabledOnCall: Bool) -> Bool { + switch self { + case .enabled: + return enabledOnCall + case .disabled: + return false + case .deferToCallDefault: + return enabledOnCall + } + } +} + +extension CallOptions { public struct MessageEncoding { public init( forRequests outbound: CompressionAlgorithm?, - acceptableForResponses inbound: [CompressionAlgorithm] + acceptableForResponses inbound: [CompressionAlgorithm] = CompressionAlgorithm.all ) { self.outbound = outbound self.inbound = inbound @@ -41,11 +69,36 @@ extension ClientConnection.Configuration { forRequests: .identity, acceptableForResponses: CompressionAlgorithm.all ) + + /// Whether compression is enabled for requests. + internal var enabledForRequests: Bool { + return self.outbound != nil + } } } -extension ClientConnection.Configuration.MessageEncoding { +extension CallOptions.MessageEncoding { var acceptEncodingHeader: String { return self.inbound.map { $0.name }.joined(separator: ",") } } + +extension Server.Configuration { + public struct MessageEncoding { + /// The set of compression algorithms advertised that we will accept from clients. Note that + /// clients may send us messages compressed with algorithms not included in this list; if we + /// support it then we still accept the message. + public var enabled: [CompressionAlgorithm] + + public init(enabled: [CompressionAlgorithm]) { + self.enabled = enabled + } + + // All supported algorithms are enabled. + public static let enabled = MessageEncoding(enabled: CompressionAlgorithm.all) + + /// No compression. + public static let none = MessageEncoding(enabled: [.identity]) + } + +} diff --git a/Sources/GRPC/GRPCClientStateMachine.swift b/Sources/GRPC/GRPCClientStateMachine.swift index f830e2a28..0fa843696 100644 --- a/Sources/GRPC/GRPCClientStateMachine.swift +++ b/Sources/GRPC/GRPCClientStateMachine.swift @@ -215,10 +215,10 @@ struct GRPCClientStateMachine { /// request will be written. mutating func sendRequest( _ message: Request, - disableCompression: Bool, + compressed: Bool, allocator: ByteBufferAllocator ) -> Result { - return self.state.sendRequest(message, disableCompression: disableCompression, allocator: allocator) + return self.state.sendRequest(message, compressed: compressed, allocator: allocator) } /// Closes the request stream. @@ -361,18 +361,18 @@ extension GRPCClientStateMachine.State { /// See `GRPCClientStateMachine.sendRequest(_:allocator:)`. mutating func sendRequest( _ message: Request, - disableCompression: Bool, + compressed: Bool, allocator: ByteBufferAllocator ) -> Result { let result: Result switch self { case .clientActiveServerIdle(var writeState, let readArity): - result = writeState.write(message, disableCompression: disableCompression, allocator: allocator) + result = writeState.write(message, compressed: compressed, allocator: allocator) self = .clientActiveServerIdle(writeState: writeState, readArity: readArity) case .clientActiveServerActive(var writeState, let readState): - result = writeState.write(message, disableCompression: disableCompression, allocator: allocator) + result = writeState.write(message, compressed: compressed, allocator: allocator) self = .clientActiveServerActive(writeState: writeState, readState: readState) case .clientClosedServerIdle, @@ -507,7 +507,7 @@ extension GRPCClientStateMachine.State { path: String, timeout: GRPCTimeout, customMetadata: HPACKHeaders, - compression: ClientConnection.Configuration.MessageEncoding + compression: CallOptions.MessageEncoding ) -> HPACKHeaders { // Note: we don't currently set the 'grpc-encoding' header, if we do we will need to feed that // encoded into the message writer. diff --git a/Sources/GRPC/GRPCServerRequestRoutingHandler.swift b/Sources/GRPC/GRPCServerRequestRoutingHandler.swift index 72727d6ee..cba0c2323 100644 --- a/Sources/GRPC/GRPCServerRequestRoutingHandler.swift +++ b/Sources/GRPC/GRPCServerRequestRoutingHandler.swift @@ -44,6 +44,7 @@ public protocol CallHandlerProvider: class { public struct CallHandlerContext { internal var errorDelegate: ServerErrorDelegate? internal var logger: Logger + internal var encoding: Server.Configuration.MessageEncoding } /// Attempts to route a request to a user-provided call handler. Also validates that the request has @@ -57,6 +58,7 @@ public struct CallHandlerContext { public final class GRPCServerRequestRoutingHandler { private let logger: Logger private let servicesByName: [String: CallHandlerProvider] + private let encoding: Server.Configuration.MessageEncoding private weak var errorDelegate: ServerErrorDelegate? private enum State: Equatable { @@ -66,10 +68,17 @@ public final class GRPCServerRequestRoutingHandler { private var state: State = .notConfigured - public init(servicesByName: [String: CallHandlerProvider], errorDelegate: ServerErrorDelegate?, logger: Logger) { + public init( + servicesByName: [String: CallHandlerProvider], + encoding: Server.Configuration.MessageEncoding, + errorDelegate: ServerErrorDelegate?, + logger: Logger + ) { self.servicesByName = servicesByName + self.encoding = encoding self.errorDelegate = errorDelegate self.logger = logger + } } @@ -215,7 +224,11 @@ extension GRPCServerRequestRoutingHandler: ChannelInboundHandler, RemovableChann // Unset the channel handler: it shouldn't be used for downstream handlers. logger[metadataKey: MetadataKey.channelHandler] = nil - let context = CallHandlerContext(errorDelegate: self.errorDelegate, logger: logger) + let context = CallHandlerContext( + errorDelegate: self.errorDelegate, + logger: logger, + encoding: self.encoding + ) guard uriComponents.count >= 3 && uriComponents[0].isEmpty, let providerForServiceName = servicesByName[uriComponents[1]], diff --git a/Sources/GRPC/HTTP1ToGRPCServerCodec.swift b/Sources/GRPC/HTTP1ToGRPCServerCodec.swift index 38558fa51..ac06e2ce9 100644 --- a/Sources/GRPC/HTTP1ToGRPCServerCodec.swift +++ b/Sources/GRPC/HTTP1ToGRPCServerCodec.swift @@ -33,7 +33,7 @@ public enum _GRPCServerRequestPart { /// - Important: This is **NOT** part of the public API. public enum _GRPCServerResponsePart { case headers(HTTPHeaders) - case message(ResponsePayload) + case message(_MessageContext) case statusAndTrailers(GRPCStatus, HTTPHeaders) } @@ -45,7 +45,8 @@ public enum _GRPCServerResponsePart { /// /// The translation from HTTP2 to HTTP1 is done by `HTTP2ToHTTP1ServerCodec`. public final class HTTP1ToGRPCServerCodec { - public init(logger: Logger) { + public init(encoding: Server.Configuration.MessageEncoding, logger: Logger) { + self.encoding = encoding self.logger = logger var accessLog = Logger(subsystem: .serverAccess) @@ -57,10 +58,14 @@ public final class HTTP1ToGRPCServerCodec RequestEncodingValidation { + guard let algorithm = CompressionAlgorithm(rawValue: requestEncoding) else { + return .unsupported + } + + if self.encoding.enabled.contains(algorithm) { + return .supported(algorithm) + } else { + return .supportedButNotDisclosed(algorithm) + } + } + + /// Makes a 'grpc-accept-encoding' header from the advertised encodings and an additional value + /// if one is specified. + func makeAcceptEncodingHeader(includeExtra extra: CompressionAlgorithm? = nil) -> String? { + switch (self.encoding.enabled.isEmpty, extra) { + case (false, .some(let extra)): + return (self.encoding.enabled + CollectionOfOne(extra)).map { $0.name }.joined(separator: ",") + case (false, .none): + return self.encoding.enabled.map { $0.name }.joined(separator: ",") + case (true, .some(let extra)): + return extra.name + case (true, .none): + return nil + } + } + + /// Selects an appropriate response encoding from the list of encodings sent to us by the client. + /// Returns `nil` if there were no appropriate algorithms, in which case the server will send + /// messages uncompressed. + func selectResponseEncoding(from acceptableEncoding: [Substring]) -> CompressionAlgorithm? { + return acceptableEncoding.compactMap { + CompressionAlgorithm(rawValue: String($0)) + }.first { + self.encoding.enabled.contains($0) + } + } +} diff --git a/Sources/GRPC/LengthPrefixedMessageReader.swift b/Sources/GRPC/LengthPrefixedMessageReader.swift index 3b5691031..6a0862577 100644 --- a/Sources/GRPC/LengthPrefixedMessageReader.swift +++ b/Sources/GRPC/LengthPrefixedMessageReader.swift @@ -31,8 +31,8 @@ import Logging /// - SeeAlso: /// [gRPC Protocol](https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md) internal struct LengthPrefixedMessageReader { - var compression: CompressionAlgorithm? - var decompressor: Zlib.Inflate? + let compression: CompressionAlgorithm? + private let decompressor: Zlib.Inflate? init(compression: CompressionAlgorithm? = nil) { self.compression = compression diff --git a/Sources/GRPC/LengthPrefixedMessageWriter.swift b/Sources/GRPC/LengthPrefixedMessageWriter.swift index dfb0f0417..5ac34601d 100644 --- a/Sources/GRPC/LengthPrefixedMessageWriter.swift +++ b/Sources/GRPC/LengthPrefixedMessageWriter.swift @@ -20,7 +20,7 @@ internal struct LengthPrefixedMessageWriter { static let metadataLength = 5 /// The compression algorithm to use, if one should be used. - private let compression: CompressionAlgorithm? + let compression: CompressionAlgorithm? private let compressor: Zlib.Deflate? /// Whether the compression message flag should be set. @@ -49,10 +49,10 @@ internal struct LengthPrefixedMessageWriter { /// - Returns: A `ByteBuffer` containing a gRPC length-prefixed message. /// - Precondition: `compression.supported` is `true`. /// - Note: See `LengthPrefixedMessageReader` for more details on the format. - func write(_ payload: GRPCPayload, into buffer: inout ByteBuffer, disableCompression: Bool = false) throws { + func write(_ payload: GRPCPayload, into buffer: inout ByteBuffer, compressed: Bool = true) throws { buffer.reserveCapacity(buffer.writerIndex + LengthPrefixedMessageWriter.metadataLength) - if !disableCompression, let compressor = self.compressor { + if compressed, let compressor = self.compressor { // Set the compression byte. buffer.writeInteger(UInt8(1)) @@ -72,14 +72,10 @@ internal struct LengthPrefixedMessageWriter { // Finally, the compression context should be reset between messages. compressor.reset() } else { - // 'identity' compression has no compressor but should still set the compression bit set - // unless we explicitly disable compression. - if self.compression?.algorithm == .identity && !disableCompression { - buffer.writeInteger(UInt8(1)) - } else { - buffer.writeInteger(UInt8(0)) - } - + // We could be using 'identity' compression, but since the result is the same we'll just + // say it isn't compressed. + buffer.writeInteger(UInt8(0)) + // Leave a gap for the length, we'll set it in a moment. let payloadSizeIndex = buffer.writerIndex buffer.moveWriterIndex(forwardBy: MemoryLayout.size) diff --git a/Sources/GRPC/ReadWriteStates.swift b/Sources/GRPC/ReadWriteStates.swift index fbe4438a2..88893b5cf 100644 --- a/Sources/GRPC/ReadWriteStates.swift +++ b/Sources/GRPC/ReadWriteStates.swift @@ -55,7 +55,7 @@ enum WriteState { /// written. mutating func write( _ message: GRPCPayload, - disableCompression: Bool, + compressed: Bool, allocator: ByteBufferAllocator ) -> Result { switch self { @@ -66,7 +66,7 @@ enum WriteState { // Zero is fine: the writer will allocate the correct amount of space. var buffer = allocator.buffer(capacity: 0) do { - try writer.write(message, into: &buffer, disableCompression: disableCompression) + try writer.write(message, into: &buffer, compressed: compressed) } catch { self = .notWriting return .failure(.serializationFailed) diff --git a/Sources/GRPC/Server.swift b/Sources/GRPC/Server.swift index 753e8e8ce..f41c37e66 100644 --- a/Sources/GRPC/Server.swift +++ b/Sources/GRPC/Server.swift @@ -103,6 +103,7 @@ public final class Server { let logger = Logger(subsystem: .serverChannelCall, metadata: [MetadataKey.requestID: "\(UUID())"]) let handler = GRPCServerRequestRoutingHandler( servicesByName: configuration.serviceProvidersByName, + encoding: configuration.messageEncoding, errorDelegate: configuration.errorDelegate, logger: logger ) @@ -180,6 +181,16 @@ extension Server { /// TLS configuration for this connection. `nil` if TLS is not desired. public var tls: TLS? + /// The compression configuration for requests and responses. + /// + /// If compression is enabled for the server it may be disabled for responses on any RPC by + /// setting `compressionEnabled` to `false` on the context of the call. + /// + /// Compression may also be disabled at the message-level for streaming responses (i.e. server + /// streaming and bidirectional streaming RPCs) by passing setting `compression` to `.disabled` + /// in `sendResponse(_:compression)`. + public var messageEncoding: MessageEncoding + /// Create a `Configuration` with some pre-defined defaults. /// /// - Parameter target: The target to bind to. @@ -188,18 +199,21 @@ extension Server { /// to handle requests. /// - Parameter errorDelegate: The error delegate, defaulting to a logging delegate. /// - Parameter tlsConfiguration: TLS configuration, defaulting to `nil`. + /// - Parameter messageEncoding: Message compression configuration, defaulting to no compression. public init( target: BindTarget, eventLoopGroup: EventLoopGroup, serviceProviders: [CallHandlerProvider], errorDelegate: ServerErrorDelegate? = LoggingServerErrorDelegate.shared, - tls: TLS? = nil + tls: TLS? = nil, + messageEncoding: MessageEncoding = .none ) { self.target = target self.eventLoopGroup = eventLoopGroup self.serviceProviders = serviceProviders self.errorDelegate = errorDelegate self.tls = tls + self.messageEncoding = messageEncoding } } } diff --git a/Sources/GRPC/ServerCallContexts/ServerCallContext.swift b/Sources/GRPC/ServerCallContexts/ServerCallContext.swift index b211d977b..98cdf25ff 100644 --- a/Sources/GRPC/ServerCallContexts/ServerCallContext.swift +++ b/Sources/GRPC/ServerCallContexts/ServerCallContext.swift @@ -29,6 +29,11 @@ public protocol ServerCallContext: class { /// The logger used for this call. var logger: Logger { get } + + /// Whether compression should be enabled for responses, defaulting to `true`. Note that for + /// this value to take effect compression must have been enabled on the server and a compression + /// algorithm must have been negotiated with the client. + var compressionEnabled: Bool { get set } } /// Base class providing data provided to the framework user for all server calls. @@ -36,6 +41,7 @@ open class ServerCallContextBase: ServerCallContext { public let eventLoop: EventLoop public let request: HTTPRequestHead public let logger: Logger + public var compressionEnabled: Bool = true /// Metadata to return at the end of the RPC. If this is required it should be updated before /// the `responsePromise` or `statusPromise` is fulfilled. diff --git a/Sources/GRPC/ServerCallContexts/StreamingResponseCallContext.swift b/Sources/GRPC/ServerCallContexts/StreamingResponseCallContext.swift index 2ce720447..c47af6e0d 100644 --- a/Sources/GRPC/ServerCallContexts/StreamingResponseCallContext.swift +++ b/Sources/GRPC/ServerCallContexts/StreamingResponseCallContext.swift @@ -35,7 +35,13 @@ open class StreamingResponseCallContext: ServerCal super.init(eventLoop: eventLoop, request: request, logger: logger) } - open func sendResponse(_ message: ResponsePayload) -> EventLoopFuture { + /// Send a response to the client. + /// + /// - Parameter message: The message to send to the client. + /// - Parameter compression: Whether compression should be used for this response. If compression + /// is enabled in the call context, the value passed here takes precedence. Defaults to deferring + /// to the value set on the call context. + open func sendResponse(_ message: ResponsePayload, compression: Compression = .deferToCallDefault) -> EventLoopFuture { fatalError("needs to be overridden") } } @@ -70,9 +76,10 @@ open class StreamingResponseCallContextImpl: Strea } } - open override func sendResponse(_ message: ResponsePayload) -> EventLoopFuture { + open override func sendResponse(_ message: ResponsePayload, compression: Compression = .deferToCallDefault) -> EventLoopFuture { let promise: EventLoopPromise = eventLoop.makePromise() - channel.writeAndFlush(NIOAny(WrappedResponse.message(message)), promise: promise) + let messageContext = _MessageContext(message, compressed: compression.isEnabled(enabledOnCall: self.compressionEnabled)) + self.channel.writeAndFlush(NIOAny(WrappedResponse.message(messageContext)), promise: promise) return promise.futureResult } } @@ -83,7 +90,7 @@ open class StreamingResponseCallContextImpl: Strea open class StreamingResponseCallContextTestStub: StreamingResponseCallContext { open var recordedResponses: [ResponsePayload] = [] - open override func sendResponse(_ message: ResponsePayload) -> EventLoopFuture { + open override func sendResponse(_ message: ResponsePayload, compression: Compression = .deferToCallDefault) -> EventLoopFuture { recordedResponses.append(message) return eventLoop.makeSucceededFuture(()) } diff --git a/Sources/GRPC/ServerCallContexts/UnaryResponseCallContext.swift b/Sources/GRPC/ServerCallContexts/UnaryResponseCallContext.swift index 88b618713..4fe67c903 100644 --- a/Sources/GRPC/ServerCallContexts/UnaryResponseCallContext.swift +++ b/Sources/GRPC/ServerCallContexts/UnaryResponseCallContext.swift @@ -70,8 +70,9 @@ open class UnaryResponseCallContextImpl: UnaryResp responsePromise.futureResult // Send the response provided to the promise. - .map { responseMessage in - self.channel.writeAndFlush(NIOAny(WrappedResponse.message(responseMessage))) + .map { responseMessage -> EventLoopFuture in + let message = _MessageContext(responseMessage, compressed: self.compressionEnabled) + return self.channel.writeAndFlush(NIOAny(WrappedResponse.message(message))) } .map { _ in self.responseStatus diff --git a/Sources/GRPC/_GRPCClientChannelHandler.swift b/Sources/GRPC/_GRPCClientChannelHandler.swift index 9f1b9f819..0b4ffc43c 100644 --- a/Sources/GRPC/_GRPCClientChannelHandler.swift +++ b/Sources/GRPC/_GRPCClientChannelHandler.swift @@ -44,7 +44,7 @@ public struct _GRPCRequestHead { var path: String var host: String var timeout: GRPCTimeout - var encoding: ClientConnection.Configuration.MessageEncoding + var encoding: CallOptions.MessageEncoding init( method: String, @@ -52,7 +52,7 @@ public struct _GRPCRequestHead { path: String, host: String, timeout: GRPCTimeout, - encoding: ClientConnection.Configuration.MessageEncoding + encoding: CallOptions.MessageEncoding ) { self.method = method self.scheme = scheme @@ -138,7 +138,7 @@ public struct _GRPCRequestHead { } } - internal var compression: ClientConnection.Configuration.MessageEncoding { + internal var compression: CallOptions.MessageEncoding { get { return self._storage.encoding } @@ -157,7 +157,7 @@ public struct _GRPCRequestHead { host: String, timeout: GRPCTimeout, customMetadata: HPACKHeaders, - encoding: ClientConnection.Configuration.MessageEncoding + encoding: CallOptions.MessageEncoding ) { self._storage = .init( method: method, @@ -385,7 +385,11 @@ extension _GRPCClientChannelHandler: ChannelInboundHandler { // Awesome: we got some messages. The state machine guarantees we only get at most a single // message for unary and client-streaming RPCs. for message in messages { - context.fireChannelRead(self.wrapInboundOut(.message(.init(message)))) + // Note: `compressed: false` is currently just a placeholder. This is fine since the message + // context is not currently exposed to the user. If we implement interceptors for the client + // and decide to surface this information then we'll need to extract that information from + // the message reader. + context.fireChannelRead(self.wrapInboundOut(.message(.init(message, compressed: false)))) } case .failure(let error): context.fireErrorCaught(error) @@ -419,7 +423,7 @@ extension _GRPCClientChannelHandler: ChannelOutboundHandler { case .message(let request): // Feed the request message into the state machine: - let result = self.stateMachine.sendRequest(request.message, disableCompression: request.disableCompression, allocator: context.channel.allocator) + let result = self.stateMachine.sendRequest(request.message, compressed: request.compressed, allocator: context.channel.allocator) switch result { case .success(let buffer): // We're clear to send a message; wrap it up in an HTTP/2 frame. diff --git a/Sources/GRPC/_MessageContext.swift b/Sources/GRPC/_MessageContext.swift index 707a6f290..2132e4c19 100644 --- a/Sources/GRPC/_MessageContext.swift +++ b/Sources/GRPC/_MessageContext.swift @@ -15,18 +15,21 @@ */ import SwiftProtobuf -/// Provides a context for Protobuf messages. +/// Provides a context for gRPC payloads. /// /// - Important: This is **NOT** part of the public API. public final class _MessageContext { + /// The message being sent or received. let message: M - let disableCompression: Bool + + /// Whether the message was, or should be compressed. + let compressed: Bool /// Constructs a box for a value. /// /// - Important: This is **NOT** part of the public API. - public init(_ message: M, disableCompression: Bool = false) { + public init(_ message: M, compressed: Bool) { self.message = message - self.disableCompression = disableCompression + self.compressed = compressed } } diff --git a/Sources/GRPCInteroperabilityTestsImplementation/InteroperabilityTestCases.swift b/Sources/GRPCInteroperabilityTestsImplementation/InteroperabilityTestCases.swift index 498e239ce..bd2d909a2 100644 --- a/Sources/GRPCInteroperabilityTestsImplementation/InteroperabilityTestCases.swift +++ b/Sources/GRPCInteroperabilityTestsImplementation/InteroperabilityTestCases.swift @@ -175,12 +175,6 @@ class LargeUnary: InteroperabilityTest { /// - Clients are free to assert that the response payload body contents are zeros and comparing the /// entire response message against a golden response. class ClientCompressedUnary: InteroperabilityTest { - func configure(defaults: ClientConnection.Configuration) -> ClientConnection.Configuration { - var configuration = defaults - configuration.messageEncoding = .init(forRequests: .gzip, acceptableForResponses: CompressionAlgorithm.all) - return configuration - } - func run(using connection: ClientConnection) throws { let client = Grpc_Testing_TestServiceServiceClient(connection: connection) @@ -194,20 +188,19 @@ class ClientCompressedUnary: InteroperabilityTest { uncompressedRequest.expectCompressed = false // For unary RPCs we disable compression at the call level. - var options = CallOptions() - options.disableCompression = true // With compression expected but *disabled*. - let probe = client.unaryCall(compressedRequest, callOptions: options) + let probe = client.unaryCall(compressedRequest) try waitAndAssertEqual(probe.status.map { $0.code }, .invalidArgument) // With compression expected and enabled. - let compressed = client.unaryCall(compressedRequest) + let options = CallOptions(messageEncoding: .init(forRequests: .gzip)) + let compressed = client.unaryCall(compressedRequest, callOptions: options) try waitAndAssertEqual(compressed.response.map { $0.payload }, .zeros(count: 314_159)) try waitAndAssertEqual(compressed.status.map { $0.code }, .ok) // With compression not expected and disabled. - let uncompressed = client.unaryCall(uncompressedRequest, callOptions: options) + let uncompressed = client.unaryCall(uncompressedRequest) try waitAndAssertEqual(uncompressed.response.map { $0.payload }, .zeros(count: 314_159)) try waitAndAssertEqual(uncompressed.status.map { $0.code }, .ok) } @@ -262,12 +255,6 @@ class ClientCompressedUnary: InteroperabilityTest { /// - clients are free to assert that the response payload body contents are zero and comparing the /// entire response message against a golden response class ServerCompressedUnary: InteroperabilityTest { - func configure(defaults: ClientConnection.Configuration) -> ClientConnection.Configuration { - var configuration = defaults - configuration.messageEncoding = .responsesOnly - return configuration - } - func run(using connection: ClientConnection) throws { let client = Grpc_Testing_TestServiceServiceClient(connection: connection) @@ -277,7 +264,8 @@ class ServerCompressedUnary: InteroperabilityTest { request.payload = .zeros(count: 271_828) } - let compressed = client.unaryCall(compressedRequest) + let options = CallOptions(messageEncoding: .responsesOnly) + let compressed = client.unaryCall(compressedRequest, callOptions: options) // We can't verify that the compression bit was set, instead we verify that the encoding header // was sent by the server. This isn't quite the same since as it can still be set but the // compression may be not set. @@ -419,23 +407,22 @@ class ClientStreaming: InteroperabilityTest { /// - Next calls succeeds. /// - Response aggregated payload size is 73086. class ClientCompressedStreaming: InteroperabilityTest { - func configure(defaults: ClientConnection.Configuration) -> ClientConnection.Configuration { - var configuration = defaults - configuration.messageEncoding = .init(forRequests: .gzip, acceptableForResponses: CompressionAlgorithm.all) - return configuration - } - func run(using connection: ClientConnection) throws { let client = Grpc_Testing_TestServiceServiceClient(connection: connection) - // Does the server support this test? + // Does the server support this test? To find out we need to send an uncompressed probe. However + // we need to disable compression at the RPC level as we don't have access to whether the + // compression byte is set on messages. As such the corresponding code in the service + // implementation checks against the 'grpc-encoding' header as a best guess. Disabling + // compression here will stop that header from being sent. let probe = client.streamingInputCall() let probeRequest: Grpc_Testing_StreamingInputCallRequest = .with { request in request.expectCompressed = true request.payload = .zeros(count: 27_182) } - probe.sendMessage(probeRequest, disableCompression: true, promise: nil) + // Compression is disabled at the RPC level. + probe.sendMessage(probeRequest, promise: nil) probe.sendEnd(promise: nil) // We *expect* invalid argument here. If not then the server doesn't support this test. @@ -450,9 +437,10 @@ class ClientCompressedStreaming: InteroperabilityTest { request.payload = .zeros(count: 45_904) } - let streaming = client.streamingInputCall() - streaming.sendMessage(probeRequest, promise: nil) - streaming.sendMessage(secondMessage, disableCompression: true, promise: nil) + let options = CallOptions(messageEncoding: .init(forRequests: .gzip)) + let streaming = client.streamingInputCall(callOptions: options) + streaming.sendMessage(probeRequest, compression: .enabled, promise: nil) + streaming.sendMessage(secondMessage, compression: .disabled, promise: nil) streaming.sendEnd(promise: nil) try waitAndAssertEqual(streaming.response.map { $0.aggregatedPayloadSize }, 73_086) @@ -553,12 +541,6 @@ class ServerStreaming: InteroperabilityTest { /// - clients are free to assert that the response payload body contents are zero and comparing the /// entire response messages against golden responses class ServerCompressedStreaming: InteroperabilityTest { - func configure(defaults: ClientConnection.Configuration) -> ClientConnection.Configuration { - var configuration = defaults - configuration.messageEncoding = .responsesOnly - return configuration - } - func run(using connection: ClientConnection) throws { let client = Grpc_Testing_TestServiceServiceClient(connection: connection) @@ -575,8 +557,9 @@ class ServerCompressedStreaming: InteroperabilityTest { ] } + let options = CallOptions(messageEncoding: .responsesOnly) var payloads: [Grpc_Testing_Payload] = [] - let rpc = client.streamingOutputCall(request) { response in + let rpc = client.streamingOutputCall(request, callOptions: options) { response in payloads.append(response.payload) } diff --git a/Sources/GRPCInteroperabilityTestsImplementation/InteroperabilityTestServer.swift b/Sources/GRPCInteroperabilityTestsImplementation/InteroperabilityTestServer.swift index 72e50bf4e..56d016bb2 100644 --- a/Sources/GRPCInteroperabilityTestsImplementation/InteroperabilityTestServer.swift +++ b/Sources/GRPCInteroperabilityTestsImplementation/InteroperabilityTestServer.swift @@ -40,7 +40,9 @@ public func makeInteroperabilityTestServer( var configuration = Server.Configuration( target: .hostAndPort(host, port), eventLoopGroup: eventLoopGroup, - serviceProviders: serviceProviders) + serviceProviders: serviceProviders, + messageEncoding: .init(enabled: CompressionAlgorithm.all) + ) if useTLS { print("Using the gRPC interop testing CA for TLS; clients should expect the host to be '*.test.google.fr'") diff --git a/Sources/GRPCInteroperabilityTestsImplementation/TestServiceProvider.swift b/Sources/GRPCInteroperabilityTestsImplementation/TestServiceProvider.swift index 8030377fe..1d6838b1c 100644 --- a/Sources/GRPCInteroperabilityTestsImplementation/TestServiceProvider.swift +++ b/Sources/GRPCInteroperabilityTestsImplementation/TestServiceProvider.swift @@ -32,7 +32,7 @@ public class TestServiceProvider: Grpc_Testing_TestServiceProvider { /// /// Some 'features' are methods, whilst others optionally modify the outcome of those methods. The /// specification is not explicit about where these modifying features should be implemented (i.e. - /// which methods should support them) and they are not listed in the individual metdod + /// which methods should support them) and they are not listed in the individual method /// descriptions. As such implementation of these modifying features within each method is /// determined by the features required by each test. public static var implementedFeatures: Set { @@ -42,7 +42,9 @@ public class TestServiceProvider: Grpc_Testing_TestServiceProvider { .streamingOutputCall, .streamingInputCall, .fullDuplexCall, - .echoStatus + .echoStatus, + .compressedResponse, + .compressedRequest ] } @@ -64,6 +66,21 @@ public class TestServiceProvider: Grpc_Testing_TestServiceProvider { request: Grpc_Testing_SimpleRequest, context: StatusOnlyCallContext ) -> EventLoopFuture { + // We can't validate messages at the wire-encoding layer (i.e. where the compression byte is + // set), so we have to check via the encoding header. Note that it is possible for the header + // to be set and for the message to not be compressed. + if request.expectCompressed.value && !context.request.headers.contains(name: "grpc-encoding") { + let status = GRPCStatus( + code: .invalidArgument, + message: "Expected compressed request, but 'grpc-encoding' was missing" + ) + return context.eventLoop.makeFailedFuture(status) + } + + // Should we enable compression? The C++ interoperability client only expects compression if + // explicitly requested; we'll do the same. + context.compressionEnabled = request.responseCompressed.value + if request.shouldEchoStatus { let code = GRPCStatus.Code(rawValue: numericCast(request.responseStatus.code)) ?? .unknown return context.eventLoop.makeFailedFuture(GRPCStatus(code: code, message: request.responseStatus.message)) @@ -125,7 +142,10 @@ public class TestServiceProvider: Grpc_Testing_TestServiceProvider { } } - return context.sendResponse(response) + // Should we enable compression? The C++ interoperability client only expects compression if + // explicitly requested; we'll do the same. + let compression: Compression = responseParameter.compressed.value ? .enabled : .disabled + return context.sendResponse(response, compression: compression) } } @@ -143,7 +163,15 @@ public class TestServiceProvider: Grpc_Testing_TestServiceProvider { return context.eventLoop.makeSucceededFuture({ event in switch event { case .message(let request): - aggregatePayloadSize += request.payload.body.count + if request.expectCompressed.value && !context.request.headers.contains(name: "grpc-encoding") { + context.responseStatus = GRPCStatus( + code: .invalidArgument, + message: "Expected compressed request, but 'grpc-encoding' was missing" + ) + context.responsePromise.fail(context.responseStatus) + } else { + aggregatePayloadSize += request.payload.body.count + } case .end: context.responsePromise.succeed(Grpc_Testing_StreamingInputCallResponse.with { response in diff --git a/Sources/GRPCPerformanceTests/Benchmarks/EmbeddedClientThroughput.swift b/Sources/GRPCPerformanceTests/Benchmarks/EmbeddedClientThroughput.swift index bad473ffa..58532df72 100644 --- a/Sources/GRPCPerformanceTests/Benchmarks/EmbeddedClientThroughput.swift +++ b/Sources/GRPCPerformanceTests/Benchmarks/EmbeddedClientThroughput.swift @@ -63,7 +63,7 @@ class EmbeddedClientThroughput: Benchmark { let channel = EmbeddedChannel() try channel.pipeline.addHandlers([ _GRPCClientChannelHandler(streamID: .init(1), callType: .unary, logger: self.logger), - _UnaryRequestChannelHandler(requestHead: self.requestHead, request: .init(self.request)) + _UnaryRequestChannelHandler(requestHead: self.requestHead, request: .init(self.request, compressed: false)) ]).wait() // Trigger the request handler. diff --git a/Tests/GRPCTests/CompressionTests.swift b/Tests/GRPCTests/CompressionTests.swift new file mode 100644 index 000000000..7be4f0275 --- /dev/null +++ b/Tests/GRPCTests/CompressionTests.swift @@ -0,0 +1,175 @@ +/* + * Copyright 2020, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import GRPC +import EchoImplementation +import EchoModel +import NIO +import NIOHPACK +import XCTest + +class MessageCompressionTests: GRPCTestCase { + var group: EventLoopGroup! + var server: Server! + var client: ClientConnection! + var defaultTimeout: TimeInterval = 0.1 + + var echo: Echo_EchoServiceClient! + + override func setUp() { + self.group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + } + + override func tearDown() { + XCTAssertNoThrow(try self.client.close().wait()) + XCTAssertNoThrow(try self.server.close().wait()) + XCTAssertNoThrow(try self.group.syncShutdownGracefully()) + } + + func setupServer(encoding: Server.Configuration.MessageEncoding) throws { + let configuration = Server.Configuration( + target: .hostAndPort("localhost", 0), + eventLoopGroup: self.group, + serviceProviders: [EchoProvider()], + messageEncoding: encoding + ) + + self.server = try Server.start(configuration: configuration).wait() + } + + func setupClient(encoding: CallOptions.MessageEncoding) { + let configuration = ClientConnection.Configuration( + target: .hostAndPort("localhost", self.server.channel.localAddress!.port!), + eventLoopGroup: self.group + ) + + self.client = ClientConnection(configuration: configuration) + self.echo = Echo_EchoServiceClient( + connection: self.client, + defaultCallOptions: CallOptions(messageEncoding: encoding) + ) + } + + func doUnaryRPC() -> UnaryCall { + let get = self.echo.get(.with { $0.text = "foo" }) + return get + } + + func testCompressedRequestsUncompressedResponses() throws { + try self.setupServer(encoding: .none) + self.setupClient(encoding: .init(forRequests: .gzip, acceptableForResponses: [.deflate, .gzip])) + + let get = self.echo.get(.with { $0.text = "foo" }) + + let initialMetadata = self.expectation(description: "received initial metadata") + get.initialMetadata.map { + $0.contains(name: "grpc-encoding") + }.assertEqual(false, fulfill: initialMetadata) + + let status = self.expectation(description: "received status") + get.status.map { + $0.code + }.assertEqual(.ok, fulfill: status) + + self.wait(for: [initialMetadata, status], timeout: self.defaultTimeout) + } + + func testUncompressedRequestsCompressedResponses() throws { + try self.setupServer(encoding: .enabled) + self.setupClient(encoding: .init(forRequests: .none, acceptableForResponses: [.deflate, .gzip])) + + let get = self.echo.get(.with { $0.text = "foo" }) + + let initialMetadata = self.expectation(description: "received initial metadata") + get.initialMetadata.map { + $0.first(name: "grpc-encoding") + }.assertEqual("deflate", fulfill: initialMetadata) + + let status = self.expectation(description: "received status") + get.status.map { + $0.code + }.assertEqual(.ok, fulfill: status) + + self.wait(for: [initialMetadata, status], timeout: self.defaultTimeout) + } + + func testServerCanDecompressNonAdvertisedButSupportedCompression() throws { + // Server should be able to decompress a format it supports but does not advertise. In doing + // so it must also return a "grpc-accept-encoding" header which includes the value it did not + // advertise. + try self.setupServer(encoding: .init(enabled: [.gzip])) + self.setupClient(encoding: .init(forRequests: .deflate, acceptableForResponses: [])) + + let get = self.echo.get(.with { $0.text = "foo" }) + + let initialMetadata = self.expectation(description: "received initial metadata") + get.initialMetadata.map { + $0[canonicalForm: "grpc-accept-encoding"] + }.assertEqual(["gzip", "deflate"], fulfill: initialMetadata) + + let status = self.expectation(description: "received status") + get.status.map { + $0.code + }.assertEqual(.ok, fulfill: status) + + self.wait(for: [initialMetadata, status], timeout: self.defaultTimeout) + } + + func testServerCompressesResponseWithDifferentAlgorithmToRequest() throws { + // Server should be able to compress responses with a different method to the client, providing + // the client supports it. + try self.setupServer(encoding: .init(enabled: [.gzip])) + self.setupClient(encoding: .init(forRequests: .deflate, acceptableForResponses: [.deflate, .gzip])) + + let get = self.echo.get(.with { $0.text = "foo" }) + + let initialMetadata = self.expectation(description: "received initial metadata") + get.initialMetadata.map { + $0.first(name: "grpc-encoding") + }.assertEqual("gzip", fulfill: initialMetadata) + + let status = self.expectation(description: "received status") + get.status.map { + $0.code + }.assertEqual(.ok, fulfill: status) + + self.wait(for: [initialMetadata, status], timeout: self.defaultTimeout) + } + + func testCompressedRequestWithCompressionNotSupportedOnServer() throws { + try self.setupServer(encoding: .init(enabled: [.gzip, .deflate])) + // We can't specify a compression we don't support, so we'll specify no compression and then + // send a 'grpc-encoding' with our initial metadata. + self.setupClient(encoding: .init(forRequests: .none, acceptableForResponses: [.deflate, .gzip])) + + let headers: HPACKHeaders = ["grpc-encoding": "you-don't-support-this"] + let get = self.echo.get(.with { $0.text = "foo" }, callOptions: CallOptions(customMetadata: headers)) + + let response = self.expectation(description: "received response") + get.response.assertError(fulfill: response) + + let trailers = self.expectation(description: "received trailing metadata") + get.trailingMetadata.map { + $0[canonicalForm: "grpc-accept-encoding"] + }.assertEqual(["gzip", "deflate"], fulfill: trailers) + + let status = self.expectation(description: "received status") + get.status.map { + $0.code + }.assertEqual(.unimplemented, fulfill: status) + + self.wait(for: [response, trailers, status], timeout: self.defaultTimeout) + } +} diff --git a/Tests/GRPCTests/GRPCClientStateMachineTests.swift b/Tests/GRPCTests/GRPCClientStateMachineTests.swift index 286e85535..4d58193f7 100644 --- a/Tests/GRPCTests/GRPCClientStateMachineTests.swift +++ b/Tests/GRPCTests/GRPCClientStateMachineTests.swift @@ -125,7 +125,7 @@ extension GRPCClientStateMachineTests { extension GRPCClientStateMachineTests { func doTestSendRequestFromInvalidState(_ state: StateMachine.State, expected: MessageWriteError) { var stateMachine = self.makeStateMachine(state) - stateMachine.sendRequest(.init(text: "Hello!"), disableCompression: false, allocator: self.allocator).assertFailure { + stateMachine.sendRequest(.init(text: "Hello!"), compressed: false, allocator: self.allocator).assertFailure { XCTAssertEqual($0, expected) } } @@ -134,7 +134,7 @@ extension GRPCClientStateMachineTests { var stateMachine = self.makeStateMachine(state) let request: Request = .with { $0.text = "Hello!" } - stateMachine.sendRequest(request, disableCompression: false, allocator: self.allocator).assertSuccess() { buffer in + stateMachine.sendRequest(request, compressed: false, allocator: self.allocator).assertSuccess() { buffer in var buffer = buffer // Remove the length and compression flag prefix. buffer.moveReaderIndex(forwardBy: 5) @@ -458,7 +458,7 @@ extension GRPCClientStateMachineTests { stateMachine.receiveResponseHeaders(self.makeResponseHeaders()).assertSuccess() // Send a request. - stateMachine.sendRequest(.with { $0.text = "Hello!" }, disableCompression: false, allocator: self.allocator).assertSuccess() + stateMachine.sendRequest(.with { $0.text = "Hello!" }, compressed: false, allocator: self.allocator).assertSuccess() // Close the request stream. stateMachine.sendEndOfRequestStream().assertSuccess() @@ -489,9 +489,9 @@ extension GRPCClientStateMachineTests { stateMachine.receiveResponseHeaders(self.makeResponseHeaders()).assertSuccess() // Send some requests. - stateMachine.sendRequest(.with { $0.text = "1" }, disableCompression: false, allocator: self.allocator).assertSuccess() - stateMachine.sendRequest(.with { $0.text = "2" }, disableCompression: false, allocator: self.allocator).assertSuccess() - stateMachine.sendRequest(.with { $0.text = "3" }, disableCompression: false, allocator: self.allocator).assertSuccess() + stateMachine.sendRequest(.with { $0.text = "1" }, compressed: false, allocator: self.allocator).assertSuccess() + stateMachine.sendRequest(.with { $0.text = "2" }, compressed: false, allocator: self.allocator).assertSuccess() + stateMachine.sendRequest(.with { $0.text = "3" }, compressed: false, allocator: self.allocator).assertSuccess() // Close the request stream. stateMachine.sendEndOfRequestStream().assertSuccess() @@ -522,7 +522,7 @@ extension GRPCClientStateMachineTests { stateMachine.receiveResponseHeaders(self.makeResponseHeaders()).assertSuccess() // Send a request. - stateMachine.sendRequest(.with { $0.text = "1" }, disableCompression: false, allocator: self.allocator).assertSuccess() + stateMachine.sendRequest(.with { $0.text = "1" }, compressed: false, allocator: self.allocator).assertSuccess() // Close the request stream. stateMachine.sendEndOfRequestStream().assertSuccess() @@ -558,15 +558,15 @@ extension GRPCClientStateMachineTests { stateMachine.receiveResponseHeaders(self.makeResponseHeaders()).assertSuccess() // Interleave requests and responses: - stateMachine.sendRequest(.with { $0.text = "1" }, disableCompression: false, allocator: self.allocator).assertSuccess() + stateMachine.sendRequest(.with { $0.text = "1" }, compressed: false, allocator: self.allocator).assertSuccess() // Receive a response. var firstBuffer = try self.writeMessage(Response.with { $0.text = "1" }) stateMachine.receiveResponseBuffer(&firstBuffer).assertSuccess() // Send two more requests. - stateMachine.sendRequest(.with { $0.text = "2" }, disableCompression: false, allocator: self.allocator).assertSuccess() - stateMachine.sendRequest(.with { $0.text = "3" }, disableCompression: false, allocator: self.allocator).assertSuccess() + stateMachine.sendRequest(.with { $0.text = "2" }, compressed: false, allocator: self.allocator).assertSuccess() + stateMachine.sendRequest(.with { $0.text = "3" }, compressed: false, allocator: self.allocator).assertSuccess() // Receive two responses in one buffer. var secondBuffer = try self.writeMessage(Response.with { $0.text = "2" }) @@ -589,9 +589,9 @@ extension GRPCClientStateMachineTests { var stateMachine = self.makeStateMachine(.clientActiveServerIdle(writeState: .one(), readArity: messageCount)) // One is fine. - stateMachine.sendRequest(.with { $0.text = "1" }, disableCompression: false, allocator: self.allocator).assertSuccess() + stateMachine.sendRequest(.with { $0.text = "1" }, compressed: false, allocator: self.allocator).assertSuccess() // Two is not. - stateMachine.sendRequest(.with { $0.text = "2" }, disableCompression: false, allocator: self.allocator).assertFailure { + stateMachine.sendRequest(.with { $0.text = "2" }, compressed: false, allocator: self.allocator).assertFailure { XCTAssertEqual($0, .cardinalityViolation) } } @@ -602,9 +602,9 @@ extension GRPCClientStateMachineTests { var stateMachine = self.makeStateMachine(.clientActiveServerActive(writeState: .one(), readState: readState)) // One is fine. - stateMachine.sendRequest(.with { $0.text = "1" }, disableCompression: false, allocator: self.allocator).assertSuccess() + stateMachine.sendRequest(.with { $0.text = "1" }, compressed: false, allocator: self.allocator).assertSuccess() // Two is not. - stateMachine.sendRequest(.with { $0.text = "2" }, disableCompression: false, allocator: self.allocator).assertFailure { + stateMachine.sendRequest(.with { $0.text = "2" }, compressed: false, allocator: self.allocator).assertFailure { XCTAssertEqual($0, .cardinalityViolation) } } @@ -614,7 +614,7 @@ extension GRPCClientStateMachineTests { var stateMachine = self.makeStateMachine(.clientClosedServerClosed) // No requests allowed! - stateMachine.sendRequest(.with { $0.text = "1" }, disableCompression: false, allocator: self.allocator).assertFailure { + stateMachine.sendRequest(.with { $0.text = "1" }, compressed: false, allocator: self.allocator).assertFailure { XCTAssertEqual($0, .cardinalityViolation) } } diff --git a/Tests/GRPCTests/GRPCInteroperabilityTests.swift b/Tests/GRPCTests/GRPCInteroperabilityTests.swift index f37acfd01..629a6189e 100644 --- a/Tests/GRPCTests/GRPCInteroperabilityTests.swift +++ b/Tests/GRPCTests/GRPCInteroperabilityTests.swift @@ -28,6 +28,7 @@ class GRPCInsecureInteroperabilityTests: GRPCTestCase { var clientEventLoopGroup: EventLoopGroup! var clientConnection: ClientConnection! + var clientDefaults: ClientConnection.Configuration! override func setUp() { super.setUp() @@ -46,7 +47,7 @@ class GRPCInsecureInteroperabilityTests: GRPCTestCase { } self.clientEventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) - self.clientConnection = makeInteroperabilityTestClientConnection( + self.clientDefaults = makeInteroperabilityTestClientConfiguration( host: "localhost", port: serverPort, eventLoopGroup: self.clientEventLoopGroup, @@ -55,8 +56,9 @@ class GRPCInsecureInteroperabilityTests: GRPCTestCase { } override func tearDown() { - XCTAssertNoThrow(try self.clientConnection.close().wait()) + XCTAssertNoThrow(try self.clientConnection?.close().wait()) XCTAssertNoThrow(try self.clientEventLoopGroup.syncShutdownGracefully()) + self.clientDefaults = nil self.clientConnection = nil self.clientEventLoopGroup = nil @@ -78,6 +80,8 @@ class GRPCInsecureInteroperabilityTests: GRPCTestCase { } let test = testCase.makeTest() + let configuration = test.configure(defaults: self.clientDefaults) + self.clientConnection = ClientConnection(configuration: configuration) XCTAssertNoThrow(try test.run(using: self.clientConnection), file: file, line: line) } @@ -93,14 +97,30 @@ class GRPCInsecureInteroperabilityTests: GRPCTestCase { self.doRunTest(.largeUnary) } + func testClientCompressedUnary() { + self.doRunTest(.clientCompressedUnary) + } + + func testServerCompressedUnary() { + self.doRunTest(.serverCompressedUnary) + } + func testClientStreaming() { self.doRunTest(.clientStreaming) } + func testClientCompressedStreaming() { + self.doRunTest(.clientCompressedStreaming) + } + func testServerStreaming() { self.doRunTest(.serverStreaming) } + func testServerCompressedStreaming() { + self.doRunTest(.serverCompressedStreaming) + } + func testPingPong() { self.doRunTest(.pingPong) } diff --git a/Tests/GRPCTests/GRPCServerRequestRoutingHandlerTests.swift b/Tests/GRPCTests/GRPCServerRequestRoutingHandlerTests.swift index bdfed41f4..8863062b8 100644 --- a/Tests/GRPCTests/GRPCServerRequestRoutingHandlerTests.swift +++ b/Tests/GRPCTests/GRPCServerRequestRoutingHandlerTests.swift @@ -32,6 +32,7 @@ class GRPCServerRequestRoutingHandlerTests: GRPCTestCase { let provider = EchoProvider() let handler = GRPCServerRequestRoutingHandler( servicesByName: [provider.serviceName: provider], + encoding: .none, errorDelegate: nil, logger: logger ) diff --git a/Tests/GRPCTests/HTTP1ToGRPCServerCodecTests.swift b/Tests/GRPCTests/HTTP1ToGRPCServerCodecTests.swift index e8512deeb..06a9bb22b 100644 --- a/Tests/GRPCTests/HTTP1ToGRPCServerCodecTests.swift +++ b/Tests/GRPCTests/HTTP1ToGRPCServerCodecTests.swift @@ -29,7 +29,7 @@ class HTTP1ToGRPCServerCodecTests: GRPCTestCase { override func setUp() { super.setUp() let logger = Logger(label: "io.grpc.testing") - let handler = HTTP1ToGRPCServerCodec(logger: logger) + let handler = HTTP1ToGRPCServerCodec(encoding: .none, logger: logger) self.channel = EmbeddedChannel(handler: handler) } diff --git a/Tests/GRPCTests/LengthPrefixedMessageReaderTests.swift b/Tests/GRPCTests/LengthPrefixedMessageReaderTests.swift index f78b61544..379ae8885 100644 --- a/Tests/GRPCTests/LengthPrefixedMessageReaderTests.swift +++ b/Tests/GRPCTests/LengthPrefixedMessageReaderTests.swift @@ -222,10 +222,10 @@ class LengthPrefixedMessageReaderTests: GRPCTestCase { func testNextMessageDoesNotThrowWhenCompressionFlagIsExpectedButNotSet() throws { // `.identity` should always be supported and requires a flag. - reader.compression = .identity + self.reader = LengthPrefixedMessageReader(compression: .identity) var buffer = byteBuffer(withBytes: lengthPrefixedTwoByteMessage()) - reader.append(buffer: &buffer) + self.reader.append(buffer: &buffer) self.assertMessagesEqual(expected: twoByteMessage, actual: try reader.nextMessage()) } diff --git a/Tests/GRPCTests/XCTestManifests.swift b/Tests/GRPCTests/XCTestManifests.swift index 055c25681..5fa002f75 100644 --- a/Tests/GRPCTests/XCTestManifests.swift +++ b/Tests/GRPCTests/XCTestManifests.swift @@ -324,12 +324,16 @@ extension GRPCInsecureInteroperabilityTests { ("testCacheableUnary", testCacheableUnary), ("testCancelAfterBegin", testCancelAfterBegin), ("testCancelAfterFirstResponse", testCancelAfterFirstResponse), + ("testClientCompressedStreaming", testClientCompressedStreaming), + ("testClientCompressedUnary", testClientCompressedUnary), ("testClientStreaming", testClientStreaming), ("testCustomMetadata", testCustomMetadata), ("testEmptyStream", testEmptyStream), ("testEmptyUnary", testEmptyUnary), ("testLargeUnary", testLargeUnary), ("testPingPong", testPingPong), + ("testServerCompressedStreaming", testServerCompressedStreaming), + ("testServerCompressedUnary", testServerCompressedUnary), ("testServerStreaming", testServerStreaming), ("testSpecialStatusAndMessage", testSpecialStatusAndMessage), ("testStatusCodeAndMessage", testStatusCodeAndMessage), @@ -347,12 +351,16 @@ extension GRPCSecureInteroperabilityTests { ("testCacheableUnary", testCacheableUnary), ("testCancelAfterBegin", testCancelAfterBegin), ("testCancelAfterFirstResponse", testCancelAfterFirstResponse), + ("testClientCompressedStreaming", testClientCompressedStreaming), + ("testClientCompressedUnary", testClientCompressedUnary), ("testClientStreaming", testClientStreaming), ("testCustomMetadata", testCustomMetadata), ("testEmptyStream", testEmptyStream), ("testEmptyUnary", testEmptyUnary), ("testLargeUnary", testLargeUnary), ("testPingPong", testPingPong), + ("testServerCompressedStreaming", testServerCompressedStreaming), + ("testServerCompressedUnary", testServerCompressedUnary), ("testServerStreaming", testServerStreaming), ("testSpecialStatusAndMessage", testSpecialStatusAndMessage), ("testStatusCodeAndMessage", testStatusCodeAndMessage), @@ -482,6 +490,19 @@ extension LengthPrefixedMessageReaderTests { ] } +extension MessageCompressionTests { + // DO NOT MODIFY: This is autogenerated, use: + // `swift test --generate-linuxmain` + // to regenerate. + static let __allTests__MessageCompressionTests = [ + ("testCompressedRequestsUncompressedResponses", testCompressedRequestsUncompressedResponses), + ("testCompressedRequestWithCompressionNotSupportedOnServer", testCompressedRequestWithCompressionNotSupportedOnServer), + ("testServerCanDecompressNonAdvertisedButSupportedCompression", testServerCanDecompressNonAdvertisedButSupportedCompression), + ("testServerCompressesResponseWithDifferentAlgorithmToRequest", testServerCompressesResponseWithDifferentAlgorithmToRequest), + ("testUncompressedRequestsCompressedResponses", testUncompressedRequestsCompressedResponses), + ] +} + extension PlatformSupportTests { // DO NOT MODIFY: This is autogenerated, use: // `swift test --generate-linuxmain` @@ -636,6 +657,7 @@ public func __allTests() -> [XCTestCaseEntry] { testCase(HTTP1ToGRPCServerCodecTests.__allTests__HTTP1ToGRPCServerCodecTests), testCase(ImmediatelyFailingProviderTests.__allTests__ImmediatelyFailingProviderTests), testCase(LengthPrefixedMessageReaderTests.__allTests__LengthPrefixedMessageReaderTests), + testCase(MessageCompressionTests.__allTests__MessageCompressionTests), testCase(PlatformSupportTests.__allTests__PlatformSupportTests), testCase(ReadStateTests.__allTests__ReadStateTests), testCase(ServerDelayedThrowingTests.__allTests__ServerDelayedThrowingTests), diff --git a/Tests/GRPCTests/ZlibTests.swift b/Tests/GRPCTests/ZlibTests.swift index f9a401539..ee53b809c 100644 --- a/Tests/GRPCTests/ZlibTests.swift +++ b/Tests/GRPCTests/ZlibTests.swift @@ -80,7 +80,6 @@ class ZlibTests: GRPCTestCase { for format in [Zlib.CompressionFormat.deflate, .gzip] { // Is the compressed size larger than the input size? let compressedSize = try self.doCompressAndDecompress(of: bytes, format: format) - print(compressedSize) XCTAssertGreaterThan(compressedSize, bytes.count) }