From a5efb4e28e345e228b4378ecc3e063d32d859b56 Mon Sep 17 00:00:00 2001 From: Karl Seguin Date: Sun, 1 Sep 2024 11:36:06 +0800 Subject: [PATCH] Rework middleware 1 - Global middleware can now be defined. This allows middleware to execute even when no route is found. This can be useful for things like CORS which should likely answer an OPTIONS request even on a not found. 2 - When adding middleware to a route, you can now pick between an "append" or "replace" strategy. This controls how/if the middleware for a route are appended or replace the middleware for a group and then up the chain to the global. https://github.com/karlseguin/http.zig/issues/67 --- src/httpz.zig | 100 ++++++++++------- src/router.zig | 299 ++++++++++++++++++++++++++++++++++++++----------- 2 files changed, 292 insertions(+), 107 deletions(-) diff --git a/src/httpz.zig b/src/httpz.zig index a801ec5..10ce7c5 100644 --- a/src/httpz.zig +++ b/src/httpz.zig @@ -240,6 +240,10 @@ pub fn Server(comptime H: type) type { const WebsocketHandler = if (Handler != void and comptime @hasDecl(Handler, "WebsocketHandler")) Handler.WebsocketHandler else DummyWebsocketHandler; + const RouterConfig = struct { + middlewares: []const Middleware(H) = &.{}, + }; + return struct { const TP = if (blockingMode()) ThreadPool(worker.Blocking(*Self, WebsocketHandler).handleConnection) else ThreadPool(worker.NonBlocking(*Self, WebsocketHandler).processData); @@ -253,8 +257,9 @@ pub fn Server(comptime H: type) type { _thread_pool: *TP, _signals: []posix.fd_t, _max_request_per_connection: u64, + _middlewares: []const Middleware(H), _websocket_state: websocket.server.WorkerState, - _middlewares: std.SinglyLinkedList(Middleware(H)), + _middleware_registry: std.SinglyLinkedList(Middleware(H)), const Self = @This(); @@ -305,8 +310,9 @@ pub fn Server(comptime H: type) type { .arena = arena.allocator(), ._mut = .{}, ._cond = .{}, - ._middlewares = .{}, + ._middleware_registry = .{}, ._signals = signals, + ._middlewares = &.{}, ._thread_pool = thread_pool, ._websocket_state = websocket_state, ._router = try Router(H, ActionArg).init(arena.allocator(), default_dispatcher, handler), @@ -315,11 +321,10 @@ pub fn Server(comptime H: type) type { } pub fn deinit(self: *Self) void { - self._thread_pool.stop(); self._websocket_state.deinit(); - var node = self._middlewares.first; + var node = self._middleware_registry.first; while (node) |n| { n.data.deinit(); node = n.next; @@ -467,11 +472,14 @@ pub fn Server(comptime H: type) type { } } - pub fn dispatcher(self: *Self, d: Dispatcher(H, ActionArg)) void { - (&self._router).dispatcher(d); - } + pub fn router(self: *Self, config: RouterConfig) *Router(H, ActionArg) { + // we store this in self for us when no route is found (these will + // still be executed). + self._middlewares = config.middlewares; + + // we store this in router to append to add/append to created routes + self._router.middlewares = config.middlewares; - pub fn router(self: *Self) *Router(H, ActionArg) { return &self._router; } @@ -523,7 +531,25 @@ pub fn Server(comptime H: type) type { self.handler.handle(&req, &res); } else { const dispatchable_action = self._router.route(req.method, req.url.path, &req.params); - self.dispatch(dispatchable_action, &req, &res) catch |err| { + + var executor = Executor{ + .index = 0, + .req = &req, + .res = &res, + .handler = self.handler, + .middlewares = undefined, + .dispatchable_action = dispatchable_action, + }; + + if (dispatchable_action) |da| { + req.route_data = da.data; + executor.middlewares = da.middlewares; + } else { + req.route_data = null; + executor.middlewares = self._middlewares; + } + + executor.next() catch |err| { if (comptime std.meta.hasFn(Handler, "uncaughtError")) { self.handler.uncaughtError(&req, &res, err); } else { @@ -539,27 +565,6 @@ pub fn Server(comptime H: type) type { }; } - fn dispatch(self: *const Self, dispatchable_action: ?DispatchableAction(H, ActionArg), req: *Request, res: *Response) !void { - const da = dispatchable_action orelse { - if (comptime std.meta.hasFn(Handler, "notFound")) { - return self.handler.notFound(req, res); - } - res.status = 404; - res.body = "Not Found"; - return; - }; - - req.route_data = da.data; - var executor = Executor{ - .da = da, - .index = 0, - .req = req, - .res = res, - .middlewares = da.middlewares, - }; - return executor.next(); - } - pub fn middleware(self: *Self, comptime M: type, config: M.Config) !Middleware(H) { const arena = self.arena; @@ -579,7 +584,7 @@ pub fn Server(comptime H: type) type { const iface = Middleware(H).init(m); node.data = iface; - self._middlewares.prepend(node); + self._middleware_registry.prepend(node); return iface; } @@ -588,24 +593,35 @@ pub fn Server(comptime H: type) type { index: usize, req: *Request, res: *Response, + handler: H, // pull this out of da since we'll access it a lot (not really, but w/e) middlewares: []const Middleware(H), - da: DispatchableAction(H, ActionArg), + dispatchable_action: ?DispatchableAction(H, ActionArg), pub fn next(self: *Executor) !void { const index = self.index; const middlewares = self.middlewares; - if (index == middlewares.len) { - const da = self.da; + if (index < middlewares.len) { + self.index = index + 1; + return middlewares[index].execute(self.req, self.res, self); + } + + // done executing our middlewares, now we either execute the + // dispatcher or not found. + if (self.dispatchable_action) |da| { if (comptime H == void) { return da.dispatcher(da.action, self.req, self.res); } return da.dispatcher(da.handler, da.action, self.req, self.res); } - self.index = index + 1; - return middlewares[index].execute(self.req, self.res, self); + if (comptime std.meta.hasFn(Handler, "notFound")) { + return self.handler.notFound(self.req, self.res); + } + self.res.status = 404; + self.res.body = "Not Found"; + return; } }; }; @@ -753,7 +769,7 @@ test "tests:beforeAll" { middlewares[0] = try default_server.middleware(TestMiddleware, .{.id = 100}); middlewares[1] = cors[0]; - var router = default_server.router(); + var router = default_server.router(.{}); // router.get("/test/ws", testWS); router.get("/fail", TestDummyHandler.fail, .{}); router.get("/test/json", TestDummyHandler.jsonRes, .{}); @@ -769,7 +785,7 @@ test "tests:beforeAll" { { dispatch_default_server = try Server(*TestHandlerDefaultDispatch).init(ga, .{ .port = 5993 }, &test_handler_default_dispatch1); - var router = dispatch_default_server.router(); + var router = dispatch_default_server.router(.{}); router.get("/", TestHandlerDefaultDispatch.echo, .{}); router.get("/write/*", TestHandlerDefaultDispatch.echoWrite, .{}); router.get("/fail", TestHandlerDefaultDispatch.fail, .{}); @@ -791,14 +807,14 @@ test "tests:beforeAll" { { dispatch_server = try Server(*TestHandlerDispatch).init(ga, .{ .port = 5994 }, &test_handler_dispatch); - var router = dispatch_server.router(); + var router = dispatch_server.router(.{}); router.get("/", TestHandlerDispatch.root, .{}); test_server_threads[2] = try dispatch_server.listenInNewThread(); } { dispatch_action_context_server = try Server(*TestHandlerDispatchContext).init(ga, .{ .port = 5995 }, &test_handler_disaptch_context); - var router = dispatch_action_context_server.router(); + var router = dispatch_action_context_server.router(.{}); router.get("/", TestHandlerDispatchContext.root, .{}); test_server_threads[3] = try dispatch_action_context_server.listenInNewThread(); } @@ -807,7 +823,7 @@ test "tests:beforeAll" { // with only 1 worker, and a min/max conn of 1, each request should // hit our reset path. reuse_server = try Server(void).init(ga, .{ .port = 5996, .workers = .{ .count = 1, .min_conn = 1, .max_conn = 1 } }, {}); - var router = reuse_server.router(); + var router = reuse_server.router(.{}); router.get("/test/writer", TestDummyHandler.reuseWriter, .{}); test_server_threads[4] = try reuse_server.listenInNewThread(); } @@ -819,7 +835,7 @@ test "tests:beforeAll" { { websocket_server = try Server(TestWebsocketHandler).init(ga, .{ .port = 5998 }, TestWebsocketHandler{}); - var router = websocket_server.router(); + var router = websocket_server.router(.{}); router.get("/ws", TestWebsocketHandler.upgrade, .{}); test_server_threads[6] = try websocket_server.listenInNewThread(); } diff --git a/src/router.zig b/src/router.zig index c2ac0e0..cb7eeb8 100644 --- a/src/router.zig +++ b/src/router.zig @@ -8,22 +8,28 @@ const Response = httpz.Response; const Allocator = std.mem.Allocator; const StringHashMap = std.StringHashMap; -pub fn Config(comptime Handler: type, comptime Action: type) type { +fn RouteConfig(comptime Handler: type, comptime Action: type) type { const Dispatcher = httpz.Dispatcher(Handler, Action); return struct { data: ?*const anyopaque = null, handler: ?Handler = null, dispatcher: ?Dispatcher = null, middlewares: ?[]const httpz.Middleware(Handler) = null, + middleware_strategy: ?MiddlewareStrategy = null, }; } +pub const MiddlewareStrategy = enum { + append, + replace, +}; + pub fn Router(comptime Handler: type, comptime Action: type) type { const Dispatcher = httpz.Dispatcher(Handler, Action); const DispatchableAction = httpz.DispatchableAction(Handler, Action); - const C = Config(Handler, Action); const P = Part(DispatchableAction); + const RC = RouteConfig(Handler, Action); return struct { _get: P, @@ -37,7 +43,7 @@ pub fn Router(comptime Handler: type, comptime Action: type) type { _allocator: Allocator, handler: Handler, dispatcher: Dispatcher, - middlewares: ?[]const httpz.Middleware(Handler), + middlewares: []const httpz.Middleware(Handler), const Self = @This(); @@ -47,7 +53,7 @@ pub fn Router(comptime Handler: type, comptime Action: type) type { .handler = handler, ._allocator = allocator, .dispatcher = dispatcher, - .middlewares = null, + .middlewares = &.{}, ._get = try P.init(allocator), ._head = try P.init(allocator), ._post = try P.init(allocator), @@ -56,11 +62,10 @@ pub fn Router(comptime Handler: type, comptime Action: type) type { ._trace = try P.init(allocator), ._delete = try P.init(allocator), ._options = try P.init(allocator), - }; } - pub fn group(self: *Self, prefix: []const u8, config: C) Group(Handler, Action) { + pub fn group(self: *Self, prefix: []const u8, config: RC) Group(Handler, Action) { return Group(Handler, Action).init(self, prefix, config); } @@ -76,66 +81,66 @@ pub fn Router(comptime Handler: type, comptime Action: type) type { }; } - pub fn get(self: *Self, path: []const u8, action: Action, config: C) void { + pub fn get(self: *Self, path: []const u8, action: Action, config: RC) void { self.tryGet(path, action, config) catch @panic("failed to create route"); } - pub fn tryGet(self: *Self, path: []const u8, action: Action, config: C) !void { + pub fn tryGet(self: *Self, path: []const u8, action: Action, config: RC) !void { return self.addRoute(&self._get, path, action, config); } - pub fn put(self: *Self, path: []const u8, action: Action, config: C) void { + pub fn put(self: *Self, path: []const u8, action: Action, config: RC) void { self.tryPut(path, action, config) catch @panic("failed to create route"); } - pub fn tryPut(self: *Self, path: []const u8, action: Action, config: C) !void { + pub fn tryPut(self: *Self, path: []const u8, action: Action, config: RC) !void { return self.addRoute(&self._put, path, action, config); } - pub fn post(self: *Self, path: []const u8, action: Action, config: C) void { + pub fn post(self: *Self, path: []const u8, action: Action, config: RC) void { self.tryPost(path, action, config) catch @panic("failed to create route"); } - pub fn tryPost(self: *Self, path: []const u8, action: Action, config: C) !void { + pub fn tryPost(self: *Self, path: []const u8, action: Action, config: RC) !void { return self.addRoute(&self._post, path, action, config); } - pub fn head(self: *Self, path: []const u8, action: Action, config: C) void { + pub fn head(self: *Self, path: []const u8, action: Action, config: RC) void { self.tryHead(path, action, config) catch @panic("failed to create route"); } - pub fn tryHead(self: *Self, path: []const u8, action: Action, config: C) !void { + pub fn tryHead(self: *Self, path: []const u8, action: Action, config: RC) !void { return self.addRoute(&self._head, path, action, config); } - pub fn patch(self: *Self, path: []const u8, action: Action, config: C) void { + pub fn patch(self: *Self, path: []const u8, action: Action, config: RC) void { self.tryPatch(path, action, config) catch @panic("failed to create route"); } - pub fn tryPatch(self: *Self, path: []const u8, action: Action, config: C) !void { + pub fn tryPatch(self: *Self, path: []const u8, action: Action, config: RC) !void { return self.addRoute(&self._patch, path, action, config); } - pub fn trace(self: *Self, path: []const u8, action: Action, config: C) void { + pub fn trace(self: *Self, path: []const u8, action: Action, config: RC) void { self.tryTrace(path, action, config) catch @panic("failed to create route"); } - pub fn tryTrace(self: *Self, path: []const u8, action: Action, config: C) !void { + pub fn tryTrace(self: *Self, path: []const u8, action: Action, config: RC) !void { return self.addRoute(&self._trace, path, action, config); } - pub fn delete(self: *Self, path: []const u8, action: Action, config: C) void { + pub fn delete(self: *Self, path: []const u8, action: Action, config: RC) void { self.tryDelete(path, action, config) catch @panic("failed to create route"); } - pub fn tryDelete(self: *Self, path: []const u8, action: Action, config: C) !void { + pub fn tryDelete(self: *Self, path: []const u8, action: Action, config: RC) !void { return self.addRoute(&self._delete, path, action, config); } - pub fn options(self: *Self, path: []const u8, action: Action, config: C) void { + pub fn options(self: *Self, path: []const u8, action: Action, config: RC) void { self.tryOptions(path, action, config) catch @panic("failed to create route"); } - pub fn tryOptions(self: *Self, path: []const u8, action: Action, config: C) !void { + pub fn tryOptions(self: *Self, path: []const u8, action: Action, config: RC) !void { return self.addRoute(&self._options, path, action, config); } - pub fn all(self: *Self, path: []const u8, action: Action, config: C) void { + pub fn all(self: *Self, path: []const u8, action: Action, config: RC) void { self.tryAll(path, action, config) catch @panic("failed to create route"); } - pub fn tryAll(self: *Self, path: []const u8, action: Action, config: C) !void { + pub fn tryAll(self: *Self, path: []const u8, action: Action, config: RC) !void { try self.tryGet(path, action, config); try self.tryPut(path, action, config); try self.tryPost(path, action, config); @@ -146,13 +151,13 @@ pub fn Router(comptime Handler: type, comptime Action: type) type { try self.tryOptions(path, action, config); } - fn addRoute(self: *Self, root: *P, path: []const u8, action: Action, config: C) !void { + fn addRoute(self: *Self, root: *P, path: []const u8, action: Action, config: RC) !void { const da = DispatchableAction{ .action = action, .data = config.data, .handler = config.handler orelse self.handler, .dispatcher = config.dispatcher orelse self.dispatcher, - .middlewares = config.middlewares orelse self.middlewares orelse &.{}, + .middlewares = try self.mergeMiddleware(self.middlewares, config), }; if (path.len == 0 or (path.len == 1 and path[0] == '/')) { @@ -229,20 +234,35 @@ pub fn Router(comptime Handler: type, comptime Action: type) type { route_part.action = da; } + + fn mergeMiddleware(self: *Self, parent_middlewares: []const httpz.Middleware(Handler), config: RC) ![]const httpz.Middleware(Handler) { + const route_middlewares = config.middlewares orelse return parent_middlewares; + + const strategy = config.middleware_strategy orelse .append; + if (strategy == .replace or parent_middlewares.len == 0) { + return route_middlewares; + } + + // allocator is an arena + const merged = try self._allocator.alloc(httpz.Middleware(Handler), route_middlewares.len + parent_middlewares.len); + @memcpy(merged[0..parent_middlewares.len], parent_middlewares); + @memcpy(merged[parent_middlewares.len..], route_middlewares); + return merged; + } }; } pub fn Group(comptime Handler: type, comptime Action: type) type { - const C = Config(Handler, Action); + const RC = RouteConfig(Handler, Action); return struct { _allocator: Allocator, _prefix: []const u8, _router: *Router(Handler, Action), - _config: Config(Handler, Action), + _config: RouteConfig(Handler, Action), const Self = @This(); - fn init(router: *Router(Handler, Action), prefix: []const u8, config: C) Self { + fn init(router: *Router(Handler, Action), prefix: []const u8, config: RC) Self { return .{ ._prefix = prefix, ._router = router, @@ -251,82 +271,73 @@ pub fn Group(comptime Handler: type, comptime Action: type) type { }; } - pub fn get(self: *Self, path: []const u8, action: Action, override: C) void { + pub fn get(self: *Self, path: []const u8, action: Action, override: RC) void { self._router.get(self.createPath(path), action, self.mergeConfig(override)); } - pub fn tryGet(self: *Self, path: []const u8, action: Action, override: C) !void { - return self._router.tryGet(self.tryCreatePath(path), action, self.mergeConfig(override)); + pub fn tryGet(self: *Self, path: []const u8, action: Action, override: RC) !void { + return self._router.tryGet(self.tryCreatePath(path), action, self.tryMergeConfig(override)); } - pub fn put(self: *Self, path: []const u8, action: Action, override: C) void { + pub fn put(self: *Self, path: []const u8, action: Action, override: RC) void { self._router.put(self.createPath(path), action, self.mergeConfig(override)); } - pub fn tryPut(self: *Self, path: []const u8, action: Action, override: C) !void { - return self._router.tryPut(self.tryCreatePath(path), action, self.mergeConfig(override)); + pub fn tryPut(self: *Self, path: []const u8, action: Action, override: RC) !void { + return self._router.tryPut(self.tryCreatePath(path), action, self.tryMergeConfig(override)); } - pub fn post(self: *Self, path: []const u8, action: Action, override: C) void { + pub fn post(self: *Self, path: []const u8, action: Action, override: RC) void { self._router.post(self.createPath(path), action, self.mergeConfig(override)); } - pub fn tryPost(self: *Self, path: []const u8, action: Action, override: C) !void { - return self._router.tryPost(self.tryCreatePath(path), action, self.mergeConfig(override)); + pub fn tryPost(self: *Self, path: []const u8, action: Action, override: RC) !void { + return self._router.tryPost(self.tryCreatePath(path), action, self.tryMergeConfig(override)); } - pub fn head(self: *Self, path: []const u8, action: Action, override: C) void { + pub fn head(self: *Self, path: []const u8, action: Action, override: RC) void { self._router.head(self.createPath(path), action, self.mergeConfig(override)); } - pub fn tryHead(self: *Self, path: []const u8, action: Action, override: C) !void { - return self._router.tryHead(self.tryCreatePath(path), action, self.mergeConfig(override)); + pub fn tryHead(self: *Self, path: []const u8, action: Action, override: RC) !void { + return self._router.tryHead(self.tryCreatePath(path), action, self.tryMergeConfig(override)); } - pub fn patch(self: *Self, path: []const u8, action: Action, override: C) void { + pub fn patch(self: *Self, path: []const u8, action: Action, override: RC) void { self._router.patch(self.createPath(path), action, self.mergeConfig(override)); } - pub fn tryPatch(self: *Self, path: []const u8, action: Action, override: C) !void { - return self._router.tryPatch(self.tryCreatePath(path), action, self.mergeConfig(override)); + pub fn tryPatch(self: *Self, path: []const u8, action: Action, override: RC) !void { + return self._router.tryPatch(self.tryCreatePath(path), action, self.tryMergeConfig(override)); } - pub fn trace(self: *Self, path: []const u8, action: Action, override: C) void { + pub fn trace(self: *Self, path: []const u8, action: Action, override: RC) void { self._router.trace(self.createPath(path), action, self.mergeConfig(override)); } - pub fn tryTrace(self: *Self, path: []const u8, action: Action, override: C) !void { - return self._router.tryTrace(self.tryCreatePath(path), action, self.mergeConfig(override)); + pub fn tryTrace(self: *Self, path: []const u8, action: Action, override: RC) !void { + return self._router.tryTrace(self.tryCreatePath(path), action, self.tryMergeConfig(override)); } - pub fn delete(self: *Self, path: []const u8, action: Action, override: C) void { + pub fn delete(self: *Self, path: []const u8, action: Action, override: RC) void { self._router.delete(self.createPath(path), action, self.mergeConfig(override)); } - pub fn tryDelete(self: *Self, path: []const u8, action: Action, override: C) !void { - return self._router.tryDelete(self.tryCreatePath(path), action, self.mergeConfig(override)); + pub fn tryDelete(self: *Self, path: []const u8, action: Action, override: RC) !void { + return self._router.tryDelete(self.tryCreatePath(path), action, self.tryMergeConfig(override)); } - pub fn options(self: *Self, path: []const u8, action: Action, override: C) void { + pub fn options(self: *Self, path: []const u8, action: Action, override: RC) void { self._router.options(self.createPath(path), action, self.mergeConfig(override)); } - pub fn tryOptions(self: *Self, path: []const u8, action: Action, override: C) !void { - return self._router.tryOptions(self.tryCreatePath(path), action, self.mergeConfig(override)); + pub fn tryOptions(self: *Self, path: []const u8, action: Action, override: RC) !void { + return self._router.tryOptions(self.tryCreatePath(path), action, self.tryMergeConfig(override)); } - pub fn all(self: *Self, path: []const u8, action: Action, override: C) void { + pub fn all(self: *Self, path: []const u8, action: Action, override: RC) void { self._router.all(self.createPath(path), action, self.mergeConfig(override)); } - pub fn tryAll(self: *Self, path: []const u8, action: Action, override: C) !void { - return self._router.tryAll(self.tryCreatePath(path), action, self.mergeConfig(override)); + pub fn tryAll(self: *Self, path: []const u8, action: Action, override: RC) !void { + return self._router.tryAll(self.tryCreatePath(path), action, self.tryMergeConfig(override)); } fn createPath(self: *Self, path: []const u8) []const u8 { return self.tryCreatePath(path) catch unreachable; } - fn mergeConfig(self: *const Self, override: C) C { - return .{ - .data = override.data orelse self._config.data, - .handler = override.handler orelse self._config.handler, - .dispatcher = override.dispatcher orelse self._config.dispatcher, - .middlewares = override.middlewares orelse self._config.middlewares, - }; - } - fn tryCreatePath(self: *Self, path: []const u8) ![]const u8 { var prefix = self._prefix; if (prefix.len == 0) { @@ -348,6 +359,20 @@ pub fn Group(comptime Handler: type, comptime Action: type) type { @memcpy(joined[prefix.len..], path); return joined; } + + fn mergeConfig(self: *Self, override: RC) RC { + return self.tryMergeConfig(override) catch unreachable; + } + + fn tryMergeConfig(self: *Self, override: RC) !RC { + return .{ + .data = override.data orelse self._config.data, + .handler = override.handler orelse self._config.handler, + .dispatcher = override.dispatcher orelse self._config.dispatcher, + .middlewares = try self._router.mergeMiddleware(self._config.middlewares orelse &.{}, override), + .middleware_strategy = override.middleware_strategy orelse self._config.middleware_strategy, + }; + } }; } @@ -624,6 +649,150 @@ test "route: glob" { } } +test "route: middlewares no global" { + defer t.reset(); + + const m1 = fakeMiddleware(&.{.id = 1}); + const m2 = fakeMiddleware(&.{.id = 2}); + const m3 = fakeMiddleware(&.{.id = 3}); + + var params = try Params.init(t.arena.allocator(), 5); + var router = Router(void, httpz.Action(void)).init(t.arena.allocator(), testDispatcher1, {}) catch unreachable; + + { + router.get("/1", testRoute1, .{}); + router.get("/2", testRoute1, .{.middlewares = &.{m1}}); + router.get("/3", testRoute1, .{.middlewares = &.{m1}, .middleware_strategy = .replace}); + router.get("/4", testRoute1, .{.middlewares = &.{m1, m2}}); + router.get("/5", testRoute1, .{.middlewares = &.{m1, m2}, .middleware_strategy = .replace}); + + try assertMiddlewares(&router, ¶ms, "/1", &.{}); + try assertMiddlewares(&router, ¶ms, "/2", &.{1}); + try assertMiddlewares(&router, ¶ms, "/3", &.{1}); + try assertMiddlewares(&router, ¶ms, "/4", &.{1, 2}); + try assertMiddlewares(&router, ¶ms, "/5", &.{1, 2}); + } + + { + // group with no group-level middleware + var group = router.group("/g1", .{}); + group.get("/1", testRoute1, .{}); + group.get("/2", testRoute1, .{.middlewares = &.{m1}}); + group.get("/3", testRoute1, .{.middlewares = &.{m1}, .middleware_strategy = .append}); + group.get("/4", testRoute1, .{.middlewares = &.{m1, m2}, .middleware_strategy = .replace}); + + try assertMiddlewares(&router, ¶ms, "/g1/1", &.{}); + try assertMiddlewares(&router, ¶ms, "/g1/2", &.{1}); + try assertMiddlewares(&router, ¶ms, "/g1/3", &.{1}); + try assertMiddlewares(&router, ¶ms, "/g1/4", &.{1, 2}); + } + + { + // group with group-level middleware + var group = router.group("/g2", .{.middlewares = &.{m1}}); + group.get("/1", testRoute1, .{}); + group.get("/2", testRoute1, .{.middlewares = &.{m2}}); + group.get("/3", testRoute1, .{.middlewares = &.{m2, m3}, .middleware_strategy = .append}); + group.get("/4", testRoute1, .{.middlewares = &.{m2, m3}, .middleware_strategy = .replace}); + + try assertMiddlewares(&router, ¶ms, "/g2/1", &.{1}); + try assertMiddlewares(&router, ¶ms, "/g2/2", &.{1, 2}); + try assertMiddlewares(&router, ¶ms, "/g2/3", &.{1, 2, 3}); + try assertMiddlewares(&router, ¶ms, "/g2/4", &.{2, 3}); + } +} + +test "route: middlewares with global" { + defer t.reset(); + + const m1 = fakeMiddleware(&.{.id = 1}); + const m2 = fakeMiddleware(&.{.id = 2}); + const m3 = fakeMiddleware(&.{.id = 3}); + const m4 = fakeMiddleware(&.{.id = 4}); + const m5 = fakeMiddleware(&.{.id = 5}); + + var params = try Params.init(t.arena.allocator(), 5); + var router = Router(void, httpz.Action(void)).init(t.arena.allocator(), testDispatcher1, {}) catch unreachable; + router.middlewares = &.{m4, m5}; + + { + router.get("/1", testRoute1, .{}); + router.get("/2", testRoute1, .{.middlewares = &.{m1}}); + router.get("/3", testRoute1, .{.middlewares = &.{m1}, .middleware_strategy = .replace}); + router.get("/4", testRoute1, .{.middlewares = &.{m1, m2}}); + router.get("/5", testRoute1, .{.middlewares = &.{m1, m2}, .middleware_strategy = .replace}); + + try assertMiddlewares(&router, ¶ms, "/1", &.{4, 5}); + try assertMiddlewares(&router, ¶ms, "/2", &.{4, 5, 1}); + try assertMiddlewares(&router, ¶ms, "/3", &.{1}); + try assertMiddlewares(&router, ¶ms, "/4", &.{4, 5, 1, 2}); + try assertMiddlewares(&router, ¶ms, "/5", &.{1, 2}); + } + + { + // group with no group-level middleware + var group = router.group("/g1", .{}); + group.get("/1", testRoute1, .{}); + group.get("/2", testRoute1, .{.middlewares = &.{m1}}); + group.get("/3", testRoute1, .{.middlewares = &.{m1, m2}, .middleware_strategy = .append}); + group.get("/4", testRoute1, .{.middlewares = &.{m1, m2}, .middleware_strategy = .replace}); + + try assertMiddlewares(&router, ¶ms, "/g1/1", &.{4, 5}); + try assertMiddlewares(&router, ¶ms, "/g1/2", &.{4, 5, 1}); + try assertMiddlewares(&router, ¶ms, "/g1/3", &.{4, 5, 1, 2}); + try assertMiddlewares(&router, ¶ms, "/g1/4", &.{1, 2}); + } + + { + // group with appended group-level middleware + var group = router.group("/g2", .{.middlewares = &.{m1}, .middleware_strategy = .append}); + group.get("/1", testRoute1, .{}); + group.get("/2", testRoute1, .{.middlewares = &.{m2}}); + group.get("/3", testRoute1, .{.middlewares = &.{m2, m3}, .middleware_strategy = .append}); + group.get("/4", testRoute1, .{.middlewares = &.{m2, m3}, .middleware_strategy = .replace}); + + try assertMiddlewares(&router, ¶ms, "/g2/1", &.{4, 5, 1}); + try assertMiddlewares(&router, ¶ms, "/g2/2", &.{4, 5, 1, 2}); + try assertMiddlewares(&router, ¶ms, "/g2/3", &.{4, 5, 1, 2, 3}); + try assertMiddlewares(&router, ¶ms, "/g2/4", &.{2, 3}); + } + + { + // group with replace group-level middleware + var group = router.group("/g2", .{.middlewares = &.{m1}, .middleware_strategy = .replace}); + group.get("/1", testRoute1, .{}); + group.get("/2", testRoute1, .{.middlewares = &.{m2}}); + group.get("/3", testRoute1, .{.middlewares = &.{m2, m3}, .middleware_strategy = .append}); + group.get("/4", testRoute1, .{.middlewares = &.{m2, m3}, .middleware_strategy = .replace}); + + try assertMiddlewares(&router, ¶ms, "/g2/1", &.{1}); + try assertMiddlewares(&router, ¶ms, "/g2/2", &.{1, 2}); + try assertMiddlewares(&router, ¶ms, "/g2/3", &.{4, 5, 1, 2, 3}); + try assertMiddlewares(&router, ¶ms, "/g2/4", &.{2, 3}); + } +} + +fn assertMiddlewares(router: anytype, params: *Params, path: []const u8, expected: []const u32) !void { + const middlewares = router.route(httpz.Method.GET, path, params).?.middlewares; + try t.expectEqual(expected.len, middlewares.len); + for (expected, middlewares) |e, m| { + const impl: *const FakeMiddlewareImpl = @ptrCast(@alignCast(m.ptr)); + try t.expectEqual(e, impl.id); + } +} + +fn fakeMiddleware(impl: *const FakeMiddlewareImpl) httpz.Middleware(void) { + return .{ + .ptr = @constCast(impl), + .deinitFn = undefined, + .executeFn = undefined, + }; +} + +const FakeMiddlewareImpl = struct { + id: u32, +}; + // TODO: this functionality isn't implemented because I can't think of a way // to do it which isn't relatively expensive (e.g. recursively or keeping a // stack of (a) parts and (b) url segments and trying to rematch every possible