diff --git a/src/s3_encryption/__init__.py b/src/s3_encryption/__init__.py index f2af6d7a..9b8772d6 100644 --- a/src/s3_encryption/__init__.py +++ b/src/s3_encryption/__init__.py @@ -34,7 +34,29 @@ @define class S3EncryptionClientConfig: - """Configuration object for the S3 Encryption Client.""" + """Configuration for the S3 Encryption Client. + + Attributes: + keyring: Keyring used for encrypting/decrypting data keys. + encryption_algorithm: Algorithm suite for encryption. Defaults to + ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY (V3 key-committing). + commitment_policy: Key commitment policy for encryption and decryption. + Defaults to REQUIRE_ENCRYPT_REQUIRE_DECRYPT. + enable_legacy_unauthenticated_modes: If True, allow decryption of objects + encrypted with legacy CBC algorithm suites. Defaults to False. + cmm: Crypto materials manager. Defaults to a DefaultCryptoMaterialsManager + wrapping the provided keyring. + instruction_file_suffix: Suffix appended to the S3 object key when + fetching instruction files. Defaults to ".instruction". + enable_delayed_authentication: If True, release plaintext from streams + before GCM tag verification. Defaults to False. Has no effect for + CBC encrypted ciphertext, which is always streamed as there is no + authentication tag. + + Raises: + S3EncryptionClientError: If the encryption algorithm is legacy, or if + the algorithm suite is incompatible with the commitment policy. + """ keyring: AbstractKeyring encryption_algorithm: AlgorithmSuite = field( @@ -60,6 +82,15 @@ class S3EncryptionClientConfig: ##% as its associated object suffixed with ".instruction". instruction_file_suffix: str = field(default=".instruction") + ##= specification/s3-encryption/client.md#enable-delayed-authentication + ##= type=implementation + ##% The S3EC MUST support the option to enable or disable Delayed Authentication mode. + + ##= specification/s3-encryption/client.md#enable-delayed-authentication + ##= type=implication + ##% Delayed Authentication mode MUST be set to false by default. + enable_delayed_authentication: bool = field(default=False) + @cmm.default def _default_cmm_for_keyring(self): return DefaultCryptoMaterialsManager(self.keyring) @@ -197,10 +228,18 @@ def on_get_object_after_call(self, parsed, **kwargs): # The parsed response already has the Body as a StreamingBody # We need to read it, decrypt it, and replace it + # content_length is going to the cipher-text's content length + content_length = parsed.get("ContentLength") + if content_length is None: + obj_key = getattr(self._context, _CTX_KEY, None) + raise S3EncryptionClientError( + f"S3 response is missing ContentLength and is invalid. Key: {obj_key}" + ) # Create a response dict that matches what the pipeline expects response = { "Body": parsed.get("Body"), "Metadata": parsed.get("Metadata", {}), + "ContentLength": content_length, } # Create a pipeline and decrypt the data @@ -212,18 +251,15 @@ def on_get_object_after_call(self, parsed, **kwargs): ) decrypted_data = pipeline.decrypt( response, - encryption_context, + instruction_suffix=self.config.instruction_file_suffix, + enable_delayed_authentication=self.config.enable_delayed_authentication, + encryption_context=encryption_context, bucket=getattr(self._context, _CTX_BUCKET, None), key=getattr(self._context, _CTX_KEY, None), - instruction_suffix=self.config.instruction_file_suffix, ) - # Create a new streaming body with the decrypted data - stream = io.BytesIO(decrypted_data) - streaming_body = StreamingBody(stream, len(decrypted_data)) - - # Replace body with decrypted data - parsed["Body"] = streaming_body + # Replace body with decrypting stream + parsed["Body"] = decrypted_data def process_instruction_file(self, parsed): """Process instruction file in plaintext mode. diff --git a/src/s3_encryption/buffered_decrypt.py b/src/s3_encryption/buffered_decrypt.py new file mode 100644 index 00000000..6c305751 --- /dev/null +++ b/src/s3_encryption/buffered_decrypt.py @@ -0,0 +1,20 @@ +# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +"""One Shot decryption into a buffer.""" + +from io import BytesIO + +from botocore.response import StreamingBody + +from s3_encryption.decryptor import Decryptor + + +def one_shot_decrypt(streaming_body: StreamingBody, decryptor: Decryptor): + """Decrypt a streaming object. + + Args: + streaming_body (object): A streaming object. + decryptor (Decryptor): Decryptor object. + """ + plaintext = decryptor.finalize(streaming_body.read()) + return StreamingBody(BytesIO(plaintext), len(plaintext)) diff --git a/src/s3_encryption/decryptor.py b/src/s3_encryption/decryptor.py new file mode 100644 index 00000000..e3d4eece --- /dev/null +++ b/src/s3_encryption/decryptor.py @@ -0,0 +1,141 @@ +# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Decryptor abstractions for S3 Encryption Client.""" + +from abc import ABC, abstractmethod + +from attrs import define, field +from cryptography.exceptions import InvalidTag + +from .exceptions import S3EncryptionClientError, S3EncryptionClientSecurityError + + +class Decryptor(ABC): + """Abstract base class for content decryption. + + Implementations own all cipher and padding state, presenting a uniform + streaming interface to the decrypting stream classes. + """ + + @property + @abstractmethod + def content_length(self) -> int: + """Total byte length of the encrypted content (ciphertext + any trailing tag).""" + + @property + @abstractmethod + def amount_read(self) -> int: + """Number of ciphertext bytes consumed so far.""" + + @abstractmethod + def update(self, data: bytes) -> bytes: + """Process a chunk of ciphertext, returning any available plaintext.""" + + @abstractmethod + def finalize(self, data: bytes) -> bytes: + """Process the final chunk of ciphertext and finalize decryption.""" + + +@define +class AesCbcDecryptor(Decryptor): + """AES-CBC decryptor that owns both the cipher and PKCS7 unpadder. + + Args: + decryptor: A cryptography CBC cipher decryptor context. + unpadder: A cryptography PKCS7 unpadding context. + content_length: Total byte length of the CBC ciphertext. + """ + + _decryptor: object = field() + _unpadder: object = field() + _content_length: int = field() + _amount_read: int = field(init=False, default=0) + + @property + def content_length(self) -> int: # noqa: D102 + return self._content_length + + @property + def amount_read(self) -> int: # noqa: D102 + return self._amount_read + + def update(self, data: bytes) -> bytes: + """Decrypt a chunk and unpad incrementally.""" + self._amount_read += len(data) + plaintext = self._decryptor.update(data) + return self._unpadder.update(plaintext) + + def finalize(self, data: bytes) -> bytes: + """Finalize CBC decryption and flush the unpadder.""" + try: + self._amount_read += len(data) + plaintext = self._decryptor.update(data) if data else b"" + plaintext += self._decryptor.finalize() + return self._unpadder.update(plaintext) + self._unpadder.finalize() + except Exception as e: + raise S3EncryptionClientSecurityError(f"Failed to decrypt CBC content: {e}") from e + + +@define +class AesGcmDecryptor(Decryptor): + """AES-GCM decryptor that handles trailing auth tag verification. + + Args: + decryptor: A cryptography GCM cipher decryptor context. + tag_length: Length of the GCM authentication tag in bytes. + content_length: Total byte length of the encrypted content (ciphertext + tag). + """ + + _decryptor: object = field() + _tag_length: int = field() + _content_length: int = field() + _amount_read: int = field(init=False, default=0) + _tail: bytes = field(init=False, default=b"") + + @property + def content_length(self) -> int: # noqa: D102 + return self._content_length + + @property + def amount_read(self) -> int: # noqa: D102 + return self._amount_read + + @property + def tag_length(self) -> int: + """Length of the GCM authentication tag in bytes.""" + return self._tag_length + + def update(self, data: bytes) -> bytes: + """Decrypt a chunk, holding back the last tag_length bytes. + + A rolling _tail buffer always retains the last tag_length bytes + so the GCM tag is never passed to the cipher's update(). + """ + self._amount_read += len(data) + buf = self._tail + data + if len(buf) <= self._tag_length: + self._tail = buf + return b"" + self._tail = buf[-self._tag_length :] + return self._decryptor.update(buf[: -self._tag_length]) + + def finalize(self, data: bytes) -> bytes: + """Finalize decryption using the buffered tag.""" + try: + self._amount_read += len(data) + buf = self._tail + data + if len(buf) < self._tag_length: + raise S3EncryptionClientError( + f"Incomplete GCM data: expected at least {self._tag_length} " + f"tag bytes, got {len(buf)} total remaining bytes." + ) + tag = buf[-self._tag_length :] + ciphertext = buf[: -self._tag_length] + plaintext = self._decryptor.update(ciphertext) if ciphertext else b"" + return plaintext + self._decryptor.finalize_with_tag(tag) + except S3EncryptionClientError: + raise + except InvalidTag as e: + raise S3EncryptionClientSecurityError(f"Failed to decrypt Object: {e}") from e + except Exception as e: + raise S3EncryptionClientError(f"Failed to decrypt Object: {e}") from e diff --git a/src/s3_encryption/materials/materials.py b/src/s3_encryption/materials/materials.py index f2e8fd4f..80f682f0 100644 --- a/src/s3_encryption/materials/materials.py +++ b/src/s3_encryption/materials/materials.py @@ -172,6 +172,26 @@ def kc_gcm_iv(self) -> bytes: ##% the IV used in the AES-GCM content encryption/decryption MUST consist entirely of bytes with the value 0x01. return b"\x01" * self.cipher_iv_length_bytes + @property + def cipher_block_size_bits(self) -> int: + """Block size of the cipher in bits.""" + return self._cipher_block_size_bits + + @property + def cipher_block_size_bytes(self) -> int: + """Block size of the cipher in bytes.""" + return self._cipher_block_size_bits // 8 + + @property + def cipher_tag_length_bits(self) -> int: + """Authentication tag length of the cipher in bits.""" + return self._cipher_tag_length_bits + + @property + def cipher_tag_length_bytes(self) -> int: + """Authentication tag length of the cipher in bytes.""" + return self._cipher_tag_length_bits // 8 + class CommitmentPolicy(Enum): """Commitment policies controlling key-commitment behavior.""" diff --git a/src/s3_encryption/pipelines.py b/src/s3_encryption/pipelines.py index d0e9ba79..5561047a 100644 --- a/src/s3_encryption/pipelines.py +++ b/src/s3_encryption/pipelines.py @@ -11,11 +11,14 @@ import os from attrs import define, field +from botocore.response import StreamingBody from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.primitives.ciphers.aead import AESGCM from cryptography.hazmat.primitives.padding import PKCS7 -from .exceptions import S3EncryptionClientError, S3EncryptionClientSecurityError +from .buffered_decrypt import one_shot_decrypt +from .decryptor import AesCbcDecryptor, AesGcmDecryptor +from .exceptions import S3EncryptionClientError from .instruction_file import fetch_instruction_file from .key_derivation import derive_keys, verify_commitment from .materials.crypto_materials_manager import AbstractCryptoMaterialsManager @@ -27,6 +30,7 @@ EncryptionMaterials, ) from .metadata import ObjectMetadata +from .stream import DecryptingStream @define @@ -222,26 +226,28 @@ def _determine_algorithm_suite(self, metadata) -> AlgorithmSuite: def decrypt( self, response, + instruction_suffix, + enable_delayed_authentication, encryption_context=None, bucket=None, key=None, - instruction_suffix=".instruction", - ): + ) -> StreamingBody: """Decrypt the data after it is retrieved from S3. Args: response (dict): The response from S3 containing the encrypted data and metadata + instruction_suffix (str): suffix for instruction file + enable_delayed_authentication (bool): If True, release plaintext before GCM tag verification. encryption_context (dict, optional): Additional context for decryption bucket (str, optional): S3 bucket name (required for instruction file) key (str, optional): S3 object key (required for instruction file) - instruction_suffix(str, optional): suffix for instruction file; defaults to ".instruction". Returns: - bytes: The decrypted data + A botocore.response.StreamingBody of plain-text """ # Convert the metadata dictionary to an ObjectMetadata instance - # TODO: Stream + Buffered Decryption - encrypted_data = response.get("Body").read() + streaming_body: StreamingBody = response.get("Body") + content_length = response.get("ContentLength") encryption_metadata = response.get("Metadata", {}) metadata = ObjectMetadata.from_dict(encryption_metadata) @@ -254,10 +260,12 @@ def decrypt( if self.s3_client is None: raise S3EncryptionClientError("s3_client required to fetch instruction file") - # TODO: we should validate that these parameters must be None - # when not in instruction file mode. if bucket is None or key is None: raise S3EncryptionClientError("Bucket and key required to fetch instruction file") + if instruction_suffix is None: + raise S3EncryptionClientError( + "instruction_suffix required to fetch instruction file" + ) instruction_key = key + instruction_suffix instruction_metadata = fetch_instruction_file(self.s3_client, bucket, instruction_key) @@ -380,24 +388,145 @@ def decrypt( ##= type=implementation ##% When the commitment policy is REQUIRE_ENCRYPT_ALLOW_DECRYPT, the S3EC MUST allow decryption using algorithm suites which do not support key commitment. - # Perform decryption based on algorithm suite + if enable_delayed_authentication is None: + raise S3EncryptionClientError("enable_delayed_authentication must be explicitly set") + + # Build decryptor and return streaming wrapper based on algorithm suite match dec_materials.algorithm_suite: case AlgorithmSuite.ALG_AES_256_CBC_IV16_NO_KDF: - return self._decrypt_cbc_content(dec_materials, encrypted_data) + return self._decrypt_cbc_streaming(dec_materials, streaming_body, content_length) case AlgorithmSuite.ALG_AES_256_GCM_IV12_TAG16_NO_KDF: - ##= specification/s3-encryption/encryption.md#alg-aes-256-gcm-iv12-tag16-no-kdf - ##= type=implementation - ##% The client MUST NOT provide any AAD when encrypting with - ##% ALG_AES_256_GCM_IV12_TAG16_NO_KDF. - aesgcm = AESGCM(dec_materials.plaintext_data_key) - return aesgcm.decrypt( - nonce=dec_materials.iv, data=encrypted_data, associated_data=None + return self._decrypt_gcm_streaming( + dec_materials, + streaming_body, + enable_delayed_authentication, + content_length, ) case AlgorithmSuite.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY: - return self._decrypt_kc_gcm_content(dec_materials, encrypted_data, metadata) + return self._decrypt_kc_gcm_streaming( + dec_materials, + metadata, + streaming_body, + enable_delayed_authentication, + content_length, + ) case _: raise S3EncryptionClientError("Unknown algorithm suite!") + @staticmethod + def _decrypt_cbc_streaming(dec_materials, streaming_body, content_length): + """Decrypt content encrypted with ALG_AES_256_CBC_IV16_NO_KDF. + + CBC is always streamed (no buffered mode) since it has no auth tag. + """ + ##= specification/s3-encryption/decryption.md#cbc-decryption + ##= type=implementation + ##% If an object is encrypted with ALG_AES_256_CBC_IV16_NO_KDF and + ##% [legacy unauthenticated algorithm suites](#legacy-decryption) is enabled, + ##% then the S3EC MUST create a cipher with AES in CBC Mode with PKCS5Padding or + ##% PKCS7Padding compatible padding for a 16-byte block cipher + ##% (example: for the Java JCE, this is "AES/CBC/PKCS5Padding"). + ##= specification/s3-encryption/decryption.md#cbc-decryption + ##= type=implementation + ##% If the cipher object cannot be created as described above, + ##% Decryption MUST fail. + ##= specification/s3-encryption/decryption.md#cbc-decryption + ##= type=implementation + ##% The error SHOULD detail why the cipher could not be initialized + ##% (such as CBC or PKCS5Padding is not supported by the underlying crypto provider). + cipher = Cipher( + algorithms.AES(dec_materials.plaintext_data_key), + modes.CBC(dec_materials.iv), + ) + # Remove PKCS7 padding (compatible with PKCS5Padding for 16-byte block ciphers) + unpadder = PKCS7(dec_materials.algorithm_suite.cipher_block_size_bits).unpadder() + decryptor = AesCbcDecryptor(cipher.decryptor(), unpadder, content_length=content_length) + return DecryptingStream(streaming_body, decryptor, content_length=content_length) + + @staticmethod + def _decrypt_gcm_streaming( + dec_materials, streaming_body, enable_delayed_authentication, content_length + ): + """Decrypt content encrypted with ALG_AES_256_GCM_IV12_TAG16_NO_KDF.""" + ##= specification/s3-encryption/encryption.md#alg-aes-256-gcm-iv12-tag16-no-kdf + ##= type=implementation + ##% The client MUST NOT provide any AAD when encrypting with + ##% ALG_AES_256_GCM_IV12_TAG16_NO_KDF. + cipher = Cipher( + algorithms.AES(dec_materials.plaintext_data_key), modes.GCM(dec_materials.iv) + ) + decryptor = AesGcmDecryptor( + cipher.decryptor(), + tag_length=dec_materials.algorithm_suite.cipher_tag_length_bytes, + content_length=content_length, + ) + if enable_delayed_authentication: + ##= specification/s3-encryption/client.md#enable-delayed-authentication + ##= type=implementation + ##% When enabled, the S3EC MAY release plaintext from a stream which has not been authenticated. + return DecryptingStream(streaming_body, decryptor, content_length=content_length) + ##= specification/s3-encryption/client.md#enable-delayed-authentication + ##= type=implementation + ##% When disabled the S3EC MUST NOT release plaintext from a stream which has not been authenticated. + return one_shot_decrypt(streaming_body, decryptor) + + def _decrypt_kc_gcm_streaming( + self, dec_materials, metadata, streaming_body, enable_delayed_authentication, content_length + ): + """Decrypt content encrypted with ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY. + + Performs HKDF key derivation, key commitment verification, then returns + a streaming decryptor. + """ + message_id = base64.b64decode(metadata.message_id_v3) + stored_commitment = base64.b64decode(metadata.key_commitment_v3) + + ##= specification/s3-encryption/decryption.md#decrypting-with-commitment + ##= type=implementation + ##% When using an algorithm suite which supports key commitment, the client MUST verify + ##% that the [derived key commitment](./key-derivation.md#hkdf-operation) contains the + ##% same bytes as the stored key commitment retrieved from the stored object's metadata. + ##= specification/s3-encryption/decryption.md#decrypting-with-commitment + ##= type=implementation + ##% When using an algorithm suite which supports key commitment, the client MUST verify the key commitment values match before deriving + ##% the [derived encryption key](./key-derivation.md#hkdf-operation). + derived_encryption_key, derived_commitment = derive_keys( + dec_materials.plaintext_data_key, message_id, dec_materials.algorithm_suite + ) + verify_commitment(stored_commitment, derived_commitment) + + ##= specification/s3-encryption/key-derivation.md#hkdf-operation + ##= type=implementation + ##% When encrypting or decrypting with ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY, + ##% the IV used in the AES-GCM content encryption/decryption MUST consist entirely of bytes with the value 0x01. + ##= specification/s3-encryption/key-derivation.md#hkdf-operation + ##= type=implementation + ##% The IV's total length MUST match the IV length defined by the algorithm suite. + ##= specification/s3-encryption/key-derivation.md#hkdf-operation + ##= type=implementation + ##% The client MUST initialize the cipher, or call an AES-GCM encryption API, with the derived encryption key, an IV containing only bytes with the value 0x01, + ##% and the tag length defined in the Algorithm Suite when encrypting or decrypting with ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY. + ##= specification/s3-encryption/key-derivation.md#hkdf-operation + ##= type=implementation + ##% The client MUST set the AAD to the Algorithm Suite ID represented as bytes. + cipher = Cipher( + algorithms.AES(derived_encryption_key), + modes.GCM(dec_materials.algorithm_suite.kc_gcm_iv), + ) + cipher_decryptor = cipher.decryptor() + cipher_decryptor.authenticate_additional_data(dec_materials.algorithm_suite.suite_id_bytes) + decryptor = AesGcmDecryptor( + cipher_decryptor, + tag_length=dec_materials.algorithm_suite.cipher_tag_length_bytes, + content_length=content_length, + ) + if enable_delayed_authentication: + return DecryptingStream(streaming_body, decryptor, content_length=content_length) + ##= specification/s3-encryption/client.md#enable-delayed-authentication + ##= type=implementation + ##% When disabled the S3EC MUST NOT release plaintext from a stream which has not been authenticated. + return one_shot_decrypt(streaming_body, decryptor) + def _decrypt_v2(self, metadata, encryption_context) -> DecryptionMaterials: """Prepare V2 decryption materials.""" return self._decrypt_v1_v2( @@ -440,40 +569,6 @@ def _decrypt_v1_v2( return self.cmm.decrypt_materials(dec_materials) - def _decrypt_cbc_content(self, dec_materials, encrypted_data): - """Decrypt content encrypted with ALG_AES_256_CBC_IV16_NO_KDF.""" - ##= specification/s3-encryption/decryption.md#cbc-decryption - ##= type=implementation - ##% If an object is encrypted with ALG_AES_256_CBC_IV16_NO_KDF and - ##% [legacy unauthenticated algorithm suites](#legacy-decryption) is enabled, - ##% then the S3EC MUST create a cipher with AES in CBC Mode with PKCS5Padding or - ##% PKCS7Padding compatible padding for a 16-byte block cipher - ##% (example: for the Java JCE, this is "AES/CBC/PKCS5Padding"). - ##= specification/s3-encryption/decryption.md#cbc-decryption - ##= type=implementation - ##% If the cipher object cannot be created as described above, - ##% Decryption MUST fail. - ##= specification/s3-encryption/decryption.md#cbc-decryption - ##= type=implementation - ##% The error SHOULD detail why the cipher could not be initialized - ##% (such as CBC or PKCS5Padding is not supported by the underlying crypto provider). - try: - cipher = Cipher( - algorithms.AES(dec_materials.plaintext_data_key), - modes.CBC(dec_materials.iv), - ) - decryptor = cipher.decryptor() - padded_plaintext = decryptor.update(encrypted_data) + decryptor.finalize() - - # Remove PKCS7 padding (compatible with PKCS5Padding for 16-byte block ciphers) - unpadder = PKCS7(128).unpadder() - return unpadder.update(padded_plaintext) + unpadder.finalize() - except Exception as e: - raise S3EncryptionClientSecurityError( - f"Failed to decrypt CBC content: {e}. " - "Ensure the underlying crypto provider supports AES/CBC/PKCS7Padding." - ) from e - ##= specification/s3-encryption/data-format/content-metadata.md#v3-only ##% The V3 format uses compression here such that each wrapping algorithm is represented by a two digit string. ##= specification/s3-encryption/data-format/content-metadata.md#v3-only @@ -527,50 +622,3 @@ def _decrypt_v3(self, metadata, encryption_context) -> DecryptionMaterials: ) return self.cmm.decrypt_materials(dec_materials) - - def _decrypt_kc_gcm_content(self, dec_materials, encrypted_data, metadata): - """Decrypt content encrypted with ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY. - - Performs HKDF key derivation, key commitment verification, and AES-GCM decryption. - """ - message_id = base64.b64decode(metadata.message_id_v3) - stored_commitment = base64.b64decode(metadata.key_commitment_v3) - - ##= specification/s3-encryption/encryption.md#alg-aes-256-gcm-hkdf-sha512-commit-key - ##= type=implementation - ##% The client MUST use HKDF to derive the key commitment value and the derived encrypting key as described in [Key Derivation](key-derivation.md). - derived_encryption_key, derived_commitment = derive_keys( - dec_materials.plaintext_data_key, message_id, dec_materials.algorithm_suite - ) - - ##= specification/s3-encryption/decryption.md#decrypting-with-commitment - ##= type=implementation - ##% When using an algorithm suite which supports key commitment, the client MUST verify - ##% that the [derived key commitment](./key-derivation.md#hkdf-operation) contains the - ##% same bytes as the stored key commitment retrieved from the stored object's metadata. - ##= specification/s3-encryption/decryption.md#decrypting-with-commitment - ##= type=implementation - ##% When using an algorithm suite which supports key commitment, the client MUST verify the key commitment values match before deriving - ##% the [derived encryption key](./key-derivation.md#hkdf-operation). - verify_commitment(stored_commitment, derived_commitment) - - ##= specification/s3-encryption/key-derivation.md#hkdf-operation - ##= type=implementation - ##% When encrypting or decrypting with ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY, - ##% the IV used in the AES-GCM content encryption/decryption MUST consist entirely of bytes with the value 0x01. - ##= specification/s3-encryption/key-derivation.md#hkdf-operation - ##= type=implementation - ##% The IV's total length MUST match the IV length defined by the algorithm suite. - ##= specification/s3-encryption/key-derivation.md#hkdf-operation - ##= type=implementation - ##% The client MUST initialize the cipher, or call an AES-GCM encryption API, with the derived encryption key, an IV containing only bytes with the value 0x01, - ##% and the tag length defined in the Algorithm Suite when encrypting or decrypting with ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY. - ##= specification/s3-encryption/key-derivation.md#hkdf-operation - ##= type=implementation - ##% The client MUST set the AAD to the Algorithm Suite ID represented as bytes. - aesgcm = AESGCM(derived_encryption_key) - return aesgcm.decrypt( - nonce=dec_materials.algorithm_suite.kc_gcm_iv, - data=encrypted_data, - associated_data=dec_materials.algorithm_suite.suite_id_bytes, - ) diff --git a/src/s3_encryption/stream.py b/src/s3_encryption/stream.py new file mode 100644 index 00000000..a4e85a74 --- /dev/null +++ b/src/s3_encryption/stream.py @@ -0,0 +1,189 @@ +# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Streaming decryption support for S3 Encryption Client.""" + +import io + +from attrs import define, field +from botocore.exceptions import IncompleteReadError +from botocore.response import StreamingBody + +from .decryptor import Decryptor + +##= specification/s3-encryption/client.md#set-buffer-size +##= type=exception +##= reason=Optional Feature that is a two-way door to implement later +##% The S3EC SHOULD accept a configurable buffer size which refers to the maximum ciphertext length in bytes to store in memory when Delayed Authentication mode is disabled. +##= specification/s3-encryption/client.md#set-buffer-size +##= type=exception +##= reason=Optional Feature that is a two-way door to implement later +##% If Delayed Authentication mode is enabled, and the buffer size has been set to a value other than its default, the S3EC MUST throw an exception. +##= specification/s3-encryption/client.md#set-buffer-size +##= type=exception +##= reason=Optional Feature that is a two-way door to implement later +##% If Delayed Authentication mode is disabled, and no buffer size is provided, the S3EC MUST set the buffer size to a reasonable default. + + +_DEFAULT_CHUNK_SIZE = 1024 + + +##= specification/s3-encryption/client.md#enable-delayed-authentication +##= type=implementation +##% When enabled, the S3EC MAY release plaintext from a stream which has not been authenticated. +# slots=False because StreamingBody extends IOBase which already has __weakref__. +@define(slots=False) +class DecryptingStream(StreamingBody): + """A stream that releases plaintext incrementally before full authentication. + + Extends botocore's StreamingBody so it can be used as a drop-in replacement + for parsed["Body"]. All StreamingBody methods are explicitly overridden. + """ + + # This stream is ALMOST cipher-agnostic — the Decryptor handles ALMOST all algorithm details. + # Ciphertext is fed through decryptor.update() incrementally, and + # decryptor.finalize() is called with any trailing data when the body is exhausted. + # + # ALMOST :: The AES-GCM tag is problematic when combined with iterators that can split + # the tag over two reads. To accommodate this, read() has a while loop with 3 return conditions. + # See inline comments of read for more details. + + _body: object = field() + _decryptor: Decryptor = field() + _content_length: int = field() + _bytes_consumed: int = field(init=False, default=0) + _finalized: bool = field(init=False, default=False) + + def __attrs_post_init__(self): # noqa: D105 + super().__init__(io.BytesIO(), content_length=self._content_length) + + def readable(self): # noqa: D102 + return not self._finalized + + def read(self, amt=None): + """Read and decrypt ciphertext, releasing plaintext incrementally. + + Args: + amt: Number of bytes to read. If None, reads all remaining data. + + Returns: + bytes: Decrypted plaintext bytes. + """ + if self._finalized: + return b"" + + # Loop until the decryptor produces non-empty plaintext. + # The GCM decryptor's tail buffer may absorb small reads entirely + # (returning b"" from update) while it holds back the trailing auth + # tag. Looping prevents callers from seeing spurious empty bytes + # mid-stream, which would break `while chunk := stream.read(amt)`. + result = b"" + while not result: + remaining = self._content_length - self._bytes_consumed + if remaining <= 0: + # All content_length bytes consumed — finalize with no extra data. + return self._finalize(b"") + + # Never read past content_length; cap at amt if provided. + to_read = remaining if amt is None else min(amt, remaining) + raw = self._body.read(to_read) + + if not raw: + # Underlying stream exhausted early — finalize with what we have. + return self._finalize(b"") + + self._bytes_consumed += len(raw) + remaining = self._content_length - self._bytes_consumed + + if remaining <= 0: + # This is the last chunk — pass it to finalize so the decryptor + # can split off the GCM tag (or flush CBC padding) and verify. + return self._finalize(raw) + + # Feed ciphertext to the decryptor. For GCM, the tail buffer holds + # back the last tag_length bytes, so update() may return b"" if + # the chunk was entirely absorbed into the buffer. + result = self._decryptor.update(raw) + return result + + def _finalize(self, trailing_data): + """Finalize decryption with any trailing data.""" + if self._finalized: + return b"" + self._finalized = True + plaintext = self._decryptor.finalize(trailing_data) + self._verify_content_length() + return plaintext + + def readinto(self, b): + """Read bytes into a pre-allocated, writable bytes-like object b. + + Returns the number of bytes decrypted. + Note: CBC Padding and GCM tag will be removed, so bytes read MAYBE greater than bytes decrypted. + """ + data = self.read(len(b)) + n = len(data) + b[:n] = data + return n + + def readlines(self): # noqa: D102 + return self.read().splitlines(True) + + def __iter__(self): + """Return an iterator to yield 1k chunks from the decryption stream.""" + return self + + def __next__(self): + """Return the next 1k chunk from the decryption stream.""" + chunk = self.read(_DEFAULT_CHUNK_SIZE) + if chunk: + return chunk + raise StopIteration() + + next = __next__ + + def iter_lines(self, chunk_size=_DEFAULT_CHUNK_SIZE, keepends=False): + """Return an iterator to yield lines from the decryption stream. + + This is achieved by reading chunk of bytes (of size chunk_size) at a + time from the chipher-text stream, decrypting them, and then yielding lines from there. + """ + pending = b"" + for chunk in self.iter_chunks(chunk_size): + lines = (pending + chunk).splitlines(True) + for line in lines[:-1]: + yield line.splitlines(keepends)[0] + pending = lines[-1] + if pending: + yield pending.splitlines(keepends)[0] + + def iter_chunks(self, chunk_size=_DEFAULT_CHUNK_SIZE): + """Return an iterator to yield chunks of chunk_size bytes from the raw stream.""" + while True: + chunk = self.read(chunk_size) + if chunk == b"": + break + yield chunk + + def _verify_content_length(self): + """Verify that the decryptor consumed exactly content_length bytes.""" + if self._decryptor.content_length is not None and not ( + self._decryptor.amount_read == self._content_length + ): + raise IncompleteReadError( + actual_bytes=self._decryptor.amount_read, + expected_bytes=self._decryptor.content_length, + ) + + def tell(self): # noqa: D102 + return self._bytes_consumed + + def close(self): + """Close the underlying cipher-text stream.""" + if hasattr(self._body, "close"): + self._body.close() + + def __enter__(self): # noqa: D105 + return self + + def __exit__(self, *args): # noqa: D105 + self.close() diff --git a/test/integration/test_i_s3_encryption.py b/test/integration/test_i_s3_encryption.py index 15133c05..36f826bd 100644 --- a/test/integration/test_i_s3_encryption.py +++ b/test/integration/test_i_s3_encryption.py @@ -252,3 +252,26 @@ def test_put_object_uses_configured_algorithm(algorithm_suite, commitment_policy meta_key, expected_value = _EXPECTED_ALGORITHM_METADATA[algorithm_suite] assert meta_key in metadata, f"Expected metadata key '{meta_key}' not found in {metadata}" assert metadata[meta_key] == expected_value + + +##= specification/s3-encryption/client.md#enable-delayed-authentication +##= type=test +##% The S3EC MUST support the option to enable or disable Delayed Authentication mode. +@pytest.mark.parametrize("enable_delayed_auth", [False, True], ids=["buffered", "delayed-auth"]) +def test_delayed_authentication_mode(enable_delayed_auth): + """S3EC MUST support enabling and disabling delayed authentication.""" + key = _unique_key("delayed-auth-mode-") + data = b"test delayed authentication mode" + + kms_client = boto3.client("kms", region_name=region) + keyring = KmsKeyring(kms_client, kms_key_id) + wrapped_client = boto3.client("s3") + config = S3EncryptionClientConfig( + keyring, + enable_delayed_authentication=enable_delayed_auth, + ) + s3ec = S3EncryptionClient(wrapped_client, config) + + s3ec.put_object(Bucket=bucket, Key=key, Body=data) + response = s3ec.get_object(Bucket=bucket, Key=key) + assert response["Body"].read() == data diff --git a/test/integration/test_i_s3_encryption_instruction_file.py b/test/integration/test_i_s3_encryption_instruction_file.py index f4f70704..6c93d832 100644 --- a/test/integration/test_i_s3_encryption_instruction_file.py +++ b/test/integration/test_i_s3_encryption_instruction_file.py @@ -24,6 +24,8 @@ "v2_instruction_file": "static-v2-instruction-file-from-java-v4", "v3_instruction_file": "static-v3-instruction-file-from-java-v4", "negative_v2_instruction_file": "NEGATIVE-static-v2-instruction-file-test-from-java-v4", + "large_v2_instruction_file": "static-large-v2-instruction-file-from-java-v4-52428800", + "large_v3_instruction_file": "static-large-v3-instruction-file-from-java-v4-52428800", } @@ -53,7 +55,8 @@ def test_decrypt_v1_instruction_file(): print("Success! V1 instruction file decryption completed.") -def test_decrypt_v2_instruction_file(): +@pytest.mark.parametrize("delayed_auth", [False, True], ids=["buffered", "delayed-auth"]) +def test_decrypt_v2_instruction_file(delayed_auth): """Test decrypting V2 object with instruction file. V2 format uses ALG_AES_256_GCM_IV12_TAG16_NO_KDF (no key commitment). @@ -68,6 +71,7 @@ def test_decrypt_v2_instruction_file(): keyring, encryption_algorithm=AlgorithmSuite.ALG_AES_256_GCM_IV12_TAG16_NO_KDF, commitment_policy=CommitmentPolicy.FORBID_ENCRYPT_ALLOW_DECRYPT, + enable_delayed_authentication=delayed_auth, ) s3ec = S3EncryptionClient(wrapped_client, config) @@ -145,7 +149,8 @@ def test_decrypt_v3_instruction_file_custom_suffix(): print("Success! V3 custom suffix instruction file decryption completed.") -def test_decrypt_v2_instruction_file_custom_suffix(): +@pytest.mark.parametrize("delayed_auth", [False, True], ids=["buffered", "delayed-auth"]) +def test_decrypt_v2_instruction_file_custom_suffix(delayed_auth): """Test decrypting V2 object with a custom instruction file suffix.""" key = TEST_OBJECTS["v2_instruction_file"] @@ -157,6 +162,7 @@ def test_decrypt_v2_instruction_file_custom_suffix(): encryption_algorithm=AlgorithmSuite.ALG_AES_256_GCM_IV12_TAG16_NO_KDF, instruction_file_suffix=".custom-suffix-instruction", commitment_policy=CommitmentPolicy.FORBID_ENCRYPT_ALLOW_DECRYPT, + enable_delayed_authentication=delayed_auth, ) s3ec = S3EncryptionClient(wrapped_client, config) @@ -165,3 +171,48 @@ def test_decrypt_v2_instruction_file_custom_suffix(): assert output == "static-v2-instruction-file-from-java-v4" print("Success! V2 custom suffix instruction file decryption completed.") + + +LARGE_FILE_SIZE = 52428800 # 50 MB + + +def test_decrypt_large_v2_instruction_file_delayed_auth(): + """Test streaming decryption of a 50 MB V2 object with delayed authentication.""" + key = TEST_OBJECTS["large_v2_instruction_file"] + + kms_client = boto3.client("kms", region_name=region) + keyring = KmsKeyring(kms_client, kms_key_id) + wrapped_client = boto3.client("s3") + config = S3EncryptionClientConfig( + keyring, + enable_delayed_authentication=True, + encryption_algorithm=AlgorithmSuite.ALG_AES_256_GCM_IV12_TAG16_NO_KDF, + commitment_policy=CommitmentPolicy.FORBID_ENCRYPT_ALLOW_DECRYPT, + ) + s3ec = S3EncryptionClient(wrapped_client, config) + + response = s3ec.get_object(Bucket=bucket, Key=key) + total = 0 + while chunk := response["Body"].read(65536): + total += len(chunk) + + assert total == LARGE_FILE_SIZE + + +@pytest.mark.skip(reason="V3 large file not yet written to static bucket") +def test_decrypt_large_v3_instruction_file_delayed_auth(): + """Test streaming decryption of a 50 MB V3 object with delayed authentication.""" + key = TEST_OBJECTS["large_v3_instruction_file"] + + kms_client = boto3.client("kms", region_name=region) + keyring = KmsKeyring(kms_client, kms_key_id) + wrapped_client = boto3.client("s3") + config = S3EncryptionClientConfig(keyring, enable_delayed_authentication=True) + s3ec = S3EncryptionClient(wrapped_client, config) + + response = s3ec.get_object(Bucket=bucket, Key=key) + total = 0 + while chunk := response["Body"].read(65536): + total += len(chunk) + + assert total == LARGE_FILE_SIZE diff --git a/test/integration/test_i_s3_encryption_streaming.py b/test/integration/test_i_s3_encryption_streaming.py new file mode 100644 index 00000000..530959bb --- /dev/null +++ b/test/integration/test_i_s3_encryption_streaming.py @@ -0,0 +1,193 @@ +# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Integration tests for streaming decryption modes (buffered vs delayed-auth). + +These tests verify that BufferedDecryptingGCMStream and DelayedAuthGCMDecryptingStream +produce correct plaintext for real S3 round-trips across algorithm suites. +""" + +import os +from datetime import datetime + +import boto3 +import pytest +from botocore.response import StreamingBody + +from s3_encryption import S3EncryptionClient, S3EncryptionClientConfig +from s3_encryption.materials.kms_keyring import KmsKeyring +from s3_encryption.materials.materials import AlgorithmSuite, CommitmentPolicy +from s3_encryption.stream import DecryptingStream + +bucket = os.environ.get("CI_S3_BUCKET", "s3ec-python-github-test-bucket") +region = os.environ.get("CI_AWS_REGION", "us-west-2") +kms_key_id = os.environ.get( + "CI_KMS_KEY_ALIAS", "arn:aws:kms:us-west-2:370957321024:alias/S3EC-Python-Github-KMS-Key" +) + +GCM_CONFIGS = [ + pytest.param( + AlgorithmSuite.ALG_AES_256_GCM_IV12_TAG16_NO_KDF, + CommitmentPolicy.FORBID_ENCRYPT_ALLOW_DECRYPT, + id="AES_GCM", + ), + pytest.param( + AlgorithmSuite.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY, + CommitmentPolicy.REQUIRE_ENCRYPT_REQUIRE_DECRYPT, + id="KC_GCM", + ), +] + + +def _make_client(algorithm_suite, commitment_policy, delayed_auth): + kms_client = boto3.client("kms", region_name=region) + keyring = KmsKeyring(kms_client, kms_key_id) + wrapped_client = boto3.client("s3") + config = S3EncryptionClientConfig( + keyring, + encryption_algorithm=algorithm_suite, + commitment_policy=commitment_policy, + enable_delayed_authentication=delayed_auth, + ) + return S3EncryptionClient(wrapped_client, config) + + +def _unique_key(prefix): + return prefix + datetime.now().strftime("%Y-%m-%d-%H:%M:%S-%f") + + +# --------------------------------------------------------------------------- +# Buffered mode: verifies tag before releasing plaintext +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("algorithm_suite,commitment_policy", GCM_CONFIGS) +def test_buffered_roundtrip(algorithm_suite, commitment_policy): + """Buffered mode decrypts correctly for a simple round-trip.""" + key = _unique_key("buffered-rt-") + data = b"buffered mode round trip test data" + + s3ec = _make_client(algorithm_suite, commitment_policy, delayed_auth=False) + s3ec.put_object(Bucket=bucket, Key=key, Body=data) + response = s3ec.get_object(Bucket=bucket, Key=key) + + body = response["Body"] + assert isinstance(body, StreamingBody) + assert body.read() == data + + +@pytest.mark.parametrize("algorithm_suite,commitment_policy", GCM_CONFIGS) +def test_buffered_partial_reads(algorithm_suite, commitment_policy): + """Buffered mode supports partial read(amt) calls.""" + key = _unique_key("buffered-partial-") + data = os.urandom(1024) + + s3ec = _make_client(algorithm_suite, commitment_policy, delayed_auth=False) + s3ec.put_object(Bucket=bucket, Key=key, Body=data) + response = s3ec.get_object(Bucket=bucket, Key=key) + + result = b"" + while chunk := response["Body"].read(100): + result += chunk + assert result == data + + +# --------------------------------------------------------------------------- +# Delayed-auth mode: releases plaintext before tag verification +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("algorithm_suite,commitment_policy", GCM_CONFIGS) +def test_delayed_auth_roundtrip(algorithm_suite, commitment_policy): + """Delayed-auth mode decrypts correctly for a simple round-trip.""" + key = _unique_key("delayed-rt-") + data = b"delayed auth round trip test data" + + s3ec = _make_client(algorithm_suite, commitment_policy, delayed_auth=True) + s3ec.put_object(Bucket=bucket, Key=key, Body=data) + response = s3ec.get_object(Bucket=bucket, Key=key) + + body = response["Body"] + assert isinstance(body, DecryptingStream) + assert body.read() == data + + +@pytest.mark.parametrize("algorithm_suite,commitment_policy", GCM_CONFIGS) +def test_delayed_auth_chunked_reads(algorithm_suite, commitment_policy): + """Delayed-auth mode supports chunked streaming reads.""" + key = _unique_key("delayed-chunked-") + data = os.urandom(4096) + + s3ec = _make_client(algorithm_suite, commitment_policy, delayed_auth=True) + s3ec.put_object(Bucket=bucket, Key=key, Body=data) + response = s3ec.get_object(Bucket=bucket, Key=key) + + result = b"" + while chunk := response["Body"].read(256): + result += chunk + assert result == data + + +# --------------------------------------------------------------------------- +# Both modes produce identical plaintext +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("algorithm_suite,commitment_policy", GCM_CONFIGS) +def test_buffered_and_delayed_produce_same_plaintext(algorithm_suite, commitment_policy): + """Both streaming modes must produce identical plaintext for the same object.""" + key = _unique_key("same-plaintext-") + data = os.urandom(2048) + + # Encrypt once + writer = _make_client(algorithm_suite, commitment_policy, delayed_auth=False) + writer.put_object(Bucket=bucket, Key=key, Body=data) + + # Decrypt with buffered + buffered = _make_client(algorithm_suite, commitment_policy, delayed_auth=False) + resp_buf = buffered.get_object(Bucket=bucket, Key=key) + plaintext_buf = resp_buf["Body"].read() + + # Decrypt with delayed-auth + delayed = _make_client(algorithm_suite, commitment_policy, delayed_auth=True) + resp_del = delayed.get_object(Bucket=bucket, Key=key) + plaintext_del = resp_del["Body"].read() + + assert plaintext_buf == plaintext_del == data + + +# --------------------------------------------------------------------------- +# Empty body +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("delayed_auth", [False, True], ids=["buffered", "delayed-auth"]) +@pytest.mark.parametrize("algorithm_suite,commitment_policy", GCM_CONFIGS) +def test_empty_body_roundtrip(algorithm_suite, commitment_policy, delayed_auth): + """Both modes handle empty plaintext correctly.""" + key = _unique_key("empty-stream-") + + s3ec = _make_client(algorithm_suite, commitment_policy, delayed_auth=delayed_auth) + s3ec.put_object(Bucket=bucket, Key=key, Body=b"") + response = s3ec.get_object(Bucket=bucket, Key=key) + assert response["Body"].read() == b"" + + +# --------------------------------------------------------------------------- +# Large object streaming +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("algorithm_suite,commitment_policy", GCM_CONFIGS) +def test_delayed_auth_large_object(algorithm_suite, commitment_policy): + """Delayed-auth streams a 1 MB object correctly via chunked reads.""" + key = _unique_key("delayed-large-") + data = os.urandom(1024 * 1024) # 1 MB + + s3ec = _make_client(algorithm_suite, commitment_policy, delayed_auth=True) + s3ec.put_object(Bucket=bucket, Key=key, Body=data) + response = s3ec.get_object(Bucket=bucket, Key=key) + + result = b"" + while chunk := response["Body"].read(65536): + result += chunk + assert result == data diff --git a/test/test_decryption.py b/test/test_decryption.py index 6d22d439..7f9a51d1 100644 --- a/test/test_decryption.py +++ b/test/test_decryption.py @@ -74,7 +74,7 @@ def _v2_gcm_metadata(): def _response(metadata, body=b"ciphertext"): - return {"Body": BytesIO(body), "Metadata": metadata} + return {"Body": BytesIO(body), "Metadata": metadata, "ContentLength": len(body)} # --------------------------------------------------------------------------- @@ -106,7 +106,9 @@ def test_cbc_object_rejected_when_legacy_disabled(self): ) with pytest.raises(S3EncryptionClientError, match="ALG_AES_256_CBC_IV16_NO_KDF"): - pipeline.decrypt(_response(_v1_cbc_metadata())) + pipeline.decrypt( + _response(_v1_cbc_metadata()), ".instruction", enable_delayed_authentication=False + ) ##= specification/s3-encryption/decryption.md#cbc-decryption ##= type=test @@ -148,8 +150,10 @@ def test_cbc_decryption_succeeds_when_legacy_enabled(self): keyring_return=dec_mats, ) - result = pipeline.decrypt(_response(metadata, ciphertext)) - assert result == plaintext + result = pipeline.decrypt( + _response(metadata, ciphertext), ".instruction", enable_delayed_authentication=False + ) + assert result.read() == plaintext ##= specification/s3-encryption/decryption.md#cbc-decryption ##= type=test @@ -193,7 +197,9 @@ def test_cbc_decryption_fails_with_wrong_key(self): ) with pytest.raises(S3EncryptionClientSecurityError, match="Failed to decrypt CBC content"): - pipeline.decrypt(_response(metadata, ciphertext)) + pipeline.decrypt( + _response(metadata, ciphertext), ".instruction", enable_delayed_authentication=False + ).read() # --------------------------------------------------------------------------- @@ -289,7 +295,11 @@ def test_commitment_verified_before_content_decryption(self): with pytest.raises( S3EncryptionClientSecurityError, match="Key commitment verification failed" ): - pipeline.decrypt(_response(metadata, b"fake-ciphertext")) + pipeline.decrypt( + _response(metadata, b"fake-ciphertext"), + ".instruction", + enable_delayed_authentication=False, + ) # --------------------------------------------------------------------------- @@ -322,7 +332,9 @@ def test_require_decrypt_rejects_non_committing_suite(self): ) with pytest.raises(S3EncryptionClientError, match="cannot decrypt non-key-committing"): - pipeline.decrypt(_response(_v2_gcm_metadata())) + pipeline.decrypt( + _response(_v2_gcm_metadata()), ".instruction", enable_delayed_authentication=False + ) def test_allow_decrypt_accepts_non_committing_suite(self): """REQUIRE_ENCRYPT_ALLOW_DECRYPT MUST allow non-committing algorithm suites.""" @@ -352,8 +364,10 @@ def test_allow_decrypt_accepts_non_committing_suite(self): keyring_return=dec_mats, ) - result = pipeline.decrypt(_response(metadata, ciphertext)) - assert result == plaintext + result = pipeline.decrypt( + _response(metadata, ciphertext), ".instruction", enable_delayed_authentication=False + ) + assert result.read() == plaintext # --------------------------------------------------------------------------- @@ -387,4 +401,6 @@ def test_legacy_cbc_rejected_by_default(self): ) with pytest.raises(S3EncryptionClientError, match="not configured to decrypt"): - pipeline.decrypt(_response(_v1_cbc_metadata())) + pipeline.decrypt( + _response(_v1_cbc_metadata()), ".instruction", enable_delayed_authentication=False + ) diff --git a/test/test_default_algorithm_commitment.py b/test/test_default_algorithm_commitment.py index 1640451a..7e61ed7f 100644 --- a/test/test_default_algorithm_commitment.py +++ b/test/test_default_algorithm_commitment.py @@ -83,6 +83,7 @@ def test_default_encryption_decryptable_with_require_decrypt(self): response = { "Body": BytesIO(ciphertext), "Metadata": metadata, + "ContentLength": len(ciphertext), } # Decrypt with REQUIRE_ENCRYPT_REQUIRE_DECRYPT — this will reject @@ -91,5 +92,7 @@ def test_default_encryption_decryptable_with_require_decrypt(self): cmm, commitment_policy=CommitmentPolicy.REQUIRE_ENCRYPT_REQUIRE_DECRYPT, ) - result = decrypt_pipeline.decrypt(response) - assert result == plaintext + result = decrypt_pipeline.decrypt( + response, ".instruction", enable_delayed_authentication=False + ) + assert result.read() == plaintext diff --git a/test/test_key_commitment.py b/test/test_key_commitment.py index a82fc9fd..a0be12ce 100644 --- a/test/test_key_commitment.py +++ b/test/test_key_commitment.py @@ -66,7 +66,11 @@ def _v2_gcm_response(key, plaintext=b"test data"): plaintext_data_key=key, algorithm_suite=AlgorithmSuite.ALG_AES_256_GCM_IV12_TAG16_NO_KDF, ) - return {"Body": BytesIO(ciphertext), "Metadata": metadata}, dec_mats, plaintext + return ( + {"Body": BytesIO(ciphertext), "Metadata": metadata, "ContentLength": len(ciphertext)}, + dec_mats, + plaintext, + ) def _v3_kc_gcm_response(key, plaintext=b"test data"): @@ -88,7 +92,11 @@ def _v3_kc_gcm_response(key, plaintext=b"test data"): plaintext_data_key=key, algorithm_suite=AlgorithmSuite.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY, ) - return {"Body": BytesIO(ciphertext), "Metadata": metadata}, dec_mats, plaintext + return ( + {"Body": BytesIO(ciphertext), "Metadata": metadata, "ContentLength": len(ciphertext)}, + dec_mats, + plaintext, + ) # --------------------------------------------------------------------------- @@ -111,8 +119,8 @@ def test_forbid_encrypt_allows_non_committing_decrypt(self): commitment_policy=CommitmentPolicy.FORBID_ENCRYPT_ALLOW_DECRYPT, keyring_return=dec_mats, ) - result = pipeline.decrypt(response) - assert result == plaintext + result = pipeline.decrypt(response, ".instruction", enable_delayed_authentication=False) + assert result.read() == plaintext ##= specification/s3-encryption/key-commitment.md#commitment-policy ##= type=test @@ -126,8 +134,8 @@ def test_require_encrypt_allow_decrypt_allows_non_committing_decrypt(self): commitment_policy=CommitmentPolicy.REQUIRE_ENCRYPT_ALLOW_DECRYPT, keyring_return=dec_mats, ) - result = pipeline.decrypt(response) - assert result == plaintext + result = pipeline.decrypt(response, ".instruction", enable_delayed_authentication=False) + assert result.read() == plaintext ##= specification/s3-encryption/key-commitment.md#commitment-policy ##= type=test @@ -142,7 +150,7 @@ def test_require_require_rejects_non_committing_decrypt(self): keyring_return=dec_mats, ) with pytest.raises(S3EncryptionClientError, match="cannot decrypt non-key-committing"): - pipeline.decrypt(response) + pipeline.decrypt(response, ".instruction", enable_delayed_authentication=False) def test_require_require_allows_committing_decrypt(self): """REQUIRE_ENCRYPT_REQUIRE_DECRYPT MUST allow decryption with committing suites.""" @@ -153,5 +161,5 @@ def test_require_require_allows_committing_decrypt(self): commitment_policy=CommitmentPolicy.REQUIRE_ENCRYPT_REQUIRE_DECRYPT, keyring_return=dec_mats, ) - result = pipeline.decrypt(response) - assert result == plaintext + result = pipeline.decrypt(response, ".instruction", enable_delayed_authentication=False) + assert result.read() == plaintext diff --git a/test/test_pipelines.py b/test/test_pipelines.py index 3d32b8cc..e3d34e35 100644 --- a/test/test_pipelines.py +++ b/test/test_pipelines.py @@ -66,7 +66,13 @@ def test_decrypt_v1_from_instruction_file(self): # Should fail when trying to decrypt (proving instruction file was fetched) with pytest.raises(Exception, match="Keyring called"): - pipeline.decrypt(mock_response, bucket="test-bucket", key="test-key") + pipeline.decrypt( + mock_response, + instruction_suffix=".instruction", + enable_delayed_authentication=False, + bucket="test-bucket", + key="test-key", + ) # Verify instruction file was fetched mock_s3_client.get_object.assert_called_once_with( @@ -125,7 +131,13 @@ def test_decrypt_v2_from_instruction_file(self): # Should fail when trying to decrypt (proving instruction file was fetched) with pytest.raises(Exception, match="Keyring called"): - pipeline.decrypt(mock_response, bucket="test-bucket", key="test-key") + pipeline.decrypt( + mock_response, + instruction_suffix=".instruction", + enable_delayed_authentication=False, + bucket="test-bucket", + key="test-key", + ) # Verify instruction file was fetched mock_s3_client.get_object.assert_called_once_with( @@ -199,7 +211,13 @@ def test_decrypt_v3_from_instruction_file(self): with pytest.raises( S3EncryptionClientSecurityError, match="Key commitment verification failed" ): - pipeline.decrypt(mock_response, bucket="test-bucket", key="test-key") + pipeline.decrypt( + mock_response, + instruction_suffix=".instruction", + enable_delayed_authentication=False, + bucket="test-bucket", + key="test-key", + ) # Verify instruction file was fetched mock_s3_client.get_object.assert_called_once_with( @@ -250,9 +268,10 @@ def test_decrypt_with_custom_instruction_file_suffix(self): with pytest.raises(Exception, match="Keyring called"): pipeline.decrypt( mock_response, + instruction_suffix=".custom-suffix", + enable_delayed_authentication=False, bucket="test-bucket", key="test-key", - instruction_suffix=".custom-suffix", ) mock_s3_client.get_object.assert_called_once_with( @@ -290,4 +309,4 @@ def test_decrypt_v3_unsupported_wrap_alg(self): with pytest.raises( S3EncryptionClientError, match="AES/GCM is not a valid key wrapping algorithm" ): - pipeline.decrypt(mock_response) + pipeline.decrypt(mock_response, ".instruction", enable_delayed_authentication=False) diff --git a/test/test_s3_encryption_client_plugin.py b/test/test_s3_encryption_client_plugin.py index bdc48c79..cbc8cd80 100644 --- a/test/test_s3_encryption_client_plugin.py +++ b/test/test_s3_encryption_client_plugin.py @@ -139,3 +139,18 @@ def test_instruction_file_mode_invalid_keys_raises_error(self): # Should raise error with pytest.raises(S3EncryptionClientError, match="Instruction file contains invalid keys"): plugin.on_get_object_after_call(parsed) + + def test_missing_content_length_raises_error(self): + """Test that a missing ContentLength in the S3 response raises an error.""" + mock_keyring = Mock(spec=S3Keyring) + config = S3EncryptionClientConfig(keyring=mock_keyring) + plugin = S3EncryptionClientPlugin(config) + plugin._context.key = "my-object" + + parsed = { + "Body": StreamingBody(io.BytesIO(b"data"), 4), + "Metadata": {}, + } + + with pytest.raises(S3EncryptionClientError, match="missing ContentLength.*Key: my-object"): + plugin.on_get_object_after_call(parsed) diff --git a/test/test_stream.py b/test/test_stream.py new file mode 100644 index 00000000..dc0f2eb9 --- /dev/null +++ b/test/test_stream.py @@ -0,0 +1,655 @@ +# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Unit tests for streaming decryption behavior.""" + +import os +from io import BytesIO +from unittest.mock import Mock + +import pytest +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes +from cryptography.hazmat.primitives.ciphers.aead import AESGCM +from cryptography.hazmat.primitives.padding import PKCS7 + +from s3_encryption.buffered_decrypt import one_shot_decrypt +from s3_encryption.decryptor import AesCbcDecryptor, AesGcmDecryptor +from s3_encryption.exceptions import S3EncryptionClientSecurityError +from s3_encryption.stream import DecryptingStream + + +def _encrypt_gcm(plaintext: bytes): + """Encrypt plaintext with AES-GCM, return (ciphertext_with_tag, key, nonce).""" + key = os.urandom(32) + nonce = os.urandom(12) + ciphertext_with_tag = AESGCM(key).encrypt(nonce, plaintext, None) + return ciphertext_with_tag, key, nonce + + +def _make_gcm_decryptor(key, nonce, content_length): + """Create an AesGcmDecryptor.""" + cipher_decryptor = Cipher(algorithms.AES(key), modes.GCM(nonce)).decryptor() + return AesGcmDecryptor(cipher_decryptor, tag_length=16, content_length=content_length) + + +def _encrypt_cbc(plaintext: bytes): + """Encrypt plaintext with AES-CBC + PKCS7 padding, return (ciphertext, key, iv).""" + key = os.urandom(32) + iv = os.urandom(16) + padder = PKCS7(128).padder() + padded = padder.update(plaintext) + padder.finalize() + encryptor = Cipher(algorithms.AES(key), modes.CBC(iv)).encryptor() + ciphertext = encryptor.update(padded) + encryptor.finalize() + return ciphertext, key, iv + + +def _make_cbc_decryptor(key, iv, content_length): + """Create an AesCbcDecryptor.""" + cipher_decryptor = Cipher(algorithms.AES(key), modes.CBC(iv)).decryptor() + unpadder = PKCS7(128).unpadder() + return AesCbcDecryptor(cipher_decryptor, unpadder, content_length=content_length) + + +def _make_streaming_body(data: bytes): + """Create a mock StreamingBody wrapping data.""" + body = Mock() + stream = BytesIO(data) + body.read = stream.read + body.close = Mock() + body._stream = stream + return body + + +class TestDelayedAuthReleasesBeforeVerification: + """Delayed auth releases plaintext before the GCM tag is verified.""" + + ##= specification/s3-encryption/client.md#enable-delayed-authentication + ##= type=test + ##% When enabled, the S3EC MAY release plaintext from a stream which has not been authenticated. + def test_delayed_auth_releases_plaintext_before_tag_verification(self): + plaintext = os.urandom(4096) + ct, key, nonce = _encrypt_gcm(plaintext) + body = _make_streaming_body(ct) + + stream = DecryptingStream( + body, + _make_gcm_decryptor(key, nonce, len(ct)), + content_length=len(ct), + ) + # read(256) decrypts a partial chunk via cipher.update(), releasing + # plaintext without consuming the full ciphertext stream. The GCM tag + # at the end of the stream has not been reached yet. + chunk = stream.read(256) + + # Plaintext was returned before the stream was fully consumed + assert len(chunk) > 0 + # _finalized is False: the GCM tag has NOT been verified yet + assert not stream._finalized + # Ciphertext remains unread in the underlying stream + assert body._stream.tell() < len(ct) + + # Finish reading the stream and verify full plaintext matches + remaining = stream.read() + assert chunk + remaining == plaintext + + +class TestBufferedWithholdsUntilVerification: + """Buffered mode does not release plaintext until the GCM tag is verified.""" + + ##= specification/s3-encryption/client.md#enable-delayed-authentication + ##= type=test + ##% When disabled the S3EC MUST NOT release plaintext from a stream which has not been authenticated. + def test_buffered_verifies_tag_before_releasing_any_plaintext(self): + plaintext = os.urandom(4096) + ct, key, nonce = _encrypt_gcm(plaintext) + body = _make_streaming_body(ct) + + decryptor = _make_gcm_decryptor(key, nonce, len(ct)) + original_finalize = decryptor.finalize + finalize_called = [] + + def spy_finalize(data): + result = original_finalize(data) + finalize_called.append(True) + return result + + decryptor.finalize = spy_finalize + + stream = one_shot_decrypt(body, decryptor) + + # one_shot_decrypt calls finalize() eagerly — tag is verified + # before any read() call on the returned stream. + assert finalize_called, "finalize (tag verification) must happen before read()" + chunk = stream.read(1) + assert chunk == plaintext[:1] + + +class TestDelayedAuthCBCDecryption: + + def test_roundtrip(self): + plaintext = b"hello world, this is a CBC test!!" + ciphertext, key, iv = _encrypt_cbc(plaintext) + stream = DecryptingStream( + _make_streaming_body(ciphertext), + _make_cbc_decryptor(key, iv, len(ciphertext)), + content_length=len(ciphertext), + ) + assert stream.read() == plaintext + + def test_chunked_read(self): + plaintext = b"A" * 256 + ciphertext, key, iv = _encrypt_cbc(plaintext) + stream = DecryptingStream( + _make_streaming_body(ciphertext), + _make_cbc_decryptor(key, iv, len(ciphertext)), + content_length=len(ciphertext), + ) + result = b"" + while chunk := stream.read(64): + result += chunk + assert result == plaintext + + def test_finalize_called(self): + plaintext = b"finalize me" + ciphertext, key, iv = _encrypt_cbc(plaintext) + stream = DecryptingStream( + _make_streaming_body(ciphertext), + _make_cbc_decryptor(key, iv, len(ciphertext)), + content_length=len(ciphertext), + ) + actual = stream.read() + assert stream._finalized + assert actual == plaintext + + def test_no_trailing_padding_bytes(self): + plaintext = b"short" + ciphertext, key, iv = _encrypt_cbc(plaintext) + stream = DecryptingStream( + _make_streaming_body(ciphertext), + _make_cbc_decryptor(key, iv, len(ciphertext)), + content_length=len(ciphertext), + ) + assert stream.read() == plaintext + + def test_read_after_finalized_returns_empty(self): + plaintext = b"done" + ciphertext, key, iv = _encrypt_cbc(plaintext) + stream = DecryptingStream( + _make_streaming_body(ciphertext), + _make_cbc_decryptor(key, iv, len(ciphertext)), + content_length=len(ciphertext), + ) + stream.read() + assert stream.read() == b"" + + def test_readable_false_after_finalized(self): + plaintext = b"readable" + ciphertext, key, iv = _encrypt_cbc(plaintext) + stream = DecryptingStream( + _make_streaming_body(ciphertext), + _make_cbc_decryptor(key, iv, len(ciphertext)), + content_length=len(ciphertext), + ) + assert stream.readable() + actual = stream.read() + assert not stream.readable() + assert actual == plaintext + + def test_close_delegates_to_body(self): + plaintext = b"close me" + ciphertext, key, iv = _encrypt_cbc(plaintext) + body = _make_streaming_body(ciphertext) + stream = DecryptingStream( + body, + _make_cbc_decryptor(key, iv, len(ciphertext)), + content_length=len(ciphertext), + ) + stream.close() + body.close.assert_called_once() + + def test_enter_returns_self(self): + plaintext = b"ctx" + ciphertext, key, iv = _encrypt_cbc(plaintext) + stream = DecryptingStream( + _make_streaming_body(ciphertext), + _make_cbc_decryptor(key, iv, len(ciphertext)), + content_length=len(ciphertext), + ) + assert stream.__enter__() is stream + + def test_wrong_key_raises_error(self): + plaintext = b"wrong key test!!" + ciphertext, _key, iv = _encrypt_cbc(plaintext) + wrong_key = os.urandom(32) + stream = DecryptingStream( + _make_streaming_body(ciphertext), + _make_cbc_decryptor(wrong_key, iv, len(ciphertext)), + content_length=len(ciphertext), + ) + with pytest.raises(S3EncryptionClientSecurityError, match="Failed to decrypt CBC content"): + stream.read() + + def test_empty_ciphertext(self): + key = os.urandom(32) + iv = os.urandom(16) + stream = DecryptingStream( + _make_streaming_body(b""), + _make_cbc_decryptor(key, iv, 0), + content_length=0, + ) + # Empty stream finalize will fail because CBC expects at least one block + with pytest.raises(S3EncryptionClientSecurityError, match="Failed to decrypt CBC content"): + stream.read() + + +class TestBufferedDecryptingStream: + + def test_full_read(self): + plaintext = os.urandom(1024) + ct, key, nonce = _encrypt_gcm(plaintext) + stream = one_shot_decrypt( + _make_streaming_body(ct), _make_gcm_decryptor(key, nonce, len(ct)) + ) + assert stream.read() == plaintext + + def test_partial_reads(self): + plaintext = os.urandom(512) + ct, key, nonce = _encrypt_gcm(plaintext) + stream = one_shot_decrypt( + _make_streaming_body(ct), + _make_gcm_decryptor(key, nonce, len(ct)), + ) + result = b"" + while chunk := stream.read(100): + result += chunk + assert result == plaintext + + def test_read_triggers_full_decrypt(self): + plaintext = os.urandom(256) + ct, key, nonce = _encrypt_gcm(plaintext) + body = _make_streaming_body(ct) + decryptor = _make_gcm_decryptor(key, nonce, len(ct)) + finalize_called = [] + original_finalize = decryptor.finalize + decryptor.finalize = lambda data: (finalize_called.append(True), original_finalize(data))[1] + + stream = one_shot_decrypt(body, decryptor) + # one_shot_decrypt eagerly decrypts — finalize called at construction + assert finalize_called + # Entire ciphertext consumed from the body + assert body._stream.tell() == len(ct) + assert stream.read(1) == plaintext[:1] + + def test_tell(self): + plaintext = os.urandom(200) + ct, key, nonce = _encrypt_gcm(plaintext) + stream = one_shot_decrypt( + _make_streaming_body(ct), + _make_gcm_decryptor(key, nonce, len(ct)), + ) + stream.read(50) + assert stream.tell() == 50 + + def test_readable(self): + plaintext = b"readable test" + ct, key, nonce = _encrypt_gcm(plaintext) + stream = one_shot_decrypt( + _make_streaming_body(ct), + _make_gcm_decryptor(key, nonce, len(ct)), + ) + assert stream.readable() + + def test_readinto(self): + """Asserts that readinto is implemented.""" + plaintext = os.urandom(64) + ct, key, nonce = _encrypt_gcm(plaintext) + stream = one_shot_decrypt( + _make_streaming_body(ct), + _make_gcm_decryptor(key, nonce, len(ct)), + ) + buf = bytearray(64) + n = stream.readinto(buf) + assert n == 64 + assert bytes(buf) == plaintext + + def test_enter_returns_stream(self): + plaintext = b"enter" + ct, key, nonce = _encrypt_gcm(plaintext) + stream = one_shot_decrypt( + _make_streaming_body(ct), + _make_gcm_decryptor(key, nonce, len(ct)), + ) + with stream as s: + assert s.read() == plaintext + + def test_close(self): + """Asserts that close does not raise.""" + plaintext = b"close" + ct, key, nonce = _encrypt_gcm(plaintext) + body = _make_streaming_body(ct) + stream = one_shot_decrypt( + body, + _make_gcm_decryptor(key, nonce, len(ct)), + ) + stream.close() # should not raise + + def test_close_without_close_attr(self): + """Asserts that close handles bodies without close.""" + plaintext = b"no close" + ct, key, nonce = _encrypt_gcm(plaintext) + body = Mock() + del body.close + body.read = BytesIO(ct).read + stream = one_shot_decrypt( + body, + _make_gcm_decryptor(key, nonce, len(ct)), + ) + stream.close() # should not raise + + def test_wrong_key_raises_error(self): + plaintext = b"wrong key" + ct, _key, nonce = _encrypt_gcm(plaintext) + wrong_key = os.urandom(32) + with pytest.raises(S3EncryptionClientSecurityError, match="Failed to decrypt"): + one_shot_decrypt( + _make_streaming_body(ct), + _make_gcm_decryptor(wrong_key, nonce, len(ct)), + ) + + def test_tampered_ciphertext_raises_error(self): + plaintext = b"tamper test" + ct, key, nonce = _encrypt_gcm(plaintext) + tampered = bytearray(ct) + tampered[0] ^= 0xFF + with pytest.raises(S3EncryptionClientSecurityError, match="Failed to decrypt"): + one_shot_decrypt( + _make_streaming_body(bytes(tampered)), + _make_gcm_decryptor(key, nonce, len(ct)), + ) + + def test_idempotent_decrypt(self): + plaintext = os.urandom(128) + ct, key, nonce = _encrypt_gcm(plaintext) + stream = one_shot_decrypt( + _make_streaming_body(ct), + _make_gcm_decryptor(key, nonce, len(ct)), + ) + first = stream.read(63) + second = stream.read(65) + assert first + second == plaintext + + +class TestDelayedAuthGCMDecryption: + + def test_full_read(self): + plaintext = os.urandom(1024) + ct, key, nonce = _encrypt_gcm(plaintext) + stream = DecryptingStream( + _make_streaming_body(ct), + _make_gcm_decryptor(key, nonce, len(ct)), + content_length=len(ct), + ) + assert stream.read() == plaintext + + def test_chunked_read(self): + plaintext = os.urandom(512) + ct, key, nonce = _encrypt_gcm(plaintext) + stream = DecryptingStream( + _make_streaming_body(ct), + _make_gcm_decryptor(key, nonce, len(ct)), + content_length=len(ct), + ) + result = b"" + while chunk := stream.read(64): + result += chunk + assert result == plaintext + + def test_read_after_finalized_returns_empty(self): + plaintext = os.urandom(128) + ct, key, nonce = _encrypt_gcm(plaintext) + stream = DecryptingStream( + _make_streaming_body(ct), + _make_gcm_decryptor(key, nonce, len(ct)), + content_length=len(ct), + ) + actual = stream.read() + assert stream._finalized + assert stream.read() == b"" + assert actual == plaintext + + def test_readable_false_after_finalized(self): + plaintext = b"readable" + ct, key, nonce = _encrypt_gcm(plaintext) + stream = DecryptingStream( + _make_streaming_body(ct), + _make_gcm_decryptor(key, nonce, len(ct)), + content_length=len(ct), + ) + assert stream.readable() + stream.read() + assert not stream.readable() + + def test_close_delegates(self): + plaintext = b"close" + ct, key, nonce = _encrypt_gcm(plaintext) + body = _make_streaming_body(ct) + stream = DecryptingStream( + body, + _make_gcm_decryptor(key, nonce, len(ct)), + content_length=len(ct), + ) + stream.close() + body.close.assert_called_once() + + def test_enter_returns_self(self): + plaintext = b"ctx" + ct, key, nonce = _encrypt_gcm(plaintext) + stream = DecryptingStream( + _make_streaming_body(ct), + _make_gcm_decryptor(key, nonce, len(ct)), + content_length=len(ct), + ) + assert stream.__enter__() is stream + + def test_wrong_key_raises_error(self): + plaintext = b"wrong key" + ct, _key, nonce = _encrypt_gcm(plaintext) + wrong_key = os.urandom(32) + stream = DecryptingStream( + _make_streaming_body(ct), + _make_gcm_decryptor(wrong_key, nonce, len(ct)), + content_length=len(ct), + ) + with pytest.raises(S3EncryptionClientSecurityError, match="Failed to decrypt"): + stream.read() + + def test_tampered_tag_raises_error(self): + plaintext = b"tamper tag" + ct, key, nonce = _encrypt_gcm(plaintext) + tampered = bytearray(ct) + tampered[-1] ^= 0xFF # flip last byte (part of tag) + stream = DecryptingStream( + _make_streaming_body(bytes(tampered)), + _make_gcm_decryptor(key, nonce, len(ct)), + content_length=len(ct), + ) + with pytest.raises(S3EncryptionClientSecurityError, match="Failed to decrypt"): + stream.read() + + def test_small_data_less_than_tag_length(self): + """Data exactly equal to tag length — only tag, no ciphertext.""" + plaintext = b"" + ct, key, nonce = _encrypt_gcm(plaintext) + # For empty plaintext, ct is just the 16-byte tag + assert len(ct) == 16 + stream = DecryptingStream( + _make_streaming_body(ct), + _make_gcm_decryptor(key, nonce, len(ct)), + content_length=len(ct), + ) + assert stream.read() == b"" + + def test_large_data(self): + plaintext = os.urandom(1024 * 1024) # 1 MB + ct, key, nonce = _encrypt_gcm(plaintext) + stream = DecryptingStream( + _make_streaming_body(ct), + _make_gcm_decryptor(key, nonce, len(ct)), + content_length=len(ct), + ) + result = b"" + while chunk := stream.read(65536): + result += chunk + assert result == plaintext + + +# --------------------------------------------------------------------------- +# Parameterized edge-case plaintext lengths +# --------------------------------------------------------------------------- +# Lengths chosen around AES block size (16) and two-block (32) boundaries, +# plus zero and one byte, to exercise padding, tag-splitting, and empty-data paths. +EDGE_CASE_LENGTHS = [0, 1, 8, 15, 16, 17, 31, 32, 33, 47, 48, 49, 300] + + +class TestEdgeCasePlaintextLengths: + + @pytest.mark.parametrize("length", EDGE_CASE_LENGTHS) + def test_buffered_gcm(self, length): + plaintext = os.urandom(length) + ct, key, nonce = _encrypt_gcm(plaintext) + stream = one_shot_decrypt( + _make_streaming_body(ct), + _make_gcm_decryptor(key, nonce, len(ct)), + ) + assert stream.read() == plaintext + + @pytest.mark.parametrize("length", EDGE_CASE_LENGTHS) + def test_delayed_auth_gcm(self, length): + plaintext = os.urandom(length) + ct, key, nonce = _encrypt_gcm(plaintext) + stream = DecryptingStream( + _make_streaming_body(ct), + _make_gcm_decryptor(key, nonce, len(ct)), + content_length=len(ct), + ) + result = b"" + while chunk := stream.read(7): + result += chunk + assert result == plaintext + + @pytest.mark.parametrize("length", EDGE_CASE_LENGTHS) + def test_delayed_auth_cbc(self, length): + plaintext = os.urandom(length) + ciphertext, key, iv = _encrypt_cbc(plaintext) + stream = DecryptingStream( + _make_streaming_body(ciphertext), + _make_cbc_decryptor(key, iv, len(ciphertext)), + content_length=len(ciphertext), + ) + result = b"" + while chunk := stream.read(7): + result += chunk + assert result == plaintext + + +class TestDecryptingStreamIterators: + """Tests for iter_chunks, iter_lines, __iter__, __next__, readinto, and readlines.""" + + def _make_gcm_stream(self, plaintext): + ct, key, nonce = _encrypt_gcm(plaintext) + return DecryptingStream( + _make_streaming_body(ct), + _make_gcm_decryptor(key, nonce, len(ct)), + content_length=len(ct), + ) + + @pytest.mark.parametrize("chunk_size", EDGE_CASE_LENGTHS[1:]) + def test_iter_chunks(self, chunk_size): + plaintext = os.urandom(300) + stream = self._make_gcm_stream(plaintext) + result = b"" + for chunk in stream.iter_chunks(chunk_size): + assert ( + len(chunk) <= chunk_size or not result + ) # first chunk may vary due to GCM buffering + result += chunk + assert result == plaintext + + def test_iter_chunks_default_size(self): + plaintext = os.urandom(2048) + stream = self._make_gcm_stream(plaintext) + result = b"".join(stream.iter_chunks()) + assert result == plaintext + + def test_iter_chunks_empty(self): + stream = self._make_gcm_stream(b"") + assert list(stream.iter_chunks()) == [] + + def test_iter(self): + plaintext = os.urandom(2048) + stream = self._make_gcm_stream(plaintext) + result = b"".join(stream) + assert result == plaintext + + def test_next(self): + plaintext = os.urandom(100) + stream = self._make_gcm_stream(plaintext) + first = next(stream) + assert len(first) > 0 + # drain the rest + rest = b"" + for chunk in stream: + rest += chunk + assert first + rest == plaintext + + def test_next_raises_stop_iteration(self): + stream = self._make_gcm_stream(b"") + with pytest.raises(StopIteration): + next(stream) + + def test_iter_lines(self): + plaintext = b"line1\nline2\nline3\n" + stream = self._make_gcm_stream(plaintext) + lines = list(stream.iter_lines()) + assert lines == [b"line1", b"line2", b"line3"] + + def test_iter_lines_keepends(self): + plaintext = b"line1\nline2\nline3\n" + stream = self._make_gcm_stream(plaintext) + lines = list(stream.iter_lines(keepends=True)) + assert lines == [b"line1\n", b"line2\n", b"line3\n"] + + def test_iter_lines_no_trailing_newline(self): + plaintext = b"first\nsecond" + stream = self._make_gcm_stream(plaintext) + lines = list(stream.iter_lines()) + assert lines == [b"first", b"second"] + + def test_iter_lines_empty(self): + stream = self._make_gcm_stream(b"") + assert list(stream.iter_lines()) == [] + + def test_readinto(self): + plaintext = os.urandom(64) + stream = self._make_gcm_stream(plaintext) + buf = bytearray(64) + n = stream.readinto(buf) + assert bytes(buf[:n]) == plaintext[:n] + + def test_readinto_partial(self): + plaintext = os.urandom(200) + stream = self._make_gcm_stream(plaintext) + buf = bytearray(50) + result = b"" + while n := stream.readinto(buf): + result += bytes(buf[:n]) + assert result == plaintext + + def test_readlines(self): + plaintext = b"aaa\nbbb\nccc\n" + stream = self._make_gcm_stream(plaintext) + assert stream.readlines() == [b"aaa\n", b"bbb\n", b"ccc\n"] + + def test_readlines_no_trailing_newline(self): + plaintext = b"aaa\nbbb" + stream = self._make_gcm_stream(plaintext) + assert stream.readlines() == [b"aaa\n", b"bbb"]