From 0dd3ffc96e1dd27b7290c204a0f7931174a3114d Mon Sep 17 00:00:00 2001 From: Edgar Asatryan Date: Sun, 4 Sep 2022 01:08:34 +0400 Subject: [PATCH] feat: Add possibility to add default headers to request. --- .../nstdio/http/ext/CachingInterceptor.java | 2 +- .../java/io/github/nstdio/http/ext/Chain.java | 9 ++ .../http/ext/CompressionInterceptor.java | 9 +- .../nstdio/http/ext/ExtendedHttpClient.java | 97 +++++++++++++---- .../http/ext/HeadersAddingInterceptor.java | 64 +++++++++++ .../nstdio/http/ext/HttpHeadersBuilder.java | 10 +- .../github/nstdio/http/ext/HttpRequests.java | 41 +++++++ .../ext/ExtendedHttpClientIntegrationTest.kt | 77 +++++++++++++- .../http/ext/HeadersAddingInterceptorTest.kt | 100 ++++++++++++++++++ 9 files changed, 374 insertions(+), 35 deletions(-) create mode 100644 src/main/java/io/github/nstdio/http/ext/HeadersAddingInterceptor.java create mode 100644 src/main/java/io/github/nstdio/http/ext/HttpRequests.java create mode 100644 src/test/kotlin/io/github/nstdio/http/ext/HeadersAddingInterceptorTest.kt diff --git a/src/main/java/io/github/nstdio/http/ext/CachingInterceptor.java b/src/main/java/io/github/nstdio/http/ext/CachingInterceptor.java index 601c04c..a9301dc 100644 --- a/src/main/java/io/github/nstdio/http/ext/CachingInterceptor.java +++ b/src/main/java/io/github/nstdio/http/ext/CachingInterceptor.java @@ -29,9 +29,9 @@ import java.util.function.Consumer; import java.util.stream.Stream; -import static io.github.nstdio.http.ext.ExtendedHttpClient.toBuilder; import static io.github.nstdio.http.ext.Headers.HEADER_IF_MODIFIED_SINCE; import static io.github.nstdio.http.ext.Headers.HEADER_IF_NONE_MATCH; +import static io.github.nstdio.http.ext.HttpRequests.toBuilder; import static io.github.nstdio.http.ext.Responses.gatewayTimeoutResponse; import static io.github.nstdio.http.ext.Responses.isSafeRequest; import static io.github.nstdio.http.ext.Responses.isSuccessful; diff --git a/src/main/java/io/github/nstdio/http/ext/Chain.java b/src/main/java/io/github/nstdio/http/ext/Chain.java index cac93a9..96d6749 100644 --- a/src/main/java/io/github/nstdio/http/ext/Chain.java +++ b/src/main/java/io/github/nstdio/http/ext/Chain.java @@ -16,6 +16,7 @@ package io.github.nstdio.http.ext; +import java.net.http.HttpRequest; import java.net.http.HttpResponse; import java.util.Optional; @@ -52,6 +53,10 @@ Chain withResponse(HttpResponse response) { return of(ctx, futureHandler, Optional.of(response)); } + Chain withRequest(HttpRequest request) { + return of(ctx.withRequest(request), futureHandler, response); + } + RequestContext ctx() { return this.ctx; } @@ -63,4 +68,8 @@ FutureHandler futureHandler() { Optional> response() { return this.response; } + + HttpRequest request() { + return ctx.request(); + } } diff --git a/src/main/java/io/github/nstdio/http/ext/CompressionInterceptor.java b/src/main/java/io/github/nstdio/http/ext/CompressionInterceptor.java index 4ffbc6e..daec0e7 100644 --- a/src/main/java/io/github/nstdio/http/ext/CompressionInterceptor.java +++ b/src/main/java/io/github/nstdio/http/ext/CompressionInterceptor.java @@ -22,9 +22,9 @@ import java.util.List; import java.util.Optional; -import static io.github.nstdio.http.ext.ExtendedHttpClient.toBuilder; import static io.github.nstdio.http.ext.Headers.HEADER_CONTENT_ENCODING; import static io.github.nstdio.http.ext.Headers.HEADER_CONTENT_LENGTH; +import static io.github.nstdio.http.ext.HttpRequests.toBuilder; import static java.util.function.Predicate.not; import static java.util.stream.Collectors.joining; @@ -64,10 +64,9 @@ private HttpRequest preProcessRequest(HttpRequest request) { return request; } - HttpRequest.Builder builder = toBuilder(request); - builder.setHeader("Accept-Encoding", supported); - - return builder.build(); + return toBuilder(request) + .setHeader("Accept-Encoding", supported) + .build(); } private DecompressingBodyHandler decompressingHandler(HttpResponse.BodyHandler bodyHandler) { diff --git a/src/main/java/io/github/nstdio/http/ext/ExtendedHttpClient.java b/src/main/java/io/github/nstdio/http/ext/ExtendedHttpClient.java index b22033a..5954ae9 100644 --- a/src/main/java/io/github/nstdio/http/ext/ExtendedHttpClient.java +++ b/src/main/java/io/github/nstdio/http/ext/ExtendedHttpClient.java @@ -25,37 +25,50 @@ import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpRequest; -import java.net.http.HttpRequest.BodyPublishers; import java.net.http.HttpResponse; import java.net.http.HttpResponse.BodyHandler; import java.net.http.HttpResponse.PushPromiseHandler; import java.net.http.WebSocket; import java.time.Clock; import java.time.Duration; +import java.util.HashMap; +import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; import java.util.concurrent.Executor; import java.util.function.Function; +import java.util.function.Supplier; import static java.util.concurrent.CompletableFuture.completedFuture; public class ExtendedHttpClient extends HttpClient { private final CompressionInterceptor compressionInterceptor; private final CachingInterceptor cachingInterceptor; + private final HeadersAddingInterceptor headersAddingInterceptor; private final HttpClient delegate; private final boolean allowInsecure; ExtendedHttpClient(HttpClient delegate, Cache cache, Clock clock) { - this(delegate, cache, false, true, clock); + this( + null, + cache instanceof NullCache ? null : new CachingInterceptor(cache, clock), + null, + delegate, + true + ); } - ExtendedHttpClient(HttpClient delegate, Cache cache, boolean transparentEncoding, boolean allowInsecure, Clock clock) { + private ExtendedHttpClient(CompressionInterceptor compressionInterceptor, + CachingInterceptor cachingInterceptor, + HeadersAddingInterceptor headersAddingInterceptor, + HttpClient delegate, boolean allowInsecure) { + this.compressionInterceptor = compressionInterceptor; + this.cachingInterceptor = cachingInterceptor; + this.headersAddingInterceptor = headersAddingInterceptor; this.delegate = delegate; - this.cachingInterceptor = cache instanceof NullCache ? null : new CachingInterceptor(cache, clock); - this.compressionInterceptor = transparentEncoding ? new CompressionInterceptor() : null; this.allowInsecure = allowInsecure; } @@ -80,20 +93,6 @@ public static ExtendedHttpClient newHttpClient() { .build(); } - static HttpRequest.Builder toBuilder(HttpRequest r) { - var builder = HttpRequest.newBuilder(); - builder - .uri(r.uri()) - .method(r.method(), r.bodyPublisher().orElseGet(BodyPublishers::noBody)) - .expectContinue(r.expectContinue()); - - r.version().ifPresent(builder::version); - r.timeout().ifPresent(builder::timeout); - r.headers().map().forEach((name, values) -> values.forEach(value -> builder.header(name, value))); - - return builder; - } - // @Override public Optional cookieHandler() { @@ -186,12 +185,17 @@ private void checkInsecureScheme(HttpRequest request) { private Chain buildAndExecute(RequestContext ctx) { Chain chain = Chain.of(ctx); - chain = compressionInterceptor != null ? compressionInterceptor.intercept(chain) : chain; - chain = cachingInterceptor != null ? cachingInterceptor.intercept(chain) : chain; + chain = possiblyApply(compressionInterceptor, chain); + chain = possiblyApply(cachingInterceptor, chain); + chain = possiblyApply(headersAddingInterceptor, chain); return chain; } + private Chain possiblyApply(Interceptor i, Chain c) { + return i != null ? i.intercept(c) : c; + } + /** * The {@code future} DOES NOT represent ongoing computation it's always either completed or failed. */ @@ -239,6 +243,8 @@ public static class Builder implements HttpClient.Builder { private boolean transparentEncoding; private boolean allowInsecure = true; private Cache cache = Cache.noop(); + private Map headers = Map.of(); + private Map> resolvableHeaders = Map.of(); Builder(HttpClient.Builder delegate) { this.delegate = delegate; @@ -376,11 +382,58 @@ public Builder allowInsecure(boolean allowInsecure) { return this; } + /** + * Provided header will be included on each request. + * + * @param name The header name. + * @param value The header value. + * + * @return builder itself. + */ + public Builder defaultHeader(String name, String value) { + Objects.requireNonNull(name); + Objects.requireNonNull(value); + + if (headers.isEmpty()) { + headers = new HashMap<>(1); + } + headers.put(name, value); + + return this; + } + + /** + * Provided header will be included on each request. Note that {@code valueSupplier} will be resolved before each + * request. + * + * @param name The header name. + * @param valueSupplier The header value supplier. + * + * @return builder itself. + */ + public Builder defaultHeader(String name, Supplier valueSupplier) { + Objects.requireNonNull(name); + Objects.requireNonNull(valueSupplier); + + if (resolvableHeaders.isEmpty()) { + resolvableHeaders = new HashMap<>(1); + } + resolvableHeaders.put(name, valueSupplier); + + return this; + } + @Override public ExtendedHttpClient build() { HttpClient client = delegate.build(); - return new ExtendedHttpClient(client, cache, transparentEncoding, allowInsecure, Clock.systemUTC()); + return new ExtendedHttpClient( + transparentEncoding ? new CompressionInterceptor() : null, + cache instanceof NullCache ? null : new CachingInterceptor(cache, Clock.systemUTC()), + new HeadersAddingInterceptor(Map.copyOf(headers), Map.copyOf(resolvableHeaders)), + client, + allowInsecure + ); } } } diff --git a/src/main/java/io/github/nstdio/http/ext/HeadersAddingInterceptor.java b/src/main/java/io/github/nstdio/http/ext/HeadersAddingInterceptor.java new file mode 100644 index 0000000..9cd84d5 --- /dev/null +++ b/src/main/java/io/github/nstdio/http/ext/HeadersAddingInterceptor.java @@ -0,0 +1,64 @@ +/* + * Copyright (C) 2022 Edgar Asatryan + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.nstdio.http.ext; + +import java.net.http.HttpHeaders; +import java.net.http.HttpRequest; +import java.util.List; +import java.util.Map; +import java.util.function.Supplier; + +class HeadersAddingInterceptor implements Interceptor { + private final Map headers; + private final Map> resolvableHeaders; + + HeadersAddingInterceptor(Map headers, Map> resolvableHeaders) { + this.headers = headers; + this.resolvableHeaders = resolvableHeaders; + } + + @Override + public Chain intercept(Chain in) { + if (in.response().isPresent() || !hasHeaders()) { + return in; + } + + return in.withRequest(apply(in.request())); + } + + private HttpRequest apply(HttpRequest request) { + var headers = addHeaders(request.headers()); + + var builder = HttpRequests.toBuilderOmitHeaders(request); + headers.forEach((name, values) -> values.forEach(v -> builder.header(name, v))); + + return builder.build(); + } + + private Map> addHeaders(HttpHeaders h) { + var headersBuilder = new HttpHeadersBuilder(h); + + headers.forEach(headersBuilder::add); + resolvableHeaders.forEach((name, valueSupplier) -> headersBuilder.add(name, valueSupplier.get())); + + return headersBuilder.map(); + } + + private boolean hasHeaders() { + return !headers.isEmpty() || !resolvableHeaders.isEmpty(); + } +} diff --git a/src/main/java/io/github/nstdio/http/ext/HttpHeadersBuilder.java b/src/main/java/io/github/nstdio/http/ext/HttpHeadersBuilder.java index b80f213..c609297 100644 --- a/src/main/java/io/github/nstdio/http/ext/HttpHeadersBuilder.java +++ b/src/main/java/io/github/nstdio/http/ext/HttpHeadersBuilder.java @@ -17,13 +17,15 @@ import java.net.http.HttpHeaders; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.TreeMap; import java.util.function.BiPredicate; +import static io.github.nstdio.http.ext.Headers.ALLOW_ALL; + class HttpHeadersBuilder { - private static final BiPredicate ALWAYS_ALLOW = (s, s2) -> true; private final TreeMap> headersMap; HttpHeadersBuilder() { @@ -92,8 +94,12 @@ HttpHeadersBuilder remove(String name) { return this; } + Map> map() { + return Collections.unmodifiableMap(headersMap); + } + HttpHeaders build() { - return build(ALWAYS_ALLOW); + return build(ALLOW_ALL); } HttpHeaders build(BiPredicate filter) { diff --git a/src/main/java/io/github/nstdio/http/ext/HttpRequests.java b/src/main/java/io/github/nstdio/http/ext/HttpRequests.java new file mode 100644 index 0000000..ad8a744 --- /dev/null +++ b/src/main/java/io/github/nstdio/http/ext/HttpRequests.java @@ -0,0 +1,41 @@ +/* + * Copyright (C) 2022 Edgar Asatryan + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.nstdio.http.ext; + +import java.net.http.HttpRequest; + +class HttpRequests { + static HttpRequest.Builder toBuilderOmitHeaders(HttpRequest r) { + var builder = HttpRequest.newBuilder(); + builder + .uri(r.uri()) + .method(r.method(), r.bodyPublisher().orElseGet(HttpRequest.BodyPublishers::noBody)) + .expectContinue(r.expectContinue()); + + r.version().ifPresent(builder::version); + r.timeout().ifPresent(builder::timeout); + + return builder; + } + + static HttpRequest.Builder toBuilder(HttpRequest r) { + var builder = toBuilderOmitHeaders(r); + r.headers().map().forEach((name, values) -> values.forEach(value -> builder.header(name, value))); + + return builder; + } +} diff --git a/src/test/kotlin/io/github/nstdio/http/ext/ExtendedHttpClientIntegrationTest.kt b/src/test/kotlin/io/github/nstdio/http/ext/ExtendedHttpClientIntegrationTest.kt index 1ce019a..1318fbf 100644 --- a/src/test/kotlin/io/github/nstdio/http/ext/ExtendedHttpClientIntegrationTest.kt +++ b/src/test/kotlin/io/github/nstdio/http/ext/ExtendedHttpClientIntegrationTest.kt @@ -19,6 +19,11 @@ import io.github.nstdio.http.ext.Assertions.assertThat import io.github.nstdio.http.ext.Assertions.awaitFor import io.github.nstdio.http.ext.Compression.deflate import io.github.nstdio.http.ext.Compression.gzip +import io.kotest.matchers.collections.shouldContainExactly +import io.kotest.matchers.maps.shouldContain +import io.kotest.matchers.nulls.shouldNotBeNull +import io.kotest.matchers.should +import io.kotest.matchers.shouldBe import io.kotest.property.Arb import io.kotest.property.arbitrary.next import io.kotest.property.arbitrary.string @@ -28,7 +33,10 @@ import org.junit.jupiter.api.Nested import org.junit.jupiter.api.Test import java.net.http.HttpClient import java.net.http.HttpRequest -import java.net.http.HttpResponse +import java.net.http.HttpResponse.BodyHandlers +import java.net.http.HttpResponse.BodyHandlers.discarding +import java.util.* +import java.util.concurrent.LinkedBlockingDeque @MockWebServerTest internal class ExtendedHttpClientIntegrationTest(private val mockWebServer: MockWebServer) { @@ -62,7 +70,7 @@ internal class ExtendedHttpClientIntegrationTest(private val mockWebServer: Mock .build() //when + then - val r1 = client.send(request1, HttpResponse.BodyHandlers.ofString()) + val r1 = client.send(request1, BodyHandlers.ofString()) assertThat(r1) .isNetwork .hasStatusCode(200) @@ -70,7 +78,7 @@ internal class ExtendedHttpClientIntegrationTest(private val mockWebServer: Mock .hasBody(expectedBody) .hasNoHeader(Headers.HEADER_CONTENT_ENCODING) awaitFor { - val r2 = client.send(request2, HttpResponse.BodyHandlers.ofString()) + val r2 = client.send(request2, BodyHandlers.ofString()) assertThat(r2) .isCached .hasStatusCode(200) @@ -103,8 +111,8 @@ internal class ExtendedHttpClientIntegrationTest(private val mockWebServer: Mock val request = HttpRequest.newBuilder(testUri).build() //when - val r1 = client.send(request, HttpResponse.BodyHandlers.ofString()) - val r2 = client.sendAsync(request, HttpResponse.BodyHandlers.ofString()).join() + val r1 = client.send(request, BodyHandlers.ofString()) + val r2 = client.sendAsync(request, BodyHandlers.ofString()).join() //then assertThat(r1) @@ -119,4 +127,63 @@ internal class ExtendedHttpClientIntegrationTest(private val mockWebServer: Mock .hasNoHeader(Headers.HEADER_CONTENT_ENCODING) } } + + @Nested + internal inner class DefaultHeadersTest { + @Test + fun `Should add default headers`() { + //given + val client: HttpClient = ExtendedHttpClient.newBuilder() + .defaultHeader("X-Testing-Value", "1") + .defaultHeader("X-Testing-Supplier") { "2" } + .build() + + val request = HttpRequest.newBuilder(mockWebServer.url("/test").toUri()).build() + mockWebServer.enqueue(MockResponse().setResponseCode(200)) + + //when + val response = client.send(request, discarding()) + + //then + response.request().headers().map().should { + it.shouldContain("X-Testing-Value", listOf("1")) + it.shouldContain("X-Testing-Supplier", listOf("2")) + } + + val actualRequest = mockWebServer.takeRequest() + actualRequest.headers.should { + it["X-Testing-Value"] + .shouldNotBeNull() + .shouldBe("1") + + it["X-Testing-Supplier"] + .shouldNotBeNull() + .shouldBe("2") + } + } + + @Test + fun `Should resolve supplier by each call`() { + //given + val requestIds = (0..3).map { UUID.randomUUID().toString() } + val queue = LinkedBlockingDeque(requestIds) + + val headerName = "X-Testing-Supplier-Resolved" + val client: HttpClient = ExtendedHttpClient.newBuilder() + .defaultHeader(headerName) { queue.pop() } + .build() + + val request = HttpRequest.newBuilder(mockWebServer.url("/test").toUri()).build() + mockWebServer.enqueue(MockResponse().setResponseCode(200), requestIds.size) + + //when + val headerValues = requestIds + .map { client.send(request, discarding()) } + .map { mockWebServer.takeRequest() } + .map { it.headers[headerName] } + + //then + headerValues.shouldContainExactly(requestIds) + } + } } \ No newline at end of file diff --git a/src/test/kotlin/io/github/nstdio/http/ext/HeadersAddingInterceptorTest.kt b/src/test/kotlin/io/github/nstdio/http/ext/HeadersAddingInterceptorTest.kt new file mode 100644 index 0000000..1e46d07 --- /dev/null +++ b/src/test/kotlin/io/github/nstdio/http/ext/HeadersAddingInterceptorTest.kt @@ -0,0 +1,100 @@ +/* + * Copyright (C) 2022 Edgar Asatryan + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.nstdio.http.ext + +import io.kotest.matchers.maps.shouldContain +import io.kotest.matchers.should +import io.kotest.matchers.types.shouldBeSameInstanceAs +import org.junit.jupiter.api.Test +import java.net.http.HttpRequest +import java.net.http.HttpResponse.BodyHandlers +import java.util.* +import java.util.function.Supplier + +class HeadersAddingInterceptorTest { + private val bodyHandler = BodyHandlers.discarding() + + @Test + fun `Should add headers to the chain`() { + //given + val interceptor = HeadersAddingInterceptor( + mapOf("a" to "4", "d" to "6"), + mapOf("b" to Supplier { "5" }, "e" to Supplier { "7" }), + ) + + val request = HttpRequest.newBuilder("https://example.com".toUri()) + .header("a", "1") + .header("b", "2") + .header("c", "3") + .build() + val chain = Chain.of(RequestContext.of(request, bodyHandler)) + + //when + val newChain = interceptor.intercept(chain) + + //then + newChain.request().headers().map().should { + it.shouldContain("a", listOf("1", "4")) + it.shouldContain("b", listOf("2", "5")) + it.shouldContain("c", listOf("3")) + it.shouldContain("d", listOf("6")) + it.shouldContain("e", listOf("7")) + } + } + + @Test + fun `Should not add headers to the chain when response is present`() { + //given + val interceptor = HeadersAddingInterceptor( + mapOf("a" to "1"), + mapOf("b" to Supplier { "2" }), + ) + + val request = HttpRequest.newBuilder("https://example.com".toUri()) + .header("c", "3") + .build() + val response = StaticHttpResponse.builder() + .request(request) + .build() + val chain = Chain.of(RequestContext.of(request, bodyHandler), { t, _ -> t }, Optional.of(response)) + + //when + val actual = interceptor.intercept(chain) + + //then + actual.shouldBeSameInstanceAs(chain) + actual.request().shouldBeSameInstanceAs(request) + } + + @Test + fun `Should alter chain when no headers to add`() { + //given + val interceptor = HeadersAddingInterceptor(mapOf(), mapOf()) + + val request = HttpRequest.newBuilder("https://example.com".toUri()) + .header("a", "1") + .build() + val chain = Chain.of(RequestContext.of(request, bodyHandler)) + + //when + val actual = interceptor.intercept(chain) + + //then + actual.shouldBeSameInstanceAs(chain) + actual.request().shouldBeSameInstanceAs(request) + } +} \ No newline at end of file