diff --git a/kong/router.lua b/kong/router.lua index b6332847b16..c214cfe228d 100644 --- a/kong/router.lua +++ b/kong/router.lua @@ -23,6 +23,7 @@ local ipairs = ipairs local pairs = pairs local error = error local type = type +local max = math.max local band = bit.band local bor = bit.bor @@ -71,6 +72,10 @@ local MATCH_RULES = { DST = 0x00000001, } +local MATCH_SUBRULES = { + HAS_REGEX_URI = 0x01, + PLAIN_HOSTS_ONLY = 0x02, +} local EMPTY_T = {} @@ -168,6 +173,8 @@ local function marshall_route(r) preserve_host = route.preserve_host == true, match_rules = 0x00, match_weight = 0, + match_weight2 = 0, + max_uri_length = 0, hosts = {}, uris = {}, methods = {}, @@ -204,6 +211,8 @@ local function marshall_route(r) route_t.match_rules = bor(route_t.match_rules, MATCH_RULES.HOST) route_t.match_weight = route_t.match_weight + 1 + local has_wildcard + for _, host_value in ipairs(host_values) do if find(host_value, "*", nil, true) then -- wildcard host matching @@ -215,6 +224,8 @@ local function marshall_route(r) regex = wildcard_host_regex, }) + has_wildcard = true + else insert(route_t.hosts, { value = host_value, @@ -223,6 +234,11 @@ local function marshall_route(r) route_t.hosts[host_value] = host_value end + + if not has_wildcard then + route_t.match_weight2 = bor(route_t.match_weight2, + MATCH_SUBRULES.PLAIN_HOSTS_ONLY) + end end end @@ -250,6 +266,7 @@ local function marshall_route(r) route_t.uris[path] = uri_t insert(route_t.uris, uri_t) + route_t.max_uri_length = max(route_t.max_uri_length, #path) else -- regex URI @@ -266,6 +283,9 @@ local function marshall_route(r) route_t.uris[path] = uri_t insert(route_t.uris, uri_t) + + route_t.match_weight2 = bor(route_t.match_weight2, + MATCH_SUBRULES.HAS_REGEX_URI) end end end @@ -782,7 +802,9 @@ do end, [MATCH_RULES.URI] = function(category, ctx) - return category.routes_by_uris[ctx.hits.uri or ctx.req_uri] + -- no ctx.req_uri indexing since regex URIs have a higher priority than + -- plain URIs + return category.routes_by_uris[ctx.hits.uri] end, [MATCH_RULES.METHOD] = function(category, ctx) @@ -913,16 +935,55 @@ function _M.new(routes) -- index routes + do + local marshalled_routes = {} - for i = 1, #routes do - local route_t, err = marshall_route(routes[i]) - if not route_t then - return nil, err + for i = 1, #routes do + local route_t, err = marshall_route(routes[i]) + if not route_t then + return nil, err + end + + marshalled_routes[i] = route_t end - categorize_route_t(route_t, route_t.match_rules, categories) - index_route_t(route_t, plain_indexes, prefix_uris, regex_uris, - wildcard_hosts, src_trust_funcs, dst_trust_funcs) + -- sort wildcard hosts and uri regexes since those rules + -- don't have their own matching category + -- + -- * plain hosts > wildcard hosts + -- * regex uris > plain uris + -- * longer plain URIs > shorter plain URIs + + sort(marshalled_routes, function(r1, r2) + if r1.match_weight2 ~= r2.match_weight2 then + return r1.match_weight2 > r2.match_weight2 + end + + do + local rp1 = r1.route.regex_priority or 0 + local rp2 = r2.route.regex_priority or 0 + + if rp1 ~= rp2 then + return rp1 > rp2 + end + end + + if r1.max_uri_length ~= r2.max_uri_length then + return r1.max_uri_length > r2.max_uri_length + end + + if r1.route.created_at ~= nil and r2.route.created_at ~= nil then + return r1.route.created_at < r2.route.created_at + end + end) + + for i = 1, #marshalled_routes do + local route_t = marshalled_routes[i] + + categorize_route_t(route_t, route_t.match_rules, categories) + index_route_t(route_t, plain_indexes, prefix_uris, regex_uris, + wildcard_hosts, src_trust_funcs, dst_trust_funcs) + end end diff --git a/kong/runloop/handler.lua b/kong/runloop/handler.lua index 48376f10ae8..a98642b8df6 100644 --- a/kong/runloop/handler.lua +++ b/kong/runloop/handler.lua @@ -30,7 +30,6 @@ local sub = string.sub local find = string.find local lower = string.lower local fmt = string.format -local sort = table.sort local ngx = ngx local arg = ngx.arg local var = ngx.var @@ -609,19 +608,6 @@ do end end - sort(routes, function(r1, r2) - r1, r2 = r1.route, r2.route - - local rp1 = r1.regex_priority or 0 - local rp2 = r2.regex_priority or 0 - - if rp1 == rp2 then - return r1.created_at < r2.created_at - end - - return rp1 > rp2 - end) - local new_router, err = Router.new(routes) if not new_router then return nil, "could not create router: " .. err diff --git a/spec/01-unit/08-router_spec.lua b/spec/01-unit/08-router_spec.lua index 41eaff0f586..a098b7fb6c2 100644 --- a/spec/01-unit/08-router_spec.lua +++ b/spec/01-unit/08-router_spec.lua @@ -493,6 +493,74 @@ describe("Router", function() assert.same([[/route/persons/\d+/profile]], match_t.matches.uri) assert.same(nil, match_t.matches.uri_captures) end) + + it("matches a [uri regex] even if a [uri] got an exact match", function() + local use_case = { + { + service = service, + route = { + paths = { "/route/fixture" }, + }, + }, + { + service = service, + route = { + paths = { "/route/(fixture)" }, + }, + }, + } + + local router = assert(Router.new(use_case)) + + local match_t = router.select("GET", "/route/fixture", "domain.org") + assert.truthy(match_t) + assert.equal(use_case[2].route, match_t.route) + assert.same(nil, match_t.matches.host) + assert.same(nil, match_t.matches.method) + assert.same("/route/(fixture)", match_t.matches.uri) + end) + + it("matches a [uri regex + host] even if a [prefix uri] got a match", function() + local use_case = { + { + service = service, + route = { + paths = { "/pat" }, + }, + headers = { + host = { "route.com" }, + }, + }, + { + service = service, + route = { + paths = { "/path" }, + methods = { "POST" }, + }, + headers = { + host = { "route.com" }, + }, + }, + { + service = service, + route = { + paths = { "/(path)" }, + }, + headers = { + host = { "route.com" }, + }, + }, + } + + local router = assert(Router.new(use_case)) + + local match_t = router.select("GET", "/path", "route.com") + assert.truthy(match_t) + assert.equal(use_case[3].route, match_t.route) + assert.same("route.com", match_t.matches.host) + assert.same(nil, match_t.matches.method) + assert.same("/(path)", match_t.matches.uri) + end) end) describe("[wildcard host]", function() @@ -533,9 +601,7 @@ describe("Router", function() assert.equal(use_case[2].route, match_t.route) end) - pending("does not take precedence over a plain host", function() - -- Pending: temporarily pending in the current commit, awaiting a fix - -- in a subsequent commit. + it("does not take precedence over a plain host", function() table.insert(use_case, 1, { service = service, route = {