Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enforce request cardinality for unary-request calls also for the case of zero request messages being sent. #392

Merged
merged 7 commits into from
Mar 6, 2019
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions Sources/SwiftGRPCNIO/CallHandlers/BaseCallHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public class BaseCallHandler<RequestMessage: Message, ResponseMessage: Message>:
/// Called when the client has half-closed the stream, indicating that they won't send any further data.
///
/// Overridden by subclasses if the "end-of-stream" event is relevant.
public func endOfStreamReceived() { }
public func endOfStreamReceived() throws { }

/// Whether this handler can still write messages to the client.
private var serverCanWrite = true
Expand All @@ -30,6 +30,12 @@ public class BaseCallHandler<RequestMessage: Message, ResponseMessage: Message>:
public init(errorDelegate: ServerErrorDelegate?) {
self.errorDelegate = errorDelegate
}

/// Sends an error status to the client while ensuring that all call context promises are fulfilled.
/// Because only the concrete call subclass knows which promises need to be fulfilled, this method needs to be overridden.
func sendErrorStatus(_ status: GRPCStatus) {
fatalError("needs to be overridden")
}
}

extension BaseCallHandler: ChannelInboundHandler {
Expand All @@ -40,10 +46,10 @@ extension BaseCallHandler: ChannelInboundHandler {
/// return a status with code `.internalError`.
public func errorCaught(ctx: ChannelHandlerContext, error: Error) {
errorDelegate?.observe(error)

MrMage marked this conversation as resolved.
Show resolved Hide resolved
let transformed = errorDelegate?.transform(error) ?? error
let status = (transformed as? GRPCStatusTransformable)?.asGRPCStatus() ?? GRPCStatus.processingError
self.write(ctx: ctx, data: NIOAny(GRPCServerResponsePart<ResponseMessage>.status(status)), promise: nil)
sendErrorStatus(status)
}

public func channelRead(ctx: ChannelHandlerContext, data: NIOAny) {
Expand All @@ -60,7 +66,11 @@ extension BaseCallHandler: ChannelInboundHandler {
}

case .end:
endOfStreamReceived()
do {
try endOfStreamReceived()
} catch {
self.errorCaught(ctx: ctx, error: error)
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,13 @@ public class BidirectionalStreamingCallHandler<RequestMessage: Message, Response
}
}

public override func endOfStreamReceived() {
public override func endOfStreamReceived() throws {
eventObserver?.whenSuccess { observer in
observer(.end)
}
}

override func sendErrorStatus(_ status: GRPCStatus) {
context?.statusPromise.fail(error: status)
MrMage marked this conversation as resolved.
Show resolved Hide resolved
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,13 @@ public class ClientStreamingCallHandler<RequestMessage: Message, ResponseMessage
}
}

public override func endOfStreamReceived() {
public override func endOfStreamReceived() throws {
eventObserver?.whenSuccess { observer in
observer(.end)
}
}

override func sendErrorStatus(_ status: GRPCStatus) {
context?.responsePromise.fail(error: status)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public class ServerStreamingCallHandler<RequestMessage: Message, ResponseMessage
public override func processMessage(_ message: RequestMessage) throws {
guard let eventObserver = self.eventObserver,
let context = self.context else {
throw GRPCError.server(.requestCardinalityViolation)
throw GRPCError.server(.requestCardinalityViolationTooManyRequests)
}

let resultFuture = eventObserver(message)
Expand All @@ -37,4 +37,14 @@ public class ServerStreamingCallHandler<RequestMessage: Message, ResponseMessage
.cascade(promise: context.statusPromise)
self.eventObserver = nil
}

public override func endOfStreamReceived() throws {
if self.eventObserver != nil {
throw GRPCError.server(.requestCardinalityViolationTooFewRequests)
}
}

override func sendErrorStatus(_ status: GRPCStatus) {
context?.statusPromise.fail(error: status)
}
}
12 changes: 11 additions & 1 deletion Sources/SwiftGRPCNIO/CallHandlers/UnaryCallHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public class UnaryCallHandler<RequestMessage: Message, ResponseMessage: Message>
public override func processMessage(_ message: RequestMessage) throws {
guard let eventObserver = self.eventObserver,
let context = self.context else {
throw GRPCError.server(.requestCardinalityViolation)
throw GRPCError.server(.requestCardinalityViolationTooManyRequests)
}

let resultFuture = eventObserver(message)
Expand All @@ -38,4 +38,14 @@ public class UnaryCallHandler<RequestMessage: Message, ResponseMessage: Message>
.cascade(promise: context.responsePromise)
self.eventObserver = nil
}

public override func endOfStreamReceived() throws {
if self.eventObserver != nil {
throw GRPCError.server(.requestCardinalityViolationTooFewRequests)
}
}

override func sendErrorStatus(_ status: GRPCStatus) {
context?.responsePromise.fail(error: status)
}
}
10 changes: 8 additions & 2 deletions Sources/SwiftGRPCNIO/GRPCError.swift
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,11 @@ public enum GRPCServerError: Error, Equatable {
/// It was not possible to serialize the response protobuf.
case responseProtoSerializationFailure

/// Zero requests were sent for a unary-request call.
rebello95 marked this conversation as resolved.
Show resolved Hide resolved
case requestCardinalityViolationTooFewRequests

/// More than one request was sent for a unary-request call.
case requestCardinalityViolation
case requestCardinalityViolationTooManyRequests

/// The server received a message when it was not in a writable state.
case serverNotWritable
Expand Down Expand Up @@ -143,7 +146,10 @@ extension GRPCServerError: GRPCStatusTransformable {
case .responseProtoSerializationFailure:
return GRPCStatus(code: .internalError, message: "could not serialize response proto")

case .requestCardinalityViolation:
case .requestCardinalityViolationTooFewRequests:
return GRPCStatus(code: .unimplemented, message: "request cardinality violation; method requires exactly one request but client sent none")

case .requestCardinalityViolationTooManyRequests:
return GRPCStatus(code: .unimplemented, message: "request cardinality violation; method requires exactly one request but client sent more")

case .serverNotWritable:
Expand Down
36 changes: 30 additions & 6 deletions Tests/SwiftGRPCNIOTests/NIOServerWebTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -42,24 +42,30 @@ class NIOServerWebTests: NIOBasicEchoTestCase {
data.insert(UInt8(0), at: 0)
return data
}

private func gRPCWebOKTrailers() -> Data {
var data = "grpc-status: 0\r\ngrpc-message: OK".data(using: .utf8)!
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revert

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like there's still trailing whitespace here

private func gRPCWebTrailers(status: Int, message: String) -> Data {
var data = "grpc-status: \(status)\r\ngrpc-message: \(message)".data(using: .utf8)!
// Add the gRPC prefix with the compression byte and the 4 length bytes.
for i in 0..<4 {
data.insert(UInt8((data.count >> (i * 8)) & 0xFF), at: 0)
}
data.insert(UInt8(0x80), at: 0)
return data
}

private func gRPCWebOKTrailers() -> Data {
return gRPCWebTrailers(status: 0, message: "OK")
MrMage marked this conversation as resolved.
Show resolved Hide resolved
}

private func sendOverHTTP1(rpcMethod: String, message: String, handler: @escaping (Data?, Error?) -> Void) {
private func sendOverHTTP1(rpcMethod: String, message: String?, handler: @escaping (Data?, Error?) -> Void) {
let serverURL = URL(string: "http://localhost:5050/echo.Echo/\(rpcMethod)")!
var request = URLRequest(url: serverURL)
request.httpMethod = "POST"
request.setValue("application/grpc-web-text", forHTTPHeaderField: "content-type")

request.httpBody = gRPCEncodedEchoRequest(message).base64EncodedData()
if let message = message {
request.httpBody = gRPCEncodedEchoRequest(message).base64EncodedData()
}

let sem = DispatchSemaphore(value: 0)
URLSession.shared.dataTask(with: request) { (data, response, error) in
Expand All @@ -85,7 +91,25 @@ extension NIOServerWebTests {
completionHandlerExpectation.fulfill()
}
}


waitForExpectations(timeout: defaultTestTimeout)
}

func testUnaryWithoutRequestMessage() {
MrMage marked this conversation as resolved.
Show resolved Hide resolved
let expectedData = gRPCWebTrailers(
status: 12, message: "request cardinality violation; method requires exactly one request but client sent none")
let expectedResponse = expectedData.base64EncodedString()

let completionHandlerExpectation = expectation(description: "completion handler called")

sendOverHTTP1(rpcMethod: "Get", message: nil) { data, error in
XCTAssertNil(error)
if let data = data {
XCTAssertEqual(String(data: data, encoding: .utf8), expectedResponse)
completionHandlerExpectation.fulfill()
}
MrMage marked this conversation as resolved.
Show resolved Hide resolved
}

waitForExpectations(timeout: defaultTestTimeout)
}

Expand Down