From cfaa898de1810779469e07bca79d891fde6b1aa2 Mon Sep 17 00:00:00 2001 From: Brian Pandola Date: Thu, 16 Jan 2025 13:36:30 -0800 Subject: [PATCH] Ensure compatibility with latest `botocore` S3 client customizations (#8495) New "default integrity protections" introduced in the AWS SDK for Python v1.36.0 necessitated the following changes: * Minor tweaks to S3 request processing to handle the new encoding format. * Minor tweaks to request processing during `proxy` and `recording` mode. * Test suite updates as a result of changes to the default request behavior for some S3 endpoints. Backward compatibility with earlier AWS SDK for Python versions has been maintained. --- moto/moto_api/_internal/recorder/models.py | 6 +++++- moto/moto_proxy/proxy3.py | 8 +++---- moto/s3/responses.py | 10 +++++++++ tests/test_s3/__init__.py | 9 ++++++++ tests/test_s3/test_s3.py | 8 +++---- tests/test_s3/test_s3_copyobject.py | 8 ++++--- tests/test_s3/test_s3_lock.py | 25 +++++++++++++++++++--- tests/test_s3/test_s3_object_attributes.py | 3 --- 8 files changed, 59 insertions(+), 18 deletions(-) diff --git a/moto/moto_api/_internal/recorder/models.py b/moto/moto_api/_internal/recorder/models.py index 20a40ea5c885..9d0855315544 100644 --- a/moto/moto_api/_internal/recorder/models.py +++ b/moto/moto_api/_internal/recorder/models.py @@ -33,7 +33,11 @@ def _record_request(self, request: Any, body: Optional[bytes] = None) -> None: if body is None: if isinstance(request, AWSPreparedRequest): - body_str, body_encoded = self._encode_body(body=request.body) + if hasattr(request.body, "read"): + body = request.body.read() # type: ignore + else: + body = request.body # type: ignore + body_str, body_encoded = self._encode_body(body) else: try: request_body = None diff --git a/moto/moto_proxy/proxy3.py b/moto/moto_proxy/proxy3.py index 2db0409e5118..da151203815d 100644 --- a/moto/moto_proxy/proxy3.py +++ b/moto/moto_proxy/proxy3.py @@ -155,12 +155,12 @@ def do_GET(self) -> None: return req_body = b"" - if "Content-Length" in req.headers: - content_length = int(req.headers["Content-Length"]) - req_body = self.rfile.read(content_length) - elif "chunked" in self.headers.get("Transfer-Encoding", ""): + if "chunked" in self.headers.get("Transfer-Encoding", ""): # https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Transfer-Encoding req_body = self.read_chunked_body(self.rfile) + elif "Content-Length" in req.headers: + content_length = int(req.headers["Content-Length"]) + req_body = self.rfile.read(content_length) if self.headers.get("Content-Type", "").startswith("multipart/form-data"): boundary = self.headers["Content-Type"].split("boundary=")[-1] req_body, form_data = get_body_from_form_data(req_body, boundary) # type: ignore diff --git a/moto/s3/responses.py b/moto/s3/responses.py index d8b88e6859a0..a88c39851e20 100644 --- a/moto/s3/responses.py +++ b/moto/s3/responses.py @@ -183,6 +183,12 @@ def setup_class(self, request: Any, full_url: str, headers: Any) -> None: # typ ) self.bucket_name = self.parse_bucket_name_from_url(request, full_url) self.request = request + if ( + not self.body + and request.headers.get("Content-Encoding", "") == "aws-chunked" + and hasattr(request, "input_stream") + ): + self.body = request.input_stream.getvalue() if ( self.request.headers.get("x-amz-content-sha256") == "STREAMING-UNSIGNED-PAYLOAD-TRAILER" @@ -1347,6 +1353,8 @@ def _handle_v4_chunk_signatures(self, body: bytes, content_length: int) -> bytes def _handle_encoded_body(self, body: bytes) -> bytes: decoded_body = b"" + if not body: + return decoded_body body_io = io.BytesIO(body) # first line should equal '{content_length}\r\n' while the content_length is a hex number content_length = int(body_io.readline().strip(), 16) @@ -1769,6 +1777,8 @@ def _get_checksum( checksum_value = compute_checksum( self.raw_body, algorithm=checksum_algorithm ) + if isinstance(checksum_value, bytes): + checksum_value = checksum_value.decode("utf-8") response_headers.update({checksum_header: checksum_value}) return checksum_algorithm, checksum_value diff --git a/tests/test_s3/__init__.py b/tests/test_s3/__init__.py index 032dd0f60ad0..fa2bbde27044 100644 --- a/tests/test_s3/__init__.py +++ b/tests/test_s3/__init__.py @@ -81,3 +81,12 @@ def empty_bucket(client, bucket_name): client.delete_object( Bucket=bucket_name, Key=key["Key"], VersionId=key.get("VersionId"), **kwargs ) + + +def generate_content_md5(content: bytes) -> str: + import base64 + import hashlib + + md = hashlib.md5(content).digest() + content_md5 = base64.b64encode(md).decode("utf-8") + return content_md5 diff --git a/tests/test_s3/test_s3.py b/tests/test_s3/test_s3.py index 742c562ad044..a8424c6a7a10 100644 --- a/tests/test_s3/test_s3.py +++ b/tests/test_s3/test_s3.py @@ -1076,7 +1076,7 @@ def test_setting_content_encoding(): bucket.put_object(Body=b"abcdef", ContentEncoding="gzip", Key="keyname") key = s3_resource.Object("mybucket", "keyname") - assert key.content_encoding == "gzip" + assert "gzip" in key.content_encoding @mock_aws @@ -1626,8 +1626,8 @@ def test_list_objects_v2_checksum_algo(): s3_client = boto3.client("s3", region_name=DEFAULT_REGION_NAME) s3_client.create_bucket(Bucket="mybucket") resp = s3_client.put_object(Bucket="mybucket", Key="0", Body="a") - assert "ChecksumCRC32" not in resp - assert "x-amz-sdk-checksum-algorithm" not in resp["ResponseMetadata"]["HTTPHeaders"] + # Default checksum behavior varies by boto3 version and will not be asserted here. + assert resp resp = s3_client.put_object( Bucket="mybucket", Key="1", Body="a", ChecksumAlgorithm="CRC32" ) @@ -1646,7 +1646,7 @@ def test_list_objects_v2_checksum_algo(): ) resp = s3_client.list_objects_v2(Bucket="mybucket")["Contents"] - assert "ChecksumAlgorithm" not in resp[0] + assert "ChecksumAlgorithm" in resp[0] assert resp[1]["ChecksumAlgorithm"] == ["CRC32"] assert resp[2]["ChecksumAlgorithm"] == ["SHA256"] diff --git a/tests/test_s3/test_s3_copyobject.py b/tests/test_s3/test_s3_copyobject.py index a75f18cc4827..136d7bb47b68 100644 --- a/tests/test_s3/test_s3_copyobject.py +++ b/tests/test_s3/test_s3_copyobject.py @@ -7,6 +7,7 @@ from moto import mock_aws from moto.s3.responses import DEFAULT_REGION_NAME +from tests.test_s3 import generate_content_md5 from tests.test_s3.test_s3 import enable_versioning from . import s3_aws_verified @@ -656,6 +657,7 @@ def test_copy_objet_legal_hold(): Key=source_key, Body=b"somedata", ObjectLockLegalHoldStatus="ON", + ContentMD5=generate_content_md5(b"somedata"), ) head_object = client.head_object(Bucket=bucket_name, Key=source_key) @@ -692,9 +694,10 @@ def test_s3_copy_object_lock(): client.put_object( Bucket=bucket_name, Key=source_key, - Body="test", + Body=b"test", ObjectLockMode="GOVERNANCE", ObjectLockRetainUntilDate=retain_until, + ContentMD5=generate_content_md5(b"test"), ) head_object = client.head_object(Bucket=bucket_name, Key=source_key) @@ -909,12 +912,11 @@ def test_copy_object_calculates_checksum(algorithm, checksum): checksum_key = f"Checksum{algorithm}" - resp = client.put_object( + client.put_object( Bucket=bucket, Key=source_key, Body=body, ) - assert checksum_key not in resp resp = client.copy_object( Bucket=bucket, diff --git a/tests/test_s3/test_s3_lock.py b/tests/test_s3/test_s3_lock.py index 242ba9eae649..69c95fa19e0e 100644 --- a/tests/test_s3/test_s3_lock.py +++ b/tests/test_s3/test_s3_lock.py @@ -10,7 +10,7 @@ from moto.core.utils import utcnow from moto.s3.responses import DEFAULT_REGION_NAME from tests import allow_aws_request -from tests.test_s3 import s3_aws_verified +from tests.test_s3 import generate_content_md5, s3_aws_verified from tests.test_s3.test_s3 import enable_versioning @@ -86,6 +86,7 @@ def test_locked_object_governance_mode(bypass_governance_retention, bucket_name= Key=key_name, ObjectLockMode="GOVERNANCE", ObjectLockRetainUntilDate=until, + ContentMD5=generate_content_md5(b"test"), ) versions_response = s3_client.list_object_versions(Bucket=bucket_name) @@ -196,6 +197,7 @@ def test_locked_object_compliance_mode(bypass_governance_retention, bucket_name= Key=key_name, ObjectLockMode="COMPLIANCE", ObjectLockRetainUntilDate=until, + ContentMD5=generate_content_md5(b"test"), ) versions_response = s3_client.list_object_versions(Bucket=bucket_name) @@ -256,6 +258,7 @@ def test_fail_locked_object(): Key=key_name, ObjectLockMode="COMPLIANCE", ObjectLockRetainUntilDate=until, + ContentMD5=generate_content_md5(b"test"), ) except ClientError as exc: assert exc.response["Error"]["Code"] == "InvalidRequest" @@ -321,7 +324,12 @@ def test_put_object_legal_hold(bucket_name=None): }, ) - s3_client.put_object(Bucket=bucket_name, Body=b"test", Key=key_name) + s3_client.put_object( + Bucket=bucket_name, + Body=b"test", + Key=key_name, + ContentMD5=generate_content_md5(b"test"), + ) versions_response = s3_client.list_object_versions(Bucket=bucket_name) version_id = versions_response["Versions"][0]["VersionId"] @@ -331,6 +339,9 @@ def test_put_object_legal_hold(bucket_name=None): Key=key_name, VersionId=version_id, LegalHold={"Status": "ON"}, + ContentMD5=generate_content_md5( + b'ON' + ), ) with pytest.raises(ClientError) as exc: @@ -349,6 +360,9 @@ def test_put_object_legal_hold(bucket_name=None): Key=key_name, VersionId=version_id, LegalHold={"Status": "OFF"}, + ContentMD5=generate_content_md5( + b'OFF' + ), ) s3_client.delete_object( Bucket=bucket_name, @@ -379,7 +393,12 @@ def test_put_default_lock(): }, ) - s3_client.put_object(Bucket=bucket_name, Body=b"test", Key=key_name) + s3_client.put_object( + Bucket=bucket_name, + Body=b"test", + Key=key_name, + ContentMD5=generate_content_md5(b"test"), + ) deleted = False versions_response = s3_client.list_object_versions(Bucket=bucket_name) diff --git a/tests/test_s3/test_s3_object_attributes.py b/tests/test_s3/test_s3_object_attributes.py index 30a279d1ae94..a4c62d3f188d 100644 --- a/tests/test_s3/test_s3_object_attributes.py +++ b/tests/test_s3/test_s3_object_attributes.py @@ -69,9 +69,6 @@ def test_get_attributes_checksum(self, algo_val): ) resp.pop("ResponseMetadata") - # Checksum is not returned, because it's not set - assert set(resp.keys()) == {"LastModified"} - # Retrieve checksum from key that was created with CRC32 resp = self.client.get_object_attributes( Bucket=self.bucket_name, Key="cs", ObjectAttributes=["Checksum"]