Skip to content

Commit

Permalink
Add support for IAM role based credentials in Kinesis Plugin (#9071)
Browse files Browse the repository at this point in the history
* Add support for IAM roles

* Add support for externalId

* Provide proper credentials to STS client

* Add default session id

* Add javadoc

Co-authored-by: Kartik Khare <[email protected]>
  • Loading branch information
KKcorps and Kartik Khare authored Jul 27, 2022
1 parent 9cf7d81 commit 49c0e24
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
*/
package org.apache.pinot.plugin.stream.kinesis;

import com.google.common.base.Joiner;
import com.google.common.base.Preconditions;
import java.util.Map;
import java.util.UUID;
import org.apache.pinot.spi.stream.StreamConfig;
import software.amazon.awssdk.services.kinesis.model.ShardIteratorType;

Expand All @@ -36,9 +38,32 @@ public class KinesisConfig {
public static final String MAX_RECORDS_TO_FETCH = "maxRecordsToFetch";
public static final String ENDPOINT = "endpoint";

// IAM role configs
/**
* Enable Role based access to AWS.
* iamRoleBasedAccessEnabled - Set it to `true` to enable role based access, default: false
* roleArn - Required. specify the ARN of the role the client should assume.
* roleSessionName - session name to be used when creating a role based session. default: pinot-kineis-uuid
* externalId - string external id value required by role's policy. default: null
* sessionDurationSeconds - The duration, in seconds, of the role session. Default: 900
* asyncSessionUpdateEnabled -
* Configure whether the provider should fetch credentials asynchronously in the background.
* If this is true, threads are less likely to block when credentials are loaded,
* but additional resources are used to maintain the provider. Default - `true`
*/
public static final String IAM_ROLE_BASED_ACCESS_ENABLED = "iamRoleBasedAccessEnabled";
public static final String ROLE_ARN = "roleArn";
public static final String ROLE_SESSION_NAME = "roleSessionName";
public static final String EXTERNAL_ID = "externalId";
public static final String SESSION_DURATION_SECONDS = "sessionDurationSeconds";
public static final String ASYNC_SESSION_UPDATED_ENABLED = "asyncSessionUpdateEnabled";

// TODO: this is a starting point, until a better default is figured out
public static final String DEFAULT_MAX_RECORDS = "20";
public static final String DEFAULT_SHARD_ITERATOR_TYPE = ShardIteratorType.LATEST.toString();
public static final String DEFAULT_IAM_ROLE_BASED_ACCESS_ENABLED = "false";
public static final String DEFAULT_SESSION_DURATION_SECONDS = "900";
public static final String DEFAULT_ASYNC_SESSION_UPDATED_ENABLED = "true";

private final String _streamTopicName;
private final String _awsRegion;
Expand All @@ -48,6 +73,14 @@ public class KinesisConfig {
private final String _secretKey;
private final String _endpoint;

// IAM Role values
private boolean _iamRoleBasedAccess;
private String _roleArn;
private String _roleSessionName;
private String _externalId;
private int _sessionDurationSeconds;
private boolean _asyncSessionUpdateEnabled;

public KinesisConfig(StreamConfig streamConfig) {
Map<String, String> props = streamConfig.getStreamConfigsMap();
_streamTopicName = streamConfig.getTopicName();
Expand All @@ -60,23 +93,23 @@ public KinesisConfig(StreamConfig streamConfig) {
_accessKey = props.get(ACCESS_KEY);
_secretKey = props.get(SECRET_KEY);
_endpoint = props.get(ENDPOINT);
}

public KinesisConfig(String streamTopicName, String awsRegion, ShardIteratorType shardIteratorType, String accessKey,
String secretKey, String endpoint) {
this(streamTopicName, awsRegion, shardIteratorType, accessKey, secretKey, Integer.parseInt(DEFAULT_MAX_RECORDS),
endpoint);
}
_iamRoleBasedAccess =
Boolean.parseBoolean(props.getOrDefault(IAM_ROLE_BASED_ACCESS_ENABLED, DEFAULT_IAM_ROLE_BASED_ACCESS_ENABLED));
_roleArn = props.get(ROLE_ARN);
_roleSessionName =
props.getOrDefault(ROLE_SESSION_NAME, Joiner.on("-").join("pinot", "kinesis", UUID.randomUUID()));
_externalId = props.get(EXTERNAL_ID);
_sessionDurationSeconds =
Integer.parseInt(props.getOrDefault(SESSION_DURATION_SECONDS, DEFAULT_SESSION_DURATION_SECONDS));
_asyncSessionUpdateEnabled =
Boolean.parseBoolean(props.getOrDefault(ASYNC_SESSION_UPDATED_ENABLED, DEFAULT_ASYNC_SESSION_UPDATED_ENABLED));

public KinesisConfig(String streamTopicName, String awsRegion, ShardIteratorType shardIteratorType, String accessKey,
String secretKey, int maxRecords, String endpoint) {
_streamTopicName = streamTopicName;
_awsRegion = awsRegion;
_shardIteratorType = shardIteratorType;
_accessKey = accessKey;
_secretKey = secretKey;
_numMaxRecordsToFetch = maxRecords;
_endpoint = endpoint;
if (_iamRoleBasedAccess) {
Preconditions.checkNotNull(_roleArn,
"Must provide 'roleArn' in stream config for table %s if iamRoleBasedAccess is enabled",
streamConfig.getTableNameWithType());
}
}

public String getStreamTopicName() {
Expand Down Expand Up @@ -106,4 +139,28 @@ public String getSecretKey() {
public String getEndpoint() {
return _endpoint;
}

public boolean isIamRoleBasedAccess() {
return _iamRoleBasedAccess;
}

public String getRoleArn() {
return _roleArn;
}

public String getRoleSessionName() {
return _roleSessionName;
}

public String getExternalId() {
return _externalId;
}

public int getSessionDurationSeconds() {
return _sessionDurationSeconds;
}

public boolean isAsyncSessionUpdateEnabled() {
return _asyncSessionUpdateEnabled;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.util.List;
import org.apache.commons.lang3.StringUtils;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.http.apache.ApacheSdkHttpService;
Expand All @@ -33,6 +34,9 @@
import software.amazon.awssdk.services.kinesis.model.ListShardsRequest;
import software.amazon.awssdk.services.kinesis.model.ListShardsResponse;
import software.amazon.awssdk.services.kinesis.model.Shard;
import software.amazon.awssdk.services.sts.StsClient;
import software.amazon.awssdk.services.sts.auth.StsAssumeRoleCredentialsProvider;
import software.amazon.awssdk.services.sts.model.AssumeRoleRequest;


/**
Expand All @@ -45,13 +49,15 @@ public class KinesisConnectionHandler {
private final String _accessKey;
private final String _secretKey;
private final String _endpoint;
private final KinesisConfig _kinesisConfig;

public KinesisConnectionHandler(KinesisConfig kinesisConfig) {
_stream = kinesisConfig.getStreamTopicName();
_region = kinesisConfig.getAwsRegion();
_accessKey = kinesisConfig.getAccessKey();
_secretKey = kinesisConfig.getSecretKey();
_endpoint = kinesisConfig.getEndpoint();
_kinesisConfig = kinesisConfig;
createConnection();
}

Expand All @@ -62,6 +68,7 @@ public KinesisConnectionHandler(KinesisConfig kinesisConfig, KinesisClient kines
_accessKey = kinesisConfig.getAccessKey();
_secretKey = kinesisConfig.getSecretKey();
_endpoint = kinesisConfig.getEndpoint();
_kinesisConfig = kinesisConfig;
_kinesisClient = kinesisClient;
}

Expand All @@ -80,17 +87,51 @@ public List<Shard> getShards() {
public void createConnection() {
if (_kinesisClient == null) {
KinesisClientBuilder kinesisClientBuilder;

AwsCredentialsProvider awsCredentialsProvider;
if (StringUtils.isNotBlank(_accessKey) && StringUtils.isNotBlank(_secretKey)) {
AwsBasicCredentials awsBasicCredentials = AwsBasicCredentials.create(_accessKey, _secretKey);
kinesisClientBuilder = KinesisClient.builder().region(Region.of(_region))
.credentialsProvider(StaticCredentialsProvider.create(awsBasicCredentials))
.httpClientBuilder(new ApacheSdkHttpService().createHttpClientBuilder());
awsCredentialsProvider = StaticCredentialsProvider.create(awsBasicCredentials);
} else {
kinesisClientBuilder =
KinesisClient.builder().region(Region.of(_region)).credentialsProvider(DefaultCredentialsProvider.create())
.httpClientBuilder(new ApacheSdkHttpService().createHttpClientBuilder());
awsCredentialsProvider = DefaultCredentialsProvider.create();
}

if (_kinesisConfig.isIamRoleBasedAccess()) {
AssumeRoleRequest.Builder assumeRoleRequestBuilder =
AssumeRoleRequest.builder()
.roleArn(_kinesisConfig.getRoleArn())
.roleSessionName(_kinesisConfig.getRoleSessionName())
.durationSeconds(_kinesisConfig.getSessionDurationSeconds());

AssumeRoleRequest assumeRoleRequest;
if (StringUtils.isNotEmpty(_kinesisConfig.getExternalId())) {
assumeRoleRequest = assumeRoleRequestBuilder
.externalId(_kinesisConfig.getExternalId())
.build();
} else {
assumeRoleRequest = assumeRoleRequestBuilder.build();
}

StsClient stsClient =
StsClient.builder()
.region(Region.of(_region))
.credentialsProvider(awsCredentialsProvider)
.build();

awsCredentialsProvider =
StsAssumeRoleCredentialsProvider.builder()
.stsClient(stsClient)
.refreshRequest(assumeRoleRequest)
.asyncCredentialUpdateEnabled(_kinesisConfig.isAsyncSessionUpdateEnabled())
.build();
}

kinesisClientBuilder =
KinesisClient.builder()
.region(Region.of(_region))
.credentialsProvider(awsCredentialsProvider)
.httpClientBuilder(new ApacheSdkHttpService().createHttpClientBuilder());

if (StringUtils.isNotBlank(_endpoint)) {
try {
kinesisClientBuilder = kinesisClientBuilder.endpointOverride(new URI(_endpoint));
Expand Down

0 comments on commit 49c0e24

Please sign in to comment.