diff --git a/src/s3_encryption/__init__.py b/src/s3_encryption/__init__.py index 3f188510..e06ca9e1 100644 --- a/src/s3_encryption/__init__.py +++ b/src/s3_encryption/__init__.py @@ -3,6 +3,7 @@ """Top-level S3 Encryption Client v4 for Python package.""" import io +import os import threading from attrs import define, field @@ -19,10 +20,19 @@ ) from .materials.keyring import AbstractKeyring from .materials.materials import AlgorithmSuite, CommitmentPolicy -from .pipelines import GetEncryptedObjectPipeline, PutEncryptedObjectPipeline +from .pipelines import ( + GetEncryptedObjectPipeline, + MultipartUploadPipeline, + PutEncryptedObjectPipeline, +) S3_METADATA_PREFIX = "x-amz-meta-" +# Default multipart threshold and chunk size (same as boto3 defaults) +_DEFAULT_MULTIPART_THRESHOLD = 8 * 1024 * 1024 # 8 MB +_DEFAULT_MULTIPART_CHUNKSIZE = 8 * 1024 * 1024 # 8 MB +_MIN_MULTIPART_PART_SIZE = 5 * 1024 * 1024 # 5 MB — S3 minimum for non-final parts + # Thread-local context attribute names _CTX_ENCRYPTION_CONTEXT = "encryption_context" _CTX_BUCKET = "bucket" @@ -341,6 +351,10 @@ class S3EncryptionClient: wrapped_s3_client = field() config: S3EncryptionClientConfig = field() _plugin: S3EncryptionClientPlugin = field(init=False) + # Each upload gets its own pipeline with independent cipher state, keyed by UploadId. + # Access is protected by a lock for thread safety across all Python runtimes. + _multipart_uploads: dict = field(init=False, factory=dict) + _multipart_lock: threading.Lock = field(init=False, factory=threading.Lock) def __attrs_post_init__(self): """Install the encryption plugin on the wrapped client using boto3 events.""" @@ -563,3 +577,290 @@ def get_object(self, **kwargs): for attr in _GET_OBJECT_CLEANUP_ATTRS: if hasattr(self._plugin._context, attr): delattr(self._plugin._context, attr) + + ##= specification/s3-encryption/client.md#optional-api-operations + ##= type=implementation + ##% CreateMultipartUpload MAY be implemented by the S3EC. + def create_multipart_upload(self, **kwargs): + """Initiate an encrypted multipart upload. + + Obtains encryption materials, initializes the cipher, and calls + the underlying S3 CreateMultipartUpload. Encryption metadata is + set on the object at this point. + + Args: + **kwargs: Arguments for S3 create_multipart_upload. + May include EncryptionContext. + + Returns: + The response from S3 create_multipart_upload. + """ + encryption_context = kwargs.pop("EncryptionContext", None) + _validate_encryption_context(encryption_context) + + pipeline = MultipartUploadPipeline( + cmm=self.config.cmm, + encryption_algorithm=self.config.encryption_algorithm, + encryption_context=encryption_context or {}, + ) + + # Merge encryption metadata into user-provided Metadata + user_metadata = dict(kwargs.get("Metadata", {})) + user_metadata.update(pipeline.metadata) + kwargs["Metadata"] = user_metadata + + ##= specification/s3-encryption/client.md#optional-api-operations + ##= type=implementation + ##% If implemented, CreateMultipartUpload MUST initiate a multipart upload. + try: + response = self.wrapped_s3_client.create_multipart_upload(**kwargs) + except Exception as e: + raise S3EncryptionClientError(f"Failed to create multipart upload: {e}") from e + + upload_id = response["UploadId"] + with self._multipart_lock: + self._multipart_uploads[upload_id] = pipeline + return response + + ##= specification/s3-encryption/client.md#optional-api-operations + ##= type=implementation + ##% UploadPart MAY be implemented by the S3EC. + ##= specification/s3-encryption/client.md#optional-api-operations + ##= type=implementation + ##% UploadPart MUST encrypt each part. + ##= specification/s3-encryption/client.md#optional-api-operations + ##= type=implementation + ##% Each part MUST be encrypted in sequence. + ##= specification/s3-encryption/client.md#optional-api-operations + ##= type=implementation + ##% Each part MUST be encrypted using the same cipher instance for each part. + def upload_part(self, **kwargs): + """Encrypt and upload a single part of a multipart upload. + + Parts must be uploaded in sequential order (1, 2, 3, ...). + The caller MUST set ``IsLastPart=True`` on the final part so the + GCM authentication tag is appended to the ciphertext. + + Args: + **kwargs: Arguments for S3 upload_part. Must include UploadId, + PartNumber, and Body. Set IsLastPart=True on the + final part. + + Returns: + The response from S3 upload_part (includes ETag). + """ + upload_id = kwargs.get("UploadId") + with self._multipart_lock: + pipeline = self._multipart_uploads.get(upload_id) + if pipeline is None: + raise S3EncryptionClientError( + f"No multipart upload found for UploadId: {upload_id}. " + "Call create_multipart_upload first." + ) + + part_number = kwargs["PartNumber"] + is_last = kwargs.pop("IsLastPart", False) + body = kwargs.get("Body", b"") + if isinstance(body, str): + body = body.encode("utf-8") + elif hasattr(body, "read"): + body = body.read() + + try: + ciphertext = pipeline.encrypt_part(part_number, body, is_last=is_last) + except S3EncryptionClientError: + raise + except Exception as e: + raise S3EncryptionClientError(f"Failed to encrypt part {part_number}: {e}") from e + + kwargs["Body"] = ciphertext + return self.wrapped_s3_client.upload_part(**kwargs) + + ##= specification/s3-encryption/client.md#optional-api-operations + ##= type=implementation + ##% CompleteMultipartUpload MAY be implemented by the S3EC. + ##% CompleteMultipartUpload MUST complete the multipart upload. + def complete_multipart_upload(self, **kwargs): + """Complete the multipart upload. + + The final part must have been uploaded with ``IsLastPart=True`` + before calling this method. + + Args: + **kwargs: Arguments for S3 complete_multipart_upload. + MultipartUpload.Parts must include PartNumber and ETag + for each part. + + Returns: + The response from S3 complete_multipart_upload. + """ + upload_id = kwargs.get("UploadId") + with self._multipart_lock: + pipeline = self._multipart_uploads.get(upload_id) + if pipeline is None: + raise S3EncryptionClientError(f"No multipart upload found for UploadId: {upload_id}.") + + if not pipeline.has_final_part_been_seen: + raise S3EncryptionClientError( + "Cannot complete multipart upload: the final part has not been uploaded. " + "Set IsLastPart=True on the last upload_part call." + ) + + try: + response = self.wrapped_s3_client.complete_multipart_upload(**kwargs) + except S3EncryptionClientError: + raise + except Exception as e: + raise S3EncryptionClientError(f"Failed to complete multipart upload: {e}") from e + else: + with self._multipart_lock: + self._multipart_uploads.pop(upload_id, None) + return response + + ##= specification/s3-encryption/client.md#optional-api-operations + ##= type=implementation + ##% AbortMultipartUpload MAY be implemented by the S3EC. + ##% AbortMultipartUpload MUST abort the multipart upload. + def abort_multipart_upload(self, **kwargs): + """Abort a multipart upload and clean up cipher state. + + Args: + **kwargs: Arguments for S3 abort_multipart_upload. + + Returns: + The response from S3 abort_multipart_upload. + """ + upload_id = kwargs.get("UploadId") + with self._multipart_lock: + self._multipart_uploads.pop(upload_id, None) + return self.wrapped_s3_client.abort_multipart_upload(**kwargs) + + def upload_file( + self, filename, bucket, key, multipart_threshold=None, multipart_chunksize=None, **kwargs + ): + """Encrypt and upload a file to S3. + + If the file is smaller than the threshold, uses put_object. + Otherwise, performs an encrypted multipart upload. + + Args: + filename: Path to the file to upload. + bucket: Target S3 bucket. + key: Target S3 object key. + multipart_threshold: File size threshold for multipart (default 8MB). + multipart_chunksize: Size of each part (default 8MB). + **kwargs: Additional arguments (e.g. EncryptionContext, Metadata). + """ + threshold = ( + _DEFAULT_MULTIPART_THRESHOLD if multipart_threshold is None else multipart_threshold + ) + chunksize = ( + _DEFAULT_MULTIPART_CHUNKSIZE if multipart_chunksize is None else multipart_chunksize + ) + if threshold <= 0: + raise S3EncryptionClientError("multipart_threshold must be a positive integer.") + if chunksize <= 0: + raise S3EncryptionClientError("multipart_chunksize must be a positive integer.") + if chunksize < _MIN_MULTIPART_PART_SIZE: + raise S3EncryptionClientError( + f"multipart_chunksize must be at least {_MIN_MULTIPART_PART_SIZE} bytes (5 MB). " + f"S3 requires all non-final parts to be at least 5 MB." + ) + file_size = os.path.getsize(filename) + + if file_size < threshold: + with open(filename, "rb") as f: + kwargs["Bucket"] = bucket + kwargs["Key"] = key + kwargs["Body"] = f.read() + return self.put_object(**kwargs) + + return self._multipart_upload_from_readable( + open(filename, "rb"), bucket, key, chunksize, owns_readable=True, **kwargs + ) + + def upload_fileobj(self, fileobj, bucket, key, multipart_chunksize=None, **kwargs): + """Encrypt and upload a file-like object to S3 via multipart upload. + + The caller retains ownership of fileobj — it will not be closed + by this method. + + Args: + fileobj: A file-like object with a read() method. + bucket: Target S3 bucket. + key: Target S3 object key. + multipart_chunksize: Size of each part (default 8MB). + **kwargs: Additional arguments (e.g. EncryptionContext, Metadata). + """ + chunksize = ( + _DEFAULT_MULTIPART_CHUNKSIZE if multipart_chunksize is None else multipart_chunksize + ) + if chunksize <= 0: + raise S3EncryptionClientError("multipart_chunksize must be a positive integer.") + if chunksize < _MIN_MULTIPART_PART_SIZE: + raise S3EncryptionClientError( + f"multipart_chunksize must be at least {_MIN_MULTIPART_PART_SIZE} bytes (5 MB). " + f"S3 requires all non-final parts to be at least 5 MB." + ) + return self._multipart_upload_from_readable( + fileobj, bucket, key, chunksize, owns_readable=False, **kwargs + ) + + def _multipart_upload_from_readable( + self, readable, bucket, key, chunksize, *, owns_readable=False, **kwargs + ): + """Perform an encrypted multipart upload from a readable source. + + Args: + readable: File-like object to read from. + bucket: Target S3 bucket. + key: Target S3 object key. + chunksize: Size of each part in bytes. + owns_readable: If True, close readable when done. If False, + the caller is responsible for closing it. + **kwargs: Additional S3 parameters forwarded to create_multipart_upload. + """ + # EncryptionContext is consumed by our pipeline, not S3 + create_kwargs = {"Bucket": bucket, "Key": key} + if "EncryptionContext" in kwargs: + create_kwargs["EncryptionContext"] = kwargs.pop("EncryptionContext") + if "Metadata" in kwargs: + create_kwargs["Metadata"] = kwargs.pop("Metadata") + # Forward remaining kwargs (ACL, ContentType, Tagging, etc.) to create_multipart_upload + create_kwargs.update(kwargs) + + create_resp = self.create_multipart_upload(**create_kwargs) + upload_id = create_resp["UploadId"] + + try: + parts = [] + part_number = 0 + # Read ahead so we can detect the last chunk + current_chunk = readable.read(chunksize) + while current_chunk: + next_chunk = readable.read(chunksize) + part_number += 1 + is_last = not next_chunk + resp = self.upload_part( + Bucket=bucket, + Key=key, + UploadId=upload_id, + PartNumber=part_number, + Body=current_chunk, + IsLastPart=is_last, + ) + parts.append({"PartNumber": part_number, "ETag": resp["ETag"]}) + current_chunk = next_chunk + + return self.complete_multipart_upload( + Bucket=bucket, + Key=key, + UploadId=upload_id, + MultipartUpload={"Parts": parts}, + ) + except Exception: + self.abort_multipart_upload(Bucket=bucket, Key=key, UploadId=upload_id) + raise + finally: + if owns_readable and hasattr(readable, "close"): + readable.close() diff --git a/src/s3_encryption/pipelines.py b/src/s3_encryption/pipelines.py index 23a169ae..ca200a7f 100644 --- a/src/s3_encryption/pipelines.py +++ b/src/s3_encryption/pipelines.py @@ -9,6 +9,7 @@ import base64 import json import os +import threading from attrs import define, field from botocore.response import StreamingBody @@ -171,6 +172,142 @@ def _encrypt_kc_gcm(self, plaintext, enc_mats, edk_bytes): return encrypted_data, metadata.to_dict() +##= specification/s3-encryption/client.md#optional-api-operations +##= type=implementation +##% UploadPart MUST encrypt each part. +##= specification/s3-encryption/client.md#optional-api-operations +##= type=implementation +##% Each part MUST be encrypted in sequence. +##= specification/s3-encryption/client.md#optional-api-operations +##= type=implementation +##% Each part MUST be encrypted using the same cipher instance for each part. +@define +class MultipartUploadPipeline: + """Pipeline for encrypting multipart uploads. + + Manages a single AES-GCM cipher instance shared across all parts. + Parts MUST be uploaded in sequence (1, 2, 3, ...). + """ + + cmm: AbstractCryptoMaterialsManager = field() + encryption_algorithm: AlgorithmSuite = field() + encryption_context: dict = field(factory=dict) + _encryptor: object = field(init=False, default=None) + _metadata: dict = field(init=False, factory=dict) + _next_part: int = field(init=False, default=1) + _has_final_part_been_seen: bool = field(init=False, default=False) + _lock: threading.Lock = field(init=False, factory=threading.Lock) + # Cached ciphertext for the most recently encrypted part, enabling retries + # if the S3 upload_part call fails after encryption has already advanced. + _last_encrypted_part: int = field(init=False, default=0) + _last_encrypted_ciphertext: bytes | None = field(init=False, default=None) + + def __attrs_post_init__(self): + """Obtain encryption materials and initialize the streaming cipher.""" + enc_mats_request = EncryptionMaterials( + encryption_algorithm=self.encryption_algorithm, + encryption_context=self.encryption_context.copy(), + ) + enc_mats = self.cmm.get_encryption_materials(enc_mats_request) + if enc_mats.plaintext_data_key is None: + raise S3EncryptionClientError("No plaintext data key found!") + if enc_mats.encrypted_data_key is None: + raise S3EncryptionClientError("No encrypted data key found!") + + edk_bytes = enc_mats.encrypted_data_key.encrypted_data_key + + if self.encryption_algorithm == AlgorithmSuite.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY: + self._init_kc_gcm(enc_mats, edk_bytes) + else: + self._init_gcm(enc_mats, edk_bytes) + + def _init_gcm(self, enc_mats, edk_bytes): + iv = os.urandom(enc_mats.encryption_algorithm.cipher_iv_length_bytes) + cipher = Cipher(algorithms.AES(enc_mats.plaintext_data_key), modes.GCM(iv)) + self._encryptor = cipher.encryptor() + self._metadata = ObjectMetadata( + encrypted_data_key_v2=base64.b64encode(edk_bytes).decode("utf-8"), + encrypted_data_key_algorithm="kms+context", + content_iv=base64.b64encode(iv).decode("utf-8"), + content_cipher="AES/GCM/NoPadding", + encrypted_data_key_context=enc_mats.encryption_context, + ).to_dict() + + def _init_kc_gcm(self, enc_mats, edk_bytes): + algorithm_suite = enc_mats.encryption_algorithm + message_id = os.urandom(algorithm_suite.commitment_nonce_length_bytes) + derived_encryption_key, commit_key = derive_keys( + enc_mats.plaintext_data_key, message_id, algorithm_suite + ) + cipher = Cipher( + algorithms.AES(derived_encryption_key), modes.GCM(algorithm_suite.kc_gcm_iv) + ) + self._encryptor = cipher.encryptor() + self._encryptor.authenticate_additional_data(algorithm_suite.suite_id_bytes) + self._metadata = ObjectMetadata( + content_cipher_v3=str(algorithm_suite.suite_id), + encrypted_data_key_algorithm_v3="12", + encrypted_data_key_v3=base64.b64encode(edk_bytes).decode("utf-8"), + message_id_v3=base64.b64encode(message_id).decode("utf-8"), + key_commitment_v3=base64.b64encode(commit_key).decode("utf-8"), + encryption_context_v3=( + enc_mats.encryption_context if enc_mats.encryption_context else None + ), + ).to_dict() + + @property + def metadata(self): + """Return the encryption metadata dict for the multipart upload.""" + return self._metadata + + @property + def has_final_part_been_seen(self): + """Return whether the final part has been encrypted.""" + return self._has_final_part_been_seen + + def encrypt_part(self, part_number, data, is_last=False): + """Encrypt a single part. Parts must be sequential starting from 1. + + If called with the same part_number as the most recently encrypted part, + returns the cached ciphertext (enabling retries after upload failures). + + Args: + part_number: The 1-based part number. + data: The plaintext bytes for this part. + is_last: If True, finalizes the cipher and appends the GCM auth tag. + + Returns: + The encrypted ciphertext bytes for this part. + """ + with self._lock: + # Allow retry of the last encrypted part + if part_number == self._last_encrypted_part: + return self._last_encrypted_ciphertext + + if self._has_final_part_been_seen: + raise S3EncryptionClientError("Cannot encrypt more parts after the final part.") + if part_number != self._next_part: + raise S3EncryptionClientError( + f"Parts must be uploaded in sequence. Expected part {self._next_part}, " + f"got {part_number}." + ) + if isinstance(data, str): + data = data.encode("utf-8") + self._next_part += 1 + + ciphertext = self._encryptor.update(data) + + if is_last: + self._encryptor.finalize() + ciphertext += self._encryptor.tag + self._has_final_part_been_seen = True + + self._last_encrypted_part = part_number + self._last_encrypted_ciphertext = ciphertext + + return ciphertext + + @define class GetEncryptedObjectPipeline: """Pipeline for decrypting objects after they are retrieved from S3. diff --git a/test/integration/test_i_s3_encryption_multipart.py b/test/integration/test_i_s3_encryption_multipart.py new file mode 100644 index 00000000..1aefef31 --- /dev/null +++ b/test/integration/test_i_s3_encryption_multipart.py @@ -0,0 +1,968 @@ +# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Integration tests for encrypted multipart upload. + +These tests verify that the S3 Encryption Client correctly encrypts +objects via multipart upload and that they can be decrypted via get_object. +Tests cover the low-level multipart API (create/upload_part/complete/abort) +and the high-level upload_file / upload_fileobj convenience methods. +""" + +import os +import threading +from datetime import datetime + +import boto3 +import pytest + +from s3_encryption import S3EncryptionClient, S3EncryptionClientConfig +from s3_encryption.exceptions import S3EncryptionClientError +from s3_encryption.materials.kms_keyring import KmsKeyring +from s3_encryption.materials.materials import AlgorithmSuite, CommitmentPolicy + +bucket = os.environ.get("CI_S3_BUCKET", "s3ec-python-github-test-bucket") +region = os.environ.get("CI_AWS_REGION", "us-west-2") +kms_key_id = os.environ.get( + "CI_KMS_KEY_ALIAS", "arn:aws:kms:us-west-2:370957321024:alias/S3EC-Python-Github-KMS-Key" +) + +ALGORITHM_CONFIGS = [ + pytest.param( + AlgorithmSuite.ALG_AES_256_GCM_IV12_TAG16_NO_KDF, + CommitmentPolicy.FORBID_ENCRYPT_ALLOW_DECRYPT, + id="AES_GCM", + ), + pytest.param( + AlgorithmSuite.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY, + CommitmentPolicy.REQUIRE_ENCRYPT_REQUIRE_DECRYPT, + id="KC_GCM", + ), +] + +# Minimum part size for S3 multipart upload is 5 MB (except last part). +FIVE_MB = 5 * 1024 * 1024 + + +def _make_client(algorithm_suite, commitment_policy, **extra_config): + kms_client = boto3.client("kms", region_name=region) + keyring = KmsKeyring(kms_client, kms_key_id) + wrapped_client = boto3.client("s3") + config = S3EncryptionClientConfig( + keyring, + encryption_algorithm=algorithm_suite, + commitment_policy=commitment_policy, + **extra_config, + ) + return S3EncryptionClient(wrapped_client, config) + + +def _unique_key(prefix): + return prefix + datetime.now().strftime("%Y-%m-%d-%H:%M:%S-%f") + + +# --------------------------------------------------------------------------- +# Low-level multipart API: create → upload_part → complete +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("algorithm_suite,commitment_policy", ALGORITHM_CONFIGS) +def test_multipart_two_parts_roundtrip(algorithm_suite, commitment_policy): + """Encrypt two 5 MB parts via multipart upload, then decrypt with get_object.""" + key = _unique_key("mpu-2part-") + part1_data = os.urandom(FIVE_MB) + part2_data = os.urandom(1024) # last part can be smaller + expected = part1_data + part2_data + + s3ec = _make_client(algorithm_suite, commitment_policy) + + ##= specification/s3-encryption/client.md#optional-api-operations + ##= type=test + ##% CreateMultipartUpload MAY be implemented by the S3EC. + ##% If implemented, CreateMultipartUpload MUST initiate a multipart upload. + create_resp = s3ec.create_multipart_upload(Bucket=bucket, Key=key) + upload_id = create_resp["UploadId"] + + try: + ##= specification/s3-encryption/client.md#optional-api-operations + ##= type=test + ##% UploadPart MAY be implemented by the S3EC. + resp1 = s3ec.upload_part( + Bucket=bucket, Key=key, UploadId=upload_id, PartNumber=1, Body=part1_data + ) + ##= specification/s3-encryption/client.md#optional-api-operations + ##= type=test + ##% Each part MUST be encrypted in sequence. + resp2 = s3ec.upload_part( + Bucket=bucket, + Key=key, + UploadId=upload_id, + PartNumber=2, + Body=part2_data, + IsLastPart=True, + ) + + ##= specification/s3-encryption/client.md#optional-api-operations + ##= type=test + ##% CompleteMultipartUpload MUST complete the multipart upload. + s3ec.complete_multipart_upload( + Bucket=bucket, + Key=key, + UploadId=upload_id, + MultipartUpload={ + "Parts": [ + {"PartNumber": 1, "ETag": resp1["ETag"]}, + {"PartNumber": 2, "ETag": resp2["ETag"]}, + ] + }, + ) + except Exception: + s3ec.abort_multipart_upload(Bucket=bucket, Key=key, UploadId=upload_id) + raise + + ##= specification/s3-encryption/client.md#optional-api-operations + ##= type=test + ##% Each part MUST be encrypted using the same cipher instance for each part. + response = s3ec.get_object(Bucket=bucket, Key=key) + assert response["Body"].read() == expected + + ##= specification/s3-encryption/client.md#optional-api-operations + ##= type=test + ##% UploadPart MUST encrypt each part. + plain_s3 = boto3.client("s3") + raw_response = plain_s3.get_object(Bucket=bucket, Key=key) + raw_content = raw_response["Body"].read() + assert raw_content != expected + + +@pytest.mark.parametrize("algorithm_suite,commitment_policy", ALGORITHM_CONFIGS) +def test_multipart_single_part(algorithm_suite, commitment_policy): + """A multipart upload with a single part should still round-trip correctly.""" + key = _unique_key("mpu-1part-") + data = os.urandom(FIVE_MB) + + s3ec = _make_client(algorithm_suite, commitment_policy) + + create_resp = s3ec.create_multipart_upload(Bucket=bucket, Key=key) + upload_id = create_resp["UploadId"] + + try: + resp = s3ec.upload_part( + Bucket=bucket, + Key=key, + UploadId=upload_id, + PartNumber=1, + Body=data, + IsLastPart=True, + ) + + s3ec.complete_multipart_upload( + Bucket=bucket, + Key=key, + UploadId=upload_id, + MultipartUpload={"Parts": [{"PartNumber": 1, "ETag": resp["ETag"]}]}, + ) + except Exception: + s3ec.abort_multipart_upload(Bucket=bucket, Key=key, UploadId=upload_id) + raise + + response = s3ec.get_object(Bucket=bucket, Key=key) + assert response["Body"].read() == data + + +@pytest.mark.parametrize("algorithm_suite,commitment_policy", ALGORITHM_CONFIGS) +def test_multipart_three_parts(algorithm_suite, commitment_policy): + """Three-part multipart upload: 5MB + 5MB + small last part.""" + key = _unique_key("mpu-3part-") + parts_data = [os.urandom(FIVE_MB), os.urandom(FIVE_MB), os.urandom(2048)] + expected = b"".join(parts_data) + + s3ec = _make_client(algorithm_suite, commitment_policy) + + create_resp = s3ec.create_multipart_upload(Bucket=bucket, Key=key) + upload_id = create_resp["UploadId"] + + try: + parts = [] + for i, part_data in enumerate(parts_data, start=1): + is_last = i == len(parts_data) + resp = s3ec.upload_part( + Bucket=bucket, + Key=key, + UploadId=upload_id, + PartNumber=i, + Body=part_data, + IsLastPart=is_last, + ) + parts.append({"PartNumber": i, "ETag": resp["ETag"]}) + + s3ec.complete_multipart_upload( + Bucket=bucket, + Key=key, + UploadId=upload_id, + MultipartUpload={"Parts": parts}, + ) + except Exception: + s3ec.abort_multipart_upload(Bucket=bucket, Key=key, UploadId=upload_id) + raise + + response = s3ec.get_object(Bucket=bucket, Key=key) + assert response["Body"].read() == expected + + +# --------------------------------------------------------------------------- +# Abort +# --------------------------------------------------------------------------- + + +##= specification/s3-encryption/client.md#optional-api-operations +##= type=test +##% AbortMultipartUpload MAY be implemented by the S3EC. +##% AbortMultipartUpload MUST abort the multipart upload. +@pytest.mark.parametrize("algorithm_suite,commitment_policy", ALGORITHM_CONFIGS) +def test_abort_multipart_upload(algorithm_suite, commitment_policy): + """Aborting a multipart upload should clean up without leaving an object.""" + key = _unique_key("mpu-abort-") + + s3ec = _make_client(algorithm_suite, commitment_policy) + + create_resp = s3ec.create_multipart_upload(Bucket=bucket, Key=key) + upload_id = create_resp["UploadId"] + + # Upload one part then abort + s3ec.upload_part( + Bucket=bucket, Key=key, UploadId=upload_id, PartNumber=1, Body=os.urandom(FIVE_MB) + ) + s3ec.abort_multipart_upload(Bucket=bucket, Key=key, UploadId=upload_id) + + # Object should not exist + plain_s3 = boto3.client("s3") + with pytest.raises(plain_s3.exceptions.NoSuchKey): + plain_s3.get_object(Bucket=bucket, Key=key) + + +# --------------------------------------------------------------------------- +# Encryption context with multipart upload +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("algorithm_suite,commitment_policy", ALGORITHM_CONFIGS) +def test_multipart_with_encryption_context(algorithm_suite, commitment_policy): + """Multipart upload with encryption context should be usable on decrypt.""" + key = _unique_key("mpu-ec-") + data = os.urandom(FIVE_MB + 1024) + encryption_context = {"project": "s3ec-python", "test": "multipart"} + + s3ec = _make_client(algorithm_suite, commitment_policy) + + create_resp = s3ec.create_multipart_upload( + Bucket=bucket, Key=key, EncryptionContext=encryption_context + ) + upload_id = create_resp["UploadId"] + + try: + resp1 = s3ec.upload_part( + Bucket=bucket, Key=key, UploadId=upload_id, PartNumber=1, Body=data[:FIVE_MB] + ) + resp2 = s3ec.upload_part( + Bucket=bucket, + Key=key, + UploadId=upload_id, + PartNumber=2, + Body=data[FIVE_MB:], + IsLastPart=True, + ) + + s3ec.complete_multipart_upload( + Bucket=bucket, + Key=key, + UploadId=upload_id, + MultipartUpload={ + "Parts": [ + {"PartNumber": 1, "ETag": resp1["ETag"]}, + {"PartNumber": 2, "ETag": resp2["ETag"]}, + ] + }, + ) + except Exception: + s3ec.abort_multipart_upload(Bucket=bucket, Key=key, UploadId=upload_id) + raise + + # Decrypt with matching encryption context + response = s3ec.get_object(Bucket=bucket, Key=key, EncryptionContext=encryption_context) + assert response["Body"].read() == data + + # Decrypt with wrong encryption context should fail + with pytest.raises(S3EncryptionClientError): + s3ec.get_object(Bucket=bucket, Key=key, EncryptionContext={"wrong": "context"}) + + +# --------------------------------------------------------------------------- +# Streaming decryption of multipart-uploaded objects +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("algorithm_suite,commitment_policy", ALGORITHM_CONFIGS) +def test_multipart_decrypt_with_delayed_auth(algorithm_suite, commitment_policy): + """Objects uploaded via multipart should be decryptable in delayed-auth mode.""" + key = _unique_key("mpu-delayed-auth-") + data = os.urandom(FIVE_MB + 2048) + + # Encrypt with default (buffered) client + writer = _make_client(algorithm_suite, commitment_policy) + create_resp = writer.create_multipart_upload(Bucket=bucket, Key=key) + upload_id = create_resp["UploadId"] + + try: + resp1 = writer.upload_part( + Bucket=bucket, Key=key, UploadId=upload_id, PartNumber=1, Body=data[:FIVE_MB] + ) + resp2 = writer.upload_part( + Bucket=bucket, + Key=key, + UploadId=upload_id, + PartNumber=2, + Body=data[FIVE_MB:], + IsLastPart=True, + ) + + writer.complete_multipart_upload( + Bucket=bucket, + Key=key, + UploadId=upload_id, + MultipartUpload={ + "Parts": [ + {"PartNumber": 1, "ETag": resp1["ETag"]}, + {"PartNumber": 2, "ETag": resp2["ETag"]}, + ] + }, + ) + except Exception: + writer.abort_multipart_upload(Bucket=bucket, Key=key, UploadId=upload_id) + raise + + # Decrypt with delayed-auth streaming + reader = _make_client(algorithm_suite, commitment_policy, enable_delayed_authentication=True) + response = reader.get_object(Bucket=bucket, Key=key) + + result = b"" + while chunk := response["Body"].read(65536): + result += chunk + assert result == data + + +# --------------------------------------------------------------------------- +# Metadata verification +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("algorithm_suite,commitment_policy", ALGORITHM_CONFIGS) +def test_multipart_metadata_present(algorithm_suite, commitment_policy): + """Multipart-uploaded objects should have encryption metadata set.""" + key = _unique_key("mpu-metadata-") + data = os.urandom(FIVE_MB + 512) + + s3ec = _make_client(algorithm_suite, commitment_policy) + + create_resp = s3ec.create_multipart_upload(Bucket=bucket, Key=key) + upload_id = create_resp["UploadId"] + + try: + resp1 = s3ec.upload_part( + Bucket=bucket, Key=key, UploadId=upload_id, PartNumber=1, Body=data[:FIVE_MB] + ) + resp2 = s3ec.upload_part( + Bucket=bucket, + Key=key, + UploadId=upload_id, + PartNumber=2, + Body=data[FIVE_MB:], + IsLastPart=True, + ) + + s3ec.complete_multipart_upload( + Bucket=bucket, + Key=key, + UploadId=upload_id, + MultipartUpload={ + "Parts": [ + {"PartNumber": 1, "ETag": resp1["ETag"]}, + {"PartNumber": 2, "ETag": resp2["ETag"]}, + ] + }, + ) + except Exception: + s3ec.abort_multipart_upload(Bucket=bucket, Key=key, UploadId=upload_id) + raise + + # Verify encryption metadata is present on the object + plain_s3 = boto3.client("s3") + head = plain_s3.head_object(Bucket=bucket, Key=key) + metadata = head.get("Metadata", {}) + + if algorithm_suite == AlgorithmSuite.ALG_AES_256_GCM_IV12_TAG16_NO_KDF: + assert "x-amz-key-v2" in metadata + assert "x-amz-iv" in metadata + assert "x-amz-cek-alg" in metadata + assert "x-amz-wrap-alg" in metadata + else: + assert "x-amz-3" in metadata + assert "x-amz-c" in metadata + assert "x-amz-d" in metadata + assert "x-amz-i" in metadata + assert "x-amz-w" in metadata + + +# --------------------------------------------------------------------------- +# Error cases +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("algorithm_suite,commitment_policy", ALGORITHM_CONFIGS) +def test_upload_part_out_of_order_fails(algorithm_suite, commitment_policy): + """Uploading parts out of sequence order must fail (serial cipher requirement).""" + key = _unique_key("mpu-ooo-") + + s3ec = _make_client(algorithm_suite, commitment_policy) + + create_resp = s3ec.create_multipart_upload(Bucket=bucket, Key=key) + upload_id = create_resp["UploadId"] + + try: + # Skip part 1, try to upload part 2 first + with pytest.raises(S3EncryptionClientError): + s3ec.upload_part( + Bucket=bucket, Key=key, UploadId=upload_id, PartNumber=2, Body=os.urandom(FIVE_MB) + ) + finally: + s3ec.abort_multipart_upload(Bucket=bucket, Key=key, UploadId=upload_id) + + +@pytest.mark.parametrize("algorithm_suite,commitment_policy", ALGORITHM_CONFIGS) +def test_upload_part_invalid_upload_id_fails(algorithm_suite, commitment_policy): + """upload_part with an unknown upload ID must fail.""" + key = _unique_key("mpu-bad-id-") + + s3ec = _make_client(algorithm_suite, commitment_policy) + + with pytest.raises(S3EncryptionClientError): + s3ec.upload_part( + Bucket=bucket, + Key=key, + UploadId="nonexistent-upload-id", + PartNumber=1, + Body=os.urandom(1024), + ) + + +@pytest.mark.parametrize("algorithm_suite,commitment_policy", ALGORITHM_CONFIGS) +def test_complete_without_parts_fails(algorithm_suite, commitment_policy): + """Completing a multipart upload without marking a final part must fail.""" + key = _unique_key("mpu-no-parts-") + + s3ec = _make_client(algorithm_suite, commitment_policy) + + create_resp = s3ec.create_multipart_upload(Bucket=bucket, Key=key) + upload_id = create_resp["UploadId"] + + try: + with pytest.raises(S3EncryptionClientError, match="final part has not been uploaded"): + s3ec.complete_multipart_upload( + Bucket=bucket, + Key=key, + UploadId=upload_id, + MultipartUpload={"Parts": []}, + ) + finally: + # Clean up in case complete didn't actually fail at the S3 level + try: + s3ec.abort_multipart_upload(Bucket=bucket, Key=key, UploadId=upload_id) + except Exception: + pass + + +# --------------------------------------------------------------------------- +# User metadata preservation with multipart +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("algorithm_suite,commitment_policy", ALGORITHM_CONFIGS) +def test_multipart_user_metadata_preserved(algorithm_suite, commitment_policy): + """User-provided metadata on create_multipart_upload should be preserved.""" + key = _unique_key("mpu-user-meta-") + user_metadata = {"author": "test-user", "version": "2.0"} + data = os.urandom(FIVE_MB + 512) + + s3ec = _make_client(algorithm_suite, commitment_policy) + + create_resp = s3ec.create_multipart_upload(Bucket=bucket, Key=key, Metadata=user_metadata) + upload_id = create_resp["UploadId"] + + try: + resp1 = s3ec.upload_part( + Bucket=bucket, Key=key, UploadId=upload_id, PartNumber=1, Body=data[:FIVE_MB] + ) + resp2 = s3ec.upload_part( + Bucket=bucket, + Key=key, + UploadId=upload_id, + PartNumber=2, + Body=data[FIVE_MB:], + IsLastPart=True, + ) + + s3ec.complete_multipart_upload( + Bucket=bucket, + Key=key, + UploadId=upload_id, + MultipartUpload={ + "Parts": [ + {"PartNumber": 1, "ETag": resp1["ETag"]}, + {"PartNumber": 2, "ETag": resp2["ETag"]}, + ] + }, + ) + except Exception: + s3ec.abort_multipart_upload(Bucket=bucket, Key=key, UploadId=upload_id) + raise + + response = s3ec.get_object(Bucket=bucket, Key=key) + assert response["Body"].read() == data + + returned_metadata = response.get("Metadata", {}) + for k, v in user_metadata.items(): + assert returned_metadata.get(k) == v + + +# --------------------------------------------------------------------------- +# Upload part after final part +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("algorithm_suite,commitment_policy", ALGORITHM_CONFIGS) +def test_upload_part_after_final_part_fails(algorithm_suite, commitment_policy): + """Uploading a part after IsLastPart=True must fail.""" + key = _unique_key("mpu-after-final-") + + s3ec = _make_client(algorithm_suite, commitment_policy) + + create_resp = s3ec.create_multipart_upload(Bucket=bucket, Key=key) + upload_id = create_resp["UploadId"] + + try: + s3ec.upload_part( + Bucket=bucket, + Key=key, + UploadId=upload_id, + PartNumber=1, + Body=os.urandom(FIVE_MB), + IsLastPart=True, + ) + + with pytest.raises(S3EncryptionClientError): + s3ec.upload_part( + Bucket=bucket, + Key=key, + UploadId=upload_id, + PartNumber=2, + Body=os.urandom(1024), + ) + finally: + s3ec.abort_multipart_upload(Bucket=bucket, Key=key, UploadId=upload_id) + + +# --------------------------------------------------------------------------- +# Empty body multipart +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("algorithm_suite,commitment_policy", ALGORITHM_CONFIGS) +def test_multipart_empty_final_part(algorithm_suite, commitment_policy): + """A multipart upload where the last part has an empty body should still work.""" + key = _unique_key("mpu-empty-last-") + part1_data = os.urandom(FIVE_MB) + + s3ec = _make_client(algorithm_suite, commitment_policy) + + create_resp = s3ec.create_multipart_upload(Bucket=bucket, Key=key) + upload_id = create_resp["UploadId"] + + try: + resp1 = s3ec.upload_part( + Bucket=bucket, Key=key, UploadId=upload_id, PartNumber=1, Body=part1_data + ) + resp2 = s3ec.upload_part( + Bucket=bucket, + Key=key, + UploadId=upload_id, + PartNumber=2, + Body=b"", + IsLastPart=True, + ) + + s3ec.complete_multipart_upload( + Bucket=bucket, + Key=key, + UploadId=upload_id, + MultipartUpload={ + "Parts": [ + {"PartNumber": 1, "ETag": resp1["ETag"]}, + {"PartNumber": 2, "ETag": resp2["ETag"]}, + ] + }, + ) + except Exception: + s3ec.abort_multipart_upload(Bucket=bucket, Key=key, UploadId=upload_id) + raise + + response = s3ec.get_object(Bucket=bucket, Key=key) + assert response["Body"].read() == part1_data + + +# --------------------------------------------------------------------------- +# Many parts (stress sequential cipher) +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("algorithm_suite,commitment_policy", ALGORITHM_CONFIGS) +def test_multipart_many_parts(algorithm_suite, commitment_policy): + """Multipart upload with 10+ parts to stress the sequential cipher.""" + key = _unique_key("mpu-many-parts-") + num_parts = 12 + parts_data = [os.urandom(FIVE_MB) for _ in range(num_parts - 1)] + parts_data.append(os.urandom(1024)) # small last part + expected = b"".join(parts_data) + + s3ec = _make_client(algorithm_suite, commitment_policy) + + create_resp = s3ec.create_multipart_upload(Bucket=bucket, Key=key) + upload_id = create_resp["UploadId"] + + try: + parts = [] + for i, part_data in enumerate(parts_data, start=1): + is_last = i == num_parts + resp = s3ec.upload_part( + Bucket=bucket, + Key=key, + UploadId=upload_id, + PartNumber=i, + Body=part_data, + IsLastPart=is_last, + ) + parts.append({"PartNumber": i, "ETag": resp["ETag"]}) + + s3ec.complete_multipart_upload( + Bucket=bucket, + Key=key, + UploadId=upload_id, + MultipartUpload={"Parts": parts}, + ) + except Exception: + s3ec.abort_multipart_upload(Bucket=bucket, Key=key, UploadId=upload_id) + raise + + response = s3ec.get_object(Bucket=bucket, Key=key) + assert response["Body"].read() == expected + + +# --------------------------------------------------------------------------- +# Non-ASCII encryption context rejected on multipart +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("algorithm_suite,commitment_policy", ALGORITHM_CONFIGS) +def test_multipart_non_ascii_encryption_context_rejected(algorithm_suite, commitment_policy): + """Non-ASCII encryption context must be rejected on create_multipart_upload.""" + key = _unique_key("mpu-non-ascii-ec-") + non_ascii_contexts = [ + {"department": "ingeniería"}, + {"部門": "engineering"}, + {"emoji": "🔑"}, + ] + + s3ec = _make_client(algorithm_suite, commitment_policy) + + for ec in non_ascii_contexts: + with pytest.raises(S3EncryptionClientError, match="US-ASCII"): + s3ec.create_multipart_upload(Bucket=bucket, Key=key, EncryptionContext=ec) + + +# --------------------------------------------------------------------------- +# Caller metadata dict not mutated +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("algorithm_suite,commitment_policy", ALGORITHM_CONFIGS) +def test_multipart_caller_metadata_not_mutated(algorithm_suite, commitment_policy): + """create_multipart_upload must not mutate the caller's Metadata dict.""" + key = _unique_key("mpu-no-mutate-") + caller_metadata = {"author": "test"} + original_keys = set(caller_metadata.keys()) + + s3ec = _make_client(algorithm_suite, commitment_policy) + + create_resp = s3ec.create_multipart_upload(Bucket=bucket, Key=key, Metadata=caller_metadata) + upload_id = create_resp["UploadId"] + + # Clean up + s3ec.abort_multipart_upload(Bucket=bucket, Key=key, UploadId=upload_id) + + assert set(caller_metadata.keys()) == original_keys + + +# --------------------------------------------------------------------------- +# Per-upload lock does not block independent uploads +# --------------------------------------------------------------------------- + + +def test_per_upload_lock_independent_uploads(): + """Per-upload locks must not block concurrent uploads to different objects.""" + s3ec = _make_client( + AlgorithmSuite.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY, + CommitmentPolicy.REQUIRE_ENCRYPT_REQUIRE_DECRYPT, + ) + + barrier = threading.Barrier(2) + results = {} + errors = [] + + def do_upload(thread_id): + try: + key = _unique_key(f"mpu-lock-{thread_id}-") + data = os.urandom(FIVE_MB + 512) + + create_resp = s3ec.create_multipart_upload(Bucket=bucket, Key=key) + upload_id = create_resp["UploadId"] + + try: + # Sync so both threads call upload_part simultaneously + barrier.wait(timeout=30) + + resp1 = s3ec.upload_part( + Bucket=bucket, + Key=key, + UploadId=upload_id, + PartNumber=1, + Body=data[:FIVE_MB], + ) + + barrier.wait(timeout=30) + + resp2 = s3ec.upload_part( + Bucket=bucket, + Key=key, + UploadId=upload_id, + PartNumber=2, + Body=data[FIVE_MB:], + IsLastPart=True, + ) + + s3ec.complete_multipart_upload( + Bucket=bucket, + Key=key, + UploadId=upload_id, + MultipartUpload={ + "Parts": [ + {"PartNumber": 1, "ETag": resp1["ETag"]}, + {"PartNumber": 2, "ETag": resp2["ETag"]}, + ] + }, + ) + except Exception: + s3ec.abort_multipart_upload(Bucket=bucket, Key=key, UploadId=upload_id) + raise + + response = s3ec.get_object(Bucket=bucket, Key=key) + assert response["Body"].read() == data + results[thread_id] = True + + except Exception as e: + errors.append(f"Thread {thread_id}: {e}") + + t1 = threading.Thread(target=do_upload, args=(0,)) + t2 = threading.Thread(target=do_upload, args=(1,)) + t1.start() + t2.start() + t1.join(timeout=120) + t2.join(timeout=120) + + if errors: + raise AssertionError( + "Per-upload lock test failed:\n" + "\n".join(f" - {e}" for e in errors) + ) + assert len(results) == 2 + + +# --------------------------------------------------------------------------- +# Extra kwargs forwarded through upload_part +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("algorithm_suite,commitment_policy", ALGORITHM_CONFIGS) +def test_upload_part_forwards_expected_bucket_owner(algorithm_suite, commitment_policy): + """upload_part must forward ExpectedBucketOwner to S3 without error.""" + key = _unique_key("mpu-fwd-kwargs-") + data = os.urandom(FIVE_MB + 512) + + s3ec = _make_client(algorithm_suite, commitment_policy) + + # Get the account ID that owns the bucket (same account we're authed as) + sts = boto3.client("sts") + account_id = sts.get_caller_identity()["Account"] + + create_resp = s3ec.create_multipart_upload(Bucket=bucket, Key=key) + upload_id = create_resp["UploadId"] + + try: + resp1 = s3ec.upload_part( + Bucket=bucket, + Key=key, + UploadId=upload_id, + PartNumber=1, + Body=data[:FIVE_MB], + ExpectedBucketOwner=account_id, + ) + resp2 = s3ec.upload_part( + Bucket=bucket, + Key=key, + UploadId=upload_id, + PartNumber=2, + Body=data[FIVE_MB:], + IsLastPart=True, + ExpectedBucketOwner=account_id, + ) + + s3ec.complete_multipart_upload( + Bucket=bucket, + Key=key, + UploadId=upload_id, + MultipartUpload={ + "Parts": [ + {"PartNumber": 1, "ETag": resp1["ETag"]}, + {"PartNumber": 2, "ETag": resp2["ETag"]}, + ] + }, + ) + except Exception: + s3ec.abort_multipart_upload(Bucket=bucket, Key=key, UploadId=upload_id) + raise + + response = s3ec.get_object(Bucket=bucket, Key=key) + assert response["Body"].read() == data + + +# --------------------------------------------------------------------------- +# Complete failure preserves state for retry +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("algorithm_suite,commitment_policy", ALGORITHM_CONFIGS) +def test_complete_retryable_after_failure(algorithm_suite, commitment_policy): + """If complete_multipart_upload fails, the upload can be retried.""" + key = _unique_key("mpu-retry-complete-") + data = os.urandom(FIVE_MB + 512) + + s3ec = _make_client(algorithm_suite, commitment_policy) + + create_resp = s3ec.create_multipart_upload(Bucket=bucket, Key=key) + upload_id = create_resp["UploadId"] + + try: + resp1 = s3ec.upload_part( + Bucket=bucket, Key=key, UploadId=upload_id, PartNumber=1, Body=data[:FIVE_MB] + ) + resp2 = s3ec.upload_part( + Bucket=bucket, + Key=key, + UploadId=upload_id, + PartNumber=2, + Body=data[FIVE_MB:], + IsLastPart=True, + ) + + parts = [ + {"PartNumber": 1, "ETag": resp1["ETag"]}, + {"PartNumber": 2, "ETag": resp2["ETag"]}, + ] + + # First attempt: deliberately pass bad parts to trigger S3 error + try: + s3ec.complete_multipart_upload( + Bucket=bucket, + Key=key, + UploadId=upload_id, + MultipartUpload={"Parts": [{"PartNumber": 99, "ETag": '"bogus"'}]}, + ) + except S3EncryptionClientError: + pass # Expected failure + + # Retry with correct parts should succeed (state preserved) + s3ec.complete_multipart_upload( + Bucket=bucket, + Key=key, + UploadId=upload_id, + MultipartUpload={"Parts": parts}, + ) + except Exception: + s3ec.abort_multipart_upload(Bucket=bucket, Key=key, UploadId=upload_id) + raise + + response = s3ec.get_object(Bucket=bucket, Key=key) + assert response["Body"].read() == data + + +# --------------------------------------------------------------------------- +# Retry upload_part with same part number +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("algorithm_suite,commitment_policy", ALGORITHM_CONFIGS) +def test_upload_part_retry_same_part_number(algorithm_suite, commitment_policy): + """Calling upload_part twice with the same part number returns cached ciphertext and decrypts.""" + key = _unique_key("mpu-retry-part-") + part1_data = os.urandom(FIVE_MB) + part2_data = os.urandom(1024) + expected = part1_data + part2_data + + s3ec = _make_client(algorithm_suite, commitment_policy) + + create_resp = s3ec.create_multipart_upload(Bucket=bucket, Key=key) + upload_id = create_resp["UploadId"] + + try: + # Upload part 1 twice (simulating a retry after transient failure) + resp1_first = s3ec.upload_part( + Bucket=bucket, Key=key, UploadId=upload_id, PartNumber=1, Body=part1_data + ) + resp1_retry = s3ec.upload_part( + Bucket=bucket, Key=key, UploadId=upload_id, PartNumber=1, Body=part1_data + ) + # Both should produce the same ETag (same ciphertext uploaded) + assert resp1_first["ETag"] == resp1_retry["ETag"] + + resp2 = s3ec.upload_part( + Bucket=bucket, + Key=key, + UploadId=upload_id, + PartNumber=2, + Body=part2_data, + IsLastPart=True, + ) + + s3ec.complete_multipart_upload( + Bucket=bucket, + Key=key, + UploadId=upload_id, + MultipartUpload={ + "Parts": [ + {"PartNumber": 1, "ETag": resp1_retry["ETag"]}, + {"PartNumber": 2, "ETag": resp2["ETag"]}, + ] + }, + ) + except Exception: + s3ec.abort_multipart_upload(Bucket=bucket, Key=key, UploadId=upload_id) + raise + + response = s3ec.get_object(Bucket=bucket, Key=key) + assert response["Body"].read() == expected diff --git a/test/integration/test_i_s3_encryption_multithreaded.py b/test/integration/test_i_s3_encryption_multithreaded.py index e71a17df..8f713ac5 100644 --- a/test/integration/test_i_s3_encryption_multithreaded.py +++ b/test/integration/test_i_s3_encryption_multithreaded.py @@ -8,6 +8,7 @@ """ import os +import threading from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime @@ -314,3 +315,108 @@ def worker_without_context(thread_id): print("Success! Mixed threads (with and without encryption context) completed successfully.") print("Thread-local storage properly isolated context between threads.") + + +def test_concurrent_multipart_uploads(): + """Test that multiple multipart uploads can run concurrently on the same client. + + Uses a barrier to ensure upload_part calls for different objects are + interleaved, exercising the per-upload cipher isolation under contention. + """ + kms_client = boto3.client("kms", region_name=region) + keyring = KmsKeyring(kms_client, kms_key_id) + wrapped_client = boto3.client("s3") + config = S3EncryptionClientConfig( + keyring, + encryption_algorithm=AlgorithmSuite.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY, + commitment_policy=CommitmentPolicy.REQUIRE_ENCRYPT_REQUIRE_DECRYPT, + ) + s3ec = S3EncryptionClient(wrapped_client, config) + + num_uploads = 5 + five_mb = 5 * 1024 * 1024 + errors = [] + + # Barrier ensures all threads hit upload_part at roughly the same time + barrier = threading.Barrier(num_uploads) + + def multipart_worker(thread_id): + """Create upload, sync at barrier, then upload parts interleaved with other threads.""" + try: + key = f"concurrent-mpu-{thread_id}-{datetime.now().strftime('%Y%m%d-%H%M%S-%f')}" + part1_data = os.urandom(five_mb) + part2_data = os.urandom(1024) + expected = part1_data + part2_data + + create_resp = s3ec.create_multipart_upload(Bucket=bucket, Key=key) + upload_id = create_resp["UploadId"] + + try: + # All threads wait here, then upload_part calls interleave + barrier.wait(timeout=30) + + resp1 = s3ec.upload_part( + Bucket=bucket, + Key=key, + UploadId=upload_id, + PartNumber=1, + Body=part1_data, + ) + + # Second barrier to interleave part 2 as well + barrier.wait(timeout=30) + + resp2 = s3ec.upload_part( + Bucket=bucket, + Key=key, + UploadId=upload_id, + PartNumber=2, + Body=part2_data, + IsLastPart=True, + ) + + s3ec.complete_multipart_upload( + Bucket=bucket, + Key=key, + UploadId=upload_id, + MultipartUpload={ + "Parts": [ + {"PartNumber": 1, "ETag": resp1["ETag"]}, + {"PartNumber": 2, "ETag": resp2["ETag"]}, + ] + }, + ) + except Exception: + s3ec.abort_multipart_upload(Bucket=bucket, Key=key, UploadId=upload_id) + raise + + response = s3ec.get_object(Bucket=bucket, Key=key) + decrypted = response["Body"].read() + + if decrypted != expected: + return { + "thread_id": thread_id, + "success": False, + "error": f"Data mismatch: expected {len(expected)} bytes, got {len(decrypted)}", + } + + return {"thread_id": thread_id, "success": True} + + except Exception as e: + return {"thread_id": thread_id, "success": False, "error": str(e)} + + with ThreadPoolExecutor(max_workers=num_uploads) as executor: + futures = [executor.submit(multipart_worker, i) for i in range(num_uploads)] + + for future in as_completed(futures): + result = future.result() + if not result["success"]: + errors.append( + f"Thread {result['thread_id']}: {result.get('error', 'Unknown error')}" + ) + + if errors: + raise RuntimeError( + f"{len(errors)} concurrent multipart upload(s) failed:\n" + + "\n".join(f" - {e}" for e in errors) + ) diff --git a/test/integration/test_i_s3_encryption_transfer_manager.py b/test/integration/test_i_s3_encryption_transfer_manager.py new file mode 100644 index 00000000..54ebc10b --- /dev/null +++ b/test/integration/test_i_s3_encryption_transfer_manager.py @@ -0,0 +1,396 @@ +# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Integration tests for S3EncryptionClient with boto3's S3Transfer / upload_file. + +These tests verify that the S3EncryptionClient's upload_file and upload_fileobj +methods correctly handle the multipart threshold boundary, produce objects +decryptable by get_object, and behave correctly with various TransferConfig-like +parameters. + +boto3's native upload_file (via s3transfer) calls create_multipart_upload, +upload_part, and complete_multipart_upload directly on the client it wraps. +Since those calls would bypass encryption if made on the raw S3 client, +the S3EncryptionClient provides its own upload_file / upload_fileobj that +route through the encrypted multipart pipeline. +""" + +import io +import os +import tempfile +from datetime import datetime + +import boto3 +import pytest + +from s3_encryption import S3EncryptionClient, S3EncryptionClientConfig +from s3_encryption.exceptions import S3EncryptionClientError +from s3_encryption.materials.kms_keyring import KmsKeyring +from s3_encryption.materials.materials import AlgorithmSuite, CommitmentPolicy + +bucket = os.environ.get("CI_S3_BUCKET", "s3ec-python-github-test-bucket") +region = os.environ.get("CI_AWS_REGION", "us-west-2") +kms_key_id = os.environ.get( + "CI_KMS_KEY_ALIAS", "arn:aws:kms:us-west-2:370957321024:alias/S3EC-Python-Github-KMS-Key" +) + +ALGORITHM_CONFIGS = [ + pytest.param( + AlgorithmSuite.ALG_AES_256_GCM_IV12_TAG16_NO_KDF, + CommitmentPolicy.FORBID_ENCRYPT_ALLOW_DECRYPT, + id="AES_GCM", + ), + pytest.param( + AlgorithmSuite.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY, + CommitmentPolicy.REQUIRE_ENCRYPT_REQUIRE_DECRYPT, + id="KC_GCM", + ), +] + +ONE_MB = 1024 * 1024 + + +def _make_client(algorithm_suite, commitment_policy, **extra_config): + kms_client = boto3.client("kms", region_name=region) + keyring = KmsKeyring(kms_client, kms_key_id) + wrapped_client = boto3.client("s3") + config = S3EncryptionClientConfig( + keyring, + encryption_algorithm=algorithm_suite, + commitment_policy=commitment_policy, + **extra_config, + ) + return S3EncryptionClient(wrapped_client, config) + + +def _unique_key(prefix): + return prefix + datetime.now().strftime("%Y-%m-%d-%H:%M:%S-%f") + + +def _write_temp_file(data): + """Write data to a temp file and return the path.""" + f = tempfile.NamedTemporaryFile(delete=False) + f.write(data) + f.close() + return f.name + + +# --------------------------------------------------------------------------- +# upload_file: below threshold → put_object path +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("algorithm_suite,commitment_policy", ALGORITHM_CONFIGS) +def test_upload_file_below_threshold(algorithm_suite, commitment_policy): + """Files smaller than the threshold should use put_object internally.""" + key = _unique_key("tm-below-") + data = os.urandom(1024) + tmp = _write_temp_file(data) + + try: + s3ec = _make_client(algorithm_suite, commitment_policy) + s3ec.upload_file(tmp, bucket, key) + assert s3ec.get_object(Bucket=bucket, Key=key)["Body"].read() == data + finally: + os.unlink(tmp) + + +# --------------------------------------------------------------------------- +# upload_file: above threshold → multipart path +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("algorithm_suite,commitment_policy", ALGORITHM_CONFIGS) +def test_upload_file_above_default_threshold(algorithm_suite, commitment_policy): + """Files larger than the default 8 MB threshold trigger multipart upload.""" + key = _unique_key("tm-above-default-") + data = os.urandom(9 * ONE_MB) + tmp = _write_temp_file(data) + + try: + s3ec = _make_client(algorithm_suite, commitment_policy) + s3ec.upload_file(tmp, bucket, key) + assert s3ec.get_object(Bucket=bucket, Key=key)["Body"].read() == data + finally: + os.unlink(tmp) + + +# --------------------------------------------------------------------------- +# upload_file: custom threshold +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("algorithm_suite,commitment_policy", ALGORITHM_CONFIGS) +def test_upload_file_custom_threshold(algorithm_suite, commitment_policy): + """A custom multipart_threshold forces multipart for smaller files.""" + key = _unique_key("tm-custom-thresh-") + # 6 MB file with a 5 MB threshold → multipart + data = os.urandom(6 * ONE_MB) + tmp = _write_temp_file(data) + + try: + s3ec = _make_client(algorithm_suite, commitment_policy) + s3ec.upload_file(tmp, bucket, key, multipart_threshold=5 * ONE_MB) + assert s3ec.get_object(Bucket=bucket, Key=key)["Body"].read() == data + finally: + os.unlink(tmp) + + +# --------------------------------------------------------------------------- +# upload_file: custom chunksize +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("algorithm_suite,commitment_policy", ALGORITHM_CONFIGS) +def test_upload_file_custom_chunksize(algorithm_suite, commitment_policy): + """A custom multipart_chunksize controls part size (more parts).""" + key = _unique_key("tm-custom-chunk-") + # 11 MB file with 5 MB chunks → 3 parts (5 + 5 + 1) + data = os.urandom(11 * ONE_MB) + tmp = _write_temp_file(data) + + try: + s3ec = _make_client(algorithm_suite, commitment_policy) + s3ec.upload_file( + tmp, + bucket, + key, + multipart_threshold=5 * ONE_MB, + multipart_chunksize=5 * ONE_MB, + ) + assert s3ec.get_object(Bucket=bucket, Key=key)["Body"].read() == data + finally: + os.unlink(tmp) + + +# --------------------------------------------------------------------------- +# upload_file: exactly at threshold boundary +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("algorithm_suite,commitment_policy", ALGORITHM_CONFIGS) +def test_upload_file_exactly_at_threshold(algorithm_suite, commitment_policy): + """A file exactly equal to the threshold should use put_object (< not <=).""" + key = _unique_key("tm-exact-thresh-") + threshold = 5 * ONE_MB + data = os.urandom(threshold) + tmp = _write_temp_file(data) + + try: + s3ec = _make_client(algorithm_suite, commitment_policy) + s3ec.upload_file(tmp, bucket, key, multipart_threshold=threshold) + assert s3ec.get_object(Bucket=bucket, Key=key)["Body"].read() == data + finally: + os.unlink(tmp) + + +# --------------------------------------------------------------------------- +# upload_fileobj: basic round-trip +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("algorithm_suite,commitment_policy", ALGORITHM_CONFIGS) +def test_upload_fileobj_roundtrip(algorithm_suite, commitment_policy): + """upload_fileobj encrypts a BytesIO via multipart and decrypts correctly.""" + key = _unique_key("tm-fileobj-") + data = os.urandom(9 * ONE_MB) + + s3ec = _make_client(algorithm_suite, commitment_policy) + s3ec.upload_fileobj(io.BytesIO(data), bucket, key) + assert s3ec.get_object(Bucket=bucket, Key=key)["Body"].read() == data + + +# --------------------------------------------------------------------------- +# upload_fileobj: small object (single part) +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("algorithm_suite,commitment_policy", ALGORITHM_CONFIGS) +def test_upload_fileobj_small(algorithm_suite, commitment_policy): + """upload_fileobj with a small object still works (single multipart part).""" + key = _unique_key("tm-fileobj-small-") + data = os.urandom(1024) + + s3ec = _make_client(algorithm_suite, commitment_policy) + s3ec.upload_fileobj(io.BytesIO(data), bucket, key) + assert s3ec.get_object(Bucket=bucket, Key=key)["Body"].read() == data + + +# --------------------------------------------------------------------------- +# upload_file with encryption context +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("algorithm_suite,commitment_policy", ALGORITHM_CONFIGS) +def test_upload_file_with_encryption_context(algorithm_suite, commitment_policy): + """upload_file passes EncryptionContext through to the multipart pipeline.""" + key = _unique_key("tm-ec-") + data = os.urandom(9 * ONE_MB) + ec = {"purpose": "transfer-manager-test"} + tmp = _write_temp_file(data) + + try: + s3ec = _make_client(algorithm_suite, commitment_policy) + s3ec.upload_file(tmp, bucket, key, EncryptionContext=ec) + assert s3ec.get_object(Bucket=bucket, Key=key, EncryptionContext=ec)["Body"].read() == data + finally: + os.unlink(tmp) + + +# --------------------------------------------------------------------------- +# upload_file with user metadata +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("algorithm_suite,commitment_policy", ALGORITHM_CONFIGS) +def test_upload_file_with_user_metadata(algorithm_suite, commitment_policy): + """User-provided Metadata is preserved through upload_file multipart path.""" + key = _unique_key("tm-meta-") + data = os.urandom(9 * ONE_MB) + user_meta = {"author": "test", "version": "3"} + tmp = _write_temp_file(data) + + try: + s3ec = _make_client(algorithm_suite, commitment_policy) + s3ec.upload_file(tmp, bucket, key, Metadata=user_meta) + + response = s3ec.get_object(Bucket=bucket, Key=key) + assert response["Body"].read() == data + returned = response.get("Metadata", {}) + for k, v in user_meta.items(): + assert returned.get(k) == v + finally: + os.unlink(tmp) + + +# --------------------------------------------------------------------------- +# Delayed-auth decryption of transfer-manager-uploaded objects +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("algorithm_suite,commitment_policy", ALGORITHM_CONFIGS) +def test_upload_file_decrypt_delayed_auth(algorithm_suite, commitment_policy): + """Objects uploaded via upload_file are decryptable in delayed-auth mode.""" + key = _unique_key("tm-delayed-") + data = os.urandom(9 * ONE_MB) + tmp = _write_temp_file(data) + + try: + writer = _make_client(algorithm_suite, commitment_policy) + writer.upload_file(tmp, bucket, key) + + reader = _make_client( + algorithm_suite, commitment_policy, enable_delayed_authentication=True + ) + response = reader.get_object(Bucket=bucket, Key=key) + result = b"" + while chunk := response["Body"].read(65536): + result += chunk + assert result == data + finally: + os.unlink(tmp) + + +# --------------------------------------------------------------------------- +# Parameter validation: zero/negative threshold and chunksize +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("algorithm_suite,commitment_policy", ALGORITHM_CONFIGS) +def test_upload_file_zero_threshold_raises(algorithm_suite, commitment_policy, tmp_path): + """upload_file with multipart_threshold=0 must raise.""" + s3ec = _make_client(algorithm_suite, commitment_policy) + f = tmp_path / "test.bin" + f.write_bytes(os.urandom(1024)) + + with pytest.raises(S3EncryptionClientError, match="multipart_threshold must be a positive"): + s3ec.upload_file(str(f), bucket, "unused-key", multipart_threshold=0) + + +@pytest.mark.parametrize("algorithm_suite,commitment_policy", ALGORITHM_CONFIGS) +def test_upload_file_zero_chunksize_raises(algorithm_suite, commitment_policy, tmp_path): + """upload_file with multipart_chunksize=0 must raise.""" + s3ec = _make_client(algorithm_suite, commitment_policy) + f = tmp_path / "test.bin" + f.write_bytes(os.urandom(1024)) + + with pytest.raises(S3EncryptionClientError, match="multipart_chunksize must be a positive"): + s3ec.upload_file(str(f), bucket, "unused-key", multipart_chunksize=0) + + +@pytest.mark.parametrize("algorithm_suite,commitment_policy", ALGORITHM_CONFIGS) +def test_upload_fileobj_zero_chunksize_raises(algorithm_suite, commitment_policy): + """upload_fileobj with multipart_chunksize=0 must raise.""" + s3ec = _make_client(algorithm_suite, commitment_policy) + + with pytest.raises(S3EncryptionClientError, match="multipart_chunksize must be a positive"): + s3ec.upload_fileobj(io.BytesIO(b"data"), bucket, "unused-key", multipart_chunksize=0) + + +@pytest.mark.parametrize("algorithm_suite,commitment_policy", ALGORITHM_CONFIGS) +def test_upload_file_chunksize_below_5mb_raises(algorithm_suite, commitment_policy, tmp_path): + """upload_file with chunksize below S3's 5 MB minimum must raise.""" + s3ec = _make_client(algorithm_suite, commitment_policy) + f = tmp_path / "test.bin" + f.write_bytes(os.urandom(1024)) + + with pytest.raises(S3EncryptionClientError, match="at least.*5 MB"): + s3ec.upload_file(str(f), bucket, "unused-key", multipart_chunksize=1024 * 1024) + + +@pytest.mark.parametrize("algorithm_suite,commitment_policy", ALGORITHM_CONFIGS) +def test_upload_fileobj_chunksize_below_5mb_raises(algorithm_suite, commitment_policy): + """upload_fileobj with chunksize below S3's 5 MB minimum must raise.""" + s3ec = _make_client(algorithm_suite, commitment_policy) + + with pytest.raises(S3EncryptionClientError, match="at least.*5 MB"): + s3ec.upload_fileobj( + io.BytesIO(b"data"), bucket, "unused-key", multipart_chunksize=4 * ONE_MB + ) + + +# --------------------------------------------------------------------------- +# S3 parameters forwarded through upload_file to create_multipart_upload +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("algorithm_suite,commitment_policy", ALGORITHM_CONFIGS) +def test_upload_file_forwards_content_type(algorithm_suite, commitment_policy, tmp_path): + """upload_file must forward ContentType to the multipart upload.""" + key = _unique_key("tm-content-type-") + data = os.urandom(9 * ONE_MB) + tmp = _write_temp_file(data) + + try: + s3ec = _make_client(algorithm_suite, commitment_policy) + s3ec.upload_file(tmp, bucket, key, ContentType="application/octet-stream") + + # Verify ContentType was set on the object + plain_s3 = boto3.client("s3") + head = plain_s3.head_object(Bucket=bucket, Key=key) + assert head["ContentType"] == "application/octet-stream" + + # Verify data round-trips + assert s3ec.get_object(Bucket=bucket, Key=key)["Body"].read() == data + finally: + os.unlink(tmp) + + +# --------------------------------------------------------------------------- +# upload_fileobj does not close the caller's file object +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("algorithm_suite,commitment_policy", ALGORITHM_CONFIGS) +def test_upload_fileobj_does_not_close_caller_stream(algorithm_suite, commitment_policy): + """upload_fileobj must not close the caller's file-like object.""" + key = _unique_key("tm-no-close-") + data = os.urandom(9 * ONE_MB) + buf = io.BytesIO(data) + + s3ec = _make_client(algorithm_suite, commitment_policy) + s3ec.upload_fileobj(buf, bucket, key) + + assert not buf.closed + + # Verify the upload worked + assert s3ec.get_object(Bucket=bucket, Key=key)["Body"].read() == data diff --git a/test/test_multipart.py b/test/test_multipart.py new file mode 100644 index 00000000..ca14149c --- /dev/null +++ b/test/test_multipart.py @@ -0,0 +1,780 @@ +# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Unit tests for multipart upload encryption pipeline and client methods.""" + +import io +import os +import threading +from unittest.mock import MagicMock + +import pytest + +from s3_encryption import S3EncryptionClient, S3EncryptionClientConfig +from s3_encryption.exceptions import S3EncryptionClientError +from s3_encryption.materials.crypto_materials_manager import DefaultCryptoMaterialsManager +from s3_encryption.materials.encrypted_data_key import EncryptedDataKey +from s3_encryption.materials.materials import ( + AlgorithmSuite, + CommitmentPolicy, +) +from s3_encryption.pipelines import MultipartUploadPipeline + + +def _mock_keyring(): + """Create a mock keyring that returns a fixed data key.""" + key = os.urandom(32) + keyring = MagicMock() + + def on_encrypt(mats): + + mats.plaintext_data_key = key + mats.encrypted_data_key = EncryptedDataKey( + key_provider_id=b"S3Keyring", + key_provider_info="kms+context", + encrypted_data_key=os.urandom(64), + ) + return mats + + keyring.on_encrypt = on_encrypt + return keyring, key + + +def _make_client(algorithm_suite=None, commitment_policy=None): + """Create an S3EncryptionClient with a mock keyring and mock S3 client.""" + keyring, _ = _mock_keyring() + algo = algorithm_suite or AlgorithmSuite.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY + policy = commitment_policy or CommitmentPolicy.REQUIRE_ENCRYPT_REQUIRE_DECRYPT + config = S3EncryptionClientConfig( + keyring=keyring, + encryption_algorithm=algo, + commitment_policy=policy, + ) + mock_s3 = MagicMock() + mock_s3.meta.config.user_agent_extra = "" + mock_s3.meta.events = MagicMock() + return S3EncryptionClient(mock_s3, config) + + +class TestMultipartUploadPipeline: + """Unit tests for the MultipartUploadPipeline cipher logic.""" + + @pytest.fixture( + params=[ + AlgorithmSuite.ALG_AES_256_GCM_IV12_TAG16_NO_KDF, + AlgorithmSuite.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY, + ] + ) + def pipeline(self, request): + + keyring, _ = _mock_keyring() + cmm = DefaultCryptoMaterialsManager(keyring) + return MultipartUploadPipeline( + cmm=cmm, + encryption_algorithm=request.param, + ) + + def test_encrypt_single_part(self, pipeline): + data = b"hello world" + ct = pipeline.encrypt_part(1, data, is_last=True) + # Ciphertext should be data + 16-byte GCM tag + assert len(ct) == len(data) + 16 + assert pipeline.has_final_part_been_seen + + def test_encrypt_multiple_parts(self, pipeline): + part1 = pipeline.encrypt_part(1, b"A" * 1024) + part2 = pipeline.encrypt_part(2, b"B" * 512, is_last=True) + assert len(part1) == 1024 + assert len(part2) == 512 + 16 # data + tag on last part + assert pipeline.has_final_part_been_seen + + def test_out_of_order_raises(self, pipeline): + with pytest.raises(S3EncryptionClientError, match="sequence"): + pipeline.encrypt_part(2, b"data") + + def test_part_after_final_raises(self, pipeline): + pipeline.encrypt_part(1, b"data", is_last=True) + with pytest.raises(S3EncryptionClientError, match="after the final part"): + pipeline.encrypt_part(2, b"more data") + + def test_empty_part(self, pipeline): + ct = pipeline.encrypt_part(1, b"", is_last=True) + # Empty data + 16-byte tag + assert len(ct) == 16 + + def test_metadata_present(self, pipeline): + assert pipeline.metadata + # Should have encryption metadata keys + assert len(pipeline.metadata) > 0 + + def test_string_body_converted(self, pipeline): + ct = pipeline.encrypt_part(1, "hello", is_last=True) + assert len(ct) == len(b"hello") + 16 + + +class TestS3EncryptionClientMultipart: + """Unit tests for the S3EncryptionClient multipart methods.""" + + def test_create_multipart_upload(self): + s3ec = _make_client() + s3ec.wrapped_s3_client.create_multipart_upload.return_value = { + "UploadId": "test-upload-id", + "Bucket": "bucket", + "Key": "key", + } + + resp = s3ec.create_multipart_upload(Bucket="bucket", Key="key") + assert resp["UploadId"] == "test-upload-id" + s3ec.wrapped_s3_client.create_multipart_upload.assert_called_once() + + def test_upload_part_unknown_upload_id(self): + s3ec = _make_client() + with pytest.raises(S3EncryptionClientError, match="No multipart upload found"): + s3ec.upload_part( + Bucket="bucket", Key="key", UploadId="nonexistent", PartNumber=1, Body=b"data" + ) + + def test_upload_part_encrypts(self): + s3ec = _make_client() + s3ec.wrapped_s3_client.create_multipart_upload.return_value = { + "UploadId": "uid-1", + "Bucket": "bucket", + "Key": "key", + } + s3ec.wrapped_s3_client.upload_part.return_value = {"ETag": '"abc123"'} + + s3ec.create_multipart_upload(Bucket="bucket", Key="key") + resp = s3ec.upload_part( + Bucket="bucket", + Key="key", + UploadId="uid-1", + PartNumber=1, + Body=b"data", + IsLastPart=True, + ) + + assert resp["ETag"] == '"abc123"' + # Verify the body passed to S3 is ciphertext (different from plaintext) + call_kwargs = s3ec.wrapped_s3_client.upload_part.call_args[1] + assert call_kwargs["Body"] != b"data" + + def test_complete_without_final_part_raises(self): + s3ec = _make_client() + s3ec.wrapped_s3_client.create_multipart_upload.return_value = { + "UploadId": "uid-2", + "Bucket": "bucket", + "Key": "key", + } + + s3ec.create_multipart_upload(Bucket="bucket", Key="key") + + with pytest.raises(S3EncryptionClientError, match="final part has not been uploaded"): + s3ec.complete_multipart_upload( + Bucket="bucket", + Key="key", + UploadId="uid-2", + MultipartUpload={"Parts": []}, + ) + + def test_complete_after_final_part_succeeds(self): + s3ec = _make_client() + s3ec.wrapped_s3_client.create_multipart_upload.return_value = { + "UploadId": "uid-3", + "Bucket": "bucket", + "Key": "key", + } + s3ec.wrapped_s3_client.upload_part.return_value = {"ETag": '"etag1"'} + s3ec.wrapped_s3_client.complete_multipart_upload.return_value = {"Location": "s3://..."} + + s3ec.create_multipart_upload(Bucket="bucket", Key="key") + s3ec.upload_part( + Bucket="bucket", + Key="key", + UploadId="uid-3", + PartNumber=1, + Body=b"x" * 1024, + IsLastPart=True, + ) + resp = s3ec.complete_multipart_upload( + Bucket="bucket", + Key="key", + UploadId="uid-3", + MultipartUpload={"Parts": [{"PartNumber": 1, "ETag": '"etag1"'}]}, + ) + assert resp["Location"] == "s3://..." + + def test_abort_cleans_up_state(self): + s3ec = _make_client() + s3ec.wrapped_s3_client.create_multipart_upload.return_value = { + "UploadId": "uid-4", + "Bucket": "bucket", + "Key": "key", + } + s3ec.wrapped_s3_client.abort_multipart_upload.return_value = {} + + s3ec.create_multipart_upload(Bucket="bucket", Key="key") + s3ec.abort_multipart_upload(Bucket="bucket", Key="key", UploadId="uid-4") + + # After abort, upload_part should fail + with pytest.raises(S3EncryptionClientError, match="No multipart upload found"): + s3ec.upload_part( + Bucket="bucket", Key="key", UploadId="uid-4", PartNumber=1, Body=b"data" + ) + + def test_complete_unknown_upload_id_raises(self): + s3ec = _make_client() + with pytest.raises(S3EncryptionClientError, match="No multipart upload found"): + s3ec.complete_multipart_upload( + Bucket="bucket", + Key="key", + UploadId="nonexistent", + MultipartUpload={"Parts": []}, + ) + + def test_create_multipart_with_encryption_context(self): + s3ec = _make_client() + s3ec.wrapped_s3_client.create_multipart_upload.return_value = { + "UploadId": "uid-ec", + "Bucket": "bucket", + "Key": "key", + } + + s3ec.create_multipart_upload(Bucket="bucket", Key="key", EncryptionContext={"env": "test"}) + + # EncryptionContext should not be passed to S3 (it's consumed by the pipeline) + call_kwargs = s3ec.wrapped_s3_client.create_multipart_upload.call_args[1] + assert "EncryptionContext" not in call_kwargs + + def test_metadata_merged_on_create(self): + s3ec = _make_client() + s3ec.wrapped_s3_client.create_multipart_upload.return_value = { + "UploadId": "uid-meta", + "Bucket": "bucket", + "Key": "key", + } + + s3ec.create_multipart_upload( + Bucket="bucket", Key="key", Metadata={"user-key": "user-value"} + ) + + call_kwargs = s3ec.wrapped_s3_client.create_multipart_upload.call_args[1] + metadata = call_kwargs["Metadata"] + # User metadata preserved + assert metadata["user-key"] == "user-value" + # Encryption metadata also present + assert len(metadata) > 1 + + +class TestUploadFileAndFileobj: + """Unit tests for upload_file and upload_fileobj high-level methods.""" + + def _setup_client(self): + s3ec = _make_client() + s3ec.wrapped_s3_client.create_multipart_upload.return_value = { + "UploadId": "uid-file", + "Bucket": "bucket", + "Key": "key", + } + s3ec.wrapped_s3_client.upload_part.return_value = {"ETag": '"etag"'} + s3ec.wrapped_s3_client.complete_multipart_upload.return_value = {"Location": "s3://..."} + return s3ec + + def test_upload_file_below_threshold_uses_put_object(self, tmp_path): + s3ec = _make_client() + # Mock put_object on the event-based path + s3ec.wrapped_s3_client.put_object.return_value = {} + + f = tmp_path / "small.bin" + f.write_bytes(b"small data") + + s3ec.upload_file(str(f), "bucket", "key", multipart_threshold=1024 * 1024) + + # put_object should have been called (via the event system) + s3ec.wrapped_s3_client.put_object.assert_called_once() + s3ec.wrapped_s3_client.create_multipart_upload.assert_not_called() + + def test_upload_file_above_threshold_uses_multipart(self, tmp_path): + s3ec = self._setup_client() + + f = tmp_path / "large.bin" + f.write_bytes(os.urandom(2048)) + + s3ec.upload_file( + str(f), "bucket", "key", multipart_threshold=1024, multipart_chunksize=5 * 1024 * 1024 + ) + + s3ec.wrapped_s3_client.create_multipart_upload.assert_called_once() + assert s3ec.wrapped_s3_client.upload_part.call_count >= 1 + s3ec.wrapped_s3_client.complete_multipart_upload.assert_called_once() + + def test_upload_fileobj_uses_multipart(self): + + s3ec = self._setup_client() + data = os.urandom(2048) + + s3ec.upload_fileobj(io.BytesIO(data), "bucket", "key", multipart_chunksize=5 * 1024 * 1024) + + s3ec.wrapped_s3_client.create_multipart_upload.assert_called_once() + assert s3ec.wrapped_s3_client.upload_part.call_count >= 1 + s3ec.wrapped_s3_client.complete_multipart_upload.assert_called_once() + + def test_upload_file_aborts_on_failure(self, tmp_path): + s3ec = _make_client() + s3ec.wrapped_s3_client.create_multipart_upload.return_value = { + "UploadId": "uid-fail", + "Bucket": "bucket", + "Key": "key", + } + s3ec.wrapped_s3_client.upload_part.side_effect = Exception("network error") + s3ec.wrapped_s3_client.abort_multipart_upload.return_value = {} + + f = tmp_path / "fail.bin" + f.write_bytes(os.urandom(2048)) + + with pytest.raises(Exception): + s3ec.upload_file( + str(f), + "bucket", + "key", + multipart_threshold=1024, + multipart_chunksize=5 * 1024 * 1024, + ) + + s3ec.wrapped_s3_client.abort_multipart_upload.assert_called_once() + + def test_upload_file_passes_encryption_context(self, tmp_path): + s3ec = self._setup_client() + + f = tmp_path / "ec.bin" + f.write_bytes(os.urandom(2048)) + + s3ec.upload_file( + str(f), + "bucket", + "key", + multipart_threshold=1024, + multipart_chunksize=5 * 1024 * 1024, + EncryptionContext={"env": "test"}, + ) + + # EncryptionContext consumed by create_multipart_upload, not passed to S3 + create_kwargs = s3ec.wrapped_s3_client.create_multipart_upload.call_args[1] + assert "EncryptionContext" not in create_kwargs + + def test_upload_file_passes_user_metadata(self, tmp_path): + s3ec = self._setup_client() + + f = tmp_path / "meta.bin" + f.write_bytes(os.urandom(2048)) + + s3ec.upload_file( + str(f), + "bucket", + "key", + multipart_threshold=1024, + multipart_chunksize=5 * 1024 * 1024, + Metadata={"author": "test"}, + ) + + create_kwargs = s3ec.wrapped_s3_client.create_multipart_upload.call_args[1] + assert create_kwargs["Metadata"]["author"] == "test" + + +class TestMultipartEncryptionContextValidation: + """Unit tests for encryption context validation in create_multipart_upload.""" + + def test_non_ascii_value_rejected(self): + s3ec = _make_client() + with pytest.raises(S3EncryptionClientError, match="US-ASCII"): + s3ec.create_multipart_upload( + Bucket="bucket", Key="key", EncryptionContext={"key": "válue"} + ) + + def test_non_ascii_key_rejected(self): + s3ec = _make_client() + with pytest.raises(S3EncryptionClientError, match="US-ASCII"): + s3ec.create_multipart_upload( + Bucket="bucket", Key="key", EncryptionContext={"clé": "value"} + ) + + def test_emoji_rejected(self): + s3ec = _make_client() + with pytest.raises(S3EncryptionClientError, match="US-ASCII"): + s3ec.create_multipart_upload( + Bucket="bucket", Key="key", EncryptionContext={"emoji": "🔑"} + ) + + def test_ascii_context_accepted(self): + s3ec = _make_client() + s3ec.wrapped_s3_client.create_multipart_upload.return_value = { + "UploadId": "uid-ascii", + "Bucket": "bucket", + "Key": "key", + } + # Should not raise + resp = s3ec.create_multipart_upload( + Bucket="bucket", Key="key", EncryptionContext={"env": "test"} + ) + assert resp["UploadId"] == "uid-ascii" + + def test_caller_metadata_dict_not_mutated(self): + s3ec = _make_client() + s3ec.wrapped_s3_client.create_multipart_upload.return_value = { + "UploadId": "uid-nomutate", + "Bucket": "bucket", + "Key": "key", + } + + caller_metadata = {"author": "test"} + original_keys = set(caller_metadata.keys()) + + s3ec.create_multipart_upload(Bucket="bucket", Key="key", Metadata=caller_metadata) + + # Caller's dict should not have been modified with encryption metadata + assert set(caller_metadata.keys()) == original_keys + + +class TestMultipartPipelineLock: + """Unit tests verifying per-upload lock prevents concurrent encrypt_part races.""" + + def test_concurrent_encrypt_part_same_pipeline_serialized(self): + """Concurrent calls to encrypt_part on the same pipeline are serialized by the lock.""" + keyring, _ = _mock_keyring() + cmm = DefaultCryptoMaterialsManager(keyring) + pipeline = MultipartUploadPipeline( + cmm=cmm, + encryption_algorithm=AlgorithmSuite.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY, + ) + + results = {} + errors = [] + barrier = threading.Barrier(2) + + def upload_part_1(): + try: + barrier.wait(timeout=5) + ct = pipeline.encrypt_part(1, b"A" * 1024) + results[1] = ct + except Exception as e: + errors.append(("part1", e)) + + def upload_part_2(): + try: + barrier.wait(timeout=5) + ct = pipeline.encrypt_part(2, b"B" * 512, is_last=True) + results[2] = ct + except Exception as e: + errors.append(("part2", e)) + + t1 = threading.Thread(target=upload_part_1) + t2 = threading.Thread(target=upload_part_2) + t1.start() + t2.start() + t1.join(timeout=10) + t2.join(timeout=10) + + # One of two outcomes is valid: + # 1. Both succeed in order (part 1 acquired lock first) + # 2. Part 2 fails with sequence error (part 2 acquired lock first) + if errors: + # If there's an error, it must be a sequence error on part 2 + assert any("sequence" in str(e).lower() for _, e in errors) + else: + # Both succeeded means part 1 ran first + assert 1 in results and 2 in results + assert len(results[1]) == 1024 + assert len(results[2]) == 512 + 16 + + def test_upload_part_forwards_extra_kwargs(self): + """upload_part must forward extra S3 parameters (e.g. RequestPayer) to the S3 client.""" + s3ec = _make_client() + s3ec.wrapped_s3_client.create_multipart_upload.return_value = { + "UploadId": "uid-fwd", + "Bucket": "bucket", + "Key": "key", + } + s3ec.wrapped_s3_client.upload_part.return_value = {"ETag": '"etag"'} + + s3ec.create_multipart_upload(Bucket="bucket", Key="key") + s3ec.upload_part( + Bucket="bucket", + Key="key", + UploadId="uid-fwd", + PartNumber=1, + Body=b"data", + IsLastPart=True, + RequestPayer="requester", + ExpectedBucketOwner="123456789012", + ) + + call_kwargs = s3ec.wrapped_s3_client.upload_part.call_args[1] + assert call_kwargs["RequestPayer"] == "requester" + assert call_kwargs["ExpectedBucketOwner"] == "123456789012" + # IsLastPart should NOT be forwarded to S3 + assert "IsLastPart" not in call_kwargs + + def test_upload_part_does_not_forward_is_last_part(self): + """IsLastPart is consumed by the client and must not reach S3.""" + s3ec = _make_client() + s3ec.wrapped_s3_client.create_multipart_upload.return_value = { + "UploadId": "uid-nolast", + "Bucket": "bucket", + "Key": "key", + } + s3ec.wrapped_s3_client.upload_part.return_value = {"ETag": '"etag"'} + + s3ec.create_multipart_upload(Bucket="bucket", Key="key") + s3ec.upload_part( + Bucket="bucket", + Key="key", + UploadId="uid-nolast", + PartNumber=1, + Body=b"x", + IsLastPart=True, + ) + + call_kwargs = s3ec.wrapped_s3_client.upload_part.call_args[1] + assert "IsLastPart" not in call_kwargs + + def test_complete_failure_preserves_state_for_retry(self): + """If complete_multipart_upload fails, the upload state is preserved for retry.""" + s3ec = _make_client() + s3ec.wrapped_s3_client.create_multipart_upload.return_value = { + "UploadId": "uid-retry", + "Bucket": "bucket", + "Key": "key", + } + s3ec.wrapped_s3_client.upload_part.return_value = {"ETag": '"etag1"'} + + s3ec.create_multipart_upload(Bucket="bucket", Key="key") + s3ec.upload_part( + Bucket="bucket", + Key="key", + UploadId="uid-retry", + PartNumber=1, + Body=b"data", + IsLastPart=True, + ) + + # First complete fails + s3ec.wrapped_s3_client.complete_multipart_upload.side_effect = Exception("network timeout") + with pytest.raises(S3EncryptionClientError, match="network timeout"): + s3ec.complete_multipart_upload( + Bucket="bucket", + Key="key", + UploadId="uid-retry", + MultipartUpload={"Parts": [{"PartNumber": 1, "ETag": '"etag1"'}]}, + ) + + # Retry should work (state not cleaned up) + s3ec.wrapped_s3_client.complete_multipart_upload.side_effect = None + s3ec.wrapped_s3_client.complete_multipart_upload.return_value = {"Location": "s3://ok"} + resp = s3ec.complete_multipart_upload( + Bucket="bucket", + Key="key", + UploadId="uid-retry", + MultipartUpload={"Parts": [{"PartNumber": 1, "ETag": '"etag1"'}]}, + ) + assert resp["Location"] == "s3://ok" + + # After success, state is cleaned up + with pytest.raises(S3EncryptionClientError, match="No multipart upload found"): + s3ec.complete_multipart_upload( + Bucket="bucket", + Key="key", + UploadId="uid-retry", + MultipartUpload={"Parts": []}, + ) + + +class TestUploadFileValidation: + """Unit tests for upload_file/upload_fileobj parameter validation.""" + + def test_zero_threshold_raises(self, tmp_path): + s3ec = _make_client() + f = tmp_path / "test.bin" + f.write_bytes(b"data") + with pytest.raises(S3EncryptionClientError, match="multipart_threshold must be a positive"): + s3ec.upload_file(str(f), "bucket", "key", multipart_threshold=0) + + def test_negative_threshold_raises(self, tmp_path): + s3ec = _make_client() + f = tmp_path / "test.bin" + f.write_bytes(b"data") + with pytest.raises(S3EncryptionClientError, match="multipart_threshold must be a positive"): + s3ec.upload_file(str(f), "bucket", "key", multipart_threshold=-1) + + def test_zero_chunksize_raises(self, tmp_path): + s3ec = _make_client() + f = tmp_path / "test.bin" + f.write_bytes(b"data") + with pytest.raises(S3EncryptionClientError, match="multipart_chunksize must be a positive"): + s3ec.upload_file(str(f), "bucket", "key", multipart_chunksize=0) + + def test_negative_chunksize_raises(self, tmp_path): + s3ec = _make_client() + f = tmp_path / "test.bin" + f.write_bytes(b"data") + with pytest.raises(S3EncryptionClientError, match="multipart_chunksize must be a positive"): + s3ec.upload_file(str(f), "bucket", "key", multipart_chunksize=-1) + + def test_upload_fileobj_zero_chunksize_raises(self): + + s3ec = _make_client() + with pytest.raises(S3EncryptionClientError, match="multipart_chunksize must be a positive"): + s3ec.upload_fileobj(io.BytesIO(b"data"), "bucket", "key", multipart_chunksize=0) + + def test_upload_fileobj_negative_chunksize_raises(self): + + s3ec = _make_client() + with pytest.raises(S3EncryptionClientError, match="multipart_chunksize must be a positive"): + s3ec.upload_fileobj(io.BytesIO(b"data"), "bucket", "key", multipart_chunksize=-1) + + def test_chunksize_below_5mb_raises(self, tmp_path): + s3ec = _make_client() + f = tmp_path / "test.bin" + f.write_bytes(os.urandom(1024)) + with pytest.raises(S3EncryptionClientError, match="at least.*5 MB"): + s3ec.upload_file(str(f), "bucket", "key", multipart_chunksize=1024 * 1024) + + def test_upload_fileobj_chunksize_below_5mb_raises(self): + + s3ec = _make_client() + with pytest.raises(S3EncryptionClientError, match="at least.*5 MB"): + s3ec.upload_fileobj( + io.BytesIO(b"data"), "bucket", "key", multipart_chunksize=4 * 1024 * 1024 + ) + + def test_upload_file_forwards_s3_params_to_create(self, tmp_path): + """upload_file must forward S3 params like ContentType to create_multipart_upload.""" + s3ec = _make_client() + s3ec.wrapped_s3_client.create_multipart_upload.return_value = { + "UploadId": "uid-fwd-create", + "Bucket": "bucket", + "Key": "key", + } + s3ec.wrapped_s3_client.upload_part.return_value = {"ETag": '"etag"'} + s3ec.wrapped_s3_client.complete_multipart_upload.return_value = {"Location": "s3://..."} + + f = tmp_path / "typed.bin" + f.write_bytes(os.urandom(2048)) + + s3ec.upload_file( + str(f), + "bucket", + "key", + multipart_threshold=1024, + multipart_chunksize=5 * 1024 * 1024, + ContentType="application/json", + Tagging="env=test", + ) + + create_kwargs = s3ec.wrapped_s3_client.create_multipart_upload.call_args[1] + assert create_kwargs["ContentType"] == "application/json" + assert create_kwargs["Tagging"] == "env=test" + + +class TestFileobjLifecycle: + """Unit tests verifying upload_fileobj does not close the caller's file object.""" + + def _setup_client(self): + s3ec = _make_client() + s3ec.wrapped_s3_client.create_multipart_upload.return_value = { + "UploadId": "uid-lifecycle", + "Bucket": "bucket", + "Key": "key", + } + s3ec.wrapped_s3_client.upload_part.return_value = {"ETag": '"etag"'} + s3ec.wrapped_s3_client.complete_multipart_upload.return_value = {"Location": "s3://..."} + return s3ec + + def test_upload_fileobj_does_not_close_caller_stream(self): + + s3ec = self._setup_client() + buf = io.BytesIO(os.urandom(1024)) + + s3ec.upload_fileobj(buf, "bucket", "key") + + assert not buf.closed + + def test_upload_file_closes_its_own_stream(self, tmp_path): + """upload_file opens the file internally and must close it after.""" + s3ec = self._setup_client() + + f = tmp_path / "owned.bin" + f.write_bytes(os.urandom(2048)) + + s3ec.upload_file( + str(f), "bucket", "key", multipart_threshold=1024, multipart_chunksize=5 * 1024 * 1024 + ) + + # We can't directly check the internal file handle is closed, + # but we can verify the upload completed without error and the + # file is still readable (not locked) + assert f.read_bytes() == f.read_bytes() + + +class TestMultipartPartRetry: + """Unit tests for retrying a failed upload_part call.""" + + @pytest.fixture + def pipeline(self): + + keyring, _ = _mock_keyring() + cmm = DefaultCryptoMaterialsManager(keyring) + return MultipartUploadPipeline( + cmm=cmm, + encryption_algorithm=AlgorithmSuite.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY, + ) + + def test_retry_same_part_returns_cached_ciphertext(self, pipeline): + ct1 = pipeline.encrypt_part(1, b"hello") + ct2 = pipeline.encrypt_part(1, b"hello") + assert ct1 == ct2 + + def test_retry_last_part_returns_cached_ciphertext(self, pipeline): + pipeline.encrypt_part(1, b"part one") + ct2 = pipeline.encrypt_part(2, b"part two", is_last=True) + ct2_retry = pipeline.encrypt_part(2, b"part two", is_last=True) + assert ct2 == ct2_retry + + def test_retry_does_not_block_next_part(self, pipeline): + pipeline.encrypt_part(1, b"first") + # Retry part 1 + pipeline.encrypt_part(1, b"first") + # Part 2 should still work + ct = pipeline.encrypt_part(2, b"second", is_last=True) + assert len(ct) == len(b"second") + 16 + + def test_client_upload_part_retry_after_s3_failure(self): + """If S3 upload_part fails, retrying the same part number succeeds.""" + s3ec = _make_client() + s3ec.wrapped_s3_client.create_multipart_upload.return_value = { + "UploadId": "uid-retry-part", + "Bucket": "bucket", + "Key": "key", + } + + s3ec.create_multipart_upload(Bucket="bucket", Key="key") + + # First attempt fails at S3 level + s3ec.wrapped_s3_client.upload_part.side_effect = Exception("network error") + with pytest.raises(Exception, match="network error"): + s3ec.upload_part( + Bucket="bucket", + Key="key", + UploadId="uid-retry-part", + PartNumber=1, + Body=b"data", + ) + + # Retry succeeds + s3ec.wrapped_s3_client.upload_part.side_effect = None + s3ec.wrapped_s3_client.upload_part.return_value = {"ETag": '"etag1"'} + resp = s3ec.upload_part( + Bucket="bucket", + Key="key", + UploadId="uid-retry-part", + PartNumber=1, + Body=b"data", + ) + assert resp["ETag"] == '"etag1"' diff --git a/test/test_stream.py b/test/test_stream.py index 692c8b00..ffa43e1c 100644 --- a/test/test_stream.py +++ b/test/test_stream.py @@ -218,14 +218,20 @@ def test_enter_returns_self(self): def test_wrong_key_raises_error(self): plaintext = b"wrong key test!!" ciphertext, _key, iv = _encrypt_cbc(plaintext) - wrong_key = os.urandom(32) - stream = DecryptingStream( - _make_streaming_body(ciphertext), - _make_cbc_decryptor(wrong_key, iv, len(ciphertext)), - content_length=len(ciphertext), - ) - with pytest.raises(S3EncryptionClientSecurityError, match="Failed to decrypt CBC content"): - stream.read() + # ~1/256 chance random garbage has valid PKCS7 padding, so retry + for _ in range(10): + wrong_key = os.urandom(32) + stream = DecryptingStream( + _make_streaming_body(ciphertext), + _make_cbc_decryptor(wrong_key, iv, len(ciphertext)), + content_length=len(ciphertext), + ) + try: + stream.read() + except S3EncryptionClientSecurityError as e: + assert "Failed to decrypt CBC content" in str(e) + return + pytest.fail("Wrong key did not produce CBC decryption error after 10 attempts") def test_empty_ciphertext(self): key = os.urandom(32)