diff --git a/.changelog/22646.txt b/.changelog/22646.txt new file mode 100644 index 000000000000..39e82e33f589 --- /dev/null +++ b/.changelog/22646.txt @@ -0,0 +1,3 @@ +```release-note:enhancement +resource/aws_dms_endpoint: Add ability to use AWS Secrets Manager with the `sqlserver` engine +``` \ No newline at end of file diff --git a/internal/service/dms/endpoint.go b/internal/service/dms/endpoint.go index 3341452fbde2..7ac6cdc41646 100644 --- a/internal/service/dms/endpoint.go +++ b/internal/service/dms/endpoint.go @@ -12,7 +12,6 @@ import ( dms "github.com/aws/aws-sdk-go/service/databasemigrationservice" "github.com/hashicorp/aws-sdk-go-base/v2/awsv1shim/v2/tfawserr" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/customdiff" - "github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation" "github.com/hashicorp/terraform-provider-aws/internal/conns" @@ -32,6 +31,11 @@ func ResourceEndpoint() *schema.Resource { State: schema.ImportStatePassthrough, }, + Timeouts: &schema.ResourceTimeout{ + Create: schema.DefaultTimeout(5 * time.Minute), + Delete: schema.DefaultTimeout(5 * time.Minute), + }, + Schema: map[string]*schema.Schema{ "certificate_arn": { Type: schema.TypeString, @@ -594,49 +598,50 @@ func resourceEndpointCreate(d *schema.ResourceData, meta interface{}) error { defaultTagsConfig := meta.(*conns.AWSClient).DefaultTagsConfig tags := defaultTagsConfig.MergeTags(tftags.New(d.Get("tags").(map[string]interface{}))) - request := &dms.CreateEndpointInput{ - EndpointIdentifier: aws.String(d.Get("endpoint_id").(string)), + endpointID := d.Get("endpoint_id").(string) + input := &dms.CreateEndpointInput{ + EndpointIdentifier: aws.String(endpointID), EndpointType: aws.String(d.Get("endpoint_type").(string)), EngineName: aws.String(d.Get("engine_name").(string)), Tags: Tags(tags.IgnoreAWS()), } if v, ok := d.GetOk("certificate_arn"); ok { - request.CertificateArn = aws.String(v.(string)) + input.CertificateArn = aws.String(v.(string)) } // Send ExtraConnectionAttributes in the API request for all resource types // per https://github.com/hashicorp/terraform-provider-aws/issues/8009 if v, ok := d.GetOk("extra_connection_attributes"); ok { - request.ExtraConnectionAttributes = aws.String(v.(string)) + input.ExtraConnectionAttributes = aws.String(v.(string)) } if v, ok := d.GetOk("kms_key_arn"); ok { - request.KmsKeyId = aws.String(v.(string)) + input.KmsKeyId = aws.String(v.(string)) } if v, ok := d.GetOk("ssl_mode"); ok { - request.SslMode = aws.String(v.(string)) + input.SslMode = aws.String(v.(string)) } switch d.Get("engine_name").(string) { case engineNameDynamoDB: - request.DynamoDbSettings = &dms.DynamoDbSettings{ + input.DynamoDbSettings = &dms.DynamoDbSettings{ ServiceAccessRoleArn: aws.String(d.Get("service_access_role").(string)), } case engineNameElasticsearch, engineNameOpenSearch: - request.ElasticsearchSettings = &dms.ElasticsearchSettings{ + input.ElasticsearchSettings = &dms.ElasticsearchSettings{ ServiceAccessRoleArn: aws.String(d.Get("elasticsearch_settings.0.service_access_role_arn").(string)), EndpointUri: aws.String(d.Get("elasticsearch_settings.0.endpoint_uri").(string)), ErrorRetryDuration: aws.Int64(int64(d.Get("elasticsearch_settings.0.error_retry_duration").(int))), FullLoadErrorPercentage: aws.Int64(int64(d.Get("elasticsearch_settings.0.full_load_error_percentage").(int))), } case engineNameKafka: - request.KafkaSettings = expandKafkaSettings(d.Get("kafka_settings").([]interface{})[0].(map[string]interface{})) + input.KafkaSettings = expandKafkaSettings(d.Get("kafka_settings").([]interface{})[0].(map[string]interface{})) case engineNameKinesis: - request.KinesisSettings = expandKinesisSettings(d.Get("kinesis_settings").([]interface{})[0].(map[string]interface{})) + input.KinesisSettings = expandKinesisSettings(d.Get("kinesis_settings").([]interface{})[0].(map[string]interface{})) case engineNameMongodb: - request.MongoDbSettings = &dms.MongoDbSettings{ + input.MongoDbSettings = &dms.MongoDbSettings{ Username: aws.String(d.Get("username").(string)), Password: aws.String(d.Get("password").(string)), ServerName: aws.String(d.Get("server_name").(string)), @@ -653,20 +658,20 @@ func resourceEndpointCreate(d *schema.ResourceData, meta interface{}) error { } // Set connection info in top-level namespace as well - request.Username = aws.String(d.Get("username").(string)) - request.Password = aws.String(d.Get("password").(string)) - request.ServerName = aws.String(d.Get("server_name").(string)) - request.Port = aws.Int64(int64(d.Get("port").(int))) - request.DatabaseName = aws.String(d.Get("database_name").(string)) + input.Username = aws.String(d.Get("username").(string)) + input.Password = aws.String(d.Get("password").(string)) + input.ServerName = aws.String(d.Get("server_name").(string)) + input.Port = aws.Int64(int64(d.Get("port").(int))) + input.DatabaseName = aws.String(d.Get("database_name").(string)) case engineNameOracle: if _, ok := d.GetOk("secrets_manager_arn"); ok { - request.OracleSettings = &dms.OracleSettings{ + input.OracleSettings = &dms.OracleSettings{ SecretsManagerAccessRoleArn: aws.String(d.Get("secrets_manager_access_role_arn").(string)), SecretsManagerSecretId: aws.String(d.Get("secrets_manager_arn").(string)), DatabaseName: aws.String(d.Get("database_name").(string)), } } else { - request.OracleSettings = &dms.OracleSettings{ + input.OracleSettings = &dms.OracleSettings{ Username: aws.String(d.Get("username").(string)), Password: aws.String(d.Get("password").(string)), ServerName: aws.String(d.Get("server_name").(string)), @@ -675,21 +680,21 @@ func resourceEndpointCreate(d *schema.ResourceData, meta interface{}) error { } // Set connection info in top-level namespace as well - request.Username = aws.String(d.Get("username").(string)) - request.Password = aws.String(d.Get("password").(string)) - request.ServerName = aws.String(d.Get("server_name").(string)) - request.Port = aws.Int64(int64(d.Get("port").(int))) - request.DatabaseName = aws.String(d.Get("database_name").(string)) + input.Username = aws.String(d.Get("username").(string)) + input.Password = aws.String(d.Get("password").(string)) + input.ServerName = aws.String(d.Get("server_name").(string)) + input.Port = aws.Int64(int64(d.Get("port").(int))) + input.DatabaseName = aws.String(d.Get("database_name").(string)) } case engineNamePostgres: if _, ok := d.GetOk("secrets_manager_arn"); ok { - request.PostgreSQLSettings = &dms.PostgreSQLSettings{ + input.PostgreSQLSettings = &dms.PostgreSQLSettings{ SecretsManagerAccessRoleArn: aws.String(d.Get("secrets_manager_access_role_arn").(string)), SecretsManagerSecretId: aws.String(d.Get("secrets_manager_arn").(string)), DatabaseName: aws.String(d.Get("database_name").(string)), } } else { - request.PostgreSQLSettings = &dms.PostgreSQLSettings{ + input.PostgreSQLSettings = &dms.PostgreSQLSettings{ Username: aws.String(d.Get("username").(string)), Password: aws.String(d.Get("password").(string)), ServerName: aws.String(d.Get("server_name").(string)), @@ -698,47 +703,61 @@ func resourceEndpointCreate(d *schema.ResourceData, meta interface{}) error { } // Set connection info in top-level namespace as well - request.Username = aws.String(d.Get("username").(string)) - request.Password = aws.String(d.Get("password").(string)) - request.ServerName = aws.String(d.Get("server_name").(string)) - request.Port = aws.Int64(int64(d.Get("port").(int))) - request.DatabaseName = aws.String(d.Get("database_name").(string)) + input.Username = aws.String(d.Get("username").(string)) + input.Password = aws.String(d.Get("password").(string)) + input.ServerName = aws.String(d.Get("server_name").(string)) + input.Port = aws.Int64(int64(d.Get("port").(int))) + input.DatabaseName = aws.String(d.Get("database_name").(string)) + } + case engineNameSQLServer: + if _, ok := d.GetOk("secrets_manager_arn"); ok { + input.MicrosoftSQLServerSettings = &dms.MicrosoftSQLServerSettings{ + SecretsManagerAccessRoleArn: aws.String(d.Get("secrets_manager_access_role_arn").(string)), + SecretsManagerSecretId: aws.String(d.Get("secrets_manager_arn").(string)), + DatabaseName: aws.String(d.Get("database_name").(string)), + } + } else { + input.MicrosoftSQLServerSettings = &dms.MicrosoftSQLServerSettings{ + Username: aws.String(d.Get("username").(string)), + Password: aws.String(d.Get("password").(string)), + ServerName: aws.String(d.Get("server_name").(string)), + Port: aws.Int64(int64(d.Get("port").(int))), + DatabaseName: aws.String(d.Get("database_name").(string)), + } + + // Set connection info in top-level namespace as well + input.Username = aws.String(d.Get("username").(string)) + input.Password = aws.String(d.Get("password").(string)) + input.ServerName = aws.String(d.Get("server_name").(string)) + input.Port = aws.Int64(int64(d.Get("port").(int))) + input.DatabaseName = aws.String(d.Get("database_name").(string)) } case engineNameS3: - request.S3Settings = expandS3Settings(d.Get("s3_settings").([]interface{})[0].(map[string]interface{})) + input.S3Settings = expandS3Settings(d.Get("s3_settings").([]interface{})[0].(map[string]interface{})) default: - request.Password = aws.String(d.Get("password").(string)) - request.Port = aws.Int64(int64(d.Get("port").(int))) - request.ServerName = aws.String(d.Get("server_name").(string)) - request.Username = aws.String(d.Get("username").(string)) + input.Password = aws.String(d.Get("password").(string)) + input.Port = aws.Int64(int64(d.Get("port").(int))) + input.ServerName = aws.String(d.Get("server_name").(string)) + input.Username = aws.String(d.Get("username").(string)) if v, ok := d.GetOk("database_name"); ok { - request.DatabaseName = aws.String(v.(string)) + input.DatabaseName = aws.String(v.(string)) } } - log.Println("[DEBUG] DMS create endpoint:", request) - - err := resource.Retry(5*time.Minute, func() *resource.RetryError { - _, err := conn.CreateEndpoint(request) - if tfawserr.ErrCodeEquals(err, "AccessDeniedFault") { - return resource.RetryableError(err) - } - if err != nil { - return resource.NonRetryableError(err) - } + log.Printf("[DEBUG] Creating DMS Endpoint: %s", input) + _, err := tfresource.RetryWhenAWSErrCodeEquals(d.Timeout(schema.TimeoutCreate), + func() (interface{}, error) { + return conn.CreateEndpoint(input) + }, + dms.ErrCodeAccessDeniedFault) - // Successful delete - return nil - }) - if tfresource.TimedOut(err) { - _, err = conn.CreateEndpoint(request) - } if err != nil { - return fmt.Errorf("Error creating DMS endpoint: %s", err) + return fmt.Errorf("creating DMS Endpoint (%s): %w", endpointID, err) } - d.SetId(d.Get("endpoint_id").(string)) + d.SetId(endpointID) + return resourceEndpointRead(d, meta) } @@ -756,7 +775,7 @@ func resourceEndpointRead(d *schema.ResourceData, meta interface{}) error { } if err != nil { - return fmt.Errorf("error reading DMS Endpoint (%s): %w", d.Id(), err) + return fmt.Errorf("reading DMS Endpoint (%s): %w", d.Id(), err) } err = resourceEndpointSetState(d, endpoint) @@ -768,18 +787,18 @@ func resourceEndpointRead(d *schema.ResourceData, meta interface{}) error { tags, err := ListTags(conn, d.Get("endpoint_arn").(string)) if err != nil { - return fmt.Errorf("error listing tags for DMS Endpoint (%s): %w", d.Get("endpoint_arn").(string), err) + return fmt.Errorf("listing tags for DMS Endpoint (%s): %w", d.Get("endpoint_arn").(string), err) } tags = tags.IgnoreAWS().IgnoreConfig(ignoreTagsConfig) //lintignore:AWSR002 if err := d.Set("tags", tags.RemoveDefaultConfig(defaultTagsConfig).Map()); err != nil { - return fmt.Errorf("error setting tags: %w", err) + return fmt.Errorf("setting tags: %w", err) } if err := d.Set("tags_all", tags.Map()); err != nil { - return fmt.Errorf("error setting tags_all: %w", err) + return fmt.Errorf("setting tags_all: %w", err) } return nil @@ -788,226 +807,226 @@ func resourceEndpointRead(d *schema.ResourceData, meta interface{}) error { func resourceEndpointUpdate(d *schema.ResourceData, meta interface{}) error { conn := meta.(*conns.AWSClient).DMSConn - request := &dms.ModifyEndpointInput{ - EndpointArn: aws.String(d.Get("endpoint_arn").(string)), - } - hasChanges := false - - if d.HasChange("endpoint_type") { - request.EndpointType = aws.String(d.Get("endpoint_type").(string)) - hasChanges = true - } - - if d.HasChange("certificate_arn") { - request.CertificateArn = aws.String(d.Get("certificate_arn").(string)) - hasChanges = true - } - - if d.HasChange("service_access_role") { - request.DynamoDbSettings = &dms.DynamoDbSettings{ - ServiceAccessRoleArn: aws.String(d.Get("service_access_role").(string)), + if d.HasChangesExcept("tags", "tags_all") { + input := &dms.ModifyEndpointInput{ + EndpointArn: aws.String(d.Get("endpoint_arn").(string)), } - hasChanges = true - } - - if d.HasChange("endpoint_type") { - request.EndpointType = aws.String(d.Get("endpoint_type").(string)) - hasChanges = true - } - if d.HasChange("engine_name") { - request.EngineName = aws.String(d.Get("engine_name").(string)) - hasChanges = true - } - - if d.HasChange("extra_connection_attributes") { - request.ExtraConnectionAttributes = aws.String(d.Get("extra_connection_attributes").(string)) - hasChanges = true - } + if d.HasChange("certificate_arn") { + input.CertificateArn = aws.String(d.Get("certificate_arn").(string)) + } - if d.HasChange("ssl_mode") { - request.SslMode = aws.String(d.Get("ssl_mode").(string)) - hasChanges = true - } + if d.HasChange("endpoint_type") { + input.EndpointType = aws.String(d.Get("endpoint_type").(string)) + } - if d.HasChange("tags_all") { - arn := d.Get("endpoint_arn").(string) - o, n := d.GetChange("tags_all") + if d.HasChange("engine_name") { + input.EngineName = aws.String(d.Get("engine_name").(string)) + } - if err := UpdateTags(conn, arn, o, n); err != nil { - return fmt.Errorf("error updating DMS Endpoint (%s) tags: %s", arn, err) + if d.HasChange("extra_connection_attributes") { + input.ExtraConnectionAttributes = aws.String(d.Get("extra_connection_attributes").(string)) } - } - switch engineName := d.Get("engine_name").(string); engineName { - case engineNameDynamoDB: if d.HasChange("service_access_role") { - request.DynamoDbSettings = &dms.DynamoDbSettings{ + input.DynamoDbSettings = &dms.DynamoDbSettings{ ServiceAccessRoleArn: aws.String(d.Get("service_access_role").(string)), } - hasChanges = true - } - case engineNameElasticsearch, engineNameOpenSearch: - if d.HasChanges( - "elasticsearch_settings.0.endpoint_uri", - "elasticsearch_settings.0.error_retry_duration", - "elasticsearch_settings.0.full_load_error_percentage", - "elasticsearch_settings.0.service_access_role_arn") { - request.ElasticsearchSettings = &dms.ElasticsearchSettings{ - ServiceAccessRoleArn: aws.String(d.Get("elasticsearch_settings.0.service_access_role_arn").(string)), - EndpointUri: aws.String(d.Get("elasticsearch_settings.0.endpoint_uri").(string)), - ErrorRetryDuration: aws.Int64(int64(d.Get("elasticsearch_settings.0.error_retry_duration").(int))), - FullLoadErrorPercentage: aws.Int64(int64(d.Get("elasticsearch_settings.0.full_load_error_percentage").(int))), - } - request.EngineName = aws.String(engineName) - hasChanges = true - } - case engineNameKafka: - if d.HasChange("kafka_settings") { - request.KafkaSettings = expandKafkaSettings(d.Get("kafka_settings").([]interface{})[0].(map[string]interface{})) - request.EngineName = aws.String(engineName) - hasChanges = true } - case engineNameKinesis: - if d.HasChanges("kinesis_settings") { - request.KinesisSettings = expandKinesisSettings(d.Get("kinesis_settings").([]interface{})[0].(map[string]interface{})) - request.EngineName = aws.String(engineName) - hasChanges = true - } - case engineNameMongodb: - if d.HasChanges( - "username", "password", "server_name", "port", "database_name", "mongodb_settings.0.auth_type", - "mongodb_settings.0.auth_mechanism", "mongodb_settings.0.nesting_level", "mongodb_settings.0.extract_doc_id", - "mongodb_settings.0.docs_to_investigate", "mongodb_settings.0.auth_source") { - request.MongoDbSettings = &dms.MongoDbSettings{ - Username: aws.String(d.Get("username").(string)), - Password: aws.String(d.Get("password").(string)), - ServerName: aws.String(d.Get("server_name").(string)), - Port: aws.Int64(int64(d.Get("port").(int))), - DatabaseName: aws.String(d.Get("database_name").(string)), - KmsKeyId: aws.String(d.Get("kms_key_arn").(string)), - - AuthType: aws.String(d.Get("mongodb_settings.0.auth_type").(string)), - AuthMechanism: aws.String(d.Get("mongodb_settings.0.auth_mechanism").(string)), - NestingLevel: aws.String(d.Get("mongodb_settings.0.nesting_level").(string)), - ExtractDocId: aws.String(d.Get("mongodb_settings.0.extract_doc_id").(string)), - DocsToInvestigate: aws.String(d.Get("mongodb_settings.0.docs_to_investigate").(string)), - AuthSource: aws.String(d.Get("mongodb_settings.0.auth_source").(string)), - } - request.EngineName = aws.String(engineName) - // Update connection info in top-level namespace as well - request.Username = aws.String(d.Get("username").(string)) - request.Password = aws.String(d.Get("password").(string)) - request.ServerName = aws.String(d.Get("server_name").(string)) - request.Port = aws.Int64(int64(d.Get("port").(int))) - request.DatabaseName = aws.String(d.Get("database_name").(string)) - - hasChanges = true + if d.HasChange("ssl_mode") { + input.SslMode = aws.String(d.Get("ssl_mode").(string)) } - case engineNameOracle: - if d.HasChanges( - "username", "password", "server_name", "port", "database_name", "secrets_manager_access_role_arn", - "secrets_manager_arn") { - if _, ok := d.GetOk("secrets_manager_arn"); ok { - request.OracleSettings = &dms.OracleSettings{ - DatabaseName: aws.String(d.Get("database_name").(string)), - SecretsManagerAccessRoleArn: aws.String(d.Get("secrets_manager_access_role_arn").(string)), - SecretsManagerSecretId: aws.String(d.Get("secrets_manager_arn").(string)), + + switch engineName := d.Get("engine_name").(string); engineName { + case engineNameDynamoDB: + if d.HasChange("service_access_role") { + input.DynamoDbSettings = &dms.DynamoDbSettings{ + ServiceAccessRoleArn: aws.String(d.Get("service_access_role").(string)), + } + } + case engineNameElasticsearch, engineNameOpenSearch: + if d.HasChanges( + "elasticsearch_settings.0.endpoint_uri", + "elasticsearch_settings.0.error_retry_duration", + "elasticsearch_settings.0.full_load_error_percentage", + "elasticsearch_settings.0.service_access_role_arn") { + input.ElasticsearchSettings = &dms.ElasticsearchSettings{ + ServiceAccessRoleArn: aws.String(d.Get("elasticsearch_settings.0.service_access_role_arn").(string)), + EndpointUri: aws.String(d.Get("elasticsearch_settings.0.endpoint_uri").(string)), + ErrorRetryDuration: aws.Int64(int64(d.Get("elasticsearch_settings.0.error_retry_duration").(int))), + FullLoadErrorPercentage: aws.Int64(int64(d.Get("elasticsearch_settings.0.full_load_error_percentage").(int))), } - } else { - request.OracleSettings = &dms.OracleSettings{ + input.EngineName = aws.String(engineName) + } + case engineNameKafka: + if d.HasChange("kafka_settings") { + input.KafkaSettings = expandKafkaSettings(d.Get("kafka_settings").([]interface{})[0].(map[string]interface{})) + input.EngineName = aws.String(engineName) + } + case engineNameKinesis: + if d.HasChanges("kinesis_settings") { + input.KinesisSettings = expandKinesisSettings(d.Get("kinesis_settings").([]interface{})[0].(map[string]interface{})) + input.EngineName = aws.String(engineName) + } + case engineNameMongodb: + if d.HasChanges( + "username", "password", "server_name", "port", "database_name", "mongodb_settings.0.auth_type", + "mongodb_settings.0.auth_mechanism", "mongodb_settings.0.nesting_level", "mongodb_settings.0.extract_doc_id", + "mongodb_settings.0.docs_to_investigate", "mongodb_settings.0.auth_source") { + input.MongoDbSettings = &dms.MongoDbSettings{ Username: aws.String(d.Get("username").(string)), Password: aws.String(d.Get("password").(string)), ServerName: aws.String(d.Get("server_name").(string)), Port: aws.Int64(int64(d.Get("port").(int))), DatabaseName: aws.String(d.Get("database_name").(string)), + KmsKeyId: aws.String(d.Get("kms_key_arn").(string)), + + AuthType: aws.String(d.Get("mongodb_settings.0.auth_type").(string)), + AuthMechanism: aws.String(d.Get("mongodb_settings.0.auth_mechanism").(string)), + NestingLevel: aws.String(d.Get("mongodb_settings.0.nesting_level").(string)), + ExtractDocId: aws.String(d.Get("mongodb_settings.0.extract_doc_id").(string)), + DocsToInvestigate: aws.String(d.Get("mongodb_settings.0.docs_to_investigate").(string)), + AuthSource: aws.String(d.Get("mongodb_settings.0.auth_source").(string)), } - request.EngineName = aws.String(d.Get("engine_name").(string)) // Must be included (should be 'oracle') + input.EngineName = aws.String(engineName) // Update connection info in top-level namespace as well - request.Username = aws.String(d.Get("username").(string)) - request.Password = aws.String(d.Get("password").(string)) - request.ServerName = aws.String(d.Get("server_name").(string)) - request.Port = aws.Int64(int64(d.Get("port").(int))) - request.DatabaseName = aws.String(d.Get("database_name").(string)) + input.Username = aws.String(d.Get("username").(string)) + input.Password = aws.String(d.Get("password").(string)) + input.ServerName = aws.String(d.Get("server_name").(string)) + input.Port = aws.Int64(int64(d.Get("port").(int))) + input.DatabaseName = aws.String(d.Get("database_name").(string)) } - hasChanges = true - } - case engineNamePostgres: - if d.HasChanges( - "username", "password", "server_name", "port", "database_name", "secrets_manager_access_role_arn", - "secrets_manager_arn") { - if _, ok := d.GetOk("secrets_manager_arn"); ok { - request.PostgreSQLSettings = &dms.PostgreSQLSettings{ - DatabaseName: aws.String(d.Get("database_name").(string)), - SecretsManagerAccessRoleArn: aws.String(d.Get("secrets_manager_access_role_arn").(string)), - SecretsManagerSecretId: aws.String(d.Get("secrets_manager_arn").(string)), + case engineNameOracle: + if d.HasChanges( + "username", "password", "server_name", "port", "database_name", "secrets_manager_access_role_arn", + "secrets_manager_arn") { + if _, ok := d.GetOk("secrets_manager_arn"); ok { + input.OracleSettings = &dms.OracleSettings{ + DatabaseName: aws.String(d.Get("database_name").(string)), + SecretsManagerAccessRoleArn: aws.String(d.Get("secrets_manager_access_role_arn").(string)), + SecretsManagerSecretId: aws.String(d.Get("secrets_manager_arn").(string)), + } + } else { + input.OracleSettings = &dms.OracleSettings{ + Username: aws.String(d.Get("username").(string)), + Password: aws.String(d.Get("password").(string)), + ServerName: aws.String(d.Get("server_name").(string)), + Port: aws.Int64(int64(d.Get("port").(int))), + DatabaseName: aws.String(d.Get("database_name").(string)), + } + input.EngineName = aws.String(engineName) // Must be included (should be 'oracle') + + // Update connection info in top-level namespace as well + input.Username = aws.String(d.Get("username").(string)) + input.Password = aws.String(d.Get("password").(string)) + input.ServerName = aws.String(d.Get("server_name").(string)) + input.Port = aws.Int64(int64(d.Get("port").(int))) + input.DatabaseName = aws.String(d.Get("database_name").(string)) } - } else { - request.PostgreSQLSettings = &dms.PostgreSQLSettings{ - Username: aws.String(d.Get("username").(string)), - Password: aws.String(d.Get("password").(string)), - ServerName: aws.String(d.Get("server_name").(string)), - Port: aws.Int64(int64(d.Get("port").(int))), - DatabaseName: aws.String(d.Get("database_name").(string)), + } + case engineNamePostgres: + if d.HasChanges( + "username", "password", "server_name", "port", "database_name", "secrets_manager_access_role_arn", + "secrets_manager_arn") { + if _, ok := d.GetOk("secrets_manager_arn"); ok { + input.PostgreSQLSettings = &dms.PostgreSQLSettings{ + DatabaseName: aws.String(d.Get("database_name").(string)), + SecretsManagerAccessRoleArn: aws.String(d.Get("secrets_manager_access_role_arn").(string)), + SecretsManagerSecretId: aws.String(d.Get("secrets_manager_arn").(string)), + } + } else { + input.PostgreSQLSettings = &dms.PostgreSQLSettings{ + Username: aws.String(d.Get("username").(string)), + Password: aws.String(d.Get("password").(string)), + ServerName: aws.String(d.Get("server_name").(string)), + Port: aws.Int64(int64(d.Get("port").(int))), + DatabaseName: aws.String(d.Get("database_name").(string)), + } + input.EngineName = aws.String(engineName) // Must be included (should be 'postgres') + + // Update connection info in top-level namespace as well + input.Username = aws.String(d.Get("username").(string)) + input.Password = aws.String(d.Get("password").(string)) + input.ServerName = aws.String(d.Get("server_name").(string)) + input.Port = aws.Int64(int64(d.Get("port").(int))) + input.DatabaseName = aws.String(d.Get("database_name").(string)) } - request.EngineName = aws.String(d.Get("engine_name").(string)) // Must be included (should be 'postgres') + } + case engineNameSQLServer: + if d.HasChanges( + "username", "password", "server_name", "port", "database_name", "secrets_manager_access_role_arn", + "secrets_manager_arn") { + if _, ok := d.GetOk("secrets_manager_arn"); ok { + input.MicrosoftSQLServerSettings = &dms.MicrosoftSQLServerSettings{ + DatabaseName: aws.String(d.Get("database_name").(string)), + SecretsManagerAccessRoleArn: aws.String(d.Get("secrets_manager_access_role_arn").(string)), + SecretsManagerSecretId: aws.String(d.Get("secrets_manager_arn").(string)), + } + } else { + input.MicrosoftSQLServerSettings = &dms.MicrosoftSQLServerSettings{ + Username: aws.String(d.Get("username").(string)), + Password: aws.String(d.Get("password").(string)), + ServerName: aws.String(d.Get("server_name").(string)), + Port: aws.Int64(int64(d.Get("port").(int))), + DatabaseName: aws.String(d.Get("database_name").(string)), + } + input.EngineName = aws.String(engineName) // Must be included (should be 'postgres') + + // Update connection info in top-level namespace as well + input.Username = aws.String(d.Get("username").(string)) + input.Password = aws.String(d.Get("password").(string)) + input.ServerName = aws.String(d.Get("server_name").(string)) + input.Port = aws.Int64(int64(d.Get("port").(int))) + input.DatabaseName = aws.String(d.Get("database_name").(string)) + } + } + case engineNameS3: + if d.HasChanges("s3_settings") { + input.S3Settings = expandS3Settings(d.Get("s3_settings").([]interface{})[0].(map[string]interface{})) + input.EngineName = aws.String(engineName) + } + default: + if d.HasChange("database_name") { + input.DatabaseName = aws.String(d.Get("database_name").(string)) + } - // Update connection info in top-level namespace as well - request.Username = aws.String(d.Get("username").(string)) - request.Password = aws.String(d.Get("password").(string)) - request.ServerName = aws.String(d.Get("server_name").(string)) - request.Port = aws.Int64(int64(d.Get("port").(int))) - request.DatabaseName = aws.String(d.Get("database_name").(string)) + if d.HasChange("password") { + input.Password = aws.String(d.Get("password").(string)) } - hasChanges = true - } - case engineNameS3: - if d.HasChanges("s3_settings") { - request.S3Settings = expandS3Settings(d.Get("s3_settings").([]interface{})[0].(map[string]interface{})) - request.EngineName = aws.String(engineName) - hasChanges = true - } - default: - if d.HasChange("database_name") { - request.DatabaseName = aws.String(d.Get("database_name").(string)) - hasChanges = true - } - if d.HasChange("password") { - request.Password = aws.String(d.Get("password").(string)) - hasChanges = true - } + if d.HasChange("port") { + input.Port = aws.Int64(int64(d.Get("port").(int))) + } - if d.HasChange("port") { - request.Port = aws.Int64(int64(d.Get("port").(int))) - hasChanges = true - } + if d.HasChange("server_name") { + input.ServerName = aws.String(d.Get("server_name").(string)) + } - if d.HasChange("server_name") { - request.ServerName = aws.String(d.Get("server_name").(string)) - hasChanges = true + if d.HasChange("username") { + input.Username = aws.String(d.Get("username").(string)) + } } - if d.HasChange("username") { - request.Username = aws.String(d.Get("username").(string)) - hasChanges = true + log.Printf("[DEBUG] Modifying DMS Endpoint: %s", input) + _, err := conn.ModifyEndpoint(input) + + if err != nil { + return fmt.Errorf("updating DMS Endpoint (%s): %w", d.Id(), err) } } - if hasChanges { - log.Println("[DEBUG] DMS update endpoint:", request) + if d.HasChange("tags_all") { + arn := d.Get("endpoint_arn").(string) + o, n := d.GetChange("tags_all") - _, err := conn.ModifyEndpoint(request) - if err != nil { - return err + if err := UpdateTags(conn, arn, o, n); err != nil { + return fmt.Errorf("updating DMS Endpoint (%s) tags: %w", arn, err) } - - return resourceEndpointRead(d, meta) } - return nil + return resourceEndpointRead(d, meta) } func resourceEndpointDelete(d *schema.ResourceData, meta interface{}) error { @@ -1023,13 +1042,11 @@ func resourceEndpointDelete(d *schema.ResourceData, meta interface{}) error { } if err != nil { - return fmt.Errorf("error deleting DMS Endpoint (%s): %w", d.Id(), err) + return fmt.Errorf("deleting DMS Endpoint (%s): %w", d.Id(), err) } - _, err = waitEndpointDeleted(conn, d.Id()) - - if err != nil { - return fmt.Errorf("error waiting for DMS Endpoint (%s) delete: %w", d.Id(), err) + if _, err = waitEndpointDeleted(conn, d.Id(), d.Timeout(schema.TimeoutDelete)); err != nil { + return fmt.Errorf("waiting for DMS Endpoint (%s) delete: %w", d.Id(), err) } return err @@ -1143,6 +1160,20 @@ func resourceEndpointSetState(d *schema.ResourceData, endpoint *dms.Endpoint) er d.Set("port", endpoint.Port) d.Set("database_name", endpoint.DatabaseName) } + case engineNameSQLServer: + if endpoint.MicrosoftSQLServerSettings != nil { + d.Set("username", endpoint.MicrosoftSQLServerSettings.Username) + d.Set("server_name", endpoint.MicrosoftSQLServerSettings.ServerName) + d.Set("port", endpoint.MicrosoftSQLServerSettings.Port) + d.Set("database_name", endpoint.MicrosoftSQLServerSettings.DatabaseName) + d.Set("secrets_manager_access_role_arn", endpoint.MicrosoftSQLServerSettings.SecretsManagerAccessRoleArn) + d.Set("secrets_manager_arn", endpoint.MicrosoftSQLServerSettings.SecretsManagerSecretId) + } else { + d.Set("username", endpoint.Username) + d.Set("server_name", endpoint.ServerName) + d.Set("port", endpoint.Port) + d.Set("database_name", endpoint.DatabaseName) + } case engineNameS3: if err := d.Set("s3_settings", flattenS3Settings(endpoint.S3Settings)); err != nil { return fmt.Errorf("Error setting s3_settings for DMS: %s", err) diff --git a/internal/service/dms/endpoint_test.go b/internal/service/dms/endpoint_test.go index 2a4188140cbc..57d0d41091cc 100644 --- a/internal/service/dms/endpoint_test.go +++ b/internal/service/dms/endpoint_test.go @@ -716,6 +716,121 @@ func TestAccDMSEndpoint_PostgreSQL_kmsKey(t *testing.T) { }) } +func TestAccDMSEndpoint_SQLServer_basic(t *testing.T) { + resourceName := "aws_dms_endpoint.test" + rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { acctest.PreCheck(t) }, + ErrorCheck: acctest.ErrorCheck(t, dms.EndpointsID), + ProviderFactories: acctest.ProviderFactories, + CheckDestroy: testAccCheckEndpointDestroy, + Steps: []resource.TestStep{ + { + Config: testAccEndpointConfig_SQLServer(rName), + Check: resource.ComposeTestCheckFunc( + testAccCheckEndpointExists(resourceName), + resource.TestCheckResourceAttrSet(resourceName, "endpoint_arn"), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + ImportStateVerifyIgnore: []string{"password"}, + }, + }, + }) +} + +func TestAccDMSEndpoint_SQLServer_secretID(t *testing.T) { + resourceName := "aws_dms_endpoint.test" + rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { acctest.PreCheck(t) }, + ErrorCheck: acctest.ErrorCheck(t, dms.EndpointsID), + ProviderFactories: acctest.ProviderFactories, + CheckDestroy: testAccCheckEndpointDestroy, + Steps: []resource.TestStep{ + { + Config: testAccEndpointConfig_SQLServerSecretID(rName), + Check: resource.ComposeTestCheckFunc( + testAccCheckEndpointExists(resourceName), + resource.TestCheckResourceAttrSet(resourceName, "endpoint_arn"), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + }, + }, + }) +} + +func TestAccDMSEndpoint_SQLServer_update(t *testing.T) { + resourceName := "aws_dms_endpoint.test" + rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { acctest.PreCheck(t) }, + ErrorCheck: acctest.ErrorCheck(t, dms.EndpointsID), + ProviderFactories: acctest.ProviderFactories, + CheckDestroy: testAccCheckEndpointDestroy, + Steps: []resource.TestStep{ + { + Config: testAccEndpointConfig_SQLServer(rName), + Check: resource.ComposeTestCheckFunc( + testAccCheckEndpointExists(resourceName), + resource.TestCheckResourceAttrSet(resourceName, "endpoint_arn"), + ), + }, + { + Config: testAccEndpointConfig_SQLServerUpdate(rName), + Check: resource.ComposeTestCheckFunc( + testAccCheckEndpointExists(resourceName), + resource.TestCheckResourceAttr(resourceName, "server_name", "tftest-new-server_name"), + resource.TestCheckResourceAttr(resourceName, "port", "27018"), + resource.TestCheckResourceAttr(resourceName, "username", "tftest-new-username"), + resource.TestCheckResourceAttr(resourceName, "password", "tftest-new-password"), + resource.TestCheckResourceAttr(resourceName, "database_name", "tftest-new-database_name"), + resource.TestCheckResourceAttr(resourceName, "ssl_mode", "require"), + resource.TestMatchResourceAttr(resourceName, "extra_connection_attributes", regexp.MustCompile(`key=value;`)), + ), + }, + { + ResourceName: resourceName, + ImportState: true, + ImportStateVerify: true, + ImportStateVerifyIgnore: []string{"password"}, + }, + }, + }) +} + +// https://github.com/hashicorp/terraform-provider-aws/issues/23143 +func TestAccDMSEndpoint_SQLServer_kmsKey(t *testing.T) { + resourceName := "aws_dms_endpoint.test" + rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { acctest.PreCheck(t) }, + ErrorCheck: acctest.ErrorCheck(t, dms.EndpointsID), + ProviderFactories: acctest.ProviderFactories, + CheckDestroy: testAccCheckEndpointDestroy, + Steps: []resource.TestStep{ + { + Config: testAccEndpointConfig_sqlserverKey(rName), + Check: resource.ComposeTestCheckFunc( + testAccCheckEndpointExists(resourceName), + resource.TestCheckResourceAttrSet(resourceName, "endpoint_arn"), + ), + }, + }, + }) +} + func TestAccDMSEndpoint_docDB(t *testing.T) { resourceName := "aws_dms_endpoint.test" rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix) @@ -1835,6 +1950,122 @@ resource "aws_dms_endpoint" "test" { ssl_mode = "none" extra_connection_attributes = "" + tags = { + Name = "tf-test-dms-endpoint-%[1]s" + Update = "to-update" + Remove = "to-remove" + } +} +`, rName) +} + +func testAccEndpointConfig_SQLServer(rName string) string { + return fmt.Sprintf(` +resource "aws_dms_endpoint" "test" { + endpoint_id = %[1]q + endpoint_type = "source" + engine_name = "sqlserver" + server_name = "tftest" + port = 27017 + username = "tftest" + password = "tftest" + database_name = "tftest" + ssl_mode = "none" + extra_connection_attributes = "" + + tags = { + Name = %[1]q + Update = "to-update" + Remove = "to-remove" + } +} +`, rName) +} + +func testAccEndpointConfig_SQLServerUpdate(rName string) string { + return fmt.Sprintf(` +resource "aws_dms_endpoint" "test" { + endpoint_id = %[1]q + endpoint_type = "source" + engine_name = "sqlserver" + server_name = "tftest-new-server_name" + port = 27018 + username = "tftest-new-username" + password = "tftest-new-password" + database_name = "tftest-new-database_name" + ssl_mode = "require" + extra_connection_attributes = "key=value;" + + tags = { + Name = %[1]q + Update = "updated" + Add = "added" + } +} +`, rName) +} + +func testAccEndpointConfig_SQLServerSecretID(rName string) string { + return fmt.Sprintf(` +data "aws_kms_alias" "dms" { + name = "alias/aws/dms" +} + +data "aws_region" "current" {} +data "aws_partition" "current" {} + +resource "aws_secretsmanager_secret" "test" { + name = %[1]q + recovery_window_in_days = 0 +} + +resource "aws_iam_role" "test" { + name = %[1]q + assume_role_policy = <