diff --git a/src/httpz.zig b/src/httpz.zig index a801ec5..10ce7c5 100644 --- a/src/httpz.zig +++ b/src/httpz.zig @@ -240,6 +240,10 @@ pub fn Server(comptime H: type) type { const WebsocketHandler = if (Handler != void and comptime @hasDecl(Handler, "WebsocketHandler")) Handler.WebsocketHandler else DummyWebsocketHandler; + const RouterConfig = struct { + middlewares: []const Middleware(H) = &.{}, + }; + return struct { const TP = if (blockingMode()) ThreadPool(worker.Blocking(*Self, WebsocketHandler).handleConnection) else ThreadPool(worker.NonBlocking(*Self, WebsocketHandler).processData); @@ -253,8 +257,9 @@ pub fn Server(comptime H: type) type { _thread_pool: *TP, _signals: []posix.fd_t, _max_request_per_connection: u64, + _middlewares: []const Middleware(H), _websocket_state: websocket.server.WorkerState, - _middlewares: std.SinglyLinkedList(Middleware(H)), + _middleware_registry: std.SinglyLinkedList(Middleware(H)), const Self = @This(); @@ -305,8 +310,9 @@ pub fn Server(comptime H: type) type { .arena = arena.allocator(), ._mut = .{}, ._cond = .{}, - ._middlewares = .{}, + ._middleware_registry = .{}, ._signals = signals, + ._middlewares = &.{}, ._thread_pool = thread_pool, ._websocket_state = websocket_state, ._router = try Router(H, ActionArg).init(arena.allocator(), default_dispatcher, handler), @@ -315,11 +321,10 @@ pub fn Server(comptime H: type) type { } pub fn deinit(self: *Self) void { - self._thread_pool.stop(); self._websocket_state.deinit(); - var node = self._middlewares.first; + var node = self._middleware_registry.first; while (node) |n| { n.data.deinit(); node = n.next; @@ -467,11 +472,14 @@ pub fn Server(comptime H: type) type { } } - pub fn dispatcher(self: *Self, d: Dispatcher(H, ActionArg)) void { - (&self._router).dispatcher(d); - } + pub fn router(self: *Self, config: RouterConfig) *Router(H, ActionArg) { + // we store this in self for us when no route is found (these will + // still be executed). + self._middlewares = config.middlewares; + + // we store this in router to append to add/append to created routes + self._router.middlewares = config.middlewares; - pub fn router(self: *Self) *Router(H, ActionArg) { return &self._router; } @@ -523,7 +531,25 @@ pub fn Server(comptime H: type) type { self.handler.handle(&req, &res); } else { const dispatchable_action = self._router.route(req.method, req.url.path, &req.params); - self.dispatch(dispatchable_action, &req, &res) catch |err| { + + var executor = Executor{ + .index = 0, + .req = &req, + .res = &res, + .handler = self.handler, + .middlewares = undefined, + .dispatchable_action = dispatchable_action, + }; + + if (dispatchable_action) |da| { + req.route_data = da.data; + executor.middlewares = da.middlewares; + } else { + req.route_data = null; + executor.middlewares = self._middlewares; + } + + executor.next() catch |err| { if (comptime std.meta.hasFn(Handler, "uncaughtError")) { self.handler.uncaughtError(&req, &res, err); } else { @@ -539,27 +565,6 @@ pub fn Server(comptime H: type) type { }; } - fn dispatch(self: *const Self, dispatchable_action: ?DispatchableAction(H, ActionArg), req: *Request, res: *Response) !void { - const da = dispatchable_action orelse { - if (comptime std.meta.hasFn(Handler, "notFound")) { - return self.handler.notFound(req, res); - } - res.status = 404; - res.body = "Not Found"; - return; - }; - - req.route_data = da.data; - var executor = Executor{ - .da = da, - .index = 0, - .req = req, - .res = res, - .middlewares = da.middlewares, - }; - return executor.next(); - } - pub fn middleware(self: *Self, comptime M: type, config: M.Config) !Middleware(H) { const arena = self.arena; @@ -579,7 +584,7 @@ pub fn Server(comptime H: type) type { const iface = Middleware(H).init(m); node.data = iface; - self._middlewares.prepend(node); + self._middleware_registry.prepend(node); return iface; } @@ -588,24 +593,35 @@ pub fn Server(comptime H: type) type { index: usize, req: *Request, res: *Response, + handler: H, // pull this out of da since we'll access it a lot (not really, but w/e) middlewares: []const Middleware(H), - da: DispatchableAction(H, ActionArg), + dispatchable_action: ?DispatchableAction(H, ActionArg), pub fn next(self: *Executor) !void { const index = self.index; const middlewares = self.middlewares; - if (index == middlewares.len) { - const da = self.da; + if (index < middlewares.len) { + self.index = index + 1; + return middlewares[index].execute(self.req, self.res, self); + } + + // done executing our middlewares, now we either execute the + // dispatcher or not found. + if (self.dispatchable_action) |da| { if (comptime H == void) { return da.dispatcher(da.action, self.req, self.res); } return da.dispatcher(da.handler, da.action, self.req, self.res); } - self.index = index + 1; - return middlewares[index].execute(self.req, self.res, self); + if (comptime std.meta.hasFn(Handler, "notFound")) { + return self.handler.notFound(self.req, self.res); + } + self.res.status = 404; + self.res.body = "Not Found"; + return; } }; }; @@ -753,7 +769,7 @@ test "tests:beforeAll" { middlewares[0] = try default_server.middleware(TestMiddleware, .{.id = 100}); middlewares[1] = cors[0]; - var router = default_server.router(); + var router = default_server.router(.{}); // router.get("/test/ws", testWS); router.get("/fail", TestDummyHandler.fail, .{}); router.get("/test/json", TestDummyHandler.jsonRes, .{}); @@ -769,7 +785,7 @@ test "tests:beforeAll" { { dispatch_default_server = try Server(*TestHandlerDefaultDispatch).init(ga, .{ .port = 5993 }, &test_handler_default_dispatch1); - var router = dispatch_default_server.router(); + var router = dispatch_default_server.router(.{}); router.get("/", TestHandlerDefaultDispatch.echo, .{}); router.get("/write/*", TestHandlerDefaultDispatch.echoWrite, .{}); router.get("/fail", TestHandlerDefaultDispatch.fail, .{}); @@ -791,14 +807,14 @@ test "tests:beforeAll" { { dispatch_server = try Server(*TestHandlerDispatch).init(ga, .{ .port = 5994 }, &test_handler_dispatch); - var router = dispatch_server.router(); + var router = dispatch_server.router(.{}); router.get("/", TestHandlerDispatch.root, .{}); test_server_threads[2] = try dispatch_server.listenInNewThread(); } { dispatch_action_context_server = try Server(*TestHandlerDispatchContext).init(ga, .{ .port = 5995 }, &test_handler_disaptch_context); - var router = dispatch_action_context_server.router(); + var router = dispatch_action_context_server.router(.{}); router.get("/", TestHandlerDispatchContext.root, .{}); test_server_threads[3] = try dispatch_action_context_server.listenInNewThread(); } @@ -807,7 +823,7 @@ test "tests:beforeAll" { // with only 1 worker, and a min/max conn of 1, each request should // hit our reset path. reuse_server = try Server(void).init(ga, .{ .port = 5996, .workers = .{ .count = 1, .min_conn = 1, .max_conn = 1 } }, {}); - var router = reuse_server.router(); + var router = reuse_server.router(.{}); router.get("/test/writer", TestDummyHandler.reuseWriter, .{}); test_server_threads[4] = try reuse_server.listenInNewThread(); } @@ -819,7 +835,7 @@ test "tests:beforeAll" { { websocket_server = try Server(TestWebsocketHandler).init(ga, .{ .port = 5998 }, TestWebsocketHandler{}); - var router = websocket_server.router(); + var router = websocket_server.router(.{}); router.get("/ws", TestWebsocketHandler.upgrade, .{}); test_server_threads[6] = try websocket_server.listenInNewThread(); } diff --git a/src/router.zig b/src/router.zig index c2ac0e0..cb7eeb8 100644 --- a/src/router.zig +++ b/src/router.zig @@ -8,22 +8,28 @@ const Response = httpz.Response; const Allocator = std.mem.Allocator; const StringHashMap = std.StringHashMap; -pub fn Config(comptime Handler: type, comptime Action: type) type { +fn RouteConfig(comptime Handler: type, comptime Action: type) type { const Dispatcher = httpz.Dispatcher(Handler, Action); return struct { data: ?*const anyopaque = null, handler: ?Handler = null, dispatcher: ?Dispatcher = null, middlewares: ?[]const httpz.Middleware(Handler) = null, + middleware_strategy: ?MiddlewareStrategy = null, }; } +pub const MiddlewareStrategy = enum { + append, + replace, +}; + pub fn Router(comptime Handler: type, comptime Action: type) type { const Dispatcher = httpz.Dispatcher(Handler, Action); const DispatchableAction = httpz.DispatchableAction(Handler, Action); - const C = Config(Handler, Action); const P = Part(DispatchableAction); + const RC = RouteConfig(Handler, Action); return struct { _get: P, @@ -37,7 +43,7 @@ pub fn Router(comptime Handler: type, comptime Action: type) type { _allocator: Allocator, handler: Handler, dispatcher: Dispatcher, - middlewares: ?[]const httpz.Middleware(Handler), + middlewares: []const httpz.Middleware(Handler), const Self = @This(); @@ -47,7 +53,7 @@ pub fn Router(comptime Handler: type, comptime Action: type) type { .handler = handler, ._allocator = allocator, .dispatcher = dispatcher, - .middlewares = null, + .middlewares = &.{}, ._get = try P.init(allocator), ._head = try P.init(allocator), ._post = try P.init(allocator), @@ -56,11 +62,10 @@ pub fn Router(comptime Handler: type, comptime Action: type) type { ._trace = try P.init(allocator), ._delete = try P.init(allocator), ._options = try P.init(allocator), - }; } - pub fn group(self: *Self, prefix: []const u8, config: C) Group(Handler, Action) { + pub fn group(self: *Self, prefix: []const u8, config: RC) Group(Handler, Action) { return Group(Handler, Action).init(self, prefix, config); } @@ -76,66 +81,66 @@ pub fn Router(comptime Handler: type, comptime Action: type) type { }; } - pub fn get(self: *Self, path: []const u8, action: Action, config: C) void { + pub fn get(self: *Self, path: []const u8, action: Action, config: RC) void { self.tryGet(path, action, config) catch @panic("failed to create route"); } - pub fn tryGet(self: *Self, path: []const u8, action: Action, config: C) !void { + pub fn tryGet(self: *Self, path: []const u8, action: Action, config: RC) !void { return self.addRoute(&self._get, path, action, config); } - pub fn put(self: *Self, path: []const u8, action: Action, config: C) void { + pub fn put(self: *Self, path: []const u8, action: Action, config: RC) void { self.tryPut(path, action, config) catch @panic("failed to create route"); } - pub fn tryPut(self: *Self, path: []const u8, action: Action, config: C) !void { + pub fn tryPut(self: *Self, path: []const u8, action: Action, config: RC) !void { return self.addRoute(&self._put, path, action, config); } - pub fn post(self: *Self, path: []const u8, action: Action, config: C) void { + pub fn post(self: *Self, path: []const u8, action: Action, config: RC) void { self.tryPost(path, action, config) catch @panic("failed to create route"); } - pub fn tryPost(self: *Self, path: []const u8, action: Action, config: C) !void { + pub fn tryPost(self: *Self, path: []const u8, action: Action, config: RC) !void { return self.addRoute(&self._post, path, action, config); } - pub fn head(self: *Self, path: []const u8, action: Action, config: C) void { + pub fn head(self: *Self, path: []const u8, action: Action, config: RC) void { self.tryHead(path, action, config) catch @panic("failed to create route"); } - pub fn tryHead(self: *Self, path: []const u8, action: Action, config: C) !void { + pub fn tryHead(self: *Self, path: []const u8, action: Action, config: RC) !void { return self.addRoute(&self._head, path, action, config); } - pub fn patch(self: *Self, path: []const u8, action: Action, config: C) void { + pub fn patch(self: *Self, path: []const u8, action: Action, config: RC) void { self.tryPatch(path, action, config) catch @panic("failed to create route"); } - pub fn tryPatch(self: *Self, path: []const u8, action: Action, config: C) !void { + pub fn tryPatch(self: *Self, path: []const u8, action: Action, config: RC) !void { return self.addRoute(&self._patch, path, action, config); } - pub fn trace(self: *Self, path: []const u8, action: Action, config: C) void { + pub fn trace(self: *Self, path: []const u8, action: Action, config: RC) void { self.tryTrace(path, action, config) catch @panic("failed to create route"); } - pub fn tryTrace(self: *Self, path: []const u8, action: Action, config: C) !void { + pub fn tryTrace(self: *Self, path: []const u8, action: Action, config: RC) !void { return self.addRoute(&self._trace, path, action, config); } - pub fn delete(self: *Self, path: []const u8, action: Action, config: C) void { + pub fn delete(self: *Self, path: []const u8, action: Action, config: RC) void { self.tryDelete(path, action, config) catch @panic("failed to create route"); } - pub fn tryDelete(self: *Self, path: []const u8, action: Action, config: C) !void { + pub fn tryDelete(self: *Self, path: []const u8, action: Action, config: RC) !void { return self.addRoute(&self._delete, path, action, config); } - pub fn options(self: *Self, path: []const u8, action: Action, config: C) void { + pub fn options(self: *Self, path: []const u8, action: Action, config: RC) void { self.tryOptions(path, action, config) catch @panic("failed to create route"); } - pub fn tryOptions(self: *Self, path: []const u8, action: Action, config: C) !void { + pub fn tryOptions(self: *Self, path: []const u8, action: Action, config: RC) !void { return self.addRoute(&self._options, path, action, config); } - pub fn all(self: *Self, path: []const u8, action: Action, config: C) void { + pub fn all(self: *Self, path: []const u8, action: Action, config: RC) void { self.tryAll(path, action, config) catch @panic("failed to create route"); } - pub fn tryAll(self: *Self, path: []const u8, action: Action, config: C) !void { + pub fn tryAll(self: *Self, path: []const u8, action: Action, config: RC) !void { try self.tryGet(path, action, config); try self.tryPut(path, action, config); try self.tryPost(path, action, config); @@ -146,13 +151,13 @@ pub fn Router(comptime Handler: type, comptime Action: type) type { try self.tryOptions(path, action, config); } - fn addRoute(self: *Self, root: *P, path: []const u8, action: Action, config: C) !void { + fn addRoute(self: *Self, root: *P, path: []const u8, action: Action, config: RC) !void { const da = DispatchableAction{ .action = action, .data = config.data, .handler = config.handler orelse self.handler, .dispatcher = config.dispatcher orelse self.dispatcher, - .middlewares = config.middlewares orelse self.middlewares orelse &.{}, + .middlewares = try self.mergeMiddleware(self.middlewares, config), }; if (path.len == 0 or (path.len == 1 and path[0] == '/')) { @@ -229,20 +234,35 @@ pub fn Router(comptime Handler: type, comptime Action: type) type { route_part.action = da; } + + fn mergeMiddleware(self: *Self, parent_middlewares: []const httpz.Middleware(Handler), config: RC) ![]const httpz.Middleware(Handler) { + const route_middlewares = config.middlewares orelse return parent_middlewares; + + const strategy = config.middleware_strategy orelse .append; + if (strategy == .replace or parent_middlewares.len == 0) { + return route_middlewares; + } + + // allocator is an arena + const merged = try self._allocator.alloc(httpz.Middleware(Handler), route_middlewares.len + parent_middlewares.len); + @memcpy(merged[0..parent_middlewares.len], parent_middlewares); + @memcpy(merged[parent_middlewares.len..], route_middlewares); + return merged; + } }; } pub fn Group(comptime Handler: type, comptime Action: type) type { - const C = Config(Handler, Action); + const RC = RouteConfig(Handler, Action); return struct { _allocator: Allocator, _prefix: []const u8, _router: *Router(Handler, Action), - _config: Config(Handler, Action), + _config: RouteConfig(Handler, Action), const Self = @This(); - fn init(router: *Router(Handler, Action), prefix: []const u8, config: C) Self { + fn init(router: *Router(Handler, Action), prefix: []const u8, config: RC) Self { return .{ ._prefix = prefix, ._router = router, @@ -251,82 +271,73 @@ pub fn Group(comptime Handler: type, comptime Action: type) type { }; } - pub fn get(self: *Self, path: []const u8, action: Action, override: C) void { + pub fn get(self: *Self, path: []const u8, action: Action, override: RC) void { self._router.get(self.createPath(path), action, self.mergeConfig(override)); } - pub fn tryGet(self: *Self, path: []const u8, action: Action, override: C) !void { - return self._router.tryGet(self.tryCreatePath(path), action, self.mergeConfig(override)); + pub fn tryGet(self: *Self, path: []const u8, action: Action, override: RC) !void { + return self._router.tryGet(self.tryCreatePath(path), action, self.tryMergeConfig(override)); } - pub fn put(self: *Self, path: []const u8, action: Action, override: C) void { + pub fn put(self: *Self, path: []const u8, action: Action, override: RC) void { self._router.put(self.createPath(path), action, self.mergeConfig(override)); } - pub fn tryPut(self: *Self, path: []const u8, action: Action, override: C) !void { - return self._router.tryPut(self.tryCreatePath(path), action, self.mergeConfig(override)); + pub fn tryPut(self: *Self, path: []const u8, action: Action, override: RC) !void { + return self._router.tryPut(self.tryCreatePath(path), action, self.tryMergeConfig(override)); } - pub fn post(self: *Self, path: []const u8, action: Action, override: C) void { + pub fn post(self: *Self, path: []const u8, action: Action, override: RC) void { self._router.post(self.createPath(path), action, self.mergeConfig(override)); } - pub fn tryPost(self: *Self, path: []const u8, action: Action, override: C) !void { - return self._router.tryPost(self.tryCreatePath(path), action, self.mergeConfig(override)); + pub fn tryPost(self: *Self, path: []const u8, action: Action, override: RC) !void { + return self._router.tryPost(self.tryCreatePath(path), action, self.tryMergeConfig(override)); } - pub fn head(self: *Self, path: []const u8, action: Action, override: C) void { + pub fn head(self: *Self, path: []const u8, action: Action, override: RC) void { self._router.head(self.createPath(path), action, self.mergeConfig(override)); } - pub fn tryHead(self: *Self, path: []const u8, action: Action, override: C) !void { - return self._router.tryHead(self.tryCreatePath(path), action, self.mergeConfig(override)); + pub fn tryHead(self: *Self, path: []const u8, action: Action, override: RC) !void { + return self._router.tryHead(self.tryCreatePath(path), action, self.tryMergeConfig(override)); } - pub fn patch(self: *Self, path: []const u8, action: Action, override: C) void { + pub fn patch(self: *Self, path: []const u8, action: Action, override: RC) void { self._router.patch(self.createPath(path), action, self.mergeConfig(override)); } - pub fn tryPatch(self: *Self, path: []const u8, action: Action, override: C) !void { - return self._router.tryPatch(self.tryCreatePath(path), action, self.mergeConfig(override)); + pub fn tryPatch(self: *Self, path: []const u8, action: Action, override: RC) !void { + return self._router.tryPatch(self.tryCreatePath(path), action, self.tryMergeConfig(override)); } - pub fn trace(self: *Self, path: []const u8, action: Action, override: C) void { + pub fn trace(self: *Self, path: []const u8, action: Action, override: RC) void { self._router.trace(self.createPath(path), action, self.mergeConfig(override)); } - pub fn tryTrace(self: *Self, path: []const u8, action: Action, override: C) !void { - return self._router.tryTrace(self.tryCreatePath(path), action, self.mergeConfig(override)); + pub fn tryTrace(self: *Self, path: []const u8, action: Action, override: RC) !void { + return self._router.tryTrace(self.tryCreatePath(path), action, self.tryMergeConfig(override)); } - pub fn delete(self: *Self, path: []const u8, action: Action, override: C) void { + pub fn delete(self: *Self, path: []const u8, action: Action, override: RC) void { self._router.delete(self.createPath(path), action, self.mergeConfig(override)); } - pub fn tryDelete(self: *Self, path: []const u8, action: Action, override: C) !void { - return self._router.tryDelete(self.tryCreatePath(path), action, self.mergeConfig(override)); + pub fn tryDelete(self: *Self, path: []const u8, action: Action, override: RC) !void { + return self._router.tryDelete(self.tryCreatePath(path), action, self.tryMergeConfig(override)); } - pub fn options(self: *Self, path: []const u8, action: Action, override: C) void { + pub fn options(self: *Self, path: []const u8, action: Action, override: RC) void { self._router.options(self.createPath(path), action, self.mergeConfig(override)); } - pub fn tryOptions(self: *Self, path: []const u8, action: Action, override: C) !void { - return self._router.tryOptions(self.tryCreatePath(path), action, self.mergeConfig(override)); + pub fn tryOptions(self: *Self, path: []const u8, action: Action, override: RC) !void { + return self._router.tryOptions(self.tryCreatePath(path), action, self.tryMergeConfig(override)); } - pub fn all(self: *Self, path: []const u8, action: Action, override: C) void { + pub fn all(self: *Self, path: []const u8, action: Action, override: RC) void { self._router.all(self.createPath(path), action, self.mergeConfig(override)); } - pub fn tryAll(self: *Self, path: []const u8, action: Action, override: C) !void { - return self._router.tryAll(self.tryCreatePath(path), action, self.mergeConfig(override)); + pub fn tryAll(self: *Self, path: []const u8, action: Action, override: RC) !void { + return self._router.tryAll(self.tryCreatePath(path), action, self.tryMergeConfig(override)); } fn createPath(self: *Self, path: []const u8) []const u8 { return self.tryCreatePath(path) catch unreachable; } - fn mergeConfig(self: *const Self, override: C) C { - return .{ - .data = override.data orelse self._config.data, - .handler = override.handler orelse self._config.handler, - .dispatcher = override.dispatcher orelse self._config.dispatcher, - .middlewares = override.middlewares orelse self._config.middlewares, - }; - } - fn tryCreatePath(self: *Self, path: []const u8) ![]const u8 { var prefix = self._prefix; if (prefix.len == 0) { @@ -348,6 +359,20 @@ pub fn Group(comptime Handler: type, comptime Action: type) type { @memcpy(joined[prefix.len..], path); return joined; } + + fn mergeConfig(self: *Self, override: RC) RC { + return self.tryMergeConfig(override) catch unreachable; + } + + fn tryMergeConfig(self: *Self, override: RC) !RC { + return .{ + .data = override.data orelse self._config.data, + .handler = override.handler orelse self._config.handler, + .dispatcher = override.dispatcher orelse self._config.dispatcher, + .middlewares = try self._router.mergeMiddleware(self._config.middlewares orelse &.{}, override), + .middleware_strategy = override.middleware_strategy orelse self._config.middleware_strategy, + }; + } }; } @@ -624,6 +649,150 @@ test "route: glob" { } } +test "route: middlewares no global" { + defer t.reset(); + + const m1 = fakeMiddleware(&.{.id = 1}); + const m2 = fakeMiddleware(&.{.id = 2}); + const m3 = fakeMiddleware(&.{.id = 3}); + + var params = try Params.init(t.arena.allocator(), 5); + var router = Router(void, httpz.Action(void)).init(t.arena.allocator(), testDispatcher1, {}) catch unreachable; + + { + router.get("/1", testRoute1, .{}); + router.get("/2", testRoute1, .{.middlewares = &.{m1}}); + router.get("/3", testRoute1, .{.middlewares = &.{m1}, .middleware_strategy = .replace}); + router.get("/4", testRoute1, .{.middlewares = &.{m1, m2}}); + router.get("/5", testRoute1, .{.middlewares = &.{m1, m2}, .middleware_strategy = .replace}); + + try assertMiddlewares(&router, ¶ms, "/1", &.{}); + try assertMiddlewares(&router, ¶ms, "/2", &.{1}); + try assertMiddlewares(&router, ¶ms, "/3", &.{1}); + try assertMiddlewares(&router, ¶ms, "/4", &.{1, 2}); + try assertMiddlewares(&router, ¶ms, "/5", &.{1, 2}); + } + + { + // group with no group-level middleware + var group = router.group("/g1", .{}); + group.get("/1", testRoute1, .{}); + group.get("/2", testRoute1, .{.middlewares = &.{m1}}); + group.get("/3", testRoute1, .{.middlewares = &.{m1}, .middleware_strategy = .append}); + group.get("/4", testRoute1, .{.middlewares = &.{m1, m2}, .middleware_strategy = .replace}); + + try assertMiddlewares(&router, ¶ms, "/g1/1", &.{}); + try assertMiddlewares(&router, ¶ms, "/g1/2", &.{1}); + try assertMiddlewares(&router, ¶ms, "/g1/3", &.{1}); + try assertMiddlewares(&router, ¶ms, "/g1/4", &.{1, 2}); + } + + { + // group with group-level middleware + var group = router.group("/g2", .{.middlewares = &.{m1}}); + group.get("/1", testRoute1, .{}); + group.get("/2", testRoute1, .{.middlewares = &.{m2}}); + group.get("/3", testRoute1, .{.middlewares = &.{m2, m3}, .middleware_strategy = .append}); + group.get("/4", testRoute1, .{.middlewares = &.{m2, m3}, .middleware_strategy = .replace}); + + try assertMiddlewares(&router, ¶ms, "/g2/1", &.{1}); + try assertMiddlewares(&router, ¶ms, "/g2/2", &.{1, 2}); + try assertMiddlewares(&router, ¶ms, "/g2/3", &.{1, 2, 3}); + try assertMiddlewares(&router, ¶ms, "/g2/4", &.{2, 3}); + } +} + +test "route: middlewares with global" { + defer t.reset(); + + const m1 = fakeMiddleware(&.{.id = 1}); + const m2 = fakeMiddleware(&.{.id = 2}); + const m3 = fakeMiddleware(&.{.id = 3}); + const m4 = fakeMiddleware(&.{.id = 4}); + const m5 = fakeMiddleware(&.{.id = 5}); + + var params = try Params.init(t.arena.allocator(), 5); + var router = Router(void, httpz.Action(void)).init(t.arena.allocator(), testDispatcher1, {}) catch unreachable; + router.middlewares = &.{m4, m5}; + + { + router.get("/1", testRoute1, .{}); + router.get("/2", testRoute1, .{.middlewares = &.{m1}}); + router.get("/3", testRoute1, .{.middlewares = &.{m1}, .middleware_strategy = .replace}); + router.get("/4", testRoute1, .{.middlewares = &.{m1, m2}}); + router.get("/5", testRoute1, .{.middlewares = &.{m1, m2}, .middleware_strategy = .replace}); + + try assertMiddlewares(&router, ¶ms, "/1", &.{4, 5}); + try assertMiddlewares(&router, ¶ms, "/2", &.{4, 5, 1}); + try assertMiddlewares(&router, ¶ms, "/3", &.{1}); + try assertMiddlewares(&router, ¶ms, "/4", &.{4, 5, 1, 2}); + try assertMiddlewares(&router, ¶ms, "/5", &.{1, 2}); + } + + { + // group with no group-level middleware + var group = router.group("/g1", .{}); + group.get("/1", testRoute1, .{}); + group.get("/2", testRoute1, .{.middlewares = &.{m1}}); + group.get("/3", testRoute1, .{.middlewares = &.{m1, m2}, .middleware_strategy = .append}); + group.get("/4", testRoute1, .{.middlewares = &.{m1, m2}, .middleware_strategy = .replace}); + + try assertMiddlewares(&router, ¶ms, "/g1/1", &.{4, 5}); + try assertMiddlewares(&router, ¶ms, "/g1/2", &.{4, 5, 1}); + try assertMiddlewares(&router, ¶ms, "/g1/3", &.{4, 5, 1, 2}); + try assertMiddlewares(&router, ¶ms, "/g1/4", &.{1, 2}); + } + + { + // group with appended group-level middleware + var group = router.group("/g2", .{.middlewares = &.{m1}, .middleware_strategy = .append}); + group.get("/1", testRoute1, .{}); + group.get("/2", testRoute1, .{.middlewares = &.{m2}}); + group.get("/3", testRoute1, .{.middlewares = &.{m2, m3}, .middleware_strategy = .append}); + group.get("/4", testRoute1, .{.middlewares = &.{m2, m3}, .middleware_strategy = .replace}); + + try assertMiddlewares(&router, ¶ms, "/g2/1", &.{4, 5, 1}); + try assertMiddlewares(&router, ¶ms, "/g2/2", &.{4, 5, 1, 2}); + try assertMiddlewares(&router, ¶ms, "/g2/3", &.{4, 5, 1, 2, 3}); + try assertMiddlewares(&router, ¶ms, "/g2/4", &.{2, 3}); + } + + { + // group with replace group-level middleware + var group = router.group("/g2", .{.middlewares = &.{m1}, .middleware_strategy = .replace}); + group.get("/1", testRoute1, .{}); + group.get("/2", testRoute1, .{.middlewares = &.{m2}}); + group.get("/3", testRoute1, .{.middlewares = &.{m2, m3}, .middleware_strategy = .append}); + group.get("/4", testRoute1, .{.middlewares = &.{m2, m3}, .middleware_strategy = .replace}); + + try assertMiddlewares(&router, ¶ms, "/g2/1", &.{1}); + try assertMiddlewares(&router, ¶ms, "/g2/2", &.{1, 2}); + try assertMiddlewares(&router, ¶ms, "/g2/3", &.{4, 5, 1, 2, 3}); + try assertMiddlewares(&router, ¶ms, "/g2/4", &.{2, 3}); + } +} + +fn assertMiddlewares(router: anytype, params: *Params, path: []const u8, expected: []const u32) !void { + const middlewares = router.route(httpz.Method.GET, path, params).?.middlewares; + try t.expectEqual(expected.len, middlewares.len); + for (expected, middlewares) |e, m| { + const impl: *const FakeMiddlewareImpl = @ptrCast(@alignCast(m.ptr)); + try t.expectEqual(e, impl.id); + } +} + +fn fakeMiddleware(impl: *const FakeMiddlewareImpl) httpz.Middleware(void) { + return .{ + .ptr = @constCast(impl), + .deinitFn = undefined, + .executeFn = undefined, + }; +} + +const FakeMiddlewareImpl = struct { + id: u32, +}; + // TODO: this functionality isn't implemented because I can't think of a way // to do it which isn't relatively expensive (e.g. recursively or keeping a // stack of (a) parts and (b) url segments and trying to rematch every possible