Skip to content

Commit

Permalink
Changed the way request uris are generated to address #138. In the fu…
Browse files Browse the repository at this point in the history
…ture, it would be ideal to get the full request uri from API Gateway in the event.
  • Loading branch information
sapessi committed Apr 5, 2018
1 parent 6c90756 commit f1ccaaa
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Set;


/**
Expand All @@ -18,6 +21,51 @@
public final class SecurityUtils {
private static Logger log = LoggerFactory.getLogger(SecurityUtils.class);

private static Set<String> SCHEMES = new HashSet<String>() {{
add("http");
add("https");
add("HTTP");
add("HTTPS");
}};

private static Set<Integer> PORTS = new HashSet<Integer>() {{
add(443);
add(80);
add(3000); // we allow port 3000 for SAM local
}};

public static boolean isValidPort(String port) {
if (port == null) {
return false;
}
try {
int intPort = Integer.parseInt(port);
return PORTS.contains(intPort);
} catch (NumberFormatException e) {
log.error("Invalid port parameter: " + crlf(port));
return false;
}
}

public static boolean isValidScheme(String scheme) {
return SCHEMES.contains(scheme);
}

public static boolean isValidHost(String host, String apiId, String region) {
if (host == null) {
return false;
}
if (host.endsWith(".amazonaws.com")) {
String defaultHost = new StringBuilder().append(apiId)
.append(".execute-api.")
.append(region)
.append(".amazonaws.com").toString();
return host.equals(defaultHost);
} else {
return LambdaContainerHandler.getContainerConfig().getCustomDomainNames().contains(host);
}
}

/**
* Replaces CRLF characters in a string with empty string ("").
* @param s The string to be cleaned
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ public abstract class AwsHttpServletRequest implements HttpServletRequest {
// We need this to pickup the protocol from the CloudFront header since Lambda doesn't receive this
// information from anywhere else
static final String CF_PROTOCOL_HEADER_NAME = "CloudFront-Forwarded-Proto";
static final String PROTOCOL_HEADER_NAME = "X-Forwarded-Proto";
static final String HOST_HEADER_NAME = "Host";
static final String PORT_HEADER_NAME = "X-Forwarded-Port";


//-------------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -501,22 +501,44 @@ public String getProtocol() {

@Override
public String getScheme() {
String headerValue = getHeaderCaseInsensitive(CF_PROTOCOL_HEADER_NAME);
if (headerValue == null) {
return "https";
String cfScheme = getHeaderCaseInsensitive(CF_PROTOCOL_HEADER_NAME);
if (cfScheme != null && SecurityUtils.isValidScheme(cfScheme)) {
return cfScheme;
}
String gwScheme = getHeaderCaseInsensitive(PROTOCOL_HEADER_NAME);
if (gwScheme != null && SecurityUtils.isValidScheme(gwScheme)) {
return gwScheme;
}
return headerValue;
// https is our default scheme
return "https";
}


@Override
public String getServerName() {
String name = getHeaderCaseInsensitive(HttpHeaders.HOST);
String region = System.getenv("AWS_REGION");
if (region == null) {
// this is not a critical failure, we just put a static region in the URI
region = "us-east-1";
}

if (name == null || name.length() == 0) {
name = "lambda.amazonaws.com";
String hostHeader = getHeaderCaseInsensitive(HOST_HEADER_NAME);
if (hostHeader != null && SecurityUtils.isValidHost(hostHeader, request.getRequestContext().getApiId(), region)) {
return hostHeader;
}

return new StringBuilder().append(request.getRequestContext().getApiId())
.append(".execute-api.")
.append(region)
.append(".amazonaws.com").toString();
}

public int getServerPort() {
String port = getHeaderCaseInsensitive(PORT_HEADER_NAME);
if (SecurityUtils.isValidPort(port)) {
return Integer.parseInt(port);
} else {
return 443; // default port
}
return name;
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@ public static ContainerConfig defaultConfig() {
private boolean consolidateSetCookieHeaders;
private boolean useStageAsServletContext;
private List<String> validFilePaths;
private List<String> customDomainNames;

public ContainerConfig() {
validFilePaths = new ArrayList<>();
customDomainNames = new ArrayList<>();
}


Expand Down Expand Up @@ -168,4 +170,31 @@ public void setValidFilePaths(List<String> validFilePaths) {
public void addValidFilePath(String filePath) {
validFilePaths.add(filePath);
}


/**
* Adds a new custom domain name to the list of allowed domains
* @param name The new custom domain name, excluding the scheme ("https") and port
*/
public void addCustomDomain(String name) {
customDomainNames.add(name);
}


/**
* Returns the list of custom domain names enabled for the application
* @return The configured custom domain names
*/
public List<String> getCustomDomainNames() {
return customDomainNames;
}


/**
* Enables localhost custom domain name for testing. This setting should be used only in local
* with SAM local
*/
public void enableLocalhost() {
customDomainNames.add("localhost");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -165,23 +165,19 @@ private ContainerRequest servletRequestToContainerRequest(ServletRequest request
return requestContext;
}

@SuppressFBWarnings("SERVLET_SERVER_NAME")
private URI getBaseUri(ServletRequest request, String basePath) {
ApiGatewayRequestContext apiGatewayCtx = (ApiGatewayRequestContext) request.getAttribute(API_GATEWAY_CONTEXT_PROPERTY);
String region = System.getenv("AWS_REGION");
if (region == null) {
// this is not a critical failure, we just put a static region in the URI
region = "us-east-1";
String finalBasePath = basePath;
if (!finalBasePath.startsWith("/")) {
finalBasePath = "/" + finalBasePath;
}
StringBuilder uriBuilder = new StringBuilder();
uriBuilder.append("https://") // we assume it's always https
.append(apiGatewayCtx.getApiId())
.append(".execute-api.")
.append(region)
.append(".amazonaws.com")
.append("/");


return UriBuilder.fromUri(uriBuilder.toString()).build();
String uriString = new StringBuilder().append(request.getScheme())
.append("://")
.append(request.getServerName())
.append(":")
.append(request.getServerPort())
.append(finalBasePath).toString();
return UriBuilder.fromUri(uriString).build();
}

//-------------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.amazonaws.serverless.proxy.spring;

import com.amazonaws.serverless.proxy.internal.LambdaContainerHandler;
import com.amazonaws.serverless.proxy.model.AwsProxyRequest;
import com.amazonaws.serverless.proxy.model.AwsProxyResponse;
import com.amazonaws.serverless.proxy.internal.servlet.AwsServletContext;
Expand Down Expand Up @@ -335,6 +336,7 @@ public void contextPath_generateLink_returnsCorrectPath() {
.serverName("api.myserver.com")
.stage("prod")
.build();
LambdaContainerHandler.getContainerConfig().addCustomDomain("api.myserver.com");
SpringLambdaContainerHandler.getContainerConfig().setUseStageAsServletContext(true);

AwsProxyResponse output = handler.proxy(request, lambdaContext);
Expand Down

0 comments on commit f1ccaaa

Please sign in to comment.