diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/ExtendedWebExchangeDataBinder.java b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/ExtendedWebExchangeDataBinder.java index 0863499d8670..3bba7da3d88f 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/ExtendedWebExchangeDataBinder.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/ExtendedWebExchangeDataBinder.java @@ -24,6 +24,7 @@ import org.springframework.http.HttpHeaders; import org.springframework.lang.Nullable; import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; import org.springframework.web.bind.support.WebExchangeDataBinder; import org.springframework.web.reactive.HandlerMapping; import org.springframework.web.server.ServerWebExchange; @@ -57,7 +58,11 @@ public Mono> getValuesToBind(ServerWebExchange exchange) { for (Map.Entry> entry : headers.entrySet()) { List values = entry.getValue(); if (!CollectionUtils.isEmpty(values)) { - String name = entry.getKey().replace("-", ""); + // For constructor args with @BindParam mapped to the actual header name + String name = entry.getKey(); + addValueIfNotPresent(map, "Header", name, (values.size() == 1 ? values.get(0) : values)); + // Also adapt to Java conventions for setters + name = StringUtils.uncapitalize(entry.getKey().replace("-", "")); addValueIfNotPresent(map, "Header", name, (values.size() == 1 ? values.get(0) : values)); } } diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/InitBinderBindingContextTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/InitBinderBindingContextTests.java index b12f755ec6a9..159c687d572c 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/InitBinderBindingContextTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/result/method/annotation/InitBinderBindingContextTests.java @@ -27,10 +27,12 @@ import org.springframework.beans.testfixture.beans.TestBean; import org.springframework.core.DefaultParameterNameDiscoverer; import org.springframework.core.ReactiveAdapterRegistry; +import org.springframework.core.ResolvableType; import org.springframework.core.convert.ConversionService; import org.springframework.format.support.DefaultFormattingConversionService; import org.springframework.http.MediaType; import org.springframework.web.bind.WebDataBinder; +import org.springframework.web.bind.annotation.BindParam; import org.springframework.web.bind.annotation.InitBinder; import org.springframework.web.bind.annotation.RequestParam; import org.springframework.web.bind.support.ConfigurableWebBindingInitializer; @@ -129,7 +131,7 @@ void createBinderTypeConversion() throws Exception { } @Test - void bindUriVariablesAndHeaders() throws Exception { + void bindUriVariablesAndHeadersViaSetters() throws Exception { MockServerHttpRequest request = MockServerHttpRequest.get("/path") .header("Some-Int-Array", "1") @@ -153,6 +155,31 @@ void bindUriVariablesAndHeaders() throws Exception { assertThat(target.getSomeIntArray()).containsExactly(1, 2); } + @Test + void bindUriVariablesAndHeadersViaConstructor() throws Exception { + + MockServerHttpRequest request = MockServerHttpRequest.get("/path") + .header("Some-Int-Array", "1") + .header("Some-Int-Array", "2") + .build(); + + MockServerWebExchange exchange = MockServerWebExchange.from(request); + exchange.getAttributes().put( + HandlerMapping.URI_TEMPLATE_VARIABLES_ATTRIBUTE, + Map.of("name", "John", "age", "25")); + + BindingContext context = createBindingContext("initBinderWithAttributeName", WebDataBinder.class); + WebExchangeDataBinder binder = context.createDataBinder(exchange, null, "dataBean", null); + binder.setTargetType(ResolvableType.forClass(DataBean.class)); + binder.construct(exchange).block(); + + DataBean bean = (DataBean) binder.getTarget(); + + assertThat(bean.name()).isEqualTo("John"); + assertThat(bean.age()).isEqualTo(25); + assertThat(bean.someIntArray()).containsExactly(1, 2); + } + @Test void bindUriVarsAndHeadersAddedConditionally() throws Exception { @@ -212,4 +239,8 @@ public void initBinderTypeConversion(WebDataBinder dataBinder, @RequestParam int } } + + private record DataBean(String name, int age, @BindParam("Some-Int-Array") Integer[] someIntArray) { + } + } diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ExtendedServletRequestDataBinder.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ExtendedServletRequestDataBinder.java index 4d4e26a131a4..c74b37fab8f3 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ExtendedServletRequestDataBinder.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/ExtendedServletRequestDataBinder.java @@ -156,6 +156,9 @@ protected Object getRequestParameter(String name, Class type) { if (uriVars != null) { value = uriVars.get(name); } + if (value == null && getRequest() instanceof HttpServletRequest httpServletRequest) { + value = getHeaderValue(httpServletRequest, name); + } } return value; } @@ -167,6 +170,13 @@ protected Set initParameterNames(ServletRequest request) { if (uriVars != null) { set.addAll(uriVars.keySet()); } + if (request instanceof HttpServletRequest httpServletRequest) { + Enumeration enumeration = httpServletRequest.getHeaderNames(); + while (enumeration.hasMoreElements()) { + String headerName = enumeration.nextElement(); + set.add(headerName.replaceAll("-", "")); + } + } return set; } } diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ExtendedServletRequestDataBinderTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ExtendedServletRequestDataBinderTests.java index 83f64ca1b8c3..36fd05508cd8 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ExtendedServletRequestDataBinderTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ExtendedServletRequestDataBinderTests.java @@ -22,7 +22,10 @@ import org.junit.jupiter.api.Test; import org.springframework.beans.testfixture.beans.TestBean; +import org.springframework.core.ResolvableType; import org.springframework.web.bind.ServletRequestDataBinder; +import org.springframework.web.bind.annotation.BindParam; +import org.springframework.web.bind.support.BindParamNameResolver; import org.springframework.web.servlet.HandlerMapping; import org.springframework.web.testfixture.servlet.MockHttpServletRequest; @@ -45,7 +48,7 @@ void setup() { @Test - void createBinder() { + void createBinderViaSetters() { request.setAttribute( HandlerMapping.URI_TEMPLATE_VARIABLES_ATTRIBUTE, Map.of("name", "John", "age", "25")); @@ -62,6 +65,27 @@ void createBinder() { assertThat(target.getSomeIntArray()).containsExactly(1, 2); } + @Test + void createBinderViaConstructor() { + request.setAttribute( + HandlerMapping.URI_TEMPLATE_VARIABLES_ATTRIBUTE, + Map.of("name", "John", "age", "25")); + + request.addHeader("Some-Int-Array", "1"); + request.addHeader("Some-Int-Array", "2"); + + ServletRequestDataBinder binder = new ExtendedServletRequestDataBinder(null); + binder.setTargetType(ResolvableType.forClass(DataBean.class)); + binder.setNameResolver(new BindParamNameResolver()); + binder.construct(request); + + DataBean bean = (DataBean) binder.getTarget(); + + assertThat(bean.name()).isEqualTo("John"); + assertThat(bean.age()).isEqualTo(25); + assertThat(bean.someIntArray()).containsExactly(1, 2); + } + @Test void uriVarsAndHeadersAddedConditionally() { request.addParameter("name", "John"); @@ -88,4 +112,8 @@ void noUriTemplateVars() { assertThat(target.getAge()).isEqualTo(0); } + + private record DataBean(String name, int age, @BindParam("Some-Int-Array") Integer[] someIntArray) { + } + }