diff --git a/http-client-jdk/src/main/java/io/micronaut/http/client/jdk/AbstractJdkHttpClient.java b/http-client-jdk/src/main/java/io/micronaut/http/client/jdk/AbstractJdkHttpClient.java index e28f6d44aee..d416d7d065b 100644 --- a/http-client-jdk/src/main/java/io/micronaut/http/client/jdk/AbstractJdkHttpClient.java +++ b/http-client-jdk/src/main/java/io/micronaut/http/client/jdk/AbstractJdkHttpClient.java @@ -26,6 +26,7 @@ import io.micronaut.core.propagation.PropagatedContext; import io.micronaut.core.type.Argument; import io.micronaut.core.util.StringUtils; +import io.micronaut.http.HttpAttributes; import io.micronaut.http.HttpResponse; import io.micronaut.http.HttpStatus; import io.micronaut.http.MutableHttpRequest; @@ -411,6 +412,10 @@ protected Publisher> responsePublisher( @NonNull io.micronaut.http.HttpRequest request, @Nullable Argument bodyType ) { + if (clientId != null && request.getAttribute(HttpAttributes.SERVICE_ID).isEmpty()) { + request.setAttribute(HttpAttributes.SERVICE_ID, clientId); + } + return Flux.defer(() -> mapToHttpRequest(request, bodyType)) // defered so any client filter changes are used .map(httpRequest -> { if (log.isDebugEnabled()) { diff --git a/http-client-jdk/src/test/groovy/io/micronaut/http/client/jdk/ServiceIdSpec.groovy b/http-client-jdk/src/test/groovy/io/micronaut/http/client/jdk/ServiceIdSpec.groovy new file mode 100644 index 00000000000..8000b082b1e --- /dev/null +++ b/http-client-jdk/src/test/groovy/io/micronaut/http/client/jdk/ServiceIdSpec.groovy @@ -0,0 +1,87 @@ +package io.micronaut.http.client.jdk + +import io.micronaut.context.ApplicationContext +import io.micronaut.context.annotation.Requires +import io.micronaut.http.HttpAttributes +import io.micronaut.http.HttpRequest +import io.micronaut.http.HttpResponse +import io.micronaut.http.HttpVersion +import io.micronaut.http.MutableHttpRequest +import io.micronaut.http.annotation.Controller +import io.micronaut.http.annotation.Filter +import io.micronaut.http.annotation.Get +import io.micronaut.http.client.HttpClientRegistry +import io.micronaut.http.client.annotation.Client +import io.micronaut.http.filter.ClientFilterChain +import io.micronaut.http.filter.HttpClientFilter +import io.micronaut.runtime.server.EmbeddedServer +import jakarta.inject.Singleton +import org.reactivestreams.Publisher +import spock.lang.AutoCleanup +import spock.lang.Specification + +class ServiceIdSpec extends Specification { + + @AutoCleanup + EmbeddedServer server = ApplicationContext.run(EmbeddedServer, [ + 'spec.name': 'ServiceIdSpec', + ]) + + @AutoCleanup + ApplicationContext clientCtx = ApplicationContext.run([ + 'spec.name': 'ServiceIdSpec', + 'micronaut.http.services.my-client-id.url': server.URI, + ]) + + def 'service id set by declarative client'() { + given: + def client = clientCtx.getBean(DeclarativeClient) + def filter = clientCtx.getBean(ServiceIdFilter) + + expect: + filter.serviceId == null + client.index() == "foo" + filter.serviceId == "my-client-id" + } + + def 'service id set by normal client'() { + given: + def client = clientCtx.getBean(HttpClientRegistry).getClient(HttpVersion.HTTP_1_1, "my-client-id", null) + def filter = clientCtx.getBean(ServiceIdFilter) + + expect: + filter.serviceId == null + client.toBlocking().exchange("/service-id", String).body() == "foo" + filter.serviceId == "my-client-id" + } + + @Client(id = "my-client-id") + @Requires(property = "spec.name", value = "ServiceIdSpec") + static interface DeclarativeClient { + @Get("/service-id") + String index() + } + + @Singleton + @Requires(property = "spec.name", value = "ServiceIdSpec") + @Controller("/service-id") + static class ServiceIdController { + @Get + def index(HttpRequest request) { + return "foo" + } + } + + @Singleton + @Requires(property = "spec.name", value = "ServiceIdSpec") + @Filter(Filter.MATCH_ALL_PATTERN) + static class ServiceIdFilter implements HttpClientFilter { + String serviceId + + @Override + Publisher> doFilter(MutableHttpRequest request, ClientFilterChain chain) { + serviceId = request.getAttribute(HttpAttributes.SERVICE_ID).orElse(null) + return chain.proceed(request) + } + } +}