Skip to content

Commit

Permalink
Find CORS config by HandlerMethod
Browse files Browse the repository at this point in the history
Before this change AbstractHandlerMethodMapping used a map from Method
to CorsConfiguration. That works for regular @RequestMapping methods.
However frameworks like Spring Boot and Spring Integration may
programmatically register the same Method under multiple mappings,
i.e. adapter/gateway type classes.

This change ensures that CorsConfiguraiton is indexed by HandlerMethod
so that we can store CorsConfiguration for different handler instances
even when the method is the same.

In order for to make this work, HandlerMethod now provides an
additional field called resolvedFromHandlerMethod that returns the
original HandlerMethod (with the String bean name). This makes it
possible to  perform reliable lookups.

Issue: SPR-11541
  • Loading branch information
rstoyanchev committed May 5, 2015
1 parent 4a8baeb commit 8853107
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ public class HandlerMethod {

private final MethodParameter[] parameters;

private final HandlerMethod resolvedFromHandlerMethod;


/**
* Create an instance from a bean instance and a method.
Expand All @@ -73,6 +75,7 @@ public HandlerMethod(Object bean, Method method) {
this.method = method;
this.bridgedMethod = BridgeMethodResolver.findBridgedMethod(method);
this.parameters = initMethodParameters();
this.resolvedFromHandlerMethod = null;
}

/**
Expand All @@ -88,6 +91,7 @@ public HandlerMethod(Object bean, String methodName, Class<?>... parameterTypes)
this.method = bean.getClass().getMethod(methodName, parameterTypes);
this.bridgedMethod = BridgeMethodResolver.findBridgedMethod(this.method);
this.parameters = initMethodParameters();
this.resolvedFromHandlerMethod = null;
}

/**
Expand All @@ -105,6 +109,7 @@ public HandlerMethod(String beanName, BeanFactory beanFactory, Method method) {
this.method = method;
this.bridgedMethod = BridgeMethodResolver.findBridgedMethod(method);
this.parameters = initMethodParameters();
this.resolvedFromHandlerMethod = null;
}

/**
Expand All @@ -118,6 +123,7 @@ protected HandlerMethod(HandlerMethod handlerMethod) {
this.method = handlerMethod.method;
this.bridgedMethod = handlerMethod.bridgedMethod;
this.parameters = handlerMethod.parameters;
this.resolvedFromHandlerMethod = handlerMethod.resolvedFromHandlerMethod;
}

/**
Expand All @@ -132,6 +138,7 @@ private HandlerMethod(HandlerMethod handlerMethod, Object handler) {
this.method = handlerMethod.method;
this.bridgedMethod = handlerMethod.bridgedMethod;
this.parameters = handlerMethod.parameters;
this.resolvedFromHandlerMethod = handlerMethod;
}


Expand Down Expand Up @@ -182,6 +189,14 @@ public MethodParameter[] getMethodParameters() {
return this.parameters;
}

/**
* Return the HandlerMethod from which this HandlerMethod instance was
* resolved via {@link #createWithResolvedBean()}.
*/
public HandlerMethod getResolvedFromHandlerMethod() {
return this.resolvedFromHandlerMethod;
}

/**
* Return the HandlerMethod return type.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -467,8 +467,8 @@ class MappingRegistry {
private final Map<String, List<HandlerMethod>> nameLookup =
new ConcurrentHashMap<String, List<HandlerMethod>>();

private final Map<Method, CorsConfiguration> corsLookup =
new ConcurrentHashMap<Method, CorsConfiguration>();
private final Map<HandlerMethod, CorsConfiguration> corsLookup =
new ConcurrentHashMap<HandlerMethod, CorsConfiguration>();


private final ReentrantReadWriteLock readWriteLock = new ReentrantReadWriteLock();
Expand Down Expand Up @@ -501,8 +501,8 @@ public List<HandlerMethod> getHandlerMethodsByMappingName(String mappingName) {
* Return CORS configuration. Thread-safe for concurrent use.
*/
public CorsConfiguration getCorsConfiguration(HandlerMethod handlerMethod) {
Method method = handlerMethod.getMethod();
return this.corsLookup.get(method);
HandlerMethod original = handlerMethod.getResolvedFromHandlerMethod();
return this.corsLookup.get(original != null ? original : handlerMethod);
}

/**
Expand Down Expand Up @@ -545,11 +545,10 @@ public void register(T mapping, Object handler, Method method) {

CorsConfiguration corsConfig = initCorsConfiguration(handler, method, mapping);
if (corsConfig != null) {
this.corsLookup.put(method, corsConfig);
this.corsLookup.put(handlerMethod, corsConfig);
}

this.registry.put(mapping,
new MappingRegistration<T>(mapping, handlerMethod, directUrls, name, corsConfig));
this.registry.put(mapping, new MappingRegistration<T>(mapping, handlerMethod, directUrls, name));
}
finally {
this.readWriteLock.writeLock().unlock();
Expand Down Expand Up @@ -582,7 +581,7 @@ private void addMappingName(String name, HandlerMethod handlerMethod) {
this.nameLookup.get(name) : Collections.<HandlerMethod>emptyList();

for (HandlerMethod current : oldList) {
if (handlerMethod.getMethod().equals(current.getMethod())) {
if (handlerMethod.equals(current)) {
return;
}
}
Expand All @@ -597,8 +596,8 @@ private void addMappingName(String name, HandlerMethod handlerMethod) {
this.nameLookup.put(name, newList);

if (newList.size() > 1) {
if (logger.isDebugEnabled()) {
logger.debug("Mapping name clash for handlerMethods=" + newList +
if (logger.isTraceEnabled()) {
logger.trace("Mapping name clash for handlerMethods=" + newList +
". Consider assigning explicit names.");
}
}
Expand Down Expand Up @@ -626,7 +625,7 @@ public void unregister(T mapping) {

removeMappingName(definition);

this.corsLookup.remove(definition.getHandlerMethod().getMethod());
this.corsLookup.remove(definition.getHandlerMethod());
}
finally {
this.readWriteLock.writeLock().unlock();
Expand Down Expand Up @@ -668,11 +667,9 @@ private static class MappingRegistration<T> {

private final String mappingName;

private final CorsConfiguration corsConfiguration;


public MappingRegistration(T mapping, HandlerMethod handlerMethod, List<String> directUrls,
String mappingName, CorsConfiguration corsConfiguration) {
public MappingRegistration(T mapping, HandlerMethod handlerMethod,
List<String> directUrls, String mappingName) {

Assert.notNull(mapping);
Assert.notNull(handlerMethod);
Expand All @@ -681,7 +678,6 @@ public MappingRegistration(T mapping, HandlerMethod handlerMethod, List<String>
this.handlerMethod = handlerMethod;
this.directUrls = (directUrls != null ? directUrls : Collections.<String>emptyList());
this.mappingName = mappingName;
this.corsConfiguration = corsConfiguration;
}


Expand All @@ -700,10 +696,6 @@ public List<String> getDirectUrls() {
public String getMappingName() {
return this.mappingName;
}

public CorsConfiguration getCorsConfiguration() {
return this.corsConfiguration;
}
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,15 @@
import org.junit.Before;
import org.junit.Test;

import org.springframework.beans.factory.support.StaticListableBeanFactory;
import org.springframework.context.support.GenericApplicationContext;
import org.springframework.context.support.StaticApplicationContext;
import org.springframework.mock.web.test.MockHttpServletRequest;
import org.springframework.stereotype.Controller;
import org.springframework.util.AntPathMatcher;
import org.springframework.util.PathMatcher;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.context.support.StaticWebApplicationContext;
import org.springframework.web.cors.CorsConfiguration;
import org.springframework.web.method.HandlerMethod;
import org.springframework.web.util.UrlPathHelper;
Expand All @@ -44,6 +47,7 @@
* Test for {@link AbstractHandlerMethodMapping}.
*
* @author Arjen Poutsma
* @author Rossen Stoyanchev
*/
@SuppressWarnings("unused")
public class HandlerMethodMappingTests {
Expand Down Expand Up @@ -153,11 +157,53 @@ public void registerMapping() throws Exception {

CorsConfiguration config = this.mapping.getMappingRegistry().getCorsConfiguration(handlerMethod1);
assertNotNull(config);
assertEquals("http://" + name1, config.getAllowedOrigins().get(0));
assertEquals("http://" + handler.hashCode() + name1, config.getAllowedOrigins().get(0));

config = this.mapping.getMappingRegistry().getCorsConfiguration(handlerMethod2);
assertNotNull(config);
assertEquals("http://" + name2, config.getAllowedOrigins().get(0));
assertEquals("http://" + handler.hashCode() + name2, config.getAllowedOrigins().get(0));
}

@Test
public void registerMappingWithSameMethodAndTwoHandlerInstances() throws Exception {

String key1 = "foo";
String key2 = "bar";

MyHandler handler1 = new MyHandler();
MyHandler handler2 = new MyHandler();

HandlerMethod handlerMethod1 = new HandlerMethod(handler1, this.method1);
HandlerMethod handlerMethod2 = new HandlerMethod(handler2, this.method1);

this.mapping.registerMapping(key1, handler1, this.method1);
this.mapping.registerMapping(key2, handler2, this.method1);

// Direct URL lookup

List directUrlMatches = this.mapping.getMappingRegistry().getMappingsByUrl(key1);
assertNotNull(directUrlMatches);
assertEquals(1, directUrlMatches.size());
assertEquals(key1, directUrlMatches.get(0));

// Mapping name lookup

String name = this.method1.getName();
List<HandlerMethod> handlerMethods = this.mapping.getMappingRegistry().getHandlerMethodsByMappingName(name);
assertNotNull(handlerMethods);
assertEquals(2, handlerMethods.size());
assertEquals(handlerMethod1, handlerMethods.get(0));
assertEquals(handlerMethod2, handlerMethods.get(1));

// CORS lookup

CorsConfiguration config = this.mapping.getMappingRegistry().getCorsConfiguration(handlerMethod1);
assertNotNull(config);
assertEquals("http://" + handler1.hashCode() + name, config.getAllowedOrigins().get(0));

config = this.mapping.getMappingRegistry().getCorsConfiguration(handlerMethod2);
assertNotNull(config);
assertEquals("http://" + handler2.hashCode() + name, config.getAllowedOrigins().get(0));
}

@Test
Expand All @@ -176,6 +222,25 @@ public void unregisterMapping() throws Exception {
assertNull(this.mapping.getMappingRegistry().getCorsConfiguration(handlerMethod));
}

@Test
public void getCorsConfigWithBeanNameHandler() throws Exception {

String key = "foo";
String beanName = "handler1";

StaticWebApplicationContext context = new StaticWebApplicationContext();
context.registerSingleton(beanName, MyHandler.class);

this.mapping.setApplicationContext(context);
this.mapping.registerMapping(key, beanName, this.method1);
HandlerMethod handlerMethod = this.mapping.getHandlerInternal(new MockHttpServletRequest("GET", key));

CorsConfiguration config = this.mapping.getMappingRegistry().getCorsConfiguration(handlerMethod);
assertNotNull(config);
assertEquals("http://" + beanName.hashCode() + this.method1.getName(), config.getAllowedOrigins().get(0));
}



private static class MyHandlerMethodMapping extends AbstractHandlerMethodMapping<String> {

Expand Down Expand Up @@ -207,7 +272,7 @@ protected Set<String> getMappingPathPatterns(String key) {
@Override
protected CorsConfiguration initCorsConfiguration(Object handler, Method method, String mapping) {
CorsConfiguration corsConfig = new CorsConfiguration();
corsConfig.setAllowedOrigins(Collections.singletonList("http://" + method.getName()));
corsConfig.setAllowedOrigins(Collections.singletonList("http://" + handler.hashCode() + method.getName()));
return corsConfig;
}

Expand Down

0 comments on commit 8853107

Please sign in to comment.