Skip to content

Commit b94be9c

Browse files
authored
Make newly added Routes take precedence over old ones (#3066) (#3337)
* Make newly added Routes take precedence over old ones (#3066) * Ensure uniqueness of routes when creating look up tree
1 parent 487b1b6 commit b94be9c

File tree

8 files changed

+256
-111
lines changed

8 files changed

+256
-111
lines changed

zio-http-testkit/src/main/scala/zio/http/TestClient.scala

+2-1
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ final case class TestClient(
7272
r <- ZIO.environment[R]
7373
provided = route.provideEnvironment(r)
7474
_ <- behavior.update(_ :+ provided)
75+
_ <- behavior.get.debug("Added route")
7576
} yield ()
7677

7778
/**
@@ -121,7 +122,7 @@ final case class TestClient(
121122
proxy: Option[Proxy],
122123
)(implicit trace: Trace): ZIO[Scope, Throwable, Response] = {
123124
for {
124-
currentBehavior <- behavior.get.map(_ :+ Method.ANY / trailing -> handler(Response.notFound))
125+
currentBehavior <- behavior.get
125126
request = Request(
126127
body = body,
127128
headers = headers,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
package zio.http
2+
3+
import zio._
4+
import zio.test.TestAspect.shrinks
5+
import zio.test._
6+
7+
import zio.http.endpoint.{AuthType, Endpoint}
8+
import zio.http.netty.NettyConfig
9+
import zio.http.netty.server.NettyDriver
10+
11+
object RoutesPrecedentsSpec extends ZIOSpecDefault {
12+
13+
trait MyService {
14+
def code: UIO[Int]
15+
}
16+
object MyService {
17+
def live(code: Int): ULayer[MyService] = ZLayer.succeed(new MyServiceLive(code))
18+
}
19+
final class MyServiceLive(_code: Int) extends MyService {
20+
def code: UIO[Int] = ZIO.succeed(_code)
21+
}
22+
23+
val endpoint: Endpoint[Unit, String, ZNothing, Int, AuthType.None] =
24+
Endpoint(RoutePattern.POST / "api").in[String].out[Int]
25+
26+
val api = endpoint.implement(_ => ZIO.serviceWithZIO[MyService](_.code))
27+
28+
// when adding the same route multiple times to the server, the last one should take precedence
29+
override def spec: Spec[TestEnvironment & Scope, Any] =
30+
test("test") {
31+
check(Gen.fromIterable(List(1, 2, 3, 4, 5))) { code =>
32+
(
33+
for {
34+
client <- ZIO.service[Client]
35+
port <- ZIO.serviceWithZIO[Server](_.port)
36+
url = URL.root.port(port) / "api"
37+
request = Request
38+
.post(url = url, body = Body.fromString(""""this is some input""""))
39+
.addHeader(Header.Accept(MediaType.application.json))
40+
_ <- TestServer.addRoutes(api.toRoutes)
41+
result <- client.batched(request)
42+
output <- result.body.asString
43+
} yield assertTrue(output == code.toString)
44+
).provideSome[TestServer & Client](
45+
ZLayer.succeed(new MyServiceLive(code)),
46+
)
47+
}.provide(
48+
ZLayer.succeed(Server.Config.default.onAnyOpenPort),
49+
TestServer.layer,
50+
Client.default,
51+
NettyDriver.customized,
52+
ZLayer.succeed(NettyConfig.defaultWithFastShutdown),
53+
)
54+
} @@ shrinks(0)
55+
}

zio-http-testkit/src/test/scala/zio/http/TestClientSpec.scala

+5-5
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@ object TestClientSpec extends ZIOHttpSpec {
2222
_ <- TestClient.addRequestResponse(request2, Response.ok)
2323
goodResponse2 <- client(request)
2424
badResponse2 <- client(request2)
25-
} yield assertTrue(extractStatus(goodResponse) == Status.Ok) && assertTrue(
25+
} yield assertTrue(
26+
extractStatus(goodResponse) == Status.Ok,
2627
extractStatus(badResponse) == Status.NotFound,
27-
) &&
28-
assertTrue(extractStatus(goodResponse2) == Status.Ok) && assertTrue(
29-
extractStatus(badResponse2) == Status.Ok,
30-
)
28+
extractStatus(goodResponse2) == Status.Ok,
29+
extractStatus(badResponse2) == Status.Ok,
30+
)
3131
},
3232
),
3333
suite("addHandler")(

zio-http/jvm/src/test/scala/zio/http/RoutePatternSpec.scala

+27-2
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import java.util.UUID
2121
import zio.Chunk
2222
import zio.test._
2323

24+
import zio.http.codec.Doc
2425
import zio.http.{int => _, uuid => _}
2526

2627
object RoutePatternSpec extends ZIOHttpSpec {
@@ -233,7 +234,6 @@ object RoutePatternSpec extends ZIOHttpSpec {
233234
val pattern2 = Method.GET / "users" / trailing / "123"
234235

235236
tree = tree.add(pattern2, 2)
236-
println(tree.get(Method.GET, Path("/users/bla/123")))
237237
tree = tree.add(pattern1, 1)
238238

239239
assertTrue(tree.get(Method.GET, Path("/users/123")).contains(1))
@@ -249,7 +249,7 @@ object RoutePatternSpec extends ZIOHttpSpec {
249249
tree = tree.add(pattern2, 2)
250250
tree = tree.add(pattern3, 3)
251251

252-
assertTrue(tree.get(Method.OPTIONS, Path("/users")) == Chunk(2, 1, 3))
252+
assertTrue(tree.get(Method.OPTIONS, Path("/users")) == Chunk(2))
253253
},
254254
test("multiple routes") {
255255
var tree: Tree[Unit] = RoutePattern.Tree.empty
@@ -497,11 +497,36 @@ object RoutePatternSpec extends ZIOHttpSpec {
497497
},
498498
)
499499

500+
def structureEquals = suite("structure equals")(
501+
test("equals") {
502+
val routePattern = Method.GET / "users" / int("user-id") / "posts" / string("post-id")
503+
504+
assertTrue(routePattern.structureEquals(routePattern))
505+
},
506+
test("equals with docs") {
507+
val routePattern = Method.GET / "users" / int("user-id") / "posts" / string("post-id")
508+
509+
assertTrue(
510+
routePattern.structureEquals(routePattern ?? Doc.p("docs")),
511+
)
512+
},
513+
test("equals with mapping") {
514+
val routePattern = Method.GET / "users" / int("user-id") / "posts" / string("post-id")
515+
val routePattern1 =
516+
Method.GET / "users" / int("user-id").transform(_.toString())(_.toInt) / "posts" / string("post-id")
517+
518+
assertTrue(
519+
routePattern.structureEquals(routePattern1),
520+
)
521+
},
522+
)
523+
500524
def spec =
501525
suite("RoutePatternSpec")(
502526
decoding,
503527
rendering,
504528
formatting,
505529
tree,
530+
structureEquals,
506531
)
507532
}

zio-http/jvm/src/test/scala/zio/http/RoutesSpec.scala

+26-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ package zio.http
1818

1919
import zio.test._
2020

21-
import zio.http.codec.PathCodec
21+
import zio.http.codec.{PathCodec, SegmentCodec}
2222

2323
object RoutesSpec extends ZIOHttpSpec {
2424
def extractStatus(response: Response): Status = response.status
@@ -112,5 +112,30 @@ object RoutesSpec extends ZIOHttpSpec {
112112
)
113113
}
114114
},
115+
test("overlapping routes with different segment types") {
116+
val app = Routes(
117+
Method.GET / "foo" / string("id") -> Handler.status(Status.NoContent),
118+
Method.GET / "foo" / string("id") -> Handler.ok,
119+
Method.GET / "foo" / (SegmentCodec.literal("prefix") ~ string("rest")) -> Handler.ok,
120+
Method.GET / "foo" / int("id") -> Handler.ok,
121+
)
122+
123+
for {
124+
stringId <- app.runZIO(Request.get("/foo/123"))
125+
stringPrefix <- app.runZIO(Request.get("/foo/prefix123"))
126+
intId <- app.runZIO(Request.get("/foo/123"))
127+
notFound <- app.runZIO(Request.get("/foo/123/456"))
128+
logs <- ZTestLogger.logOutput.map { logs => logs.map(_.message()) }
129+
} yield {
130+
println(logs)
131+
assertTrue(
132+
logs.contains("Duplicate routes detected:\nGET /foo/{id}\nThe last route of each path will be used."),
133+
extractStatus(stringId) == Status.Ok,
134+
extractStatus(stringPrefix) == Status.Ok,
135+
extractStatus(intId) == Status.Ok,
136+
extractStatus(notFound) == Status.NotFound,
137+
)
138+
}
139+
},
115140
)
116141
}

zio-http/jvm/src/test/scala/zio/http/security/UserDataSpec.scala

+77-73
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@ package zio.http.security
33
import zio._
44
import zio.test._
55

6-
import zio.schema._
7-
86
import zio.http._
97
import zio.http.codec._
108
import zio.http.endpoint._
@@ -40,16 +38,16 @@ object UserDataSpec extends ZIOSpecDefault {
4038

4139
val spec = suite("UserDataSpec")(
4240
test("No sanitation and write to server") {
41+
val endpoint = Endpoint(Method.GET / "test")
42+
.query(HttpCodec.query[String]("data"))
43+
.out[String]
44+
val route = endpoint.implementHandler(Handler.fromFunction { (s: String) =>
45+
// writeToServer or other actions
46+
s
47+
})
4348
// this is not a bug but could be a vulnerability used wrong
4449
check(tuples.zip(functions)) { case (_, msg, expectedResponse, _) =>
45-
val endpoint = Endpoint(Method.GET / "test")
46-
.query(HttpCodec.query[String]("data"))
47-
.out[String]
48-
val route = endpoint.implementHandler(Handler.fromFunction { case (s: String) =>
49-
// writeToServer or other actions
50-
s
51-
})
52-
val request =
50+
val request =
5351
Request.get(URL(Path.root / "test", queryParams = QueryParams(("data", msg))))
5452
for {
5553
response <- route.toRoutes.runZIO(request)
@@ -58,13 +56,13 @@ object UserDataSpec extends ZIOSpecDefault {
5856
}
5957
} @@ TestAspect.failing,
6058
test("No sanitation using Dom") {
59+
val endpoint = Endpoint(Method.GET / "test")
60+
.in[Dom]
61+
.out[Dom]
62+
val route = endpoint.implementHandler(Handler.fromFunction(identity))
6163
// this is not a bug but could be a vulnerability used wrong
6264
check(tuples.zip(functions)) { case (_, msg, expectedResponse, _) =>
63-
val endpoint = Endpoint(Method.GET / "test")
64-
.in[Dom]
65-
.out[Dom]
66-
val route = endpoint.implementHandler(Handler.fromFunction(identity))
67-
val request =
65+
val request =
6866
Request.post(URL(Path.root / "test"), Body.fromString(msg))
6967
for {
7068
response <- route.toRoutes.runZIO(request)
@@ -73,28 +71,29 @@ object UserDataSpec extends ZIOSpecDefault {
7371
}
7472
} @@ TestAspect.failing,
7573
test("Header injection") {
76-
check(tuples.zip(functions)) { case (mediaType, msg, expectedResponse, f) =>
77-
val endpoint = Endpoint(Method.GET / "test")
78-
.query(HttpCodec.query[String]("data"))
79-
.out[Dom]
80-
val route = endpoint.implementHandler(Handler.fromFunction { case (s: String) => f(s) })
81-
val request =
82-
Request
83-
.get(URL(Path.root / "test", queryParams = QueryParams(("data", msg))))
84-
.addHeader(Header.Accept(mediaType))
85-
for {
86-
response <- route.toRoutes.runZIO(request)
87-
body <- response.body.asString
88-
} yield assertTrue(body.contains(expectedResponse))
74+
check(tuples.zip(functions).zip(Gen.alphaNumericStringBounded(1, 50))) {
75+
case (mediaType, msg, expectedResponse, f, suffix) =>
76+
val endpoint = Endpoint(Method.GET / "test" / suffix)
77+
.query(HttpCodec.query[String]("data"))
78+
.out[Dom]
79+
val route = endpoint.implementHandler(Handler.fromFunction { (s: String) => f(s) })
80+
val request =
81+
Request
82+
.get(URL(Path.root / "test" / suffix, queryParams = QueryParams(("data", msg))))
83+
.addHeader(Header.Accept(mediaType))
84+
for {
85+
response <- route.toRoutes.runZIO(request)
86+
body <- response.body.asString
87+
} yield assertTrue(body.contains(expectedResponse))
8988
}
9089
},
9190
test("Header injection DOM") {
91+
val endpoint = Endpoint(Method.GET / "test")
92+
.query(HttpCodec.query[Dom]("data"))
93+
.out[Dom]
94+
val route = endpoint.implementHandler(Handler.fromFunction(s => s))
9295
check(tuples.zip(functions)) { case (mediaType, msg, expectedResponse, _) =>
93-
val endpoint = Endpoint(Method.GET / "test")
94-
.query(HttpCodec.query[Dom]("data"))
95-
.out[Dom]
96-
val route = endpoint.implementHandler(Handler.fromFunction(s => s))
97-
val request =
96+
val request =
9897
Request
9998
.get(URL(Path.root / "test", queryParams = QueryParams(("data", msg))))
10099
.addHeader(Header.Accept(mediaType))
@@ -105,52 +104,57 @@ object UserDataSpec extends ZIOSpecDefault {
105104
}
106105
} @@ TestAspect.failing,
107106
test("Path injection") {
108-
check(tuples.zip(functions)) { case (mediaType, msg, expectedResponse, f) =>
109-
val request = Request.get(URL(Path.root / "test" / msg)).addHeader(Header.Accept(mediaType))
110-
val route = Routes(
111-
Endpoint(Method.GET / "test" / string("message"))
112-
.out[Dom]
113-
.implementHandler(Handler.fromFunction { case (s: String) =>
114-
f(s)
115-
}),
116-
)
117-
for {
118-
response <- route.runZIO(request)
119-
body <- response.body.asString
120-
} yield assertTrue(body.contains(expectedResponse))
107+
check(tuples.zip(functions).zip(Gen.alphaNumericStringBounded(1, 50))) {
108+
case (mediaType, msg, expectedResponse, f, suffix) =>
109+
val request = Request.get(URL(Path.root / "test" / suffix / msg)).addHeader(Header.Accept(mediaType))
110+
val route = Routes(
111+
Endpoint(Method.GET / "test" / suffix / string("message"))
112+
.out[Dom]
113+
.implementHandler(Handler.fromFunction { (s: String) =>
114+
f(s)
115+
}),
116+
)
117+
for {
118+
response <- route.runZIO(request)
119+
body <- response.body.asString
120+
} yield assertTrue(body.contains(expectedResponse))
121121
}
122122
},
123123
test("Body injection") {
124-
check(tuples.zip(functions)) { case (mediaType, msg, expectedResponse, f) =>
125-
val body = Body.fromArray(msg.getBytes())
126-
val request = Request.post("/test", body).addHeader(Header.Accept(mediaType))
127-
val route = Routes(Method.POST / "test" -> handler { (req: Request) =>
124+
check(tuples.zip(functions).zip(Gen.alphaNumericStringBounded(1, 50))) {
125+
case (mediaType, msg, expectedResponse, f, suffix) =>
126+
val body = Body.fromArray(msg.getBytes())
127+
val request = Request.post(url"/test/$suffix", body).addHeader(Header.Accept(mediaType))
128+
val route = Routes(Method.POST / "test" / suffix -> handler { (req: Request) =>
129+
for {
130+
msg <- req.body.asString.orDie
131+
} yield Response.text(f(msg).encode)
132+
})
128133
for {
129-
msg <- req.body.asString.orDie
130-
} yield Response.text(f(msg).encode)
131-
})
132-
for {
133-
response <- route.runZIO(request)
134-
body <- response.body.asString
135-
} yield assertTrue(body.contains(expectedResponse))
134+
response <- route.runZIO(request)
135+
body <- response.body.asString
136+
} yield assertTrue(body.contains(expectedResponse))
136137
}
137138
},
138139
test("Error injection") {
139-
check(tuples.zip(functions)) { case (mediaType, msg, expectedResponse, _) =>
140-
val routes = Routes(Method.POST / "test" -> handler { (req: Request) =>
141-
req.body.asString.orDie.map(msg => Response.error(Status.InternalServerError, msg))
142-
})
143-
val body = Body.fromString(msg)
144-
val request = Request.post("/test", body).addHeader(Header.Accept(mediaType))
145-
for {
146-
port <- Server.install(routes)
147-
response <- ZIO.scoped {
148-
Client
149-
.batched(request.updateURL(_ => URL.decode(s"http://localhost:$port/test").toOption.get))
150-
}
151-
body <- response.body.asString
152-
} yield assertTrue(body == expectedResponse)
153-
}
140+
val routes = Routes(Method.POST / "test" -> handler { (req: Request) =>
141+
req.body.asString.orDie.map(msg => Response.error(Status.InternalServerError, msg))
142+
})
143+
for {
144+
port <- Server.install(routes)
145+
result <- check(tuples.zip(functions)) { case (mediaType, msg, expectedResponse, _) =>
146+
147+
val body = Body.fromString(msg)
148+
val request = Request.post("/test", body).addHeader(Header.Accept(mediaType))
149+
for {
150+
response <- ZIO.scoped {
151+
Client
152+
.batched(request.updateURL(_ => URL.decode(s"http://localhost:$port/test").toOption.get))
153+
}
154+
body <- response.body.asString
155+
} yield assertTrue(body == expectedResponse)
156+
}
157+
} yield result
154158
},
155159
).provide(
156160
Scope.default,
@@ -160,6 +164,6 @@ object UserDataSpec extends ZIOSpecDefault {
160164
),
161165
ZLayer.succeed(NettyConfig.defaultWithFastShutdown),
162166
Client.default,
163-
)
167+
) @@ TestAspect.sequential
164168

165169
}

0 commit comments

Comments
 (0)