Skip to content

Commit

Permalink
Merge pull request #31070 from DrFaust92/sagemaker-end-args
Browse files Browse the repository at this point in the history
r/sagemaker_endpoint_configuration - add `async_inference_config` and `s3_failure_path`
  • Loading branch information
ewbankkit authored May 3, 2023
2 parents db3340f + 5989cfc commit a2e93c5
Show file tree
Hide file tree
Showing 4 changed files with 188 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .changelog/31070.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:enhancement
resource/aws_sagemaker_endpoint_configuration: Add `async_inference_config.output_config.notification_config.include_inference_response_in` and `async_inference_config.output_config.s3_failure_path` arguments
```
34 changes: 34 additions & 0 deletions internal/service/sagemaker/endpoint_configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,15 @@ func ResourceEndpointConfiguration() *schema.Resource {
ForceNew: true,
ValidateFunc: verify.ValidARN,
},
"include_inference_response_in": {
Type: schema.TypeSet,
Optional: true,
ForceNew: true,
Elem: &schema.Schema{
Type: schema.TypeString,
ValidateFunc: validation.StringInSlice(sagemaker.AsyncNotificationTopicTypes_Values(), false),
},
},
"success_topic": {
Type: schema.TypeString,
Optional: true,
Expand All @@ -97,6 +106,15 @@ func ResourceEndpointConfiguration() *schema.Resource {
},
},
},
"s3_failure_path": {
Type: schema.TypeString,
Optional: true,
ForceNew: true,
ValidateFunc: validation.All(
validation.StringMatch(regexp.MustCompile(`^(https|s3)://([^/])/?(.*)$`), ""),
validation.StringLenBetween(1, 512),
),
},
"s3_output_path": {
Type: schema.TypeString,
Required: true,
Expand Down Expand Up @@ -848,6 +866,10 @@ func expandEndpointConfigOutputConfig(configured []interface{}) *sagemaker.Async
c.KmsKeyId = aws.String(v)
}

if v, ok := m["s3_failure_path"].(string); ok && v != "" {
c.S3FailurePath = aws.String(v)
}

if v, ok := m["notification_config"].([]interface{}); ok && len(v) > 0 {
c.NotificationConfig = expandEndpointConfigNotificationConfig(v)
}
Expand All @@ -872,6 +894,10 @@ func expandEndpointConfigNotificationConfig(configured []interface{}) *sagemaker
c.SuccessTopic = aws.String(v)
}

if v, ok := m["include_inference_response_in"].(*schema.Set); ok && v.Len() > 0 {
c.IncludeInferenceResponseIn = flex.ExpandStringSet(v)
}

return c
}

Expand Down Expand Up @@ -964,6 +990,10 @@ func flattenEndpointConfigOutputConfig(config *sagemaker.AsyncInferenceOutputCon
cfg["notification_config"] = flattenEndpointConfigNotificationConfig(config.NotificationConfig)
}

if config.S3FailurePath != nil {
cfg["s3_failure_path"] = aws.StringValue(config.S3FailurePath)
}

return []map[string]interface{}{cfg}
}

Expand All @@ -982,6 +1012,10 @@ func flattenEndpointConfigNotificationConfig(config *sagemaker.AsyncInferenceNot
cfg["success_topic"] = aws.StringValue(config.SuccessTopic)
}

if config.IncludeInferenceResponseIn != nil {
cfg["include_inference_response_in"] = flex.FlattenStringSet(config.IncludeInferenceResponseIn)
}

return []map[string]interface{}{cfg}
}

Expand Down
149 changes: 149 additions & 0 deletions internal/service/sagemaker/endpoint_configuration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,41 @@ func TestAccSageMakerEndpointConfiguration_async(t *testing.T) {
})
}

func TestAccSageMakerEndpointConfiguration_async_includeInference(t *testing.T) {
ctx := acctest.Context(t)
rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix)
resourceName := "aws_sagemaker_endpoint_configuration.test"

resource.ParallelTest(t, resource.TestCase{
PreCheck: func() { acctest.PreCheck(ctx, t) },
ErrorCheck: acctest.ErrorCheck(t, sagemaker.EndpointsID),
ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories,
CheckDestroy: testAccCheckEndpointConfigurationDestroy(ctx),
Steps: []resource.TestStep{
{
Config: testAccEndpointConfigurationConfig_asyncNotifInferenceIn(rName),
Check: resource.ComposeTestCheckFunc(
testAccCheckEndpointConfigurationExists(ctx, resourceName),
resource.TestCheckResourceAttr(resourceName, "name", rName),
resource.TestCheckResourceAttr(resourceName, "async_inference_config.#", "1"),
resource.TestCheckResourceAttr(resourceName, "async_inference_config.0.output_config.#", "1"),
resource.TestCheckResourceAttr(resourceName, "async_inference_config.0.output_config.0.notification_config.#", "1"),
resource.TestCheckResourceAttrPair(resourceName, "async_inference_config.0.output_config.0.notification_config.0.error_topic", "aws_sns_topic.test", "arn"),
resource.TestCheckResourceAttrPair(resourceName, "async_inference_config.0.output_config.0.notification_config.0.success_topic", "aws_sns_topic.test", "arn"),
resource.TestCheckResourceAttr(resourceName, "async_inference_config.0.output_config.0.notification_config.0.include_inference_response_in.#", "2"),
resource.TestCheckTypeSetElemAttr(resourceName, "async_inference_config.0.output_config.0.notification_config.0.include_inference_response_in.*", "SUCCESS_NOTIFICATION_TOPIC"),
resource.TestCheckTypeSetElemAttr(resourceName, "async_inference_config.0.output_config.0.notification_config.0.include_inference_response_in.*", "ERROR_NOTIFICATION_TOPIC"),
),
},
{
ResourceName: resourceName,
ImportState: true,
ImportStateVerify: true,
},
},
})
}

func TestAccSageMakerEndpointConfiguration_async_kms(t *testing.T) {
ctx := acctest.Context(t)
rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix)
Expand Down Expand Up @@ -525,6 +560,39 @@ func TestAccSageMakerEndpointConfiguration_Async_client(t *testing.T) {
})
}

func TestAccSageMakerEndpointConfiguration_Async_client_failurePath(t *testing.T) {
ctx := acctest.Context(t)
rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix)
resourceName := "aws_sagemaker_endpoint_configuration.test"

resource.ParallelTest(t, resource.TestCase{
PreCheck: func() { acctest.PreCheck(ctx, t) },
ErrorCheck: acctest.ErrorCheck(t, sagemaker.EndpointsID),
ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories,
CheckDestroy: testAccCheckEndpointConfigurationDestroy(ctx),
Steps: []resource.TestStep{
{
Config: testAccEndpointConfigurationConfig_asyncClientFailure(rName),
Check: resource.ComposeTestCheckFunc(
testAccCheckEndpointConfigurationExists(ctx, resourceName),
resource.TestCheckResourceAttr(resourceName, "name", rName),
resource.TestCheckResourceAttr(resourceName, "async_inference_config.#", "1"),
resource.TestCheckResourceAttr(resourceName, "async_inference_config.0.client_config.#", "1"),
resource.TestCheckResourceAttr(resourceName, "async_inference_config.0.client_config.0.max_concurrent_invocations_per_instance", "1"),
resource.TestCheckResourceAttr(resourceName, "async_inference_config.0.output_config.#", "1"),
resource.TestCheckResourceAttrSet(resourceName, "async_inference_config.0.output_config.0.s3_output_path"),
resource.TestCheckResourceAttrSet(resourceName, "async_inference_config.0.output_config.0.s3_failure_path"),
),
},
{
ResourceName: resourceName,
ImportState: true,
ImportStateVerify: true,
},
},
})
}

func TestAccSageMakerEndpointConfiguration_upgradeToEnableSSMAccess(t *testing.T) {
ctx := acctest.Context(t)
rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix)
Expand Down Expand Up @@ -977,6 +1045,49 @@ resource "aws_sagemaker_endpoint_configuration" "test" {
`, rName))
}

func testAccEndpointConfigurationConfig_asyncNotifInferenceIn(rName string) string {
return acctest.ConfigCompose(testAccEndpointConfigurationConfig_base(rName), fmt.Sprintf(`
resource "aws_s3_bucket" "test" {
bucket = %[1]q
force_destroy = true
}
resource "aws_sns_topic" "test" {
name = %[1]q
}
resource "aws_kms_key" "test" {
description = %[1]q
deletion_window_in_days = 7
}
resource "aws_sagemaker_endpoint_configuration" "test" {
name = %[1]q
production_variants {
variant_name = "variant-1"
model_name = aws_sagemaker_model.test.name
initial_instance_count = 2
instance_type = "ml.t2.medium"
initial_variant_weight = 1
}
async_inference_config {
output_config {
s3_output_path = "s3://${aws_s3_bucket.test.bucket}/"
kms_key_id = aws_kms_key.test.arn
notification_config {
error_topic = aws_sns_topic.test.arn
include_inference_response_in = ["SUCCESS_NOTIFICATION_TOPIC", "ERROR_NOTIFICATION_TOPIC"]
success_topic = aws_sns_topic.test.arn
}
}
}
}
`, rName))
}

func testAccEndpointConfigurationConfig_asyncClient(rName string) string {
return acctest.ConfigCompose(testAccEndpointConfigurationConfig_base(rName), fmt.Sprintf(`
resource "aws_s3_bucket" "test" {
Expand Down Expand Up @@ -1014,6 +1125,44 @@ resource "aws_sagemaker_endpoint_configuration" "test" {
`, rName))
}

func testAccEndpointConfigurationConfig_asyncClientFailure(rName string) string {
return acctest.ConfigCompose(testAccEndpointConfigurationConfig_base(rName), fmt.Sprintf(`
resource "aws_s3_bucket" "test" {
bucket = %[1]q
force_destroy = true
}
resource "aws_kms_key" "test" {
description = %[1]q
deletion_window_in_days = 7
}
resource "aws_sagemaker_endpoint_configuration" "test" {
name = %[1]q
production_variants {
variant_name = "variant-1"
model_name = aws_sagemaker_model.test.name
initial_instance_count = 2
instance_type = "ml.t2.medium"
initial_variant_weight = 1
}
async_inference_config {
client_config {
max_concurrent_invocations_per_instance = 1
}
output_config {
s3_output_path = "s3://${aws_s3_bucket.test.bucket}/"
s3_failure_path = "s3://${aws_s3_bucket.test.bucket}/"
kms_key_id = aws_kms_key.test.arn
}
}
}
`, rName))
}

func testAccEndpointConfigurationConfig_serverless(rName string) string {
return acctest.ConfigCompose(testAccEndpointConfigurationConfig_base(rName), fmt.Sprintf(`
resource "aws_sagemaker_endpoint_configuration" "test" {
Expand Down
2 changes: 2 additions & 0 deletions website/docs/r/sagemaker_endpoint_configuration.html.markdown
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,13 @@ The following arguments are supported:
#### output_config

* `s3_output_path` - (Required) The Amazon S3 location to upload inference responses to.
* `s3_failure_path` - (Optional) The Amazon S3 location to upload failure inference responses to.
* `kms_key_id` - (Optional) The Amazon Web Services Key Management Service (Amazon Web Services KMS) key that Amazon SageMaker uses to encrypt the asynchronous inference output in Amazon S3.
* `notification_config` - (Optional) Specifies the configuration for notifications of inference results for asynchronous inference.

##### notification_config

* `include_inference_response_in` - (Optional) The Amazon SNS topics where you want the inference response to be included. Valid values are `SUCCESS_NOTIFICATION_TOPIC` and `ERROR_NOTIFICATION_TOPIC`.
* `error_topic` - (Optional) Amazon SNS topic to post a notification to when inference fails. If no topic is provided, no notification is sent on failure.
* `success_topic` - (Optional) Amazon SNS topic to post a notification to when inference completes successfully. If no topic is provided, no notification is sent on success.

Expand Down

0 comments on commit a2e93c5

Please sign in to comment.