Skip to content

Commit

Permalink
Support authentication on global region for AWS IAM
Browse files Browse the repository at this point in the history
  • Loading branch information
Amuerte committed Mar 4, 2023
1 parent 2d9463e commit 7585758
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import software.amazon.awssdk.auth.credentials.AwsCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.regions.providers.AwsRegionProvider;
import software.amazon.awssdk.regions.providers.DefaultAwsRegionProviderChain;

Expand Down Expand Up @@ -78,15 +79,24 @@ public class AwsIamAuthenticationOptions {
*/
private final URI endpointUri;

/**
* This parameter enables to sign the AWS request with the global region (us-east-1)
* in case the Vault server making the proxy request is configured to use STS global
* endpoint, and your application is deployed in another region.
*/
private final boolean useGlobalEndpoint;

private AwsIamAuthenticationOptions(String path, AwsCredentialsProvider credentialsProvider,
AwsRegionProvider regionProvider, @Nullable String role, @Nullable String serverId, URI endpointUri) {
AwsRegionProvider regionProvider, @Nullable String role, @Nullable String serverId, URI endpointUri,
boolean useGlobalEndpoint) {

this.path = path;
this.credentialsProvider = credentialsProvider;
this.regionProvider = regionProvider;
this.role = role;
this.serverId = serverId;
this.endpointUri = endpointUri;
this.useGlobalEndpoint = useGlobalEndpoint;
}

/**
Expand Down Expand Up @@ -163,6 +173,8 @@ public static class AwsIamAuthenticationOptionsBuilder {
@Nullable
private String serverId;

private boolean useGlobalEndpoint;

private URI endpointUri = URI.create("https://sts.amazonaws.com/");

AwsIamAuthenticationOptionsBuilder() {
Expand Down Expand Up @@ -282,16 +294,27 @@ public AwsIamAuthenticationOptionsBuilder endpointUri(URI endpointUri) {
return this;
}

public AwsIamAuthenticationOptionsBuilder useGlobalEndpoint(Boolean useGlobalEndpoint) {

Assert.notNull(useGlobalEndpoint, "Flag useGlobalEndpoint must not be null");

this.useGlobalEndpoint = useGlobalEndpoint;
return this;
}

/**
* Build a new {@link AwsIamAuthenticationOptions} instance.
* @return a new {@link AwsIamAuthenticationOptions}.
*/
public AwsIamAuthenticationOptions build() {

Assert.state(this.credentialsProvider != null, "Credentials or CredentialProvider must not be null");
if (useGlobalEndpoint) {
regionProvider(() -> Region.US_EAST_1);
}

return new AwsIamAuthenticationOptions(this.path, this.credentialsProvider, this.regionProvider, this.role,
this.serverId, this.endpointUri);
this.serverId, this.endpointUri, this.useGlobalEndpoint);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,15 @@
package org.springframework.vault.authentication;

import java.time.Duration;
import java.util.Base64;

import org.hamcrest.BaseMatcher;
import org.hamcrest.Description;
import org.hamcrest.Matcher;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Nested;
import org.junit.jupiter.api.Test;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.regions.Region;
Expand Down Expand Up @@ -102,4 +109,25 @@ void shouldUsingAuthenticationSteps() {
assertThat(((LoginToken) login).isRenewable()).isTrue();
}

@Nested
@DisplayName("Unit Tests AwsIamAuthenticationOptions")
class AwsIamAuthenticationOptionsUnitTests {

@Test
void shouldSignRequestOnGlobalRegion() {
AwsIamAuthenticationOptions options = AwsIamAuthenticationOptions.builder().role("foo-role")
.regionProvider(() -> Region.US_WEST_1).credentials(AwsBasicCredentials.create("foo", "bar"))
.useGlobalEndpoint(true).build();

assertThat(options.getRegionProvider().getRegion()).isEqualTo(Region.US_EAST_1);
}

@Test
void shouldThrowExceptionWhenUseGlobalRegionIsNull() {
assertThatThrownBy(() -> AwsIamAuthenticationOptions.builder().useGlobalEndpoint(null))
.isInstanceOf(IllegalArgumentException.class).hasMessageContaining("useGlobalEndpoint");
}

}

}

0 comments on commit 7585758

Please sign in to comment.