diff --git a/moto/rds/models.py b/moto/rds/models.py index 55b3f66628b8..89abaf740088 100644 --- a/moto/rds/models.py +++ b/moto/rds/models.py @@ -12,8 +12,9 @@ from moto.core.base_backend import BackendDict, BaseBackend from moto.core.common_models import BaseModel, CloudFormationModel -from moto.core.utils import iso_8601_datetime_with_milliseconds +from moto.core.utils import iso_8601_datetime_with_milliseconds, utcnow from moto.ec2.models import ec2_backends +from moto.kms.models import KmsBackend, kms_backends from moto.moto_api._internal import mock_random as random from moto.utilities.utils import ARN_PARTITION_REGEX, load_resource @@ -295,15 +296,35 @@ class DBCluster(RDSBaseModel): resource_type = "cluster" - def __init__(self, backend: RDSBackend, db_cluster_identifier: str, **kwargs: Any): + def __init__( + self, + backend: RDSBackend, + db_cluster_identifier: str, + engine: str, + engine_version: Optional[str] = None, + master_username: Optional[str] = None, + master_user_password: Optional[str] = None, + backup_retention_period: Optional[int] = 1, + character_set_name: Optional[str] = None, + copy_tags_to_snapshot: Optional[bool] = False, + database_name: Optional[str] = None, + db_cluster_parameter_group_name: Optional[str] = None, + db_subnet_group_name: Optional[str] = None, + port: Optional[int] = None, + preferred_backup_window: Optional[str] = "01:37-02:07", + preferred_maintenance_window: Optional[str] = "wed:02:40-wed:03:10", + storage_encrypted: Optional[bool] = False, + tags: Optional[List[Dict[str, str]]] = None, + vpc_security_group_ids: Optional[List[str]] = None, + deletion_protection: Optional[bool] = False, + **kwargs: Any, + ): super().__init__(backend) - self.database_name = kwargs.get("database_name") + self.database_name = database_name self.db_cluster_identifier = db_cluster_identifier self.db_cluster_instance_class = kwargs.get("db_cluster_instance_class") - self.deletion_protection = kwargs.get("deletion_protection") - if self.deletion_protection is None: - self.deletion_protection = False - self.engine = kwargs.get("engine") + self.deletion_protection = deletion_protection + self.engine = engine if self.engine not in ClusterEngine.list_cluster_engines(): raise InvalidParameterValue( ( @@ -315,18 +336,16 @@ def __init__(self, backend: RDSBackend, db_cluster_identifier: str, **kwargs: An valid_engines=ClusterEngine.list_cluster_engines(), ) ) - self.engine_version = kwargs.get( - "engine_version" - ) or DBCluster.default_engine_version(self.engine) + self.engine_version = engine_version or DBCluster.default_engine_version( + self.engine + ) self.engine_mode = kwargs.get("engine_mode") or "provisioned" self.iops = kwargs.get("iops") self.kms_key_id = kwargs.get("kms_key_id") self.network_type = kwargs.get("network_type") or "IPV4" self._status = "creating" self.cluster_create_time = iso_8601_datetime_with_milliseconds() - self.copy_tags_to_snapshot = kwargs.get("copy_tags_to_snapshot") - if self.copy_tags_to_snapshot is None: - self.copy_tags_to_snapshot = False + self.copy_tags_to_snapshot = copy_tags_to_snapshot self.storage_type = kwargs.get("storage_type") if self.storage_type is None: self.storage_type = DBCluster.default_storage_type(iops=self.iops) @@ -335,7 +354,8 @@ def __init__(self, backend: RDSBackend, db_cluster_identifier: str, **kwargs: An self.allocated_storage = DBCluster.default_allocated_storage( engine=self.engine, storage_type=self.storage_type ) - self.master_username = kwargs.get("master_username") + self.master_username = master_username + self.character_set_name = character_set_name self.global_cluster_identifier = kwargs.get("global_cluster_identifier") if ( not self.master_username @@ -348,7 +368,7 @@ def __init__(self, backend: RDSBackend, db_cluster_identifier: str, **kwargs: An "The parameter MasterUsername must be provided and must not be blank." ) else: - self.master_user_password = kwargs.get("master_user_password") # type: ignore + self.master_user_password = master_user_password or "" self.master_user_secret_kms_key_id = kwargs.get("master_user_secret_kms_key_id") self.manage_master_user_password = kwargs.get( @@ -368,33 +388,25 @@ def __init__(self, backend: RDSBackend, db_cluster_identifier: str, **kwargs: An default_pg = ( "default.neptune1.3" if self.engine == "neptune" else "default.aurora8.0" ) - self.parameter_group = ( - kwargs.get("db_cluster_parameter_group_name") or default_pg - ) - self.subnet_group = kwargs.get("db_subnet_group_name") or "default" + self.parameter_group = db_cluster_parameter_group_name or default_pg + self.db_subnet_group = db_subnet_group_name or "default" self.url_identifier = "".join( random.choice(string.ascii_lowercase + string.digits) for _ in range(12) ) self.endpoint = f"{self.db_cluster_identifier}.cluster-{self.url_identifier}.{self.region}.rds.amazonaws.com" self.reader_endpoint = f"{self.db_cluster_identifier}.cluster-ro-{self.url_identifier}.{self.region}.rds.amazonaws.com" - self.port: int = kwargs.get("port") # type: ignore - if self.port is None: - self.port = DBCluster.default_port(self.engine) - self.preferred_backup_window = ( - kwargs.get("preferred_backup_window") or "01:37-02:07" - ) - self.preferred_maintenance_window = "wed:02:40-wed:03:10" + self.port = port or DBCluster.default_port(self.engine) + self.preferred_backup_window = preferred_backup_window or "01:37-02:07" + self.preferred_maintenance_window = preferred_maintenance_window # This should default to the default security group - self._vpc_security_group_ids: List[str] = kwargs.get( - "vpc_security_group_ids", [] - ) + self._vpc_security_group_ids = vpc_security_group_ids or [] self.hosted_zone_id = "".join( random.choice(string.ascii_uppercase + string.digits) for _ in range(14) ) self.resource_id = "cluster-" + "".join( random.choice(string.ascii_uppercase + string.digits) for _ in range(26) ) - self.tags = kwargs.get("tags", []) + self.tags = tags or [] self.enabled_cloudwatch_logs_exports = ( kwargs.get("enable_cloudwatch_logs_exports") or [] ) @@ -418,9 +430,7 @@ def __init__(self, backend: RDSBackend, db_cluster_identifier: str, **kwargs: An self.replication_source_identifier = kwargs.get("replication_source_identifier") self.read_replica_identifiers: List[str] = list() self.is_writer: bool = False - self.storage_encrypted = kwargs.get("storage_encrypted", False) - if self.storage_encrypted is None: - self.storage_encrypted = False + self.storage_encrypted = storage_encrypted if self.storage_encrypted: self.kms_key_id = kwargs.get("kms_key_id", "default_kms_key_id") else: @@ -429,7 +439,7 @@ def __init__(self, backend: RDSBackend, db_cluster_identifier: str, **kwargs: An self._global_write_forwarding_requested = kwargs.get( "enable_global_write_forwarding" ) - self.backup_retention_period = kwargs.get("backup_retention_period") or 1 + self.backup_retention_period = backup_retention_period if backtrack := kwargs.get("backtrack_window"): if self.engine == "aurora-mysql": @@ -492,10 +502,6 @@ def master_user_password(self, val: str) -> None: ) self._master_user_password = val - @property - def db_subnet_group(self) -> str: - return self.subnet_group - @property def enable_http_endpoint(self) -> bool: return self._enable_http_endpoint @@ -548,7 +554,7 @@ def master_user_secret(self) -> Optional[Dict[str, Any]]: # type: ignore[misc] @property def db_cluster_parameter_group(self) -> str: - return self.cluster.parameter_group + return self.parameter_group @property def status(self) -> str: @@ -657,9 +663,9 @@ def default_storage_type(iops: Any) -> str: # type: ignore[misc] @staticmethod def default_allocated_storage(engine: str, storage_type: str) -> int: return { - "aurora": {"gp2": 0, "io1": 0, "standard": 0}, - "aurora-mysql": {"gp2": 20, "io1": 100, "standard": 10}, - "aurora-postgresql": {"gp2": 20, "io1": 100, "standard": 10}, + "aurora": {"gp2": 1, "io1": 1, "standard": 1}, + "aurora-mysql": {"gp2": 1, "io1": 1, "standard": 1}, + "aurora-postgresql": {"gp2": 1, "io1": 1, "standard": 1}, "mysql": {"gp2": 20, "io1": 100, "standard": 5}, "neptune": {"gp2": 0, "io1": 0, "standard": 0}, "postgres": {"gp2": 20, "io1": 100, "standard": 5}, @@ -755,8 +761,10 @@ def __init__( db_instance_identifier: str, db_instance_class: str, engine: str, + engine_version: Optional[str] = None, port: Optional[int] = None, allocated_storage: Optional[int] = None, + max_allocated_storage: Optional[int] = None, backup_retention_period: int = 1, character_set_name: Optional[str] = None, auto_minor_version_upgrade: bool = True, @@ -793,20 +801,7 @@ def __init__( raise InvalidParameterValue( f"Value {self.engine} for parameter Engine is invalid. Reason: engine {self.engine} not supported" ) - self.engine_version = kwargs.get("engine_version", None) - if not self.engine_version and self.engine in self.default_engine_versions: - self.engine_version = self.default_engine_versions[self.engine] self.iops = iops - self.storage_encrypted = storage_encrypted - if self.storage_encrypted: - self.kms_key_id = kwargs.get("kms_key_id", "default_kms_key_id") - else: - self.kms_key_id = kwargs.get("kms_key_id") - self.storage_type = storage_type - if self.storage_type is None: - self.storage_type = DBInstance.default_storage_type(iops=self.iops) - self.master_username = master_username - self.master_user_password = master_user_password self.master_user_secret_kms_key_id = kwargs.get("master_user_secret_kms_key_id") self.master_user_secret_status = kwargs.get( "master_user_secret_status", "active" @@ -815,12 +810,6 @@ def __init__( "manage_master_user_password", False ) self.auto_minor_version_upgrade = auto_minor_version_upgrade - self.allocated_storage = allocated_storage - if self.allocated_storage is None: - self.allocated_storage = DBInstance.default_allocated_storage( - engine=self.engine, storage_type=self.storage_type - ) - self.db_cluster_identifier: Optional[str] = db_cluster_identifier self.db_instance_identifier = db_instance_identifier self.source_db_identifier = source_db_instance_identifier self.db_instance_class = db_instance_class @@ -831,7 +820,6 @@ def __init__( self.instance_create_time = iso_8601_datetime_with_milliseconds() self.publicly_accessible = publicly_accessible self.copy_tags_to_snapshot = copy_tags_to_snapshot - self.backup_retention_period = backup_retention_period self.availability_zone = kwargs.get("availability_zone") if not self.availability_zone: self.availability_zone = f"{self.region}a" @@ -850,14 +838,6 @@ def __init__( default_sg = ec2_backend.get_default_security_group(default_vpc.id) self.vpc_security_group_ids.append(default_sg.id) # type: ignore self.preferred_maintenance_window = preferred_maintenance_window.lower() - self.preferred_backup_window = preferred_backup_window - msg = valid_preferred_maintenance_window( - self.preferred_maintenance_window, - self.preferred_backup_window, - ) - if msg: - raise RDSClientError("InvalidParameterValue", msg) - self.db_parameter_group_name = db_parameter_group_name if ( self.db_parameter_group_name @@ -866,7 +846,6 @@ def __init__( not in rds_backends[self.account_id][self.region].db_parameter_groups ): raise DBParameterGroupNotFoundError(self.db_parameter_group_name) - self.license_model = license_model self.option_group_name = option_group_name self.option_group_supplied = self.option_group_name is not None @@ -883,7 +862,6 @@ def __init__( } if not self.option_group_name and self.engine in self.default_option_groups: self.option_group_name = self.default_option_groups[self.engine] - self.character_set_name = character_set_name self.enable_iam_database_authentication = kwargs.get( "enable_iam_database_authentication", False ) @@ -892,9 +870,57 @@ def __init__( self.dbi_resource_id = "db-M5ENSHXFPU6XHZ4G4ZEI5QIO2U" self.tags = tags or [] self.deletion_protection = deletion_protection - self.enabled_cloudwatch_logs_exports = enable_cloudwatch_logs_exports or [] + self.db_cluster_identifier = db_cluster_identifier + if self.db_cluster_identifier is None: + self.storage_type = storage_type or DBInstance.default_storage_type( + iops=self.iops + ) + self.allocated_storage = ( + allocated_storage + or DBInstance.default_allocated_storage( + engine=self.engine, storage_type=self.storage_type + ) + ) + self.max_allocated_storage = max_allocated_storage or self.allocated_storage + self.storage_encrypted = storage_encrypted + if self.storage_encrypted: + self.kms_key_id = kwargs.get("kms_key_id", "default_kms_key_id") + else: + self.kms_key_id = kwargs.get("kms_key_id") + self.backup_retention_period = backup_retention_period + self.character_set_name = character_set_name + self.engine_version = engine_version + if not self.engine_version and self.engine in self.default_engine_versions: + self.engine_version = self.default_engine_versions[self.engine] + self.master_username = master_username + self.master_user_password = master_user_password + self.preferred_backup_window = preferred_backup_window + msg = valid_preferred_maintenance_window( + self.preferred_maintenance_window, + self.preferred_backup_window, + ) + if msg: + raise RDSClientError("InvalidParameterValue", msg) + else: + # TODO: Refactor this into a DBClusterInstance subclass + self.cluster = self.backend.clusters[self.db_cluster_identifier] + self.allocated_storage = self.cluster.allocated_storage or 1 + self.max_allocated_storage = ( + self.cluster.allocated_storage or self.allocated_storage + ) + self.storage_encrypted = self.cluster.storage_encrypted or True + self.kms_key_id = self.cluster.kms_key_id + self.preferred_backup_window = self.cluster.preferred_backup_window + self.backup_retention_period = self.cluster.backup_retention_period or 1 + self.character_set_name = self.cluster.character_set_name + self.engine_version = self.cluster.engine_version + self.master_username = self.cluster.master_username + self.master_user_password = self.cluster.master_user_password + if self.db_name is None: + self.db_name = self.cluster.database_name + @property def name(self) -> str: return self.db_instance_identifier @@ -955,6 +981,19 @@ def master_user_secret(self) -> Dict[str, Any] | None: # type: ignore[misc] } return secret_info if self.manage_master_user_password else None + @property + def max_allocated_storage(self) -> Optional[int]: + value: int = self._max_allocated_storage or 0 # type: ignore[has-type] + return value if value != self.allocated_storage else None + + @max_allocated_storage.setter + def max_allocated_storage(self, value: int) -> None: + if value < self.allocated_storage: + raise InvalidParameterCombination( + "Max storage size must be greater than storage size" + ) + self._max_allocated_storage = value + @property def address(self) -> str: return ( @@ -1181,6 +1220,11 @@ def delete(self, account_id: str, region_name: str) -> None: backend = rds_backends[account_id][region_name] backend.delete_db_instance(self.db_instance_identifier) # type: ignore[arg-type] + def save_automated_backup(self) -> None: + time_stamp = utcnow().strftime("%Y-%m-%d-%H-%M") + snapshot_id = f"rds:{self.db_instance_identifier}-{time_stamp}" + self.backend.create_auto_snapshot(self.db_instance_identifier, snapshot_id) + class DBSnapshot(RDSBaseModel): resource_type = "snapshot" @@ -1201,43 +1245,46 @@ def __init__( database: DBInstance, snapshot_id: str, snapshot_type: str, - tags: List[Dict[str, str]], + tags: Optional[List[Dict[str, str]]] = None, original_created_at: Optional[str] = None, + kms_key_id: Optional[str] = None, ): super().__init__(backend) - self.database = database + self.database = copy.copy(database) # TODO: Refactor this out. self.snapshot_id = snapshot_id self.snapshot_type = snapshot_type - self.tags = tags + self.tags = tags or [] self.status = "available" self.created_at = iso_8601_datetime_with_milliseconds() self.original_created_at = original_created_at or self.created_at self.attributes: List[Dict[str, Any]] = [] + # Database attributes are captured at the time the snapshot is taken. + self.allocated_storage = database.allocated_storage + self.dbi_resource_id = database.dbi_resource_id + self.db_instance_identifier = database.db_instance_identifier + self.engine = database.engine + self.engine_version = database.engine_version + if kms_key_id is not None: + self.kms_key_id = kms_key_id + self.encrypted = self.database.storage_encrypted = True + else: + self.kms_key_id = database.kms_key_id + self.encrypted = database.storage_encrypted + self.iam_database_authentication_enabled = ( + database.enable_iam_database_authentication + ) + self.instance_create_time = database.created + self.master_username = database.master_username + self.port = database.port @property def name(self) -> str: return self.snapshot_id - @property - def dbi_resource_id(self) -> str: - return self.database.dbi_resource_id - - @property - def engine(self) -> str: - return self.database.engine - @property def db_snapshot_identifier(self) -> str: return self.snapshot_id - @property - def db_instance_identifier(self) -> str: - return self.database.db_instance_identifier - - @property - def iam_database_authentication_enabled(self) -> bool: - return self.database.enable_iam_database_authentication - @property def snapshot_create_time(self) -> str: return self.created_at @@ -1567,6 +1614,25 @@ def name(self) -> str: return self.unique_id +class DBInstanceAutomatedBackup(XFormedAttributeAccessMixin): + def __init__( + self, + backend: RDSBackend, + db_instance_identifier: str, + automated_snapshots: List[DBSnapshot], + ) -> None: + self.backend = backend + self.db_instance_identifier = db_instance_identifier + self.automated_snapshots = automated_snapshots + + @property + def status(self) -> str: + status = "active" + if self.db_instance_identifier not in self.backend.databases: + status = "retained" + return status + + class RDSBackend(BaseBackend): def __init__(self, region_name: str, account_id: str): super().__init__(region_name, account_id) @@ -1604,6 +1670,10 @@ def __init__(self, region_name: str, account_id: str): OptionGroup: self.option_groups, } + @property + def kms(self) -> KmsBackend: + return kms_backends[self.account_id][self.region_name] + @lru_cache() def db_cluster_options(self, engine) -> List[Dict[str, Any]]: # type: ignore from moto.rds.utils import decode_orderable_db_instance @@ -1640,6 +1710,7 @@ def create_db_instance(self, db_kwargs: Dict[str, Any]) -> DBInstance: ) cluster.cluster_members.append(database_id) self.databases[database_id] = database + database.save_automated_backup() return database def create_auto_snapshot( @@ -1658,6 +1729,7 @@ def create_db_snapshot( snapshot_type: str = "manual", tags: Optional[List[Dict[str, str]]] = None, original_created_at: Optional[str] = None, + kms_key_id: Optional[str] = None, ) -> DBSnapshot: if isinstance(db_instance, str): database = self.databases.get(db_instance) @@ -1683,25 +1755,30 @@ def create_db_snapshot( snapshot_type, tags, original_created_at, + kms_key_id, ) self.database_snapshots[db_snapshot_identifier] = snapshot return snapshot def copy_db_snapshot( self, - source_snapshot_identifier: str, - target_snapshot_identifier: str, + source_db_snapshot_identifier: str, + target_db_snapshot_identifier: str, tags: Optional[List[Dict[str, str]]] = None, - copy_tags: bool = False, + copy_tags: Optional[bool] = False, + kms_key_id: Optional[str] = None, ) -> DBSnapshot: - if source_snapshot_identifier.startswith("arn:aws:rds:"): - source_snapshot_identifier = self.extract_snapshot_name_from_arn( - source_snapshot_identifier + if source_db_snapshot_identifier.startswith("arn:aws:rds:"): + source_db_snapshot_identifier = self.extract_snapshot_name_from_arn( + source_db_snapshot_identifier ) - if source_snapshot_identifier not in self.database_snapshots: - raise DBSnapshotNotFoundError(source_snapshot_identifier) - - source_snapshot = self.database_snapshots[source_snapshot_identifier] + if source_db_snapshot_identifier not in self.database_snapshots: + raise DBSnapshotNotFoundError(source_db_snapshot_identifier) + if kms_key_id is not None: + key = self.kms.describe_key(kms_key_id) + # We do this in case an alias was passed in. + kms_key_id = key.id + source_snapshot = self.database_snapshots[source_db_snapshot_identifier] # When tags are passed, AWS does NOT copy/merge tags of the # source snapshot, even when copy_tags=True is given. @@ -1711,9 +1788,10 @@ def copy_db_snapshot( return self.create_db_snapshot( db_instance=source_snapshot.database, - db_snapshot_identifier=target_snapshot_identifier, + db_snapshot_identifier=target_db_snapshot_identifier, tags=tags, original_created_at=source_snapshot.original_created_at, + kms_key_id=kms_key_id, ) def delete_db_snapshot(self, db_snapshot_identifier: str) -> DBSnapshot: @@ -1763,7 +1841,8 @@ def describe_db_instances( def describe_db_snapshots( self, db_instance_identifier: Optional[str], - db_snapshot_identifier: str, + db_snapshot_identifier: Optional[str] = None, + snapshot_type: Optional[str] = None, filters: Optional[Dict[str, Any]] = None, ) -> List[DBSnapshot]: snapshots = self.database_snapshots @@ -1775,6 +1854,18 @@ def describe_db_snapshots( filters = merge_filters( filters, {"db-snapshot-id": [db_snapshot_identifier]} ) + snapshot_types = ( + ["automated", "manual"] + if ( + snapshot_type is None + and (filters is not None and "snapshot-type" not in filters) + ) + else [snapshot_type] + if snapshot_type is not None + else [] + ) + if snapshot_types: + filters = merge_filters(filters, {"snapshot-type": snapshot_types}) if filters: snapshots = self._filter_resources(snapshots, filters, DBSnapshot) if db_snapshot_identifier and not snapshots and not db_instance_identifier: @@ -1852,13 +1943,21 @@ def restore_db_instance_from_db_snapshot( new_instance_props = {} for key, value in original_database.__dict__.items(): - if key != "backend": - new_instance_props[key] = copy.deepcopy(value) + if key not in [ + "backend", + "db_parameter_group_name", + "vpc_security_group_ids", + ]: + new_instance_props[key] = copy.copy(value) if not original_database.option_group_supplied: # If the option group is not supplied originally, the 'option_group_name' will receive a default value # Force this reconstruction, and prevent any validation on the default value del new_instance_props["option_group_name"] - + if "allocated_storage" in overrides: + if overrides["allocated_storage"] < snapshot.allocated_storage: + raise InvalidParameterValue( + "The allocated storage size can't be less than the source snapshot or backup size." + ) for key, value in overrides.items(): if value: new_instance_props[key] = value @@ -1887,7 +1986,11 @@ def restore_db_instance_to_point_in_time( # If the option group is not supplied originally, the 'option_group_name' will receive a default value # Force this reconstruction, and prevent any validation on the default value del new_instance_props["option_group_name"] - + if "allocated_storage" in overrides: + if overrides["allocated_storage"] < db_instance.allocated_storage: + raise InvalidParameterValue( + "Allocated storage size can't be less than the source instance size." + ) for key, value in overrides.items(): if value: new_instance_props[key] = value @@ -1938,16 +2041,24 @@ def find_db_from_id(self, db_id: str) -> DBInstance: return backend.describe_db_instances(db_name)[0] def delete_db_instance( - self, db_instance_identifier: str, db_snapshot_name: Optional[str] = None + self, + db_instance_identifier: str, + final_db_snapshot_identifier: Optional[str] = None, + skip_final_snapshot: Optional[bool] = False, + delete_automated_backups: Optional[bool] = True, ) -> DBInstance: self._validate_db_identifier(db_instance_identifier) if db_instance_identifier in self.databases: if self.databases[db_instance_identifier].deletion_protection: - raise InvalidParameterValue( - "Can't delete Instance with protection enabled" + raise InvalidParameterCombination( + "Cannot delete protected DB Instance, please disable deletion protection and try again." + ) + if final_db_snapshot_identifier and not skip_final_snapshot: + self.create_db_snapshot( + db_instance_identifier, + final_db_snapshot_identifier, + snapshot_type="manual", ) - if db_snapshot_name: - self.create_auto_snapshot(db_instance_identifier, db_snapshot_name) database = self.databases.pop(db_instance_identifier) if database.is_replica: primary = self.find_db_from_id(database.source_db_instance_identifier) # type: ignore @@ -1956,6 +2067,14 @@ def delete_db_instance( self.clusters[database.db_cluster_identifier].cluster_members.remove( db_instance_identifier ) + automated_snapshots = self.describe_db_snapshots( + db_instance_identifier, + db_snapshot_identifier=None, + snapshot_type="automated", + ) + if delete_automated_backups: + for snapshot in automated_snapshots: + self.delete_db_snapshot(snapshot.db_snapshot_identifier) database.status = "deleting" return database else: @@ -2282,9 +2401,7 @@ def create_db_cluster(self, kwargs: Dict[str, Any]) -> DBCluster: original_cluster = find_cluster(cluster_identifier) original_cluster.read_replica_identifiers.append(cluster.db_cluster_arn) - initial_state = copy.deepcopy(cluster) # Return status=creating - cluster.status = "available" # Already set the final status in the background - return initial_state + return cluster def modify_db_cluster(self, kwargs: Dict[str, Any]) -> DBCluster: cluster_id = kwargs["db_cluster_identifier"] @@ -2589,40 +2706,37 @@ def _find_resource(self, resource_type: str, resource_name: str) -> Any: if resource.arn.endswith(resource_name): return resource - def list_tags_for_resource(self, arn: str) -> List[Dict[str, str]]: + def _get_resource_for_tagging(self, arn: str) -> Any: if self.arn_regex.match(arn): arn_breakdown = arn.split(":") resource_type = arn_breakdown[len(arn_breakdown) - 2] resource_name = arn_breakdown[len(arn_breakdown) - 1] + # FIXME: HACK for automated snapshots + if resource_type == "rds": + resource_type = arn_breakdown[-3] + resource_name = arn_breakdown[-2] + ":" + arn_breakdown[-1] resource = self._find_resource(resource_type, resource_name) - if resource: - return resource.get_tags() - return [] + return resource raise RDSClientError("InvalidParameterValue", f"Invalid resource name: {arn}") + def list_tags_for_resource(self, arn: str) -> List[Dict[str, str]]: + resource = self._get_resource_for_tagging(arn) + if resource: + return resource.get_tags() + return [] + def remove_tags_from_resource(self, arn: str, tag_keys: List[str]) -> None: - if self.arn_regex.match(arn): - arn_breakdown = arn.split(":") - resource_type = arn_breakdown[len(arn_breakdown) - 2] - resource_name = arn_breakdown[len(arn_breakdown) - 1] - resource = self._find_resource(resource_type, resource_name) - if resource: - resource.remove_tags(tag_keys) - return - raise RDSClientError("InvalidParameterValue", f"Invalid resource name: {arn}") + resource = self._get_resource_for_tagging(arn) + if resource: + resource.remove_tags(tag_keys) def add_tags_to_resource( # type: ignore[return] self, arn: str, tags: List[Dict[str, str]] ) -> List[Dict[str, str]]: - if self.arn_regex.match(arn): - arn_breakdown = arn.split(":") - resource_type = arn_breakdown[-2] - resource_name = arn_breakdown[-1] - resource = self._find_resource(resource_type, resource_name) - if resource: - return resource.add_tags(tags) - return [] - raise RDSClientError("InvalidParameterValue", f"Invalid resource name: {arn}") + resource = self._get_resource_for_tagging(arn) + if resource: + return resource.add_tags(tags) + return [] @staticmethod def _filter_resources(resources: Any, filters: Any, resource_class: Any) -> Any: # type: ignore[misc] @@ -3001,6 +3115,26 @@ def modify_db_proxy_target_group( target_group.session_pinning_filters = config["SessionPinningFilters"] return target_group + def describe_db_instance_automated_backups( + self, + db_instance_identifier: Optional[str] = None, + **_: Any, + ) -> List[DBInstanceAutomatedBackup]: + snapshots = list(self.database_snapshots.values()) + if db_instance_identifier is not None: + snapshots = [ + snap + for snap in self.database_snapshots.values() + if snap.db_instance_identifier == db_instance_identifier + ] + snapshots_grouped = defaultdict(list) + for snapshot in snapshots: + if snapshot.snapshot_type == "automated": + snapshots_grouped[snapshot.db_instance_identifier].append(snapshot) + return [ + DBInstanceAutomatedBackup(self, k, v) for k, v in snapshots_grouped.items() + ] + class OptionGroup(RDSBaseModel): resource_type = "og" @@ -3012,6 +3146,7 @@ def __init__( engine_name: str, major_engine_version: str, option_group_description: Optional[str] = None, + tags: Optional[List[Dict[str, str]]] = None, ): super().__init__(backend) self.engine_name = engine_name @@ -3021,7 +3156,7 @@ def __init__( self.vpc_and_non_vpc_instance_memberships = False self._options: Dict[str, Any] = {} self.vpcId = "null" - self.tags: List[Dict[str, str]] = [] + self.tags = tags or [] @property def name(self) -> str: diff --git a/moto/rds/responses.py b/moto/rds/responses.py index 1a346260b19e..a1d90a9b903d 100644 --- a/moto/rds/responses.py +++ b/moto/rds/responses.py @@ -121,16 +121,12 @@ def modify_db_instance(self) -> TYPE_RESPONSE: return self.serialize(result) def delete_db_instance(self) -> TYPE_RESPONSE: - db_instance_identifier = self.parameters.get("DBInstanceIdentifier") db_snapshot_name = self.parameters.get("FinalDBSnapshotIdentifier") if db_snapshot_name is not None: self.backend.validate_db_snapshot_identifier( db_snapshot_name, parameter_name="FinalDBSnapshotIdentifier" ) - - database = self.backend.delete_db_instance( - db_instance_identifier, db_snapshot_name - ) + database = self.backend.delete_db_instance(**self.parameters) result = {"DBInstance": database} return self.serialize(result) @@ -156,27 +152,22 @@ def create_db_snapshot(self) -> TYPE_RESPONSE: return self.serialize(result) def copy_db_snapshot(self) -> TYPE_RESPONSE: - source_snapshot_identifier = self.parameters.get("SourceDBSnapshotIdentifier") target_snapshot_identifier = self.parameters.get("TargetDBSnapshotIdentifier") - tags = self.parameters.get("Tags", []) - copy_tags = self.parameters.get("CopyTags") self.backend.validate_db_snapshot_identifier( target_snapshot_identifier, parameter_name="TargetDBSnapshotIdentifier" ) - - snapshot = self.backend.copy_db_snapshot( - source_snapshot_identifier, target_snapshot_identifier, tags, copy_tags - ) + snapshot = self.backend.copy_db_snapshot(**self.parameters) result = {"DBSnapshot": snapshot} return self.serialize(result) def describe_db_snapshots(self) -> TYPE_RESPONSE: db_instance_identifier = self.parameters.get("DBInstanceIdentifier") db_snapshot_identifier = self.parameters.get("DBSnapshotIdentifier") + snapshot_type = self.parameters.get("SnapshotType") filters = self.parameters.get("Filters", []) filter_dict = {f["Name"]: f["Values"] for f in filters} snapshots = self.backend.describe_db_snapshots( - db_instance_identifier, db_snapshot_identifier, filter_dict + db_instance_identifier, db_snapshot_identifier, snapshot_type, filter_dict ) result = {"DBSnapshots": snapshots} return self.serialize(result) @@ -779,6 +770,13 @@ def modify_db_proxy_target_group(self) -> TYPE_RESPONSE: result = {"DBProxyTargetGroup": group} return self.serialize(result) + def describe_db_instance_automated_backups(self) -> TYPE_RESPONSE: + automated_backups = self.backend.describe_db_instance_automated_backups( + **self.parameters + ) + result = {"DBInstanceAutomatedBackups": automated_backups} + return self.serialize(result) + def _paginate(self, resources: List[Any]) -> Tuple[List[Any], Optional[str]]: from moto.rds.exceptions import InvalidParameterValue diff --git a/tests/test_rds/test_filters.py b/tests/test_rds/test_filters.py index 3cbde201bd6a..1bae4face150 100644 --- a/tests/test_rds/test_filters.py +++ b/tests/test_rds/test_filters.py @@ -15,7 +15,13 @@ def setup_class(cls): for i in range(10): instance_identifier = f"db-instance-{i}" cluster_identifier = f"db-cluster-{i}" - engine = "postgres" if (i % 3) else "mysql" + engine = "aurora-postgresql" if (i % 3) else "aurora-mysql" + client.create_db_cluster( + DBClusterIdentifier=cluster_identifier, + Engine=engine, + MasterUsername="root", + MasterUserPassword="password", + ) client.create_db_instance( DBInstanceIdentifier=instance_identifier, DBClusterIdentifier=cluster_identifier, @@ -108,7 +114,7 @@ def test_multiple_filters(self): "Name": "db-instance-id", "Values": ["db-instance-0", "db-instance-1", "db-instance-3"], }, - {"Name": "engine", "Values": ["mysql", "oracle"]}, + {"Name": "engine", "Values": ["aurora-mysql", "oracle"]}, ] ) returned_identifiers = [ @@ -148,7 +154,7 @@ def test_valid_db_instance_identifier_with_exclusive_filter(self): DBInstanceIdentifier="db-instance-0", Filters=[ {"Name": "db-instance-id", "Values": ["db-instance-1"]}, - {"Name": "engine", "Values": ["postgres"]}, + {"Name": "engine", "Values": ["aurora-postgresql"]}, ], ) returned_identifiers = [ @@ -164,7 +170,7 @@ def test_valid_db_instance_identifier_with_inclusive_filter(self): DBInstanceIdentifier="db-instance-0", Filters=[ {"Name": "db-instance-id", "Values": ["db-instance-1"]}, - {"Name": "engine", "Values": ["mysql", "postgres"]}, + {"Name": "engine", "Values": ["aurora-mysql", "aurora-postgresql"]}, ], ) returned_identifiers = [ @@ -296,7 +302,7 @@ def test_snapshot_type_filter(self): snapshots = self.client.describe_db_snapshots( Filters=[{"Name": "snapshot-type", "Values": ["automated"]}] )["DBSnapshots"] - assert len(snapshots) == 0 + assert len(snapshots) == 2 def test_multiple_filters(self): snapshots = self.client.describe_db_snapshots( diff --git a/tests/test_rds/test_rds.py b/tests/test_rds/test_rds.py index 29b151cecac0..5035cc90da01 100644 --- a/tests/test_rds/test_rds.py +++ b/tests/test_rds/test_rds.py @@ -1,4 +1,5 @@ import datetime +import time from uuid import uuid4 import boto3 @@ -251,8 +252,8 @@ def test_stop_database(client): DBInstanceIdentifier=mydb["DBInstanceIdentifier"], DBSnapshotIdentifier="rocky4570-rds-snap", ) - response = client.describe_db_snapshots() - assert response["DBSnapshots"] == [] + with pytest.raises(ClientError): + client.describe_db_snapshots(DBSnapshotIdentifier="rocky4570-rds-snap") @mock_aws @@ -273,7 +274,7 @@ def test_start_database(client): ) assert response["ResponseMetadata"]["HTTPStatusCode"] == 200 assert response["DBInstance"]["DBInstanceStatus"] == "stopped" - response = client.describe_db_snapshots() + response = client.describe_db_snapshots(DBSnapshotIdentifier="rocky4570-rds-snap") assert response["DBSnapshots"][0]["DBSnapshotIdentifier"] == "rocky4570-rds-snap" response = client.start_db_instance( DBInstanceIdentifier=mydb["DBInstanceIdentifier"] @@ -281,7 +282,7 @@ def test_start_database(client): assert response["ResponseMetadata"]["HTTPStatusCode"] == 200 assert response["DBInstance"]["DBInstanceStatus"] == "available" # starting database should not remove snapshot - response = client.describe_db_snapshots() + response = client.describe_db_snapshots(DBSnapshotIdentifier="rocky4570-rds-snap") assert response["DBSnapshots"][0]["DBSnapshotIdentifier"] == "rocky4570-rds-snap" # test stopping database, create snapshot with existing snapshot already # created should throw error @@ -743,11 +744,52 @@ def test_delete_database(client): assert len(instances["DBInstances"]) == 0 # Saved the snapshot - snapshot = client.describe_db_snapshots(DBInstanceIdentifier="db-1")["DBSnapshots"][ - 0 - ] + snapshot = client.describe_db_snapshots( + DBInstanceIdentifier="db-1", DBSnapshotIdentifier="primary-1-snapshot" + )["DBSnapshots"][0] assert snapshot["Engine"] == "postgres" - assert snapshot["SnapshotType"] == "automated" + assert snapshot["SnapshotType"] == "manual" + + +@mock_aws +def test_max_allocated_storage(client): + # MaxAllocatedStorage is not set or included in details by default. + details = create_db_instance() + assert "MaxAllocatedStorage" not in details + # Can't set to less than AllocatedStorage. + with pytest.raises(ClientError) as excinfo: + create_db_instance( + DBInstanceIdentifier="less-than-allocated-storage", + AllocatedStorage=50, + MaxAllocatedStorage=25, + ) + error_info = excinfo.value.response["Error"] + assert error_info["Code"] == "InvalidParameterCombination" + assert error_info["Message"] == "Max storage size must be greater than storage size" + # Set at creation time. + details = create_db_instance( + DBInstanceIdentifier="test-max-allocated-storage", MaxAllocatedStorage=500 + ) + assert details["MaxAllocatedStorage"] == 500 + # Set to higher limit. + details = client.modify_db_instance( + DBInstanceIdentifier=details["DBInstanceIdentifier"], MaxAllocatedStorage=1000 + )["DBInstance"] + assert details["MaxAllocatedStorage"] == 1000 + # Disable by setting equal to AllocatedStorage. + details = client.modify_db_instance( + DBInstanceIdentifier=details["DBInstanceIdentifier"], + MaxAllocatedStorage=details["AllocatedStorage"], + )["DBInstance"] + assert "MaxAllocatedStorage" not in details + # Can't set to less than AllocatedStorage. + with pytest.raises(ClientError) as excinfo: + client.modify_db_instance( + DBInstanceIdentifier=details["DBInstanceIdentifier"], MaxAllocatedStorage=5 + ) + error_info = excinfo.value.response["Error"] + assert error_info["Code"] == "InvalidParameterCombination" + assert error_info["Message"] == "Max storage size must be greater than storage size" @mock_aws @@ -789,6 +831,15 @@ def test_create_db_snapshots_copy_tags(client): Tags=[{"Key": "foo", "Value": "bar"}, {"Key": "foo1", "Value": "bar1"}], ) + snapshot = client.describe_db_snapshots( + DBInstanceIdentifier="db-primary-1", SnapshotType="automated" + )["DBSnapshots"][0] + result = client.list_tags_for_resource(ResourceName=snapshot["DBSnapshotArn"]) + assert result["TagList"] == [ + {"Value": "bar", "Key": "foo"}, + {"Value": "bar1", "Key": "foo1"}, + ] + snapshot = client.create_db_snapshot( DBInstanceIdentifier="db-primary-1", DBSnapshotIdentifier="g-1" )["DBSnapshot"] @@ -813,9 +864,9 @@ def test_create_db_snapshots_with_tags(client): Tags=[{"Key": "foo", "Value": "bar"}, {"Key": "foo1", "Value": "bar1"}], ) - snapshots = client.describe_db_snapshots(DBInstanceIdentifier="db-primary-1")[ - "DBSnapshots" - ] + snapshots = client.describe_db_snapshots( + DBInstanceIdentifier="db-primary-1", SnapshotType="manual" + )["DBSnapshots"] assert snapshots[0]["DBSnapshotIdentifier"] == "g-1" assert snapshots[0]["TagList"] == [ {"Value": "bar", "Key": "foo"}, @@ -872,13 +923,16 @@ def test_copy_db_snapshots_snapshot_type_is_always_manual(client): db_instance_identifier = create_db_instance()["DBInstanceIdentifier"] client.delete_db_instance( DBInstanceIdentifier=db_instance_identifier, - FinalDBSnapshotIdentifier="final-snapshot", + SkipFinalSnapshot=True, + DeleteAutomatedBackups=False, ) - snapshot1 = client.describe_db_snapshots()["DBSnapshots"][0] + snapshot1 = client.describe_db_snapshots( + DBInstanceIdentifier=db_instance_identifier, SnapshotType="automated" + )["DBSnapshots"][0] assert snapshot1["SnapshotType"] == "automated" snapshot2 = client.copy_db_snapshot( - SourceDBSnapshotIdentifier="final-snapshot", + SourceDBSnapshotIdentifier=snapshot1["DBSnapshotIdentifier"], TargetDBSnapshotIdentifier="snapshot-2", )["DBSnapshot"] assert snapshot2["SnapshotType"] == "manual" @@ -954,7 +1008,9 @@ def test_describe_db_snapshots(client): assert created["Engine"] == "postgres" assert created["SnapshotType"] == "manual" - by_database_id = client.describe_db_snapshots(DBInstanceIdentifier="db-primary-1") + by_database_id = client.describe_db_snapshots( + DBInstanceIdentifier="db-primary-1", SnapshotType="manual" + ) by_snapshot_id = client.describe_db_snapshots(DBSnapshotIdentifier="snapshot-1") assert by_snapshot_id["DBSnapshots"] == by_database_id["DBSnapshots"] @@ -964,9 +1020,9 @@ def test_describe_db_snapshots(client): client.create_db_snapshot( DBInstanceIdentifier="db-primary-1", DBSnapshotIdentifier="snapshot-2" ) - snapshots = client.describe_db_snapshots(DBInstanceIdentifier="db-primary-1")[ - "DBSnapshots" - ] + snapshots = client.describe_db_snapshots( + DBInstanceIdentifier="db-primary-1", SnapshotType="manual" + )["DBSnapshots"] assert len(snapshots) == 2 @@ -1438,6 +1494,7 @@ def test_modify_non_existent_option_group(client): ) +@pytest.mark.aws_verified @mock_aws def test_delete_database_with_protection(client): create_db_instance(DBInstanceIdentifier="db-primary-1", DeletionProtection=True) @@ -1445,7 +1502,11 @@ def test_delete_database_with_protection(client): with pytest.raises(ClientError) as exc: client.delete_db_instance(DBInstanceIdentifier="db-primary-1") err = exc.value.response["Error"] - assert err["Message"] == "Can't delete Instance with protection enabled" + assert err["Code"] == "InvalidParameterCombination" + assert ( + err["Message"] + == "Cannot delete protected DB Instance, please disable deletion protection and try again." + ) @mock_aws @@ -2628,6 +2689,210 @@ def test_modify_db_snapshot_attribute(client): assert snapshot_attributes[0]["AttributeValues"] == ["Test2", "Test3"] +@mock_aws +@pytest.mark.parametrize("skip_final_snapshot", [False, True]) +def test_delete_db_instance_with_skip_final_snapshot_param(client, skip_final_snapshot): + create_db_instance(DBInstanceIdentifier="db-primary-1") + + deletion_kwargs = dict( + DBInstanceIdentifier="db-primary-1", SkipFinalSnapshot=skip_final_snapshot + ) + if not skip_final_snapshot: + deletion_kwargs["FinalDBSnapshotIdentifier"] = "final-snapshot" + client.delete_db_instance(**deletion_kwargs) + + with pytest.raises(ClientError): + client.describe_db_instances(DBInstanceIdentifier="db-primary-1") + + resp = client.describe_db_snapshots( + DBInstanceIdentifier="db-primary-1", + DBSnapshotIdentifier="final-snapshot", + SnapshotType="manual", + ) + snapshot_count = len(resp["DBSnapshots"]) + valid_conditions = [ + (skip_final_snapshot and snapshot_count == 0), + (snapshot_count == 1 and not skip_final_snapshot), + ] + assert any(valid_conditions) + if not skip_final_snapshot: + assert resp["DBSnapshots"][0]["DBSnapshotIdentifier"] == "final-snapshot" + + +@mock_aws +@pytest.mark.parametrize("delete_automated_backups", [False, True]) +def test_delete_db_instance_with_delete_automated_backups_param( + client, + delete_automated_backups, +): + create_db_instance(DBInstanceIdentifier="db-primary-1") + + client.delete_db_instance( + DBInstanceIdentifier="db-primary-1", + SkipFinalSnapshot=True, + DeleteAutomatedBackups=delete_automated_backups, + ) + + with pytest.raises(ClientError): + client.describe_db_instances(DBInstanceIdentifier="db-primary-1") + + resp = client.describe_db_snapshots( + DBInstanceIdentifier="db-primary-1", + SnapshotType="automated", + ) + automated_snapshot_count = len(resp["DBSnapshots"]) + valid_conditions = [ + (delete_automated_backups and automated_snapshot_count == 0), + (automated_snapshot_count >= 1 and not delete_automated_backups), + ] + assert any(valid_conditions) + + +@mock_aws +def test_describe_db_instance_automated_backups_lifecycle(client): + instance_id = "test-instance" + create_db_instance(DBInstanceIdentifier=instance_id) + resp = client.describe_db_instance_automated_backups( + DBInstanceIdentifier=instance_id, + ) + automated_backups = resp["DBInstanceAutomatedBackups"] + assert len(automated_backups) == 1 + automated_backup = automated_backups[0] + assert automated_backup["DBInstanceIdentifier"] == instance_id + assert automated_backup["Status"] == "active" + + client.delete_db_instance( + DBInstanceIdentifier=instance_id, + SkipFinalSnapshot=True, + DeleteAutomatedBackups=False, + ) + + resp = client.describe_db_instance_automated_backups( + DBInstanceIdentifier=instance_id, + ) + automated_backups = resp["DBInstanceAutomatedBackups"] + assert len(automated_backups) == 1 + automated_backup = automated_backups[0] + assert automated_backup["DBInstanceIdentifier"] == instance_id + assert automated_backup["Status"] == "retained" + + +@mock_aws +def test_delete_automated_backups_by_default(client): + instance_id = "test-instance" + create_db_instance(DBInstanceIdentifier=instance_id) + resp = client.describe_db_instance_automated_backups( + DBInstanceIdentifier=instance_id, + ) + automated_backups = resp["DBInstanceAutomatedBackups"] + assert len(automated_backups) == 1 + automated_backup = automated_backups[0] + assert automated_backup["DBInstanceIdentifier"] == instance_id + assert automated_backup["Status"] == "active" + + client.delete_db_instance(DBInstanceIdentifier=instance_id, SkipFinalSnapshot=True) + + resp = client.describe_db_instance_automated_backups( + DBInstanceIdentifier=instance_id, + ) + automated_backups = resp["DBInstanceAutomatedBackups"] + assert len(automated_backups) == 0 + + +@mock_aws +def test_restore_db_instance_from_db_snapshot_with_allocated_storage(client): + instance_id = "db-primary-1" + allocated_storage = 20 + create_db_instance( + DBInstanceIdentifier=instance_id, + AllocatedStorage=allocated_storage, + Engine="postgres", + ) + snapshot = client.create_db_snapshot( + DBInstanceIdentifier=instance_id, DBSnapshotIdentifier="snap" + ).get("DBSnapshot") + snapshot_id = snapshot["DBSnapshotIdentifier"] + # Default + restored = client.restore_db_instance_from_db_snapshot( + DBInstanceIdentifier="restored-default", + DBSnapshotIdentifier=snapshot_id, + ).get("DBInstance") + assert restored["AllocatedStorage"] == allocated_storage + # More than snapshot allocated storage + restored = client.restore_db_instance_from_db_snapshot( + DBInstanceIdentifier="restored-with-allocated-storage", + DBSnapshotIdentifier=snapshot_id, + AllocatedStorage=allocated_storage * 2, + ).get("DBInstance") + assert restored["AllocatedStorage"] == allocated_storage * 2 + # Less than snapshot allocated storage + with pytest.raises(ClientError, match=r"allocated storage") as excinfo: + client.restore_db_instance_from_db_snapshot( + DBInstanceIdentifier="restored-with-too-little-storage", + DBSnapshotIdentifier=snapshot_id, + AllocatedStorage=int(allocated_storage / 2), + ) + exc = excinfo.value + assert exc.response["Error"]["Code"] == "InvalidParameterValue" + + +@mock_aws +def test_restore_db_instance_to_point_in_time_with_allocated_storage(client): + allocated_storage = 20 + details_source = create_db_instance(AllocatedStorage=allocated_storage) + source_identifier = details_source["DBInstanceIdentifier"] + restore_time = datetime.datetime.fromtimestamp( + time.time() - 600, datetime.timezone.utc + ).strftime("%Y-%m-%dT%H:%M:%SZ") + # Default + restored = client.restore_db_instance_to_point_in_time( + SourceDBInstanceIdentifier=source_identifier, + TargetDBInstanceIdentifier="pit-default", + RestoreTime=restore_time, + ).get("DBInstance") + assert restored["AllocatedStorage"] == allocated_storage + # More than source allocated storage + restored = client.restore_db_instance_to_point_in_time( + SourceDBInstanceIdentifier=source_identifier, + TargetDBInstanceIdentifier="pit-with-allocated-storage", + RestoreTime=restore_time, + AllocatedStorage=allocated_storage * 2, + ).get("DBInstance") + assert restored["AllocatedStorage"] == allocated_storage * 2 + # Less than source allocated storage + with pytest.raises(ClientError, match=r"Allocated storage") as excinfo: + client.restore_db_instance_to_point_in_time( + SourceDBInstanceIdentifier=source_identifier, + TargetDBInstanceIdentifier="pit-with-too-little-storage", + RestoreTime=restore_time, + AllocatedStorage=int(allocated_storage / 2), + ) + exc = excinfo.value + assert exc.response["Error"]["Code"] == "InvalidParameterValue" + + +@mock_aws +def test_copy_unencrypted_db_snapshot_to_encrypted_db_snapshot(client): + instance_identifier = "unencrypted-db-instance" + create_db_instance(DBInstanceIdentifier=instance_identifier, StorageEncrypted=False) + snapshot = client.create_db_snapshot( + DBInstanceIdentifier=instance_identifier, + DBSnapshotIdentifier="unencrypted-db-snapshot", + ).get("DBSnapshot") + assert snapshot["Encrypted"] is False + + client.copy_db_snapshot( + SourceDBSnapshotIdentifier="unencrypted-db-snapshot", + TargetDBSnapshotIdentifier="encrypted-db-snapshot", + KmsKeyId="alias/aws/rds", + ) + snapshot = client.describe_db_snapshots( + DBSnapshotIdentifier="encrypted-db-snapshot" + ).get("DBSnapshots")[0] + assert snapshot["DBSnapshotIdentifier"] == "encrypted-db-snapshot" + assert snapshot["Encrypted"] is True + + def validation_helper(exc): err = exc.value.response["Error"] assert err["Code"] == "InvalidParameterValue" diff --git a/tests/test_rds/test_rds_clusters.py b/tests/test_rds/test_rds_clusters.py index e0e795efe000..00896f99c570 100644 --- a/tests/test_rds/test_rds_clusters.py +++ b/tests/test_rds/test_rds_clusters.py @@ -355,6 +355,43 @@ def test_create_db_cluster_additional_parameters(client): assert cluster["IAMDatabaseAuthenticationEnabled"] is True +@mock_aws +def test_modify_db_cluster_serverless_v2_scaling_configuration(client): + resp = client.create_db_cluster( + DBClusterIdentifier="cluster-id", + Engine="aurora-postgresql", + EngineMode="serverless", + MasterUsername="root", + MasterUserPassword="hunter2_", + ServerlessV2ScalingConfiguration={ + "MinCapacity": 2, + "MaxCapacity": 4, + }, + ) + cluster = resp["DBCluster"] + assert cluster["Engine"] == "aurora-postgresql" + assert cluster["EngineMode"] == "serverless" + assert cluster["ServerlessV2ScalingConfiguration"] == { + "MaxCapacity": 4.0, + "MinCapacity": 2.0, + } + client.modify_db_cluster( + DBClusterIdentifier="cluster-id", + ServerlessV2ScalingConfiguration={ + "MinCapacity": 4, + "MaxCapacity": 8, + }, + ) + resp = client.describe_db_clusters(DBClusterIdentifier="cluster-id") + cluster_modified = resp["DBClusters"][0] + assert cluster_modified["Engine"] == "aurora-postgresql" + assert cluster_modified["EngineMode"] == "serverless" + assert cluster_modified["ServerlessV2ScalingConfiguration"] == { + "MaxCapacity": 8.0, + "MinCapacity": 4.0, + } + + @mock_aws def test_describe_db_cluster_after_creation(client): client.create_db_cluster( diff --git a/tests/test_resourcegroupstaggingapi/test_resourcegroupstagging_rds.py b/tests/test_resourcegroupstaggingapi/test_resourcegroupstagging_rds.py index fc54c0f63999..00aa3c1abbc2 100644 --- a/tests/test_resourcegroupstaggingapi/test_resourcegroupstagging_rds.py +++ b/tests/test_resourcegroupstaggingapi/test_resourcegroupstagging_rds.py @@ -27,6 +27,17 @@ def setUp(self) -> None: group = self.resources_tagged if i else self.resources_untagged group.append(database["DBInstanceArn"]) group.append(snapshot["DBSnapshotArn"]) + automated_snapshots = self.rds.describe_db_snapshots( + Filters=[ + { + "Name": "db-instance-id", + "Values": ["db-instance-1", "db-instance-2"], + }, + {"Name": "snapshot-type", "Values": ["automated"]}, + ], + )["DBSnapshots"] + for snapshot in automated_snapshots: + self.resources_tagged.append(snapshot["DBSnapshotArn"]) def test_get_resources_rds(self): def assert_response(response, expected_count, resource_type=None): @@ -40,15 +51,15 @@ def assert_response(response, expected_count, resource_type=None): assert f":{resource_type}:" in arn resp = self.rtapi.get_resources(ResourceTypeFilters=["rds"]) - assert_response(resp, 4) + assert_response(resp, 6) resp = self.rtapi.get_resources(ResourceTypeFilters=["rds:db"]) assert_response(resp, 2, resource_type="db") resp = self.rtapi.get_resources(ResourceTypeFilters=["rds:snapshot"]) - assert_response(resp, 2, resource_type="snapshot") + assert_response(resp, 4, resource_type="snapshot") resp = self.rtapi.get_resources( TagFilters=[{"Key": "test", "Values": ["value-1"]}] ) - assert_response(resp, 2) + assert_response(resp, 3) def test_tag_resources_rds(self): # WHEN