Skip to content

Commit

Permalink
Create ErrorMiddleware to handle errors or return internal server error
Browse files Browse the repository at this point in the history
  • Loading branch information
sukhrobkhakimov committed Aug 12, 2024
1 parent 0ac58be commit b2a4632
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 39 deletions.
8 changes: 4 additions & 4 deletions Sources/HTTP/Middleware/CORSMiddleware.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ public struct CORSMiddleware: Middleware {

public func handle(
request: Request,
nextHandler: @escaping (Request) async -> Response
) async -> Response {
guard request.headers.get(.origin) != nil else { return await nextHandler(request) }
var response = request.isPreflight ? Response(status: .noContent) : await nextHandler(request)
responder: @escaping Responder
) async throws -> Response {
guard request.headers.get(.origin) != nil else { return try await responder(request) }
var response = request.isPreflight ? Response(status: .noContent) : try await responder(request)
setAllowCredentialsHeader(response: &response)
setAllowHeadersHeader(request: request, response: &response)
setAllowMethodsHeader(response: &response)
Expand Down
6 changes: 3 additions & 3 deletions Sources/HTTP/Middleware/HTTPMethodOverrideMiddleware.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ public struct HTTPMethodOverrideMiddleware: Middleware {

public func handle(
request: Request,
nextHandler: @escaping (Request) async -> Response
) async -> Response {
responder: @escaping Responder
) async throws -> Response {
var request = request

if let methodName: String = request.getParameter("_method"),
Expand All @@ -16,6 +16,6 @@ public struct HTTPMethodOverrideMiddleware: Middleware {
request.method = method
}

return await nextHandler(request)
return try await responder(request)
}
}
34 changes: 32 additions & 2 deletions Sources/HTTP/Middleware/Middleware.swift
Original file line number Diff line number Diff line change
@@ -1,6 +1,36 @@
public typealias Responder = (Request) async throws -> Response
public typealias ErrorResponder = (Request, Error) async throws -> Response

public protocol Middleware {
func handle(
request: Request,
nextHandler: @escaping (Request) async -> Response
) async -> Response
responder: @escaping Responder
) async throws -> Response
}

public extension Middleware {
func handle(
request: Request,
responder: @escaping Responder
) async throws -> Response {
try await responder(request)
}
}

public protocol ErrorMiddleware {
func handle(
request: Request,
error: Error,
responder: @escaping ErrorResponder
) async throws -> Response
}

public extension ErrorMiddleware {
func handle(
request: Request,
error: Error,
responder: @escaping ErrorResponder
) async throws -> Response {
try await responder(request, error)
}
}
156 changes: 127 additions & 29 deletions Sources/HTTP/Server/Handler/RequestResponseHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,91 @@ final class RequestResponseHandler: ChannelInboundHandler {
response.body = .init()
}

prepareAndWrite(response: response, for: request, in: context)
prepareAndWrite(
response: response,
for: request,
in: context
)
}
}

private func prepareAndWrite(response: Response, for request: Request, in context: ChannelHandlerContext) {
let future = handle(request: request, response: response, middleware: server.middleware)
future.whenSuccess { [self] request, response in
write(response: response, for: request, in: context)
extension RequestResponseHandler {
private func prepareAndWrite(
response: Response,
for request: Request,
in context: ChannelHandlerContext
) {
let future = processMiddleware(
server.middleware,
request: request,
response: response
)
future.whenSuccess { [weak self] request, response in
self?.write(
response: response,
for: request,
in: context
)
}
future.whenFailure { [weak self] error in
guard let self else { return }
let future = processMiddleware(
server.errorMiddleware,
request: request,
response: response,
error: error
)
future.whenSuccess { [weak self] request, response in
self?.write(
response: response,
for: request,
in: context
)
}
future.whenFailure { [weak self] error in
self?.server.logger.error("Server error: \(error)")
self?.write(
response: .init(status: .internalServerError),
for: request,
in: context
)
}
}
}

private func handle(request: Request, response: Response) async -> Response {
private func write(
response: Response,
for request: Request,
in context: ChannelHandlerContext
) {
if request.version.major >= Version.Major.two.rawValue {
context.write(
wrapOutboundOut(response),
promise: nil
)
} else {
let future = context.write(wrapOutboundOut(response))
future.whenComplete { _ in
if response.headers.has("close") {
context.close(
mode: .output,
promise: nil
)
}
}
}
}
}

extension RequestResponseHandler {
private func handle(
request: Request,
response: Response
) async throws -> Response {
var response = response

if let onReceive = server.onReceive {
let result = await onReceive(request)
let result = try await onReceive(request)

if let result = result as? Response {
response = result
Expand All @@ -64,10 +134,10 @@ final class RequestResponseHandler: ChannelInboundHandler {
return response
}

private func handle(
private func processMiddleware(
_ middleware: [Middleware],
request: Request,
response: Response,
middleware: [Middleware],
nextIndex index: Int = 0
) -> EventLoopFuture<(Request, Response)> {
let promise = request.eventLoop.makePromise(of: (Request, Response).self)
Expand All @@ -76,27 +146,29 @@ final class RequestResponseHandler: ChannelInboundHandler {
let lastIndex = middleware.count - 1

if index > lastIndex {
let response = await handle(request: request, response: response)
let response = try await handle(
request: request,
response: response
)
return (request, response)
}

let response = await middleware[index].handle(request: request) { [weak self] request in
let response = try await middleware[index].handle(request: request) { [weak self] request in
guard let self else { return response }

if index == lastIndex {
return await handle(request: request, response: response)
}

do {
return try await handle(
request: request,
response: response,
middleware: middleware,
nextIndex: index + 1
).get().1
} catch {
return response
response: response
)
}

return try await processMiddleware(
middleware,
request: request,
response: response,
nextIndex: index + 1
).get().1
}

return (request, response)
Expand All @@ -105,17 +177,43 @@ final class RequestResponseHandler: ChannelInboundHandler {
return promise.futureResult
}

private func write(response: Response, for request: Request, in context: ChannelHandlerContext) {
if request.version.major >= Version.Major.two.rawValue {
context.write(wrapOutboundOut(response), promise: nil)
} else {
let future = context.write(wrapOutboundOut(response))
private func processMiddleware(
_ middleware: [ErrorMiddleware],
request: Request,
response: Response,
error: Error,
nextIndex index: Int = 0
) -> EventLoopFuture<(Request, Response)> {
let promise = request.eventLoop.makePromise(of: (Request, Response).self)
promise.completeWithTask {
let lastIndex = middleware.count - 1

if index > lastIndex {
throw error
}

if response.headers.has("close") {
future.whenComplete { _ in
context.close(mode: .output, promise: nil)
let response = try await middleware[index].handle(
request: request,
error: error
) { [weak self] request, error in
if index == lastIndex {
throw error
}

guard let self else { return response }

return try await processMiddleware(
middleware,
request: request,
response: response,
error: error,
nextIndex: index + 1
).get().1
}

return (request, response)
}

return promise.futureResult
}
}
3 changes: 2 additions & 1 deletion Sources/HTTP/Server/Server.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ public final class Server {
public var onStart: ((EventLoop) -> Void)?
public var onStop: (() -> Void)?
public var onError: ((Error, EventLoop) -> Void)?
public var onReceive: ((Request) async -> Encodable)?
public var onReceive: ((Request) async throws -> Encodable)?
public var middleware: [Middleware] = .init()
public var errorMiddleware: [ErrorMiddleware] = .init()

public init(configuration: Configuration = .init()) {
self.configuration = configuration
Expand Down

0 comments on commit b2a4632

Please sign in to comment.