From aa38719ae7bc9d32ae3bd036ab2fb9b40713cadb Mon Sep 17 00:00:00 2001 From: texastony <5892063+texastony@users.noreply.github.com> Date: Fri, 6 Mar 2026 11:13:22 -0600 Subject: [PATCH 01/31] feat(decryption): add buffered streaming decryption for AES-GCM Introduce BufferedDecryptingStream that wraps the S3 StreamingBody and decrypts lazily on first read. No plaintext is released until the entire ciphertext is read and the GCM auth tag is verified, matching the Java S3EC's BufferedCipherSubscriber behavior. - Add stream.py with BufferedDecryptingStream (read, iter_chunks, close) - Pipeline returns BufferedDecryptingStream instead of decrypted bytes - Event handler passes stream directly as parsed["Body"] --- src/s3_encryption/__init__.py | 8 +--- src/s3_encryption/pipelines.py | 14 +++--- src/s3_encryption/stream.py | 80 ++++++++++++++++++++++++++++++++++ 3 files changed, 90 insertions(+), 12 deletions(-) create mode 100644 src/s3_encryption/stream.py diff --git a/src/s3_encryption/__init__.py b/src/s3_encryption/__init__.py index a3558195..f5b7f684 100644 --- a/src/s3_encryption/__init__.py +++ b/src/s3_encryption/__init__.py @@ -147,12 +147,8 @@ def on_get_object_after_call(self, parsed, **kwargs): instruction_suffix=self.config.instruction_file_suffix, ) - # Create a new streaming body with the decrypted data - stream = io.BytesIO(decrypted_data) - streaming_body = StreamingBody(stream, len(decrypted_data)) - - # Replace body with decrypted data - parsed["Body"] = streaming_body + # Replace body with decrypting stream + parsed["Body"] = decrypted_data def process_instruction_file(self, parsed): """Process instruction file in plaintext mode. diff --git a/src/s3_encryption/pipelines.py b/src/s3_encryption/pipelines.py index 02a5a9c9..dfaea539 100644 --- a/src/s3_encryption/pipelines.py +++ b/src/s3_encryption/pipelines.py @@ -18,6 +18,7 @@ from .materials.encrypted_data_key import EncryptedDataKey from .materials.materials import DecryptionMaterials, EncryptionMaterials from .metadata import ObjectMetadata +from .stream import BufferedDecryptingStream @define @@ -112,11 +113,10 @@ def decrypt( instruction_suffix(str, optional): suffix for instruction file; defaults to ".instruction". Returns: - bytes: The decrypted data + BufferedDecryptingStream: A stream that decrypts data lazily on first read. """ # Convert the metadata dictionary to an ObjectMetadata instance - # TODO: Stream + Buffered Decryption - encrypted_data = response.get("Body").read() + streaming_body = response.get("Body") encryption_metadata = response.get("Metadata", {}) metadata = ObjectMetadata.from_dict(encryption_metadata) @@ -171,9 +171,11 @@ def decrypt( ##% the S3EC MUST throw an error which details that client was ##% not configured to decrypt objects with ALG_AES_256_CBC_IV16_NO_KDF. - # Perform decryption - aesgcm = AESGCM(dec_materials.plaintext_data_key) - return aesgcm.decrypt(nonce=dec_materials.iv, data=encrypted_data, associated_data=None) + # Return a buffered decrypting stream — no plaintext is released + # until the entire ciphertext is read and the GCM tag is verified. + return BufferedDecryptingStream( + streaming_body, dec_materials.plaintext_data_key, dec_materials.iv + ) def _decrypt_v2(self, metadata, encryption_context) -> DecryptionMaterials: """Prepare V2 decryption materials.""" diff --git a/src/s3_encryption/stream.py b/src/s3_encryption/stream.py new file mode 100644 index 00000000..bdbff82c --- /dev/null +++ b/src/s3_encryption/stream.py @@ -0,0 +1,80 @@ +# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Streaming decryption support for S3 Encryption Client.""" + +import io + +from cryptography.hazmat.primitives.ciphers.aead import AESGCM + +from .exceptions import S3EncryptionClientError + + +class BufferedDecryptingStream: + """A stream that buffers all ciphertext, verifies the GCM auth tag, then releases plaintext. + + This matches the Java S3EC's BufferedCipherSubscriber behavior: no plaintext + is released until the entire ciphertext has been read and authenticated. + + Implements the same read interface as botocore's StreamingBody so it can be + used as a drop-in replacement for parsed["Body"]. + """ + + def __init__(self, streaming_body, key, nonce): + """Initialize the buffered decrypting stream. + + Args: + streaming_body: The original StreamingBody containing ciphertext. + key: The plaintext data key (bytes). + nonce: The IV/nonce for AES-GCM decryption (bytes). + """ + self._body = streaming_body + self._key = key + self._nonce = nonce + self._plaintext = None + + def _decrypt(self): + """Read all ciphertext, decrypt and verify, cache plaintext.""" + if self._plaintext is not None: + return + try: + ciphertext = self._body.read() + aesgcm = AESGCM(self._key) + decrypted = aesgcm.decrypt(nonce=self._nonce, data=ciphertext, associated_data=None) + except Exception as e: + raise S3EncryptionClientError(f"Failed to decrypt object: {e}") from e + self._plaintext = io.BytesIO(decrypted) + + def read(self, amt=None): + """Read decrypted data. + + Args: + amt: Number of bytes to read. If None, reads all remaining data. + + Returns: + bytes: Decrypted plaintext bytes. + """ + self._decrypt() + if amt is None: + return self._plaintext.read() + return self._plaintext.read(amt) + + def iter_chunks(self, chunk_size=1024): + """Iterate over decrypted data in chunks. + + Args: + chunk_size: Size of each chunk in bytes. + + Yields: + bytes: Chunks of decrypted plaintext. + """ + self._decrypt() + while True: + chunk = self._plaintext.read(chunk_size) + if not chunk: + break + yield chunk + + def close(self): + """Close the underlying stream.""" + if hasattr(self._body, "close"): + self._body.close() From 24eae317157fc8af8b5dfd364f3f94907e1e76ed Mon Sep 17 00:00:00 2001 From: texastony <5892063+texastony@users.noreply.github.com> Date: Mon, 9 Mar 2026 11:43:10 -0700 Subject: [PATCH 02/31] feat(config): add enable_delayed_authentication option with duvet citations Add enable_delayed_authentication field to S3EncryptionClientConfig, defaulting to False. Includes duvet specification citations from client.md#enable-delayed-authentication. --- src/s3_encryption/__init__.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/s3_encryption/__init__.py b/src/s3_encryption/__init__.py index f5b7f684..b1d98683 100644 --- a/src/s3_encryption/__init__.py +++ b/src/s3_encryption/__init__.py @@ -37,6 +37,15 @@ class S3EncryptionClientConfig: ##% as its associated object suffixed with ".instruction". instruction_file_suffix: str = field(default=".instruction") + ##= specification/s3-encryption/client.md#enable-delayed-authentication + ##= type=implementation + ##% The S3EC MUST support the option to enable or disable Delayed Authentication mode. + + ##= specification/s3-encryption/client.md#enable-delayed-authentication + ##= type=implication + ##% Delayed Authentication mode MUST be set to false by default. + enable_delayed_authentication: bool = field(default=False) + @cmm.default def _default_cmm_for_keyring(self): return DefaultCryptoMaterialsManager(self.keyring) From 44aac6c3383f5d7367be2c59e2433a5b4bc49094 Mon Sep 17 00:00:00 2001 From: texastony <5892063+texastony@users.noreply.github.com> Date: Mon, 9 Mar 2026 11:43:37 -0700 Subject: [PATCH 03/31] feat(stream): add DelayedAuthDecryptingStream for unauthenticated streaming Add DelayedAuthDecryptingStream that releases plaintext incrementally via AES-GCM cipher.update() before tag verification. The GCM tag (last 16 bytes) is held back and verified on stream exhaustion via finalize_with_tag(). Matches Java S3EC CipherSubscriber pattern. --- src/s3_encryption/stream.py | 79 +++++++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/src/s3_encryption/stream.py b/src/s3_encryption/stream.py index bdbff82c..915d5bb2 100644 --- a/src/s3_encryption/stream.py +++ b/src/s3_encryption/stream.py @@ -4,10 +4,13 @@ import io +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.primitives.ciphers.aead import AESGCM from .exceptions import S3EncryptionClientError +GCM_TAG_LENGTH = 16 + class BufferedDecryptingStream: """A stream that buffers all ciphertext, verifies the GCM auth tag, then releases plaintext. @@ -78,3 +81,79 @@ def close(self): """Close the underlying stream.""" if hasattr(self._body, "close"): self._body.close() + + +##= specification/s3-encryption/client.md#enable-delayed-authentication +##= type=implementation +##% When enabled, the S3EC MAY release plaintext from a stream which has not been authenticated. +class DelayedAuthDecryptingStream: + """A stream that releases plaintext before GCM tag verification. + + Matches the Java S3EC's CipherSubscriber: plaintext is released incrementally + via cipher.update(), and the GCM tag is only verified when the stream is fully + consumed. Data read before finalization is unauthenticated. + """ + + def __init__(self, streaming_body, key, nonce): + """Initialize the delayed-auth decrypting stream. + + Args: + streaming_body: The original StreamingBody containing ciphertext. + key: The plaintext data key (bytes). + nonce: The IV/nonce for AES-GCM decryption (bytes). + """ + self._body = streaming_body + self._decryptor = Cipher(algorithms.AES(key), modes.GCM(nonce)).decryptor() + self._tag_buffer = b"" + self._finalized = False + + def read(self, amt=None): + """Read and decrypt data, releasing plaintext before authentication. + + The last 16 bytes of ciphertext are the GCM tag. We hold back a + rolling buffer of 16 bytes so the tag is never passed to update(). + """ + if self._finalized: + return b"" + + raw = self._body.read(amt) + if not raw and not self._tag_buffer: + return b"" + + data = self._tag_buffer + raw + if len(data) <= GCM_TAG_LENGTH: + if raw: + self._tag_buffer = data + return b"" + return self._finalize(tag=data) + + self._tag_buffer = data[-GCM_TAG_LENGTH:] + ciphertext = data[:-GCM_TAG_LENGTH] + plaintext = self._decryptor.update(ciphertext) + + # Check if underlying stream is exhausted + peek = self._body.read(1) + if peek: + self._tag_buffer = self._tag_buffer + peek + if len(self._tag_buffer) > GCM_TAG_LENGTH: + extra_ct = self._tag_buffer[:-GCM_TAG_LENGTH] + self._tag_buffer = self._tag_buffer[-GCM_TAG_LENGTH:] + plaintext += self._decryptor.update(extra_ct) + else: + plaintext += self._finalize(tag=self._tag_buffer) + + return plaintext + + def _finalize(self, tag): + """Verify the GCM tag and finalize decryption.""" + self._finalized = True + self._tag_buffer = b"" + try: + return self._decryptor.finalize_with_tag(tag) + except Exception as e: + raise S3EncryptionClientError(f"GCM tag verification failed: {e}") from e + + def close(self): + """Close the underlying stream.""" + if hasattr(self._body, "close"): + self._body.close() From f5eae606db61f37033885ce0f3abfd29e0ed0702 Mon Sep 17 00:00:00 2001 From: texastony <5892063+texastony@users.noreply.github.com> Date: Mon, 9 Mar 2026 11:44:01 -0700 Subject: [PATCH 04/31] feat(decrypt): wire delayed authentication through pipeline and event handler MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Pipeline.decrypt() accepts enable_delayed_authentication param and returns DelayedAuthDecryptingStream when True, BufferedDecryptingStream when False. Raises error if param is None (must be explicitly set). - Event handler passes config flag to pipeline. - Remove duplicated defaults from pipeline params — config is single source of truth. - Update unit tests to pass instruction_suffix explicitly. --- src/s3_encryption/__init__.py | 1 + src/s3_encryption/pipelines.py | 15 +++++++++++++-- test/test_pipelines.py | 6 +++--- 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/src/s3_encryption/__init__.py b/src/s3_encryption/__init__.py index b1d98683..08c801a7 100644 --- a/src/s3_encryption/__init__.py +++ b/src/s3_encryption/__init__.py @@ -154,6 +154,7 @@ def on_get_object_after_call(self, parsed, **kwargs): bucket=getattr(self._context, "bucket", None), key=getattr(self._context, "key", None), instruction_suffix=self.config.instruction_file_suffix, + enable_delayed_authentication=self.config.enable_delayed_authentication, ) # Replace body with decrypting stream diff --git a/src/s3_encryption/pipelines.py b/src/s3_encryption/pipelines.py index dfaea539..96dceaf4 100644 --- a/src/s3_encryption/pipelines.py +++ b/src/s3_encryption/pipelines.py @@ -18,7 +18,7 @@ from .materials.encrypted_data_key import EncryptedDataKey from .materials.materials import DecryptionMaterials, EncryptionMaterials from .metadata import ObjectMetadata -from .stream import BufferedDecryptingStream +from .stream import BufferedDecryptingStream, DelayedAuthDecryptingStream @define @@ -101,7 +101,8 @@ def decrypt( encryption_context=None, bucket=None, key=None, - instruction_suffix=".instruction", + instruction_suffix=None, + enable_delayed_authentication=None, ): """Decrypt the data after it is retrieved from S3. @@ -111,6 +112,7 @@ def decrypt( bucket (str, optional): S3 bucket name (required for instruction file) key (str, optional): S3 object key (required for instruction file) instruction_suffix(str, optional): suffix for instruction file; defaults to ".instruction". + enable_delayed_authentication (bool): If True, release plaintext before GCM tag verification. Returns: BufferedDecryptingStream: A stream that decrypts data lazily on first read. @@ -173,6 +175,15 @@ def decrypt( # Return a buffered decrypting stream — no plaintext is released # until the entire ciphertext is read and the GCM tag is verified. + ##= specification/s3-encryption/client.md#enable-delayed-authentication + ##= type=implementation + ##% When disabled the S3EC MUST NOT release plaintext from a stream which has not been authenticated. + if enable_delayed_authentication is None: + raise S3EncryptionClientError("enable_delayed_authentication must be explicitly set") + if enable_delayed_authentication: + return DelayedAuthDecryptingStream( + streaming_body, dec_materials.plaintext_data_key, dec_materials.iv + ) return BufferedDecryptingStream( streaming_body, dec_materials.plaintext_data_key, dec_materials.iv ) diff --git a/test/test_pipelines.py b/test/test_pipelines.py index 9f40cd5c..92bdff50 100644 --- a/test/test_pipelines.py +++ b/test/test_pipelines.py @@ -60,7 +60,7 @@ def test_decrypt_v1_from_instruction_file(self): # Should fail when trying to decrypt (proving instruction file was fetched) with pytest.raises(Exception, match="Keyring called"): - pipeline.decrypt(mock_response, bucket="test-bucket", key="test-key") + pipeline.decrypt(mock_response, bucket="test-bucket", key="test-key", instruction_suffix=".instruction") # Verify instruction file was fetched mock_s3_client.get_object.assert_called_once_with( @@ -115,7 +115,7 @@ def test_decrypt_v2_from_instruction_file(self): # Should fail when trying to decrypt (proving instruction file was fetched) with pytest.raises(Exception, match="Keyring called"): - pipeline.decrypt(mock_response, bucket="test-bucket", key="test-key") + pipeline.decrypt(mock_response, bucket="test-bucket", key="test-key", instruction_suffix=".instruction") # Verify instruction file was fetched mock_s3_client.get_object.assert_called_once_with( @@ -184,7 +184,7 @@ def test_decrypt_v3_from_instruction_file(self): # This should fail with NotImplementedError since V3 decryption isn't implemented yet with pytest.raises(NotImplementedError, match="V3 decryption not yet implemented"): - pipeline.decrypt(mock_response, bucket="test-bucket", key="test-key") + pipeline.decrypt(mock_response, bucket="test-bucket", key="test-key", instruction_suffix=".instruction") # Verify instruction file was fetched mock_s3_client.get_object.assert_called_once_with( From f348251600cd3e7c3bd0b41fd593e2954e8118f4 Mon Sep 17 00:00:00 2001 From: texastony <5892063+texastony@users.noreply.github.com> Date: Mon, 9 Mar 2026 11:44:19 -0700 Subject: [PATCH 05/31] test(integration): add parametrized roundtrip tests for buffered and delayed auth Replace individual delayed-auth tests with pytest.mark.parametrize covering both buffered and delayed-auth modes across ascii, empty, unicode, utf-8, latin-1, binary data, and no-body cases. --- test/integration/test_i_s3_encryption.py | 49 ++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/test/integration/test_i_s3_encryption.py b/test/integration/test_i_s3_encryption.py index 616f8da4..073c2fb0 100644 --- a/test/integration/test_i_s3_encryption.py +++ b/test/integration/test_i_s3_encryption.py @@ -470,3 +470,52 @@ def test_encryption_context_missing_on_decrypt(): print(f"Unexpected error type: {type(e).__name__}") print(f"Error message: {str(e)}") raise + + +@pytest.mark.parametrize("delayed_auth", [False, True], ids=["buffered", "delayed-auth"]) +@pytest.mark.parametrize( + "key_prefix, data, encoding", + [ + ("simple-rt", "test input for simple v3 round trip", "utf-8"), + ("empty-string-rt", "", "utf-8"), + ("unicode-rt", "Unicode test: 你好, こんにちは, 안녕하세요, Привет, مرحبا, ¡Hola!, ½⅓¼⅕⅙⅐⅛⅑⅒⅔⅖⅗⅘⅙⅚⅜⅝⅞", "utf-8"), + ("utf8-rt", "UTF-8 encoding test: 你好, こんにちは, 안녕하세요, Привет, مرحبا, ¡Hola!", "utf-8"), + ("latin1-rt", "Latin-1 encoding test: éèêë àâäãåá çñ ¿¡ øæå ØÆÅÉÈÊËÀÂÄÃÅÁ", "latin-1"), + ("binary-rt", bytes(range(256)), None), + ], + ids=["ascii", "empty", "unicode", "utf8", "latin1", "binary"], +) +def test_roundtrip(delayed_auth, key_prefix, data, encoding): + key = f"{key_prefix}-{'da' if delayed_auth else 'buf'}-" + key += datetime.now().strftime("%Y-%m-%d-%H:%M:%S") + + body = data.encode(encoding) if encoding and isinstance(data, str) else data + + kms_client = boto3.client("kms", region_name=region) + keyring = KmsKeyring(kms_client, kms_key_id) + wrapped_client = boto3.client("s3") + config = S3EncryptionClientConfig(keyring, enable_delayed_authentication=delayed_auth) + s3ec = S3EncryptionClient(wrapped_client, config) + s3ec.put_object(Bucket=bucket, Key=key, Body=body) + response = s3ec.get_object(Bucket=bucket, Key=key) + output = response["Body"].read() + + if encoding: + assert output.decode(encoding) == data + else: + assert output == data + + +@pytest.mark.parametrize("delayed_auth", [False, True], ids=["buffered", "delayed-auth"]) +def test_no_body_roundtrip(delayed_auth): + key = f"no-body-rt-{'da' if delayed_auth else 'buf'}-" + key += datetime.now().strftime("%Y-%m-%d-%H:%M:%S") + + kms_client = boto3.client("kms", region_name=region) + keyring = KmsKeyring(kms_client, kms_key_id) + wrapped_client = boto3.client("s3") + config = S3EncryptionClientConfig(keyring, enable_delayed_authentication=delayed_auth) + s3ec = S3EncryptionClient(wrapped_client, config) + s3ec.put_object(Bucket=bucket, Key=key) + response = s3ec.get_object(Bucket=bucket, Key=key) + assert response["Body"].read() == b"" From 622b5a1d8d5e242fd03e60f632a510c8ab133ff7 Mon Sep 17 00:00:00 2001 From: texastony <5892063+texastony@users.noreply.github.com> Date: Mon, 9 Mar 2026 12:08:39 -0700 Subject: [PATCH 06/31] test(duvet): add streaming decryption tests for delayed authentication citations Add unit tests that verify the behavioral contract of both stream modes: - DelayedAuthDecryptingStream releases plaintext before GCM tag verification - BufferedDecryptingStream withholds all plaintext until tag is verified Includes duvet type=test citations for enable-delayed-authentication spec. --- test/integration/test_i_s3_encryption.py | 15 ++++- test/test_pipelines.py | 21 ++++++- test/test_stream.py | 76 ++++++++++++++++++++++++ 3 files changed, 107 insertions(+), 5 deletions(-) create mode 100644 test/test_stream.py diff --git a/test/integration/test_i_s3_encryption.py b/test/integration/test_i_s3_encryption.py index 073c2fb0..441cdf7c 100644 --- a/test/integration/test_i_s3_encryption.py +++ b/test/integration/test_i_s3_encryption.py @@ -472,14 +472,25 @@ def test_encryption_context_missing_on_decrypt(): raise +##= specification/s3-encryption/client.md#enable-delayed-authentication +##= type=test +##% The S3EC MUST support the option to enable or disable Delayed Authentication mode. @pytest.mark.parametrize("delayed_auth", [False, True], ids=["buffered", "delayed-auth"]) @pytest.mark.parametrize( "key_prefix, data, encoding", [ ("simple-rt", "test input for simple v3 round trip", "utf-8"), ("empty-string-rt", "", "utf-8"), - ("unicode-rt", "Unicode test: 你好, こんにちは, 안녕하세요, Привет, مرحبا, ¡Hola!, ½⅓¼⅕⅙⅐⅛⅑⅒⅔⅖⅗⅘⅙⅚⅜⅝⅞", "utf-8"), - ("utf8-rt", "UTF-8 encoding test: 你好, こんにちは, 안녕하세요, Привет, مرحبا, ¡Hola!", "utf-8"), + ( + "unicode-rt", + "Unicode test: 你好, こんにちは, 안녕하세요, Привет, مرحبا, ¡Hola!, ½⅓¼⅕⅙⅐⅛⅑⅒⅔⅖⅗⅘⅙⅚⅜⅝⅞", + "utf-8", + ), + ( + "utf8-rt", + "UTF-8 encoding test: 你好, こんにちは, 안녕하세요, Привет, مرحبا, ¡Hola!", + "utf-8", + ), ("latin1-rt", "Latin-1 encoding test: éèêë àâäãåá çñ ¿¡ øæå ØÆÅÉÈÊËÀÂÄÃÅÁ", "latin-1"), ("binary-rt", bytes(range(256)), None), ], diff --git a/test/test_pipelines.py b/test/test_pipelines.py index 92bdff50..3865913f 100644 --- a/test/test_pipelines.py +++ b/test/test_pipelines.py @@ -60,7 +60,12 @@ def test_decrypt_v1_from_instruction_file(self): # Should fail when trying to decrypt (proving instruction file was fetched) with pytest.raises(Exception, match="Keyring called"): - pipeline.decrypt(mock_response, bucket="test-bucket", key="test-key", instruction_suffix=".instruction") + pipeline.decrypt( + mock_response, + bucket="test-bucket", + key="test-key", + instruction_suffix=".instruction", + ) # Verify instruction file was fetched mock_s3_client.get_object.assert_called_once_with( @@ -115,7 +120,12 @@ def test_decrypt_v2_from_instruction_file(self): # Should fail when trying to decrypt (proving instruction file was fetched) with pytest.raises(Exception, match="Keyring called"): - pipeline.decrypt(mock_response, bucket="test-bucket", key="test-key", instruction_suffix=".instruction") + pipeline.decrypt( + mock_response, + bucket="test-bucket", + key="test-key", + instruction_suffix=".instruction", + ) # Verify instruction file was fetched mock_s3_client.get_object.assert_called_once_with( @@ -184,7 +194,12 @@ def test_decrypt_v3_from_instruction_file(self): # This should fail with NotImplementedError since V3 decryption isn't implemented yet with pytest.raises(NotImplementedError, match="V3 decryption not yet implemented"): - pipeline.decrypt(mock_response, bucket="test-bucket", key="test-key", instruction_suffix=".instruction") + pipeline.decrypt( + mock_response, + bucket="test-bucket", + key="test-key", + instruction_suffix=".instruction", + ) # Verify instruction file was fetched mock_s3_client.get_object.assert_called_once_with( diff --git a/test/test_stream.py b/test/test_stream.py new file mode 100644 index 00000000..099f07a6 --- /dev/null +++ b/test/test_stream.py @@ -0,0 +1,76 @@ +# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Unit tests for streaming decryption behavior.""" + +import os +from io import BytesIO +from unittest.mock import Mock + +from cryptography.hazmat.primitives.ciphers.aead import AESGCM + +from s3_encryption.stream import BufferedDecryptingStream, DelayedAuthDecryptingStream + + +def _encrypt(plaintext: bytes): + """Encrypt plaintext with AES-GCM, return (ciphertext_with_tag, key, nonce).""" + key = os.urandom(32) + nonce = os.urandom(12) + ciphertext_with_tag = AESGCM(key).encrypt(nonce, plaintext, None) + return ciphertext_with_tag, key, nonce + + +def _make_streaming_body(data: bytes): + """Create a mock StreamingBody wrapping data.""" + body = Mock() + stream = BytesIO(data) + body.read = stream.read + body.close = Mock() + body._stream = stream + return body + + +class TestDelayedAuthReleasesBeforeVerification: + """Delayed auth releases plaintext before the GCM tag is verified.""" + + ##= specification/s3-encryption/client.md#enable-delayed-authentication + ##= type=test + ##% When enabled, the S3EC MAY release plaintext from a stream which has not been authenticated. + def test_delayed_auth_releases_plaintext_before_tag_verification(self): + plaintext = os.urandom(4096) + ciphertext_with_tag, key, nonce = _encrypt(plaintext) + body = _make_streaming_body(ciphertext_with_tag) + + stream = DelayedAuthDecryptingStream(body, key, nonce) + # read(256) decrypts a partial chunk via cipher.update(), releasing + # plaintext without consuming the full ciphertext stream. The GCM tag + # at the end of the stream has not been reached yet. + chunk = stream.read(256) + + # Plaintext was returned before the stream was fully consumed + assert len(chunk) > 0 + # _finalized is False: the GCM tag has NOT been verified yet + assert not stream._finalized + # Ciphertext remains unread in the underlying stream + assert body._stream.tell() < len(ciphertext_with_tag) + + +class TestBufferedWithholdsUntilVerification: + """Buffered mode does not release plaintext until the GCM tag is verified.""" + + ##= specification/s3-encryption/client.md#enable-delayed-authentication + ##= type=test + ##% When disabled the S3EC MUST NOT release plaintext from a stream which has not been authenticated. + def test_buffered_verifies_tag_before_releasing_any_plaintext(self): + plaintext = os.urandom(4096) + ciphertext_with_tag, key, nonce = _encrypt(plaintext) + body = _make_streaming_body(ciphertext_with_tag) + + stream = BufferedDecryptingStream(body, key, nonce) + # read(1) triggers _decrypt(), which calls self._body.read() with no amt, + # consuming the entire ciphertext and verifying the GCM tag before + # returning even 1 byte of plaintext. + chunk = stream.read(1) + + assert chunk == plaintext[:1] + # _plaintext being set confirms full decrypt+verify already happened + assert stream._plaintext is not None From 5869c255006bd262413a3efc7804ed74efe3cd81 Mon Sep 17 00:00:00 2001 From: texastony <5892063+texastony@users.noreply.github.com> Date: Mon, 9 Mar 2026 13:11:07 -0700 Subject: [PATCH 07/31] test(integration): parametrize instruction file tests with buffered and delayed-auth modes --- .../test_i_s3_encryption_instruction_file.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/test/integration/test_i_s3_encryption_instruction_file.py b/test/integration/test_i_s3_encryption_instruction_file.py index 467ddc7a..ea3fda08 100644 --- a/test/integration/test_i_s3_encryption_instruction_file.py +++ b/test/integration/test_i_s3_encryption_instruction_file.py @@ -49,7 +49,8 @@ def test_decrypt_v1_instruction_file(): print("Success! V1 instruction file decryption completed.") -def test_decrypt_v2_instruction_file(): +@pytest.mark.parametrize("delayed_auth", [False, True], ids=["buffered", "delayed-auth"]) +def test_decrypt_v2_instruction_file(delayed_auth): """Test decrypting V2 object with instruction file. V2 format uses ALG_AES_256_GCM_IV12_TAG16_NO_KDF (no key commitment). @@ -60,7 +61,7 @@ def test_decrypt_v2_instruction_file(): kms_client = boto3.client("kms", region_name=region) keyring = KmsKeyring(kms_client, kms_key_id) wrapped_client = boto3.client("s3") - config = S3EncryptionClientConfig(keyring) + config = S3EncryptionClientConfig(keyring, enable_delayed_authentication=delayed_auth) s3ec = S3EncryptionClient(wrapped_client, config) response = s3ec.get_object(Bucket=bucket, Key=key) @@ -134,14 +135,19 @@ def test_decrypt_v3_instruction_file_custom_suffix(): print("Success! V3 custom suffix instruction file decryption completed.") -def test_decrypt_v2_instruction_file_custom_suffix(): +@pytest.mark.parametrize("delayed_auth", [False, True], ids=["buffered", "delayed-auth"]) +def test_decrypt_v2_instruction_file_custom_suffix(delayed_auth): """Test decrypting V2 object with a custom instruction file suffix.""" key = TEST_OBJECTS["v2_instruction_file"] kms_client = boto3.client("kms", region_name=region) keyring = KmsKeyring(kms_client, kms_key_id) wrapped_client = boto3.client("s3") - config = S3EncryptionClientConfig(keyring, instruction_file_suffix=".custom-suffix-instruction") + config = S3EncryptionClientConfig( + keyring, + instruction_file_suffix=".custom-suffix-instruction", + enable_delayed_authentication=delayed_auth, + ) s3ec = S3EncryptionClient(wrapped_client, config) response = s3ec.get_object(Bucket=bucket, Key=key) From 377004baaa9c500de456fc2d6c68d22dde9f6e8d Mon Sep 17 00:00:00 2001 From: texastony <5892063+texastony@users.noreply.github.com> Date: Mon, 9 Mar 2026 15:24:52 -0700 Subject: [PATCH 08/31] test(integration): add large file and 61 GiB placeholder tests for delayed-auth streaming - Add 50 MB V2 delayed-auth streaming decryption test against static object - Add 50 MB V3 test (skipped, V3 not yet implemented) - Add 61 GiB V2/V3 placeholder tests marked @pytest.mark.large (skipped, static objects not yet created) - Parametrize existing instruction file tests with buffered/delayed-auth modes --- .../test_i_s3_encryption_instruction_file.py | 88 +++++++++++++++++++ 1 file changed, 88 insertions(+) diff --git a/test/integration/test_i_s3_encryption_instruction_file.py b/test/integration/test_i_s3_encryption_instruction_file.py index ea3fda08..7c734098 100644 --- a/test/integration/test_i_s3_encryption_instruction_file.py +++ b/test/integration/test_i_s3_encryption_instruction_file.py @@ -23,6 +23,10 @@ "v2_instruction_file": "static-v2-instruction-file-from-java-v4", "v3_instruction_file": "static-v3-instruction-file-from-java-v4", "negative_v2_instruction_file": "NEGATIVE-static-v2-instruction-file-test-from-java-v4", + "large_v2_instruction_file": "static-large-v2-instruction-file-from-java-v4-52428800", + "large_v3_instruction_file": "static-large-v3-instruction-file-from-java-v4-52428800", + "xlarge_v2_instruction_file": "TODO-static-xlarge-v2-instruction-file-61GiB", + "xlarge_v3_instruction_file": "TODO-static-xlarge-v3-instruction-file-61GiB", } @@ -155,3 +159,87 @@ def test_decrypt_v2_instruction_file_custom_suffix(delayed_auth): assert output == "static-v2-instruction-file-from-java-v4" print("Success! V2 custom suffix instruction file decryption completed.") + + +LARGE_FILE_SIZE = 52428800 # 50 MB + + +def test_decrypt_large_v2_instruction_file_delayed_auth(): + """Test streaming decryption of a 50 MB V2 object with delayed authentication.""" + key = TEST_OBJECTS["large_v2_instruction_file"] + + kms_client = boto3.client("kms", region_name=region) + keyring = KmsKeyring(kms_client, kms_key_id) + wrapped_client = boto3.client("s3") + config = S3EncryptionClientConfig(keyring, enable_delayed_authentication=True) + s3ec = S3EncryptionClient(wrapped_client, config) + + response = s3ec.get_object(Bucket=bucket, Key=key) + total = 0 + while chunk := response["Body"].read(65536): + total += len(chunk) + + assert total == LARGE_FILE_SIZE + + +# TODO(v3): enable once V3 decryption is implemented +@pytest.mark.skip(reason="V3 decryption not yet implemented") +def test_decrypt_large_v3_instruction_file_delayed_auth(): + """Test streaming decryption of a 50 MB V3 object with delayed authentication.""" + key = TEST_OBJECTS["large_v3_instruction_file"] + + kms_client = boto3.client("kms", region_name=region) + keyring = KmsKeyring(kms_client, kms_key_id) + wrapped_client = boto3.client("s3") + config = S3EncryptionClientConfig(keyring, enable_delayed_authentication=True) + s3ec = S3EncryptionClient(wrapped_client, config) + + response = s3ec.get_object(Bucket=bucket, Key=key) + total = 0 + while chunk := response["Body"].read(65536): + total += len(chunk) + + assert total == LARGE_FILE_SIZE + + +XLARGE_FILE_SIZE = 61 * 1024 * 1024 * 1024 # 61 GiB + + +@pytest.mark.large +@pytest.mark.skip(reason="61 GiB static test object not yet created") +def test_decrypt_xlarge_v2_instruction_file_delayed_auth(): + """Test streaming decryption of a 61 GiB V2 object with delayed authentication.""" + key = TEST_OBJECTS["xlarge_v2_instruction_file"] + + kms_client = boto3.client("kms", region_name=region) + keyring = KmsKeyring(kms_client, kms_key_id) + wrapped_client = boto3.client("s3") + config = S3EncryptionClientConfig(keyring, enable_delayed_authentication=True) + s3ec = S3EncryptionClient(wrapped_client, config) + + response = s3ec.get_object(Bucket=bucket, Key=key) + total = 0 + while chunk := response["Body"].read(65536): + total += len(chunk) + + assert total == XLARGE_FILE_SIZE + + +@pytest.mark.large +@pytest.mark.skip(reason="61 GiB static test object not yet created") +def test_decrypt_xlarge_v3_instruction_file_delayed_auth(): + """Test streaming decryption of a 61 GiB V3 object with delayed authentication.""" + key = TEST_OBJECTS["xlarge_v3_instruction_file"] + + kms_client = boto3.client("kms", region_name=region) + keyring = KmsKeyring(kms_client, kms_key_id) + wrapped_client = boto3.client("s3") + config = S3EncryptionClientConfig(keyring, enable_delayed_authentication=True) + s3ec = S3EncryptionClient(wrapped_client, config) + + response = s3ec.get_object(Bucket=bucket, Key=key) + total = 0 + while chunk := response["Body"].read(65536): + total += len(chunk) + + assert total == XLARGE_FILE_SIZE From 182d228f2791af292a6f81504ed09a67ef535b20 Mon Sep 17 00:00:00 2001 From: texastony <5892063+texastony@users.noreply.github.com> Date: Mon, 9 Mar 2026 15:37:10 -0700 Subject: [PATCH 09/31] fix: update decrypt return type docstring and verify full delayed-auth plaintext - Fix pipeline decrypt() docstring to reflect both return types - Add assertion that full delayed-auth stream output matches expected plaintext --- src/s3_encryption/pipelines.py | 2 +- test/test_stream.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/s3_encryption/pipelines.py b/src/s3_encryption/pipelines.py index 96dceaf4..703512e5 100644 --- a/src/s3_encryption/pipelines.py +++ b/src/s3_encryption/pipelines.py @@ -115,7 +115,7 @@ def decrypt( enable_delayed_authentication (bool): If True, release plaintext before GCM tag verification. Returns: - BufferedDecryptingStream: A stream that decrypts data lazily on first read. + A decrypting stream (BufferedDecryptingStream or DelayedAuthDecryptingStream). """ # Convert the metadata dictionary to an ObjectMetadata instance streaming_body = response.get("Body") diff --git a/test/test_stream.py b/test/test_stream.py index 099f07a6..79e8f7b5 100644 --- a/test/test_stream.py +++ b/test/test_stream.py @@ -53,6 +53,10 @@ def test_delayed_auth_releases_plaintext_before_tag_verification(self): # Ciphertext remains unread in the underlying stream assert body._stream.tell() < len(ciphertext_with_tag) + # Finish reading the stream and verify full plaintext matches + remaining = stream.read() + assert chunk + remaining == plaintext + class TestBufferedWithholdsUntilVerification: """Buffered mode does not release plaintext until the GCM tag is verified.""" From 4aaddc5508ee9dee937aea1ab4d7fb9f87cfb273 Mon Sep 17 00:00:00 2001 From: texastony <5892063+texastony@users.noreply.github.com> Date: Mon, 9 Mar 2026 15:46:38 -0700 Subject: [PATCH 10/31] chore: register large pytest mark in pyproject.toml --- pyproject.toml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index a5ab41ef..f59bc57e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,11 @@ line-length = 100 target-version = ["py311"] include = '\.pyi?$' +[tool.pytest.ini_options] +markers = [ + "large: marks tests requiring large static test objects (deselect with '-m \"not large\"')", +] + [tool.ruff] line-length = 100 target-version = "py311" From ef178325b6729d205ff282f7f052e1250b7015e0 Mon Sep 17 00:00:00 2001 From: texastony <5892063+texastony@users.noreply.github.com> Date: Tue, 10 Mar 2026 10:44:40 -0700 Subject: [PATCH 11/31] test(integration): remove xlarge placeholder tests and large pytest mark - Remove 61 GiB V2/V3 placeholder tests (static objects not yet created) - Remove large pytest mark registration (no longer used) --- pyproject.toml | 5 --- .../test_i_s3_encryption_instruction_file.py | 45 ------------------- 2 files changed, 50 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f59bc57e..a5ab41ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,11 +36,6 @@ line-length = 100 target-version = ["py311"] include = '\.pyi?$' -[tool.pytest.ini_options] -markers = [ - "large: marks tests requiring large static test objects (deselect with '-m \"not large\"')", -] - [tool.ruff] line-length = 100 target-version = "py311" diff --git a/test/integration/test_i_s3_encryption_instruction_file.py b/test/integration/test_i_s3_encryption_instruction_file.py index 7c734098..9b0841fe 100644 --- a/test/integration/test_i_s3_encryption_instruction_file.py +++ b/test/integration/test_i_s3_encryption_instruction_file.py @@ -25,8 +25,6 @@ "negative_v2_instruction_file": "NEGATIVE-static-v2-instruction-file-test-from-java-v4", "large_v2_instruction_file": "static-large-v2-instruction-file-from-java-v4-52428800", "large_v3_instruction_file": "static-large-v3-instruction-file-from-java-v4-52428800", - "xlarge_v2_instruction_file": "TODO-static-xlarge-v2-instruction-file-61GiB", - "xlarge_v3_instruction_file": "TODO-static-xlarge-v3-instruction-file-61GiB", } @@ -200,46 +198,3 @@ def test_decrypt_large_v3_instruction_file_delayed_auth(): total += len(chunk) assert total == LARGE_FILE_SIZE - - -XLARGE_FILE_SIZE = 61 * 1024 * 1024 * 1024 # 61 GiB - - -@pytest.mark.large -@pytest.mark.skip(reason="61 GiB static test object not yet created") -def test_decrypt_xlarge_v2_instruction_file_delayed_auth(): - """Test streaming decryption of a 61 GiB V2 object with delayed authentication.""" - key = TEST_OBJECTS["xlarge_v2_instruction_file"] - - kms_client = boto3.client("kms", region_name=region) - keyring = KmsKeyring(kms_client, kms_key_id) - wrapped_client = boto3.client("s3") - config = S3EncryptionClientConfig(keyring, enable_delayed_authentication=True) - s3ec = S3EncryptionClient(wrapped_client, config) - - response = s3ec.get_object(Bucket=bucket, Key=key) - total = 0 - while chunk := response["Body"].read(65536): - total += len(chunk) - - assert total == XLARGE_FILE_SIZE - - -@pytest.mark.large -@pytest.mark.skip(reason="61 GiB static test object not yet created") -def test_decrypt_xlarge_v3_instruction_file_delayed_auth(): - """Test streaming decryption of a 61 GiB V3 object with delayed authentication.""" - key = TEST_OBJECTS["xlarge_v3_instruction_file"] - - kms_client = boto3.client("kms", region_name=region) - keyring = KmsKeyring(kms_client, kms_key_id) - wrapped_client = boto3.client("s3") - config = S3EncryptionClientConfig(keyring, enable_delayed_authentication=True) - s3ec = S3EncryptionClient(wrapped_client, config) - - response = s3ec.get_object(Bucket=bucket, Key=key) - total = 0 - while chunk := response["Body"].read(65536): - total += len(chunk) - - assert total == XLARGE_FILE_SIZE From edfcfb89acdcdf0506229052e65b0750764938b2 Mon Sep 17 00:00:00 2001 From: texastony <5892063+texastony@users.noreply.github.com> Date: Tue, 10 Mar 2026 10:46:36 -0700 Subject: [PATCH 12/31] fix(test): use tuple for parametrize args and remove duplicate test_no_body_roundtrip --- test/integration/test_i_s3_encryption.py | 36 +----------------------- 1 file changed, 1 insertion(+), 35 deletions(-) diff --git a/test/integration/test_i_s3_encryption.py b/test/integration/test_i_s3_encryption.py index 441cdf7c..f3f2aa7b 100644 --- a/test/integration/test_i_s3_encryption.py +++ b/test/integration/test_i_s3_encryption.py @@ -71,40 +71,6 @@ def test_empty_string_roundtrip(): print("Success! Empty string encrypted and decrypted correctly.") -def test_no_body_roundtrip(): - key = "no-body-rt" - key += datetime.now().strftime("%Y-%m-%d-%H:%M:%S") - - # Expected data when no Body is provided (empty bytes) - expected_data = b"" - - kms_client = boto3.client("kms", region_name=region) - - keyring = KmsKeyring(kms_client, kms_key_id) - - wrapped_client = boto3.client("s3") - config = S3EncryptionClientConfig(keyring) - s3ec = S3EncryptionClient(wrapped_client, config) - - # Call put_object without providing a Body parameter - s3ec.put_object(Bucket=bucket, Key=key) - - get_req = {"Bucket": bucket, "Key": key} - response = s3ec.get_object(**get_req) - output = response["Body"].read() - - if output != expected_data: - print("Uh oh! Output doesn't match expected empty bytes!") - print("Expected:") - print(repr(expected_data)) - print("Output:") - print(repr(output)) - raise RuntimeError - print( - "Success! Object with no Body parameter encrypted and decrypted correctly as empty bytes." - ) - - def test_unicode_string_roundtrip(): key = "unicode-string-rt" key += datetime.now().strftime("%Y-%m-%d-%H:%M:%S") @@ -477,7 +443,7 @@ def test_encryption_context_missing_on_decrypt(): ##% The S3EC MUST support the option to enable or disable Delayed Authentication mode. @pytest.mark.parametrize("delayed_auth", [False, True], ids=["buffered", "delayed-auth"]) @pytest.mark.parametrize( - "key_prefix, data, encoding", + ("key_prefix", "data", "encoding"), [ ("simple-rt", "test input for simple v3 round trip", "utf-8"), ("empty-string-rt", "", "utf-8"), From cf00af7965a05005e9cd0b1f2cf88cc364e9466c Mon Sep 17 00:00:00 2001 From: texastony <5892063+texastony@users.noreply.github.com> Date: Wed, 18 Mar 2026 14:47:25 -0700 Subject: [PATCH 13/31] refactor: streaming decryptors accept cipher object and tag_length BufferedDecryptingStream and DelayedAuthDecryptingStream now take a decryptor object and tag_length instead of raw key/nonce. This makes them reusable across algorithm suites (GCM, key-committing GCM, CBC). --- src/s3_encryption/pipelines.py | 17 ++++--- src/s3_encryption/stream.py | 84 ++++++++++++++++++++-------------- test/test_stream.py | 24 +++++++--- 3 files changed, 78 insertions(+), 47 deletions(-) diff --git a/src/s3_encryption/pipelines.py b/src/s3_encryption/pipelines.py index 703512e5..218faee1 100644 --- a/src/s3_encryption/pipelines.py +++ b/src/s3_encryption/pipelines.py @@ -10,6 +10,7 @@ import os from attrs import define, field +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.primitives.ciphers.aead import AESGCM from .exceptions import S3EncryptionClientError @@ -18,7 +19,7 @@ from .materials.encrypted_data_key import EncryptedDataKey from .materials.materials import DecryptionMaterials, EncryptionMaterials from .metadata import ObjectMetadata -from .stream import BufferedDecryptingStream, DelayedAuthDecryptingStream +from .stream import GCM_TAG_LENGTH, BufferedDecryptingStream, DelayedAuthDecryptingStream @define @@ -180,13 +181,15 @@ def decrypt( ##% When disabled the S3EC MUST NOT release plaintext from a stream which has not been authenticated. if enable_delayed_authentication is None: raise S3EncryptionClientError("enable_delayed_authentication must be explicitly set") + + decryptor = Cipher( + algorithms.AES(dec_materials.plaintext_data_key), + modes.GCM(dec_materials.iv), + ).decryptor() + if enable_delayed_authentication: - return DelayedAuthDecryptingStream( - streaming_body, dec_materials.plaintext_data_key, dec_materials.iv - ) - return BufferedDecryptingStream( - streaming_body, dec_materials.plaintext_data_key, dec_materials.iv - ) + return DelayedAuthDecryptingStream(streaming_body, decryptor, tag_length=GCM_TAG_LENGTH) + return BufferedDecryptingStream(streaming_body, decryptor, tag_length=GCM_TAG_LENGTH) def _decrypt_v2(self, metadata, encryption_context) -> DecryptionMaterials: """Prepare V2 decryption materials.""" diff --git a/src/s3_encryption/stream.py b/src/s3_encryption/stream.py index 915d5bb2..33547dc1 100644 --- a/src/s3_encryption/stream.py +++ b/src/s3_encryption/stream.py @@ -4,35 +4,34 @@ import io -from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes -from cryptography.hazmat.primitives.ciphers.aead import AESGCM - from .exceptions import S3EncryptionClientError GCM_TAG_LENGTH = 16 class BufferedDecryptingStream: - """A stream that buffers all ciphertext, verifies the GCM auth tag, then releases plaintext. + """A stream that buffers all ciphertext, decrypts, then releases plaintext. - This matches the Java S3EC's BufferedCipherSubscriber behavior: no plaintext - is released until the entire ciphertext has been read and authenticated. + For authenticated ciphers (GCM), no plaintext is released until the entire + ciphertext has been read and the auth tag verified. For unauthenticated + ciphers (CBC), all ciphertext is still buffered before decryption. Implements the same read interface as botocore's StreamingBody so it can be used as a drop-in replacement for parsed["Body"]. """ - def __init__(self, streaming_body, key, nonce): + def __init__(self, streaming_body, decryptor, tag_length=0): """Initialize the buffered decrypting stream. Args: streaming_body: The original StreamingBody containing ciphertext. - key: The plaintext data key (bytes). - nonce: The IV/nonce for AES-GCM decryption (bytes). + decryptor: A cipher decryptor object supporting update()/finalize() + (or finalize_with_tag() when tag_length > 0). + tag_length: Length of the auth tag appended to ciphertext (0 for CBC). """ self._body = streaming_body - self._key = key - self._nonce = nonce + self._decryptor = decryptor + self._tag_length = tag_length self._plaintext = None def _decrypt(self): @@ -40,12 +39,17 @@ def _decrypt(self): if self._plaintext is not None: return try: - ciphertext = self._body.read() - aesgcm = AESGCM(self._key) - decrypted = aesgcm.decrypt(nonce=self._nonce, data=ciphertext, associated_data=None) + data = self._body.read() + if self._tag_length > 0: + ciphertext, tag = data[: -self._tag_length], data[-self._tag_length :] + plaintext = self._decryptor.update(ciphertext) + self._decryptor.finalize_with_tag( + tag + ) + else: + plaintext = self._decryptor.update(data) + self._decryptor.finalize() except Exception as e: raise S3EncryptionClientError(f"Failed to decrypt object: {e}") from e - self._plaintext = io.BytesIO(decrypted) + self._plaintext = io.BytesIO(plaintext) def read(self, amt=None): """Read decrypted data. @@ -87,31 +91,35 @@ def close(self): ##= type=implementation ##% When enabled, the S3EC MAY release plaintext from a stream which has not been authenticated. class DelayedAuthDecryptingStream: - """A stream that releases plaintext before GCM tag verification. + """A stream that releases plaintext before full verification. - Matches the Java S3EC's CipherSubscriber: plaintext is released incrementally - via cipher.update(), and the GCM tag is only verified when the stream is fully - consumed. Data read before finalization is unauthenticated. + Plaintext is released incrementally via cipher.update(). For authenticated + ciphers (GCM), the auth tag is only verified when the stream is fully + consumed. For unauthenticated ciphers (CBC), this behaves identically + to streaming decryption with no tag holdback. """ - def __init__(self, streaming_body, key, nonce): + def __init__(self, streaming_body, decryptor, tag_length=0): """Initialize the delayed-auth decrypting stream. Args: streaming_body: The original StreamingBody containing ciphertext. - key: The plaintext data key (bytes). - nonce: The IV/nonce for AES-GCM decryption (bytes). + decryptor: A cipher decryptor object supporting update()/finalize() + (or finalize_with_tag() when tag_length > 0). + tag_length: Length of the auth tag appended to ciphertext (0 for CBC). """ self._body = streaming_body - self._decryptor = Cipher(algorithms.AES(key), modes.GCM(nonce)).decryptor() + self._decryptor = decryptor + self._tag_length = tag_length self._tag_buffer = b"" self._finalized = False def read(self, amt=None): """Read and decrypt data, releasing plaintext before authentication. - The last 16 bytes of ciphertext are the GCM tag. We hold back a - rolling buffer of 16 bytes so the tag is never passed to update(). + When tag_length > 0, the last tag_length bytes of ciphertext are the + auth tag. We hold back a rolling buffer so the tag is never passed + to update(). """ if self._finalized: return b"" @@ -120,24 +128,30 @@ def read(self, amt=None): if not raw and not self._tag_buffer: return b"" + if self._tag_length == 0: + # No tag to hold back (e.g. CBC) + if not raw: + return self._finalize(tag=b"") + return self._decryptor.update(raw) + data = self._tag_buffer + raw - if len(data) <= GCM_TAG_LENGTH: + if len(data) <= self._tag_length: if raw: self._tag_buffer = data return b"" return self._finalize(tag=data) - self._tag_buffer = data[-GCM_TAG_LENGTH:] - ciphertext = data[:-GCM_TAG_LENGTH] + self._tag_buffer = data[-self._tag_length :] + ciphertext = data[: -self._tag_length] plaintext = self._decryptor.update(ciphertext) # Check if underlying stream is exhausted peek = self._body.read(1) if peek: self._tag_buffer = self._tag_buffer + peek - if len(self._tag_buffer) > GCM_TAG_LENGTH: - extra_ct = self._tag_buffer[:-GCM_TAG_LENGTH] - self._tag_buffer = self._tag_buffer[-GCM_TAG_LENGTH:] + if len(self._tag_buffer) > self._tag_length: + extra_ct = self._tag_buffer[: -self._tag_length] + self._tag_buffer = self._tag_buffer[-self._tag_length :] plaintext += self._decryptor.update(extra_ct) else: plaintext += self._finalize(tag=self._tag_buffer) @@ -145,13 +159,15 @@ def read(self, amt=None): return plaintext def _finalize(self, tag): - """Verify the GCM tag and finalize decryption.""" + """Finalize decryption, verifying the auth tag if present.""" self._finalized = True self._tag_buffer = b"" try: - return self._decryptor.finalize_with_tag(tag) + if tag: + return self._decryptor.finalize_with_tag(tag) + return self._decryptor.finalize() except Exception as e: - raise S3EncryptionClientError(f"GCM tag verification failed: {e}") from e + raise S3EncryptionClientError(f"Decryption finalization failed: {e}") from e def close(self): """Close the underlying stream.""" diff --git a/test/test_stream.py b/test/test_stream.py index 79e8f7b5..16579a66 100644 --- a/test/test_stream.py +++ b/test/test_stream.py @@ -6,12 +6,17 @@ from io import BytesIO from unittest.mock import Mock +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.primitives.ciphers.aead import AESGCM -from s3_encryption.stream import BufferedDecryptingStream, DelayedAuthDecryptingStream +from s3_encryption.stream import ( + GCM_TAG_LENGTH, + BufferedDecryptingStream, + DelayedAuthDecryptingStream, +) -def _encrypt(plaintext: bytes): +def _encrypt_gcm(plaintext: bytes): """Encrypt plaintext with AES-GCM, return (ciphertext_with_tag, key, nonce).""" key = os.urandom(32) nonce = os.urandom(12) @@ -19,6 +24,11 @@ def _encrypt(plaintext: bytes): return ciphertext_with_tag, key, nonce +def _make_gcm_decryptor(key, nonce): + """Create a GCM decryptor object.""" + return Cipher(algorithms.AES(key), modes.GCM(nonce)).decryptor() + + def _make_streaming_body(data: bytes): """Create a mock StreamingBody wrapping data.""" body = Mock() @@ -37,10 +47,11 @@ class TestDelayedAuthReleasesBeforeVerification: ##% When enabled, the S3EC MAY release plaintext from a stream which has not been authenticated. def test_delayed_auth_releases_plaintext_before_tag_verification(self): plaintext = os.urandom(4096) - ciphertext_with_tag, key, nonce = _encrypt(plaintext) + ciphertext_with_tag, key, nonce = _encrypt_gcm(plaintext) body = _make_streaming_body(ciphertext_with_tag) - stream = DelayedAuthDecryptingStream(body, key, nonce) + decryptor = _make_gcm_decryptor(key, nonce) + stream = DelayedAuthDecryptingStream(body, decryptor, tag_length=GCM_TAG_LENGTH) # read(256) decrypts a partial chunk via cipher.update(), releasing # plaintext without consuming the full ciphertext stream. The GCM tag # at the end of the stream has not been reached yet. @@ -66,10 +77,11 @@ class TestBufferedWithholdsUntilVerification: ##% When disabled the S3EC MUST NOT release plaintext from a stream which has not been authenticated. def test_buffered_verifies_tag_before_releasing_any_plaintext(self): plaintext = os.urandom(4096) - ciphertext_with_tag, key, nonce = _encrypt(plaintext) + ciphertext_with_tag, key, nonce = _encrypt_gcm(plaintext) body = _make_streaming_body(ciphertext_with_tag) - stream = BufferedDecryptingStream(body, key, nonce) + decryptor = _make_gcm_decryptor(key, nonce) + stream = BufferedDecryptingStream(body, decryptor, tag_length=GCM_TAG_LENGTH) # read(1) triggers _decrypt(), which calls self._body.read() with no amt, # consuming the entire ciphertext and verifying the GCM tag before # returning even 1 byte of plaintext. From 8c3fbd6d49033c9e431a9e55988f0117d3070c41 Mon Sep 17 00:00:00 2001 From: texastony <5892063+texastony@users.noreply.github.com> Date: Thu, 19 Mar 2026 11:50:58 -0700 Subject: [PATCH 14/31] refactor: use AlgorithmSuite properties for tag length and block size - Add cipher_tag_length_bytes and cipher_block_size_bytes properties to AlgorithmSuite - Replace hardcoded GCM_TAG_LENGTH and PKCS7(128) with algorithm suite properties - Remove dead code: _decrypt_cbc_content() - Make _make_decrypting_stream and _decrypt_kc_gcm_content static methods - Remove GCM_TAG_LENGTH constant from stream.py - Make tag_length required (no default) on DelayedAuthDecryptingStream --- src/s3_encryption/materials/materials.py | 20 ++++++ src/s3_encryption/pipelines.py | 89 ++++++++++-------------- src/s3_encryption/stream.py | 8 +-- test/test_stream.py | 14 +++- 4 files changed, 68 insertions(+), 63 deletions(-) diff --git a/src/s3_encryption/materials/materials.py b/src/s3_encryption/materials/materials.py index f2e8fd4f..80f682f0 100644 --- a/src/s3_encryption/materials/materials.py +++ b/src/s3_encryption/materials/materials.py @@ -172,6 +172,26 @@ def kc_gcm_iv(self) -> bytes: ##% the IV used in the AES-GCM content encryption/decryption MUST consist entirely of bytes with the value 0x01. return b"\x01" * self.cipher_iv_length_bytes + @property + def cipher_block_size_bits(self) -> int: + """Block size of the cipher in bits.""" + return self._cipher_block_size_bits + + @property + def cipher_block_size_bytes(self) -> int: + """Block size of the cipher in bytes.""" + return self._cipher_block_size_bits // 8 + + @property + def cipher_tag_length_bits(self) -> int: + """Authentication tag length of the cipher in bits.""" + return self._cipher_tag_length_bits + + @property + def cipher_tag_length_bytes(self) -> int: + """Authentication tag length of the cipher in bytes.""" + return self._cipher_tag_length_bits // 8 + class CommitmentPolicy(Enum): """Commitment policies controlling key-commitment behavior.""" diff --git a/src/s3_encryption/pipelines.py b/src/s3_encryption/pipelines.py index 014b781c..587cd5c1 100644 --- a/src/s3_encryption/pipelines.py +++ b/src/s3_encryption/pipelines.py @@ -15,7 +15,7 @@ from cryptography.hazmat.primitives.ciphers.aead import AESGCM from cryptography.hazmat.primitives.padding import PKCS7 -from .exceptions import S3EncryptionClientError, S3EncryptionClientSecurityError +from .exceptions import S3EncryptionClientError from .instruction_file import fetch_instruction_file from .key_derivation import derive_keys, verify_commitment from .materials.crypto_materials_manager import AbstractCryptoMaterialsManager @@ -27,7 +27,7 @@ EncryptionMaterials, ) from .metadata import ObjectMetadata -from .stream import GCM_TAG_LENGTH, BufferedDecryptingStream, DelayedAuthDecryptingStream +from .stream import BufferedDecryptingStream, DelayedAuthDecryptingStream @define @@ -256,8 +256,6 @@ def decrypt( if self.s3_client is None: raise S3EncryptionClientError("s3_client required to fetch instruction file") - # TODO: we should validate that these parameters must be None - # when not in instruction file mode. if bucket is None or key is None: raise S3EncryptionClientError("Bucket and key required to fetch instruction file") if instruction_suffix is None: @@ -395,15 +393,32 @@ def decrypt( # Build cipher decryptor and return streaming wrapper based on algorithm suite match dec_materials.algorithm_suite: case AlgorithmSuite.ALG_AES_256_CBC_IV16_NO_KDF: - decryptor = Cipher( + ##= specification/s3-encryption/decryption.md#cbc-decryption + ##= type=implementation + ##% If an object is encrypted with ALG_AES_256_CBC_IV16_NO_KDF and + ##% [legacy unauthenticated algorithm suites](#legacy-decryption) is enabled, + ##% then the S3EC MUST create a cipher with AES in CBC Mode with PKCS5Padding or + ##% PKCS7Padding compatible padding for a 16-byte block cipher + ##% (example: for the Java JCE, this is "AES/CBC/PKCS5Padding"). + ##= specification/s3-encryption/decryption.md#cbc-decryption + ##= type=implementation + ##% If the cipher object cannot be created as described above, + ##% Decryption MUST fail. + ##= specification/s3-encryption/decryption.md#cbc-decryption + ##= type=implementation + ##% The error SHOULD detail why the cipher could not be initialized + ##% (such as CBC or PKCS5Padding is not supported by the underlying crypto provider). + cipher = Cipher( algorithms.AES(dec_materials.plaintext_data_key), modes.CBC(dec_materials.iv), - ).decryptor() - unpadder = PKCS7(128).unpadder() + ) + decryptor = cipher.decryptor() + # Remove PKCS7 padding (compatible with PKCS5Padding for 16-byte block ciphers) + unpadder = PKCS7(dec_materials.algorithm_suite.cipher_block_size_bits).unpadder() return self._make_decrypting_stream( streaming_body, decryptor, - tag_length=0, + tag_length=dec_materials.algorithm_suite.cipher_tag_length_bytes, enable_delayed_authentication=enable_delayed_authentication, unpadder=unpadder, ) @@ -412,14 +427,14 @@ def decrypt( ##= type=implementation ##% The client MUST NOT provide any AAD when encrypting with ##% ALG_AES_256_GCM_IV12_TAG16_NO_KDF. - decryptor = Cipher( - algorithms.AES(dec_materials.plaintext_data_key), - modes.GCM(dec_materials.iv), - ).decryptor() + cipher = Cipher( + algorithms.AES(dec_materials.plaintext_data_key), modes.GCM(dec_materials.iv) + ) + decryptor = cipher.decryptor() return self._make_decrypting_stream( streaming_body, decryptor, - tag_length=GCM_TAG_LENGTH, + tag_length=dec_materials.algorithm_suite.cipher_tag_length_bytes, enable_delayed_authentication=enable_delayed_authentication, ) case AlgorithmSuite.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY: @@ -432,8 +447,9 @@ def decrypt( case _: raise S3EncryptionClientError("Unknown algorithm suite!") + @staticmethod def _make_decrypting_stream( - self, streaming_body, decryptor, tag_length, enable_delayed_authentication, unpadder=None + streaming_body, decryptor, tag_length, enable_delayed_authentication, unpadder=None ): """Return a BufferedDecryptingStream or DelayedAuthDecryptingStream.""" if enable_delayed_authentication: @@ -460,16 +476,16 @@ def _decrypt_kc_gcm_streaming( ) verify_commitment(stored_commitment, derived_commitment) - decryptor = Cipher( + cipher = Cipher( algorithms.AES(derived_encryption_key), modes.GCM(dec_materials.algorithm_suite.kc_gcm_iv), - ).decryptor() + ) + decryptor = cipher.decryptor() decryptor.authenticate_additional_data(dec_materials.algorithm_suite.suite_id_bytes) - return self._make_decrypting_stream( streaming_body, decryptor, - tag_length=GCM_TAG_LENGTH, + tag_length=dec_materials.algorithm_suite.cipher_tag_length_bytes, enable_delayed_authentication=enable_delayed_authentication, ) @@ -515,40 +531,6 @@ def _decrypt_v1_v2( return self.cmm.decrypt_materials(dec_materials) - def _decrypt_cbc_content(self, dec_materials, encrypted_data): - """Decrypt content encrypted with ALG_AES_256_CBC_IV16_NO_KDF.""" - ##= specification/s3-encryption/decryption.md#cbc-decryption - ##= type=implementation - ##% If an object is encrypted with ALG_AES_256_CBC_IV16_NO_KDF and - ##% [legacy unauthenticated algorithm suites](#legacy-decryption) is enabled, - ##% then the S3EC MUST create a cipher with AES in CBC Mode with PKCS5Padding or - ##% PKCS7Padding compatible padding for a 16-byte block cipher - ##% (example: for the Java JCE, this is "AES/CBC/PKCS5Padding"). - ##= specification/s3-encryption/decryption.md#cbc-decryption - ##= type=implementation - ##% If the cipher object cannot be created as described above, - ##% Decryption MUST fail. - ##= specification/s3-encryption/decryption.md#cbc-decryption - ##= type=implementation - ##% The error SHOULD detail why the cipher could not be initialized - ##% (such as CBC or PKCS5Padding is not supported by the underlying crypto provider). - try: - cipher = Cipher( - algorithms.AES(dec_materials.plaintext_data_key), - modes.CBC(dec_materials.iv), - ) - decryptor = cipher.decryptor() - padded_plaintext = decryptor.update(encrypted_data) + decryptor.finalize() - - # Remove PKCS7 padding (compatible with PKCS5Padding for 16-byte block ciphers) - unpadder = PKCS7(128).unpadder() - return unpadder.update(padded_plaintext) + unpadder.finalize() - except Exception as e: - raise S3EncryptionClientSecurityError( - f"Failed to decrypt CBC content: {e}. " - "Ensure the underlying crypto provider supports AES/CBC/PKCS7Padding." - ) from e - ##= specification/s3-encryption/data-format/content-metadata.md#v3-only ##% The V3 format uses compression here such that each wrapping algorithm is represented by a two digit string. ##= specification/s3-encryption/data-format/content-metadata.md#v3-only @@ -603,7 +585,8 @@ def _decrypt_v3(self, metadata, encryption_context) -> DecryptionMaterials: return self.cmm.decrypt_materials(dec_materials) - def _decrypt_kc_gcm_content(self, dec_materials, encrypted_data, metadata): + @staticmethod + def _decrypt_kc_gcm_content(dec_materials, encrypted_data, metadata): """Decrypt content encrypted with ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY. Performs HKDF key derivation, key commitment verification, and AES-GCM decryption. diff --git a/src/s3_encryption/stream.py b/src/s3_encryption/stream.py index 97dc3465..d1ef3b3b 100644 --- a/src/s3_encryption/stream.py +++ b/src/s3_encryption/stream.py @@ -6,8 +6,6 @@ from .exceptions import S3EncryptionClientError -GCM_TAG_LENGTH = 16 - def _unpad(plaintext, unpadder): """Apply unpadder if provided, otherwise return plaintext as-is.""" @@ -19,10 +17,6 @@ def _unpad(plaintext, unpadder): class BufferedDecryptingStream: """A stream that buffers all ciphertext, decrypts, then releases plaintext. - For authenticated ciphers (GCM), no plaintext is released until the entire - ciphertext has been read and the auth tag verified. For unauthenticated - ciphers (CBC), all ciphertext is still buffered before decryption. - Implements the same read interface as botocore's StreamingBody so it can be used as a drop-in replacement for parsed["Body"]. """ @@ -109,7 +103,7 @@ class DelayedAuthDecryptingStream: to streaming decryption with no tag holdback. """ - def __init__(self, streaming_body, decryptor, tag_length=0, unpadder=None): + def __init__(self, streaming_body, decryptor, tag_length, unpadder=None): """Initialize the delayed-auth decrypting stream. Args: diff --git a/test/test_stream.py b/test/test_stream.py index 16579a66..c1956e70 100644 --- a/test/test_stream.py +++ b/test/test_stream.py @@ -9,8 +9,8 @@ from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.primitives.ciphers.aead import AESGCM +from s3_encryption.materials import AlgorithmSuite from s3_encryption.stream import ( - GCM_TAG_LENGTH, BufferedDecryptingStream, DelayedAuthDecryptingStream, ) @@ -51,7 +51,11 @@ def test_delayed_auth_releases_plaintext_before_tag_verification(self): body = _make_streaming_body(ciphertext_with_tag) decryptor = _make_gcm_decryptor(key, nonce) - stream = DelayedAuthDecryptingStream(body, decryptor, tag_length=GCM_TAG_LENGTH) + stream = DelayedAuthDecryptingStream( + body, + decryptor, + tag_length=AlgorithmSuite.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY.cipher_tag_length_bytes, + ) # read(256) decrypts a partial chunk via cipher.update(), releasing # plaintext without consuming the full ciphertext stream. The GCM tag # at the end of the stream has not been reached yet. @@ -81,7 +85,11 @@ def test_buffered_verifies_tag_before_releasing_any_plaintext(self): body = _make_streaming_body(ciphertext_with_tag) decryptor = _make_gcm_decryptor(key, nonce) - stream = BufferedDecryptingStream(body, decryptor, tag_length=GCM_TAG_LENGTH) + stream = BufferedDecryptingStream( + body, + decryptor, + tag_length=AlgorithmSuite.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY.cipher_tag_length_bytes, + ) # read(1) triggers _decrypt(), which calls self._body.read() with no amt, # consuming the entire ciphertext and verifying the GCM tag before # returning even 1 byte of plaintext. From a6fb1362736fc1e3b3983c71db3154f0423fca8c Mon Sep 17 00:00:00 2001 From: texastony <5892063+texastony@users.noreply.github.com> Date: Fri, 20 Mar 2026 12:07:35 -0700 Subject: [PATCH 15/31] fix: address PR #150 review comments from kessplas - Make instruction_suffix and enable_delayed_authentication positional args - Move duvet annotation to BufferedDecryptingStream return - Hardcode CBC to always stream (no auth tag, matches Java behavior) - Move duvet annotations from _decrypt_kc_gcm_content to _decrypt_kc_gcm_streaming - Remove unused _decrypt_kc_gcm_content method - Fix DelayedAuthDecryptingStream CBC unpadding (peek + incremental unpadder) - Add CBC unit tests for both stream types (roundtrip, chunked, finalize, padding) - Add delayed authentication mode integration test with duvet citation --- src/s3_encryption/__init__.py | 6 +- src/s3_encryption/pipelines.py | 89 ++++++++------------- src/s3_encryption/stream.py | 37 +++++++-- test/integration/test_i_s3_encryption.py | 23 ++++++ test/test_decryption.py | 24 ++++-- test/test_default_algorithm_commitment.py | 4 +- test/test_key_commitment.py | 8 +- test/test_pipelines.py | 18 ++--- test/test_stream.py | 94 +++++++++++++++++++++++ 9 files changed, 217 insertions(+), 86 deletions(-) diff --git a/src/s3_encryption/__init__.py b/src/s3_encryption/__init__.py index ee7a93ab..01207c4d 100644 --- a/src/s3_encryption/__init__.py +++ b/src/s3_encryption/__init__.py @@ -221,11 +221,11 @@ def on_get_object_after_call(self, parsed, **kwargs): ) decrypted_data = pipeline.decrypt( response, - encryption_context, - bucket=getattr(self._context, _CTX_BUCKET, None), - key=getattr(self._context, _CTX_KEY, None), instruction_suffix=self.config.instruction_file_suffix, enable_delayed_authentication=self.config.enable_delayed_authentication, + encryption_context=encryption_context, + bucket=getattr(self._context, _CTX_BUCKET, None), + key=getattr(self._context, _CTX_KEY, None), ) # Replace body with decrypting stream diff --git a/src/s3_encryption/pipelines.py b/src/s3_encryption/pipelines.py index 587cd5c1..7a0a61e3 100644 --- a/src/s3_encryption/pipelines.py +++ b/src/s3_encryption/pipelines.py @@ -223,20 +223,20 @@ def _determine_algorithm_suite(self, metadata) -> AlgorithmSuite: def decrypt( self, response, + instruction_suffix, + enable_delayed_authentication, encryption_context=None, bucket=None, key=None, - instruction_suffix=None, - enable_delayed_authentication=None, ): """Decrypt the data after it is retrieved from S3. Args: response (dict): The response from S3 containing the encrypted data and metadata + instruction_suffix(str): suffix for instruction file encryption_context (dict, optional): Additional context for decryption bucket (str, optional): S3 bucket name (required for instruction file) key (str, optional): S3 object key (required for instruction file) - instruction_suffix(str, optional): suffix for instruction file; defaults to ".instruction". enable_delayed_authentication (bool): If True, release plaintext before GCM tag verification. Returns: @@ -384,9 +384,6 @@ def decrypt( ##= type=implementation ##% When the commitment policy is REQUIRE_ENCRYPT_ALLOW_DECRYPT, the S3EC MUST allow decryption using algorithm suites which do not support key commitment. - ##= specification/s3-encryption/client.md#enable-delayed-authentication - ##= type=implementation - ##% When disabled the S3EC MUST NOT release plaintext from a stream which has not been authenticated. if enable_delayed_authentication is None: raise S3EncryptionClientError("enable_delayed_authentication must be explicitly set") @@ -419,7 +416,8 @@ def decrypt( streaming_body, decryptor, tag_length=dec_materials.algorithm_suite.cipher_tag_length_bytes, - enable_delayed_authentication=enable_delayed_authentication, + # AES-CBC does not have an Auth tag, and thus should always be streamed. + enable_delayed_authentication=True, unpadder=unpadder, ) case AlgorithmSuite.ALG_AES_256_GCM_IV12_TAG16_NO_KDF: @@ -456,6 +454,9 @@ def _make_decrypting_stream( return DelayedAuthDecryptingStream( streaming_body, decryptor, tag_length=tag_length, unpadder=unpadder ) + ##= specification/s3-encryption/client.md#enable-delayed-authentication + ##= type=implementation + ##% When disabled the S3EC MUST NOT release plaintext from a stream which has not been authenticated. return BufferedDecryptingStream( streaming_body, decryptor, tag_length=tag_length, unpadder=unpadder ) @@ -468,14 +469,40 @@ def _decrypt_kc_gcm_streaming( Performs HKDF key derivation, key commitment verification, then returns a streaming decryptor. """ + ##= specification/s3-encryption/encryption.md#alg-aes-256-gcm-hkdf-sha512-commit-key + ##= type=implementation + ##% The client MUST use HKDF to derive the key commitment value and the derived encrypting key as described in [Key Derivation](key-derivation.md). message_id = base64.b64decode(metadata.message_id_v3) stored_commitment = base64.b64decode(metadata.key_commitment_v3) + ##= specification/s3-encryption/decryption.md#decrypting-with-commitment + ##= type=implementation + ##% When using an algorithm suite which supports key commitment, the client MUST verify + ##% that the [derived key commitment](./key-derivation.md#hkdf-operation) contains the + ##% same bytes as the stored key commitment retrieved from the stored object's metadata. + ##= specification/s3-encryption/decryption.md#decrypting-with-commitment + ##= type=implementation + ##% When using an algorithm suite which supports key commitment, the client MUST verify the key commitment values match before deriving + ##% the [derived encryption key](./key-derivation.md#hkdf-operation). derived_encryption_key, derived_commitment = derive_keys( dec_materials.plaintext_data_key, message_id, dec_materials.algorithm_suite ) verify_commitment(stored_commitment, derived_commitment) + ##= specification/s3-encryption/key-derivation.md#hkdf-operation + ##= type=implementation + ##% When encrypting or decrypting with ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY, + ##% the IV used in the AES-GCM content encryption/decryption MUST consist entirely of bytes with the value 0x01. + ##= specification/s3-encryption/key-derivation.md#hkdf-operation + ##= type=implementation + ##% The IV's total length MUST match the IV length defined by the algorithm suite. + ##= specification/s3-encryption/key-derivation.md#hkdf-operation + ##= type=implementation + ##% The client MUST initialize the cipher, or call an AES-GCM encryption API, with the derived encryption key, an IV containing only bytes with the value 0x01, + ##% and the tag length defined in the Algorithm Suite when encrypting or decrypting with ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY. + ##= specification/s3-encryption/key-derivation.md#hkdf-operation + ##= type=implementation + ##% The client MUST set the AAD to the Algorithm Suite ID represented as bytes. cipher = Cipher( algorithms.AES(derived_encryption_key), modes.GCM(dec_materials.algorithm_suite.kc_gcm_iv), @@ -584,51 +611,3 @@ def _decrypt_v3(self, metadata, encryption_context) -> DecryptionMaterials: ) return self.cmm.decrypt_materials(dec_materials) - - @staticmethod - def _decrypt_kc_gcm_content(dec_materials, encrypted_data, metadata): - """Decrypt content encrypted with ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY. - - Performs HKDF key derivation, key commitment verification, and AES-GCM decryption. - """ - message_id = base64.b64decode(metadata.message_id_v3) - stored_commitment = base64.b64decode(metadata.key_commitment_v3) - - ##= specification/s3-encryption/encryption.md#alg-aes-256-gcm-hkdf-sha512-commit-key - ##= type=implementation - ##% The client MUST use HKDF to derive the key commitment value and the derived encrypting key as described in [Key Derivation](key-derivation.md). - derived_encryption_key, derived_commitment = derive_keys( - dec_materials.plaintext_data_key, message_id, dec_materials.algorithm_suite - ) - - ##= specification/s3-encryption/decryption.md#decrypting-with-commitment - ##= type=implementation - ##% When using an algorithm suite which supports key commitment, the client MUST verify - ##% that the [derived key commitment](./key-derivation.md#hkdf-operation) contains the - ##% same bytes as the stored key commitment retrieved from the stored object's metadata. - ##= specification/s3-encryption/decryption.md#decrypting-with-commitment - ##= type=implementation - ##% When using an algorithm suite which supports key commitment, the client MUST verify the key commitment values match before deriving - ##% the [derived encryption key](./key-derivation.md#hkdf-operation). - verify_commitment(stored_commitment, derived_commitment) - - ##= specification/s3-encryption/key-derivation.md#hkdf-operation - ##= type=implementation - ##% When encrypting or decrypting with ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY, - ##% the IV used in the AES-GCM content encryption/decryption MUST consist entirely of bytes with the value 0x01. - ##= specification/s3-encryption/key-derivation.md#hkdf-operation - ##= type=implementation - ##% The IV's total length MUST match the IV length defined by the algorithm suite. - ##= specification/s3-encryption/key-derivation.md#hkdf-operation - ##= type=implementation - ##% The client MUST initialize the cipher, or call an AES-GCM encryption API, with the derived encryption key, an IV containing only bytes with the value 0x01, - ##% and the tag length defined in the Algorithm Suite when encrypting or decrypting with ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY. - ##= specification/s3-encryption/key-derivation.md#hkdf-operation - ##= type=implementation - ##% The client MUST set the AAD to the Algorithm Suite ID represented as bytes. - aesgcm = AESGCM(derived_encryption_key) - return aesgcm.decrypt( - nonce=dec_materials.algorithm_suite.kc_gcm_iv, - data=encrypted_data, - associated_data=dec_materials.algorithm_suite.suite_id_bytes, - ) diff --git a/src/s3_encryption/stream.py b/src/s3_encryption/stream.py index d1ef3b3b..45ca58a4 100644 --- a/src/s3_encryption/stream.py +++ b/src/s3_encryption/stream.py @@ -6,6 +6,19 @@ from .exceptions import S3EncryptionClientError +##= specification/s3-encryption/client.md#set-buffer-size +##= type=exception +##= reason=Optional Feature that is a two-way door to implement later +##% The S3EC SHOULD accept a configurable buffer size which refers to the maximum ciphertext length in bytes to store in memory when Delayed Authentication mode is disabled. +##= specification/s3-encryption/client.md#set-buffer-size +##= type=exception +##= reason=Optional Feature that is a two-way door to implement later +##% If Delayed Authentication mode is enabled, and the buffer size has been set to a value other than its default, the S3EC MUST throw an exception. +##= specification/s3-encryption/client.md#set-buffer-size +##= type=exception +##= reason=Optional Feature that is a two-way door to implement later +##% If Delayed Authentication mode is disabled, and no buffer size is provided, the S3EC MUST set the buffer size to a reasonable default. + def _unpad(plaintext, unpadder): """Apply unpadder if provided, otherwise return plaintext as-is.""" @@ -21,7 +34,7 @@ class BufferedDecryptingStream: used as a drop-in replacement for parsed["Body"]. """ - def __init__(self, streaming_body, decryptor, tag_length=0, unpadder=None): + def __init__(self, streaming_body, decryptor, tag_length, unpadder=None): """Initialize the buffered decrypting stream. Args: @@ -56,7 +69,7 @@ def _decrypt(self): self._plaintext = io.BytesIO(plaintext) def read(self, amt=None): - """Read decrypted data. + """Reads the entire ciphertext stream and then returns decrypted data. Args: amt: Number of bytes to read. If None, reads all remaining data. @@ -70,7 +83,7 @@ def read(self, amt=None): return self._plaintext.read(amt) def iter_chunks(self, chunk_size=1024): - """Iterate over decrypted data in chunks. + """Reads the entire ciphertext stream and then iterates over decrypted data in chunks. Args: chunk_size: Size of each chunk in bytes. @@ -136,9 +149,19 @@ def read(self, amt=None): if self._tag_length == 0: # No tag to hold back (e.g. CBC) - if not raw: + data = self._tag_buffer + raw + self._tag_buffer = b"" + if not data: return self._finalize(tag=b"") - return self._decryptor.update(raw) + plaintext = self._decryptor.update(data) + if self._unpadder: + plaintext = self._unpadder.update(plaintext) + peek = self._body.read(1) + if peek: + self._tag_buffer = peek + else: + plaintext += self._finalize(tag=b"") + return plaintext data = self._tag_buffer + raw if len(data) <= self._tag_length: @@ -173,7 +196,9 @@ def _finalize(self, tag): plaintext = self._decryptor.finalize_with_tag(tag) else: plaintext = self._decryptor.finalize() - return _unpad(plaintext, self._unpadder) + if self._unpadder: + plaintext = self._unpadder.update(plaintext) + self._unpadder.finalize() + return plaintext except Exception as e: raise S3EncryptionClientError(f"Decryption finalization failed: {e}") from e diff --git a/test/integration/test_i_s3_encryption.py b/test/integration/test_i_s3_encryption.py index 15133c05..36f826bd 100644 --- a/test/integration/test_i_s3_encryption.py +++ b/test/integration/test_i_s3_encryption.py @@ -252,3 +252,26 @@ def test_put_object_uses_configured_algorithm(algorithm_suite, commitment_policy meta_key, expected_value = _EXPECTED_ALGORITHM_METADATA[algorithm_suite] assert meta_key in metadata, f"Expected metadata key '{meta_key}' not found in {metadata}" assert metadata[meta_key] == expected_value + + +##= specification/s3-encryption/client.md#enable-delayed-authentication +##= type=test +##% The S3EC MUST support the option to enable or disable Delayed Authentication mode. +@pytest.mark.parametrize("enable_delayed_auth", [False, True], ids=["buffered", "delayed-auth"]) +def test_delayed_authentication_mode(enable_delayed_auth): + """S3EC MUST support enabling and disabling delayed authentication.""" + key = _unique_key("delayed-auth-mode-") + data = b"test delayed authentication mode" + + kms_client = boto3.client("kms", region_name=region) + keyring = KmsKeyring(kms_client, kms_key_id) + wrapped_client = boto3.client("s3") + config = S3EncryptionClientConfig( + keyring, + enable_delayed_authentication=enable_delayed_auth, + ) + s3ec = S3EncryptionClient(wrapped_client, config) + + s3ec.put_object(Bucket=bucket, Key=key, Body=data) + response = s3ec.get_object(Bucket=bucket, Key=key) + assert response["Body"].read() == data diff --git a/test/test_decryption.py b/test/test_decryption.py index a96b9e24..3c4e5ae1 100644 --- a/test/test_decryption.py +++ b/test/test_decryption.py @@ -106,7 +106,9 @@ def test_cbc_object_rejected_when_legacy_disabled(self): ) with pytest.raises(S3EncryptionClientError, match="ALG_AES_256_CBC_IV16_NO_KDF"): - pipeline.decrypt(_response(_v1_cbc_metadata()), enable_delayed_authentication=False) + pipeline.decrypt( + _response(_v1_cbc_metadata()), ".instruction", enable_delayed_authentication=False + ) ##= specification/s3-encryption/decryption.md#cbc-decryption ##= type=test @@ -149,7 +151,7 @@ def test_cbc_decryption_succeeds_when_legacy_enabled(self): ) result = pipeline.decrypt( - _response(metadata, ciphertext), enable_delayed_authentication=False + _response(metadata, ciphertext), ".instruction", enable_delayed_authentication=False ) assert result.read() == plaintext @@ -194,9 +196,9 @@ def test_cbc_decryption_fails_with_wrong_key(self): keyring_return=dec_mats, ) - with pytest.raises(S3EncryptionClientError, match="Failed to decrypt object"): + with pytest.raises(S3EncryptionClientError, match="Decryption finalization failed"): pipeline.decrypt( - _response(metadata, ciphertext), enable_delayed_authentication=False + _response(metadata, ciphertext), ".instruction", enable_delayed_authentication=False ).read() @@ -294,7 +296,9 @@ def test_commitment_verified_before_content_decryption(self): S3EncryptionClientSecurityError, match="Key commitment verification failed" ): pipeline.decrypt( - _response(metadata, b"fake-ciphertext"), enable_delayed_authentication=False + _response(metadata, b"fake-ciphertext"), + ".instruction", + enable_delayed_authentication=False, ) @@ -328,7 +332,9 @@ def test_require_decrypt_rejects_non_committing_suite(self): ) with pytest.raises(S3EncryptionClientError, match="cannot decrypt non-key-committing"): - pipeline.decrypt(_response(_v2_gcm_metadata()), enable_delayed_authentication=False) + pipeline.decrypt( + _response(_v2_gcm_metadata()), ".instruction", enable_delayed_authentication=False + ) def test_allow_decrypt_accepts_non_committing_suite(self): """REQUIRE_ENCRYPT_ALLOW_DECRYPT MUST allow non-committing algorithm suites.""" @@ -359,7 +365,7 @@ def test_allow_decrypt_accepts_non_committing_suite(self): ) result = pipeline.decrypt( - _response(metadata, ciphertext), enable_delayed_authentication=False + _response(metadata, ciphertext), ".instruction", enable_delayed_authentication=False ) assert result.read() == plaintext @@ -395,4 +401,6 @@ def test_legacy_cbc_rejected_by_default(self): ) with pytest.raises(S3EncryptionClientError, match="not configured to decrypt"): - pipeline.decrypt(_response(_v1_cbc_metadata()), enable_delayed_authentication=False) + pipeline.decrypt( + _response(_v1_cbc_metadata()), ".instruction", enable_delayed_authentication=False + ) diff --git a/test/test_default_algorithm_commitment.py b/test/test_default_algorithm_commitment.py index 03c9a841..0b55b9aa 100644 --- a/test/test_default_algorithm_commitment.py +++ b/test/test_default_algorithm_commitment.py @@ -91,5 +91,7 @@ def test_default_encryption_decryptable_with_require_decrypt(self): cmm, commitment_policy=CommitmentPolicy.REQUIRE_ENCRYPT_REQUIRE_DECRYPT, ) - result = decrypt_pipeline.decrypt(response, enable_delayed_authentication=False) + result = decrypt_pipeline.decrypt( + response, ".instruction", enable_delayed_authentication=False + ) assert result.read() == plaintext diff --git a/test/test_key_commitment.py b/test/test_key_commitment.py index 73ff44c9..79b50b9b 100644 --- a/test/test_key_commitment.py +++ b/test/test_key_commitment.py @@ -111,7 +111,7 @@ def test_forbid_encrypt_allows_non_committing_decrypt(self): commitment_policy=CommitmentPolicy.FORBID_ENCRYPT_ALLOW_DECRYPT, keyring_return=dec_mats, ) - result = pipeline.decrypt(response, enable_delayed_authentication=False) + result = pipeline.decrypt(response, ".instruction", enable_delayed_authentication=False) assert result.read() == plaintext ##= specification/s3-encryption/key-commitment.md#commitment-policy @@ -126,7 +126,7 @@ def test_require_encrypt_allow_decrypt_allows_non_committing_decrypt(self): commitment_policy=CommitmentPolicy.REQUIRE_ENCRYPT_ALLOW_DECRYPT, keyring_return=dec_mats, ) - result = pipeline.decrypt(response, enable_delayed_authentication=False) + result = pipeline.decrypt(response, ".instruction", enable_delayed_authentication=False) assert result.read() == plaintext ##= specification/s3-encryption/key-commitment.md#commitment-policy @@ -142,7 +142,7 @@ def test_require_require_rejects_non_committing_decrypt(self): keyring_return=dec_mats, ) with pytest.raises(S3EncryptionClientError, match="cannot decrypt non-key-committing"): - pipeline.decrypt(response, enable_delayed_authentication=False) + pipeline.decrypt(response, ".instruction", enable_delayed_authentication=False) def test_require_require_allows_committing_decrypt(self): """REQUIRE_ENCRYPT_REQUIRE_DECRYPT MUST allow decryption with committing suites.""" @@ -153,5 +153,5 @@ def test_require_require_allows_committing_decrypt(self): commitment_policy=CommitmentPolicy.REQUIRE_ENCRYPT_REQUIRE_DECRYPT, keyring_return=dec_mats, ) - result = pipeline.decrypt(response, enable_delayed_authentication=False) + result = pipeline.decrypt(response, ".instruction", enable_delayed_authentication=False) assert result.read() == plaintext diff --git a/test/test_pipelines.py b/test/test_pipelines.py index 9488634c..e3d34e35 100644 --- a/test/test_pipelines.py +++ b/test/test_pipelines.py @@ -68,10 +68,10 @@ def test_decrypt_v1_from_instruction_file(self): with pytest.raises(Exception, match="Keyring called"): pipeline.decrypt( mock_response, - bucket="test-bucket", - key="test-key", instruction_suffix=".instruction", enable_delayed_authentication=False, + bucket="test-bucket", + key="test-key", ) # Verify instruction file was fetched @@ -133,10 +133,10 @@ def test_decrypt_v2_from_instruction_file(self): with pytest.raises(Exception, match="Keyring called"): pipeline.decrypt( mock_response, - bucket="test-bucket", - key="test-key", instruction_suffix=".instruction", enable_delayed_authentication=False, + bucket="test-bucket", + key="test-key", ) # Verify instruction file was fetched @@ -213,10 +213,10 @@ def test_decrypt_v3_from_instruction_file(self): ): pipeline.decrypt( mock_response, - bucket="test-bucket", - key="test-key", instruction_suffix=".instruction", enable_delayed_authentication=False, + bucket="test-bucket", + key="test-key", ) # Verify instruction file was fetched @@ -268,10 +268,10 @@ def test_decrypt_with_custom_instruction_file_suffix(self): with pytest.raises(Exception, match="Keyring called"): pipeline.decrypt( mock_response, - bucket="test-bucket", - key="test-key", instruction_suffix=".custom-suffix", enable_delayed_authentication=False, + bucket="test-bucket", + key="test-key", ) mock_s3_client.get_object.assert_called_once_with( @@ -309,4 +309,4 @@ def test_decrypt_v3_unsupported_wrap_alg(self): with pytest.raises( S3EncryptionClientError, match="AES/GCM is not a valid key wrapping algorithm" ): - pipeline.decrypt(mock_response, enable_delayed_authentication=False) + pipeline.decrypt(mock_response, ".instruction", enable_delayed_authentication=False) diff --git a/test/test_stream.py b/test/test_stream.py index c1956e70..2fc2f10d 100644 --- a/test/test_stream.py +++ b/test/test_stream.py @@ -98,3 +98,97 @@ def test_buffered_verifies_tag_before_releasing_any_plaintext(self): assert chunk == plaintext[:1] # _plaintext being set confirms full decrypt+verify already happened assert stream._plaintext is not None + + +def _encrypt_cbc(plaintext: bytes): + """Encrypt plaintext with AES-CBC + PKCS7 padding, return (ciphertext, key, iv, unpadder).""" + from cryptography.hazmat.primitives.padding import PKCS7 + + key = os.urandom(32) + iv = os.urandom(16) + padder = PKCS7(128).padder() + padded = padder.update(plaintext) + padder.finalize() + encryptor = Cipher(algorithms.AES(key), modes.CBC(iv)).encryptor() + ciphertext = encryptor.update(padded) + encryptor.finalize() + unpadder = PKCS7(128).unpadder() + return ciphertext, key, iv, unpadder + + +def _make_cbc_decryptor(key, iv): + return Cipher(algorithms.AES(key), modes.CBC(iv)).decryptor() + + +class TestBufferedCBCDecryption: + + def test_roundtrip(self): + plaintext = b"hello world, this is a CBC test!!" + ciphertext, key, iv, unpadder = _encrypt_cbc(plaintext) + stream = BufferedDecryptingStream( + _make_streaming_body(ciphertext), + _make_cbc_decryptor(key, iv), + tag_length=0, + unpadder=unpadder, + ) + assert stream.read() == plaintext + + def test_no_trailing_padding_bytes(self): + plaintext = b"short" + ciphertext, key, iv, unpadder = _encrypt_cbc(plaintext) + stream = BufferedDecryptingStream( + _make_streaming_body(ciphertext), + _make_cbc_decryptor(key, iv), + tag_length=0, + unpadder=unpadder, + ) + assert stream.read() == plaintext + + +class TestDelayedAuthCBCDecryption: + + def test_roundtrip(self): + plaintext = b"hello world, this is a CBC test!!" + ciphertext, key, iv, unpadder = _encrypt_cbc(plaintext) + stream = DelayedAuthDecryptingStream( + _make_streaming_body(ciphertext), + _make_cbc_decryptor(key, iv), + tag_length=0, + unpadder=unpadder, + ) + assert stream.read() == plaintext + + def test_chunked_read(self): + plaintext = b"A" * 256 + ciphertext, key, iv, unpadder = _encrypt_cbc(plaintext) + stream = DelayedAuthDecryptingStream( + _make_streaming_body(ciphertext), + _make_cbc_decryptor(key, iv), + tag_length=0, + unpadder=unpadder, + ) + result = b"" + while chunk := stream.read(64): + result += chunk + assert result == plaintext + + def test_finalize_called(self): + plaintext = b"finalize me" + ciphertext, key, iv, unpadder = _encrypt_cbc(plaintext) + stream = DelayedAuthDecryptingStream( + _make_streaming_body(ciphertext), + _make_cbc_decryptor(key, iv), + tag_length=0, + unpadder=unpadder, + ) + stream.read() + assert stream._finalized + + def test_no_trailing_padding_bytes(self): + plaintext = b"short" + ciphertext, key, iv, unpadder = _encrypt_cbc(plaintext) + stream = DelayedAuthDecryptingStream( + _make_streaming_body(ciphertext), + _make_cbc_decryptor(key, iv), + tag_length=0, + unpadder=unpadder, + ) + assert stream.read() == plaintext From 8d4dbdce31d8032fa1b3cb2dcd0ff250116b50c3 Mon Sep 17 00:00:00 2001 From: texastony <5892063+texastony@users.noreply.github.com> Date: Mon, 23 Mar 2026 14:47:38 -0700 Subject: [PATCH 16/31] refactor: split DelayedAuthDecryptingStream into CBC and GCM classes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Split DelayedAuthDecryptingStream into DelayedAuthCBCDecryptingStream and DelayedAuthGCMDecryptingStream. CBC and GCM are mutually exclusive paths — CBC uses an unpadder with no auth tag, GCM uses a rolling tag buffer with no padding — so the single-class design carried impossible field combinations and conditional branching in read(). All three stream classes (BufferedDecryptingStream and the two new delayed-auth classes) now extend botocore's StreamingBody with @define(slots=False), inheriting iter_chunks, iter_lines, __iter__, and __next__ for free. Updated _make_decrypting_stream dispatch in pipelines.py and test constructors in test_stream.py. --- src/s3_encryption/pipelines.py | 19 ++- src/s3_encryption/stream.py | 246 ++++++++++++++++++++++----------- test/test_stream.py | 17 +-- 3 files changed, 188 insertions(+), 94 deletions(-) diff --git a/src/s3_encryption/pipelines.py b/src/s3_encryption/pipelines.py index 7a0a61e3..7be8b507 100644 --- a/src/s3_encryption/pipelines.py +++ b/src/s3_encryption/pipelines.py @@ -27,7 +27,11 @@ EncryptionMaterials, ) from .metadata import ObjectMetadata -from .stream import BufferedDecryptingStream, DelayedAuthDecryptingStream +from .stream import ( + BufferedDecryptingStream, + DelayedAuthCBCDecryptingStream, + DelayedAuthGCMDecryptingStream, +) @define @@ -449,11 +453,16 @@ def decrypt( def _make_decrypting_stream( streaming_body, decryptor, tag_length, enable_delayed_authentication, unpadder=None ): - """Return a BufferedDecryptingStream or DelayedAuthDecryptingStream.""" + """Return the appropriate decrypting stream. + + When delayed auth is disabled, BufferedDecryptingStream buffers all + ciphertext and verifies before releasing any plaintext. + When delayed auth is enabled, the CBC or GCM specific stream is used. + """ if enable_delayed_authentication: - return DelayedAuthDecryptingStream( - streaming_body, decryptor, tag_length=tag_length, unpadder=unpadder - ) + if tag_length == 0: + return DelayedAuthCBCDecryptingStream(streaming_body, decryptor, unpadder=unpadder) + return DelayedAuthGCMDecryptingStream(streaming_body, decryptor, tag_length=tag_length) ##= specification/s3-encryption/client.md#enable-delayed-authentication ##= type=implementation ##% When disabled the S3EC MUST NOT release plaintext from a stream which has not been authenticated. diff --git a/src/s3_encryption/stream.py b/src/s3_encryption/stream.py index 45ca58a4..fdab5670 100644 --- a/src/s3_encryption/stream.py +++ b/src/s3_encryption/stream.py @@ -4,6 +4,9 @@ import io +from attrs import define, field +from botocore.response import StreamingBody + from .exceptions import S3EncryptionClientError ##= specification/s3-encryption/client.md#set-buffer-size @@ -27,28 +30,25 @@ def _unpad(plaintext, unpadder): return unpadder.update(plaintext) + unpadder.finalize() -class BufferedDecryptingStream: +# slots=False because StreamingBody extends IOBase which already has __weakref__. +@define(slots=False) +class BufferedDecryptingStream(StreamingBody): """A stream that buffers all ciphertext, decrypts, then releases plaintext. - Implements the same read interface as botocore's StreamingBody so it can be - used as a drop-in replacement for parsed["Body"]. + Extends botocore's StreamingBody so it can be used as a drop-in replacement + for parsed["Body"], inheriting iter_chunks, iter_lines, __iter__, etc. """ - def __init__(self, streaming_body, decryptor, tag_length, unpadder=None): - """Initialize the buffered decrypting stream. + _body: object = field() + _decryptor: object = field() + _tag_length: int = field() + _unpadder: object = field(default=None) + _plaintext: object = field(init=False, default=None) - Args: - streaming_body: The original StreamingBody containing ciphertext. - decryptor: A cipher decryptor object supporting update()/finalize() - (or finalize_with_tag() when tag_length > 0). - tag_length: Length of the auth tag appended to ciphertext (0 for CBC). - unpadder: Optional PKCS7 unpadder for CBC mode. - """ - self._body = streaming_body - self._decryptor = decryptor - self._tag_length = tag_length - self._unpadder = unpadder - self._plaintext = None + def __attrs_post_init__(self): # noqa: D105 + # Initialize StreamingBody with a placeholder; _raw_stream is replaced + # on first read after decryption. + super().__init__(io.BytesIO(), content_length=None) def _decrypt(self): """Read all ciphertext, decrypt and verify, cache plaintext.""" @@ -67,6 +67,14 @@ def _decrypt(self): except Exception as e: raise S3EncryptionClientError(f"Failed to decrypt object: {e}") from e self._plaintext = io.BytesIO(plaintext) + self._raw_stream = self._plaintext + + # Inherited iter_chunks, iter_lines, __iter__, and __next__ all delegate + # to self.read(), which calls _decrypt(). No override needed. + + def readable(self): # noqa: D102 + self._decrypt() + return self._raw_stream.readable() def read(self, amt=None): """Reads the entire ciphertext stream and then returns decrypted data. @@ -82,21 +90,17 @@ def read(self, amt=None): return self._plaintext.read() return self._plaintext.read(amt) - def iter_chunks(self, chunk_size=1024): - """Reads the entire ciphertext stream and then iterates over decrypted data in chunks. + def readinto(self, b): # noqa: D102 + self._decrypt() + return self._raw_stream.readinto(b) - Args: - chunk_size: Size of each chunk in bytes. + def tell(self): # noqa: D102 + self._decrypt() + return self._raw_stream.tell() - Yields: - bytes: Chunks of decrypted plaintext. - """ + def __enter__(self): # noqa: D105 self._decrypt() - while True: - chunk = self._plaintext.read(chunk_size) - if not chunk: - break - yield chunk + return self._raw_stream def close(self): """Close the underlying stream.""" @@ -107,101 +111,185 @@ def close(self): ##= specification/s3-encryption/client.md#enable-delayed-authentication ##= type=implementation ##% When enabled, the S3EC MAY release plaintext from a stream which has not been authenticated. -class DelayedAuthDecryptingStream: - """A stream that releases plaintext before full verification. +# slots=False because StreamingBody extends IOBase which already has __weakref__. +@define(slots=False) +class DelayedAuthCBCDecryptingStream(StreamingBody): + """A delayed-auth stream for AES-CBC decryption. - Plaintext is released incrementally via cipher.update(). For authenticated - ciphers (GCM), the auth tag is only verified when the stream is fully - consumed. For unauthenticated ciphers (CBC), this behaves identically - to streaming decryption with no tag holdback. + Extends botocore's StreamingBody so it can be used as a drop-in replacement + for parsed["Body"], inheriting iter_chunks, iter_lines, __iter__, etc. + + CBC has no auth tag, so plaintext is released incrementally via + cipher.update(). A 1-byte peek-ahead detects stream exhaustion so the + PKCS7 unpadder can be finalized. """ - def __init__(self, streaming_body, decryptor, tag_length, unpadder=None): - """Initialize the delayed-auth decrypting stream. + _body: object = field() + _decryptor: object = field() + _unpadder: object = field() + _peek_buffer: bytes = field(init=False, default=b"") + _finalized: bool = field(init=False, default=False) - Args: - streaming_body: The original StreamingBody containing ciphertext. - decryptor: A cipher decryptor object supporting update()/finalize() - (or finalize_with_tag() when tag_length > 0). - tag_length: Length of the auth tag appended to ciphertext (0 for CBC). - unpadder: Optional PKCS7 unpadder for CBC mode. - """ - self._body = streaming_body - self._decryptor = decryptor - self._tag_length = tag_length - self._unpadder = unpadder - self._tag_buffer = b"" - self._finalized = False + def __attrs_post_init__(self): # noqa: D105 + # Initialize StreamingBody; _raw_stream is unused since plaintext is + # produced incrementally via read(). + super().__init__(io.BytesIO(), content_length=None) + + # Inherited iter_chunks, iter_lines, __iter__, and __next__ all delegate + # to self.read(). No override needed. + + def readable(self): # noqa: D102 + return not self._finalized def read(self, amt=None): - """Read and decrypt data, releasing plaintext before authentication. + """Read and decrypt CBC ciphertext, releasing plaintext incrementally.""" + # Stream already fully consumed and finalized; nothing left to return. + if self._finalized: + return b"" - When tag_length > 0, the last tag_length bytes of ciphertext are the - auth tag. We hold back a rolling buffer so the tag is never passed - to update(). - """ + # Read the next chunk of raw ciphertext from the underlying stream. + raw = self._body.read(amt) + + # Prepend any previously held-back peek byte to the new data. + data = self._peek_buffer + raw + self._peek_buffer = b"" + + # No data at all; the stream is empty. + if not data: + return self._finalize() + + # Decrypt incrementally; plaintext is released immediately. + plaintext = self._decryptor.update(data) + plaintext = self._unpadder.update(plaintext) + + # Peek 1 byte ahead to detect stream exhaustion. If the stream + # is exhausted we must finalize now to flush the unpadder. + peek = self._body.read(1) + if peek: + # Stream continues; stash the peeked byte for the next read. + self._peek_buffer = peek + else: + # Stream exhausted; finalize to flush any remaining padding. + plaintext += self._finalize() + + return plaintext + + def _finalize(self): + """Finalize CBC decryption and flush the unpadder.""" + self._finalized = True + try: + plaintext = self._decryptor.finalize() + plaintext = self._unpadder.update(plaintext) + self._unpadder.finalize() + return plaintext + except Exception as e: + raise S3EncryptionClientError(f"Decryption finalization failed: {e}") from e + + def __enter__(self): # noqa: D105 + return self + + def close(self): + """Close the underlying stream.""" + if hasattr(self._body, "close"): + self._body.close() + + +##= specification/s3-encryption/client.md#enable-delayed-authentication +##= type=implementation +##% When enabled, the S3EC MAY release plaintext from a stream which has not been authenticated. +# slots=False because StreamingBody extends IOBase which already has __weakref__. +@define(slots=False) +class DelayedAuthGCMDecryptingStream(StreamingBody): + """A delayed-auth stream for AES-GCM decryption. + + Extends botocore's StreamingBody so it can be used as a drop-in replacement + for parsed["Body"], inheriting iter_chunks, iter_lines, __iter__, etc. + + Plaintext is released incrementally via cipher.update(). The last + tag_length bytes of ciphertext are the GCM auth tag, held back in a + rolling buffer. The tag is only verified via finalize_with_tag() when + the stream is fully consumed. + """ + + _body: object = field() + _decryptor: object = field() + _tag_length: int = field() + _tag_buffer: bytes = field(init=False, default=b"") + _finalized: bool = field(init=False, default=False) + + def __attrs_post_init__(self): # noqa: D105 + # Initialize StreamingBody; _raw_stream is unused since plaintext is + # produced incrementally via read(). + super().__init__(io.BytesIO(), content_length=None) + + # Inherited iter_chunks, iter_lines, __iter__, and __next__ all delegate + # to self.read(). No override needed. + + def readable(self): # noqa: D102 + return not self._finalized + + def read(self, amt=None): + """Read and decrypt GCM ciphertext, holding back the trailing auth tag.""" + # Stream already fully consumed and finalized; nothing left to return. if self._finalized: return b"" + # Read the next chunk of raw ciphertext from the underlying stream. raw = self._body.read(amt) + + # No new data and no held-back bytes; the stream is empty. if not raw and not self._tag_buffer: return b"" - if self._tag_length == 0: - # No tag to hold back (e.g. CBC) - data = self._tag_buffer + raw - self._tag_buffer = b"" - if not data: - return self._finalize(tag=b"") - plaintext = self._decryptor.update(data) - if self._unpadder: - plaintext = self._unpadder.update(plaintext) - peek = self._body.read(1) - if peek: - self._tag_buffer = peek - else: - plaintext += self._finalize(tag=b"") - return plaintext - + # Combine any previously held-back bytes with the new data. data = self._tag_buffer + raw + + # Not enough data to separate ciphertext from tag yet. if len(data) <= self._tag_length: if raw: + # More data may arrive; buffer everything and wait. self._tag_buffer = data return b"" + # No more data coming; everything buffered is the tag. return self._finalize(tag=data) + # Split: the last tag_length bytes are the candidate tag; + # everything before is ciphertext safe to decrypt now. self._tag_buffer = data[-self._tag_length :] ciphertext = data[: -self._tag_length] plaintext = self._decryptor.update(ciphertext) - # Check if underlying stream is exhausted + # Peek 1 byte ahead to detect whether the underlying stream is + # exhausted. This determines if the current tag_buffer is truly + # the final GCM tag or just more ciphertext. peek = self._body.read(1) if peek: + # Stream continues; the peeked byte may shift what we thought + # was the tag back into ciphertext territory. self._tag_buffer = self._tag_buffer + peek if len(self._tag_buffer) > self._tag_length: + # Extra bytes beyond tag_length are ciphertext; decrypt them. extra_ct = self._tag_buffer[: -self._tag_length] self._tag_buffer = self._tag_buffer[-self._tag_length :] plaintext += self._decryptor.update(extra_ct) else: + # Stream exhausted; tag_buffer holds the final GCM auth tag. plaintext += self._finalize(tag=self._tag_buffer) return plaintext def _finalize(self, tag): - """Finalize decryption, verifying the auth tag if present.""" + """Finalize GCM decryption, verifying the auth tag.""" self._finalized = True self._tag_buffer = b"" try: - if tag: - plaintext = self._decryptor.finalize_with_tag(tag) - else: - plaintext = self._decryptor.finalize() - if self._unpadder: - plaintext = self._unpadder.update(plaintext) + self._unpadder.finalize() + plaintext = self._decryptor.finalize_with_tag(tag) return plaintext except Exception as e: raise S3EncryptionClientError(f"Decryption finalization failed: {e}") from e + def __enter__(self): # noqa: D105 + return self + def close(self): """Close the underlying stream.""" if hasattr(self._body, "close"): diff --git a/test/test_stream.py b/test/test_stream.py index 2fc2f10d..b824c80c 100644 --- a/test/test_stream.py +++ b/test/test_stream.py @@ -12,7 +12,8 @@ from s3_encryption.materials import AlgorithmSuite from s3_encryption.stream import ( BufferedDecryptingStream, - DelayedAuthDecryptingStream, + DelayedAuthCBCDecryptingStream, + DelayedAuthGCMDecryptingStream, ) @@ -51,7 +52,7 @@ def test_delayed_auth_releases_plaintext_before_tag_verification(self): body = _make_streaming_body(ciphertext_with_tag) decryptor = _make_gcm_decryptor(key, nonce) - stream = DelayedAuthDecryptingStream( + stream = DelayedAuthGCMDecryptingStream( body, decryptor, tag_length=AlgorithmSuite.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY.cipher_tag_length_bytes, @@ -148,10 +149,9 @@ class TestDelayedAuthCBCDecryption: def test_roundtrip(self): plaintext = b"hello world, this is a CBC test!!" ciphertext, key, iv, unpadder = _encrypt_cbc(plaintext) - stream = DelayedAuthDecryptingStream( + stream = DelayedAuthCBCDecryptingStream( _make_streaming_body(ciphertext), _make_cbc_decryptor(key, iv), - tag_length=0, unpadder=unpadder, ) assert stream.read() == plaintext @@ -159,10 +159,9 @@ def test_roundtrip(self): def test_chunked_read(self): plaintext = b"A" * 256 ciphertext, key, iv, unpadder = _encrypt_cbc(plaintext) - stream = DelayedAuthDecryptingStream( + stream = DelayedAuthCBCDecryptingStream( _make_streaming_body(ciphertext), _make_cbc_decryptor(key, iv), - tag_length=0, unpadder=unpadder, ) result = b"" @@ -173,10 +172,9 @@ def test_chunked_read(self): def test_finalize_called(self): plaintext = b"finalize me" ciphertext, key, iv, unpadder = _encrypt_cbc(plaintext) - stream = DelayedAuthDecryptingStream( + stream = DelayedAuthCBCDecryptingStream( _make_streaming_body(ciphertext), _make_cbc_decryptor(key, iv), - tag_length=0, unpadder=unpadder, ) stream.read() @@ -185,10 +183,9 @@ def test_finalize_called(self): def test_no_trailing_padding_bytes(self): plaintext = b"short" ciphertext, key, iv, unpadder = _encrypt_cbc(plaintext) - stream = DelayedAuthDecryptingStream( + stream = DelayedAuthCBCDecryptingStream( _make_streaming_body(ciphertext), _make_cbc_decryptor(key, iv), - tag_length=0, unpadder=unpadder, ) assert stream.read() == plaintext From 4991e1bee29ebe8077b5faadeabae62178a7c72b Mon Sep 17 00:00:00 2001 From: texastony <5892063+texastony@users.noreply.github.com> Date: Tue, 24 Mar 2026 09:26:43 -0700 Subject: [PATCH 17/31] refactor: rename BufferedDecryptingStream to BufferedDecryptingGCMStream and simplify dispatch --- src/s3_encryption/pipelines.py | 25 ++++++++----------------- src/s3_encryption/stream.py | 15 +++------------ 2 files changed, 11 insertions(+), 29 deletions(-) diff --git a/src/s3_encryption/pipelines.py b/src/s3_encryption/pipelines.py index 7be8b507..97a04a98 100644 --- a/src/s3_encryption/pipelines.py +++ b/src/s3_encryption/pipelines.py @@ -28,7 +28,7 @@ ) from .metadata import ObjectMetadata from .stream import ( - BufferedDecryptingStream, + BufferedDecryptingGCMStream, DelayedAuthCBCDecryptingStream, DelayedAuthGCMDecryptingStream, ) @@ -416,14 +416,7 @@ def decrypt( decryptor = cipher.decryptor() # Remove PKCS7 padding (compatible with PKCS5Padding for 16-byte block ciphers) unpadder = PKCS7(dec_materials.algorithm_suite.cipher_block_size_bits).unpadder() - return self._make_decrypting_stream( - streaming_body, - decryptor, - tag_length=dec_materials.algorithm_suite.cipher_tag_length_bytes, - # AES-CBC does not have an Auth tag, and thus should always be streamed. - enable_delayed_authentication=True, - unpadder=unpadder, - ) + return DelayedAuthCBCDecryptingStream(streaming_body, decryptor, unpadder=unpadder) case AlgorithmSuite.ALG_AES_256_GCM_IV12_TAG16_NO_KDF: ##= specification/s3-encryption/encryption.md#alg-aes-256-gcm-iv12-tag16-no-kdf ##= type=implementation @@ -433,7 +426,7 @@ def decrypt( algorithms.AES(dec_materials.plaintext_data_key), modes.GCM(dec_materials.iv) ) decryptor = cipher.decryptor() - return self._make_decrypting_stream( + return self._make_decrypting_gcm_stream( streaming_body, decryptor, tag_length=dec_materials.algorithm_suite.cipher_tag_length_bytes, @@ -450,8 +443,8 @@ def decrypt( raise S3EncryptionClientError("Unknown algorithm suite!") @staticmethod - def _make_decrypting_stream( - streaming_body, decryptor, tag_length, enable_delayed_authentication, unpadder=None + def _make_decrypting_gcm_stream( + streaming_body, decryptor, tag_length, enable_delayed_authentication ): """Return the appropriate decrypting stream. @@ -460,14 +453,12 @@ def _make_decrypting_stream( When delayed auth is enabled, the CBC or GCM specific stream is used. """ if enable_delayed_authentication: - if tag_length == 0: - return DelayedAuthCBCDecryptingStream(streaming_body, decryptor, unpadder=unpadder) return DelayedAuthGCMDecryptingStream(streaming_body, decryptor, tag_length=tag_length) ##= specification/s3-encryption/client.md#enable-delayed-authentication ##= type=implementation ##% When disabled the S3EC MUST NOT release plaintext from a stream which has not been authenticated. - return BufferedDecryptingStream( - streaming_body, decryptor, tag_length=tag_length, unpadder=unpadder + return BufferedDecryptingGCMStream( + streaming_body, decryptor, tag_length=tag_length ) def _decrypt_kc_gcm_streaming( @@ -518,7 +509,7 @@ def _decrypt_kc_gcm_streaming( ) decryptor = cipher.decryptor() decryptor.authenticate_additional_data(dec_materials.algorithm_suite.suite_id_bytes) - return self._make_decrypting_stream( + return self._make_decrypting_gcm_stream( streaming_body, decryptor, tag_length=dec_materials.algorithm_suite.cipher_tag_length_bytes, diff --git a/src/s3_encryption/stream.py b/src/s3_encryption/stream.py index fdab5670..47ed62a8 100644 --- a/src/s3_encryption/stream.py +++ b/src/s3_encryption/stream.py @@ -23,16 +23,9 @@ ##% If Delayed Authentication mode is disabled, and no buffer size is provided, the S3EC MUST set the buffer size to a reasonable default. -def _unpad(plaintext, unpadder): - """Apply unpadder if provided, otherwise return plaintext as-is.""" - if unpadder is None: - return plaintext - return unpadder.update(plaintext) + unpadder.finalize() - - # slots=False because StreamingBody extends IOBase which already has __weakref__. @define(slots=False) -class BufferedDecryptingStream(StreamingBody): +class BufferedDecryptingGCMStream(StreamingBody): """A stream that buffers all ciphertext, decrypts, then releases plaintext. Extends botocore's StreamingBody so it can be used as a drop-in replacement @@ -42,7 +35,6 @@ class BufferedDecryptingStream(StreamingBody): _body: object = field() _decryptor: object = field() _tag_length: int = field() - _unpadder: object = field(default=None) _plaintext: object = field(init=False, default=None) def __attrs_post_init__(self): # noqa: D105 @@ -63,7 +55,6 @@ def _decrypt(self): ) else: plaintext = self._decryptor.update(data) + self._decryptor.finalize() - plaintext = _unpad(plaintext, self._unpadder) except Exception as e: raise S3EncryptionClientError(f"Failed to decrypt object: {e}") from e self._plaintext = io.BytesIO(plaintext) @@ -182,7 +173,7 @@ def _finalize(self): plaintext = self._unpadder.update(plaintext) + self._unpadder.finalize() return plaintext except Exception as e: - raise S3EncryptionClientError(f"Decryption finalization failed: {e}") from e + raise S3EncryptionClientError(f"Failed to decrypt CBC content: {e}") from e def __enter__(self): # noqa: D105 return self @@ -285,7 +276,7 @@ def _finalize(self, tag): plaintext = self._decryptor.finalize_with_tag(tag) return plaintext except Exception as e: - raise S3EncryptionClientError(f"Decryption finalization failed: {e}") from e + raise S3EncryptionClientError(f"Failed to decrypt GCM content: {e}") from e def __enter__(self): # noqa: D105 return self From 07dbee4f77ff9a191dda5ac8f190508f697013ef Mon Sep 17 00:00:00 2001 From: texastony <5892063+texastony@users.noreply.github.com> Date: Tue, 24 Mar 2026 09:28:39 -0700 Subject: [PATCH 18/31] test: add unit and integration tests for CBC and GCM decrypting streams --- .../test_i_s3_encryption_streaming.py | 195 +++++++++ test/test_decryption.py | 2 +- test/test_stream.py | 374 ++++++++++++++++-- 3 files changed, 544 insertions(+), 27 deletions(-) create mode 100644 test/integration/test_i_s3_encryption_streaming.py diff --git a/test/integration/test_i_s3_encryption_streaming.py b/test/integration/test_i_s3_encryption_streaming.py new file mode 100644 index 00000000..0c26f90d --- /dev/null +++ b/test/integration/test_i_s3_encryption_streaming.py @@ -0,0 +1,195 @@ +# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Integration tests for streaming decryption modes (buffered vs delayed-auth). + +These tests verify that BufferedDecryptingGCMStream and DelayedAuthGCMDecryptingStream +produce correct plaintext for real S3 round-trips across algorithm suites. +""" + +import os +from datetime import datetime + +import boto3 +import pytest + +from s3_encryption import S3EncryptionClient, S3EncryptionClientConfig +from s3_encryption.materials.kms_keyring import KmsKeyring +from s3_encryption.materials.materials import AlgorithmSuite, CommitmentPolicy +from s3_encryption.stream import ( + BufferedDecryptingGCMStream, + DelayedAuthGCMDecryptingStream, +) + +bucket = os.environ.get("CI_S3_BUCKET", "s3ec-python-github-test-bucket") +region = os.environ.get("CI_AWS_REGION", "us-west-2") +kms_key_id = os.environ.get( + "CI_KMS_KEY_ALIAS", "arn:aws:kms:us-west-2:370957321024:alias/S3EC-Python-Github-KMS-Key" +) + +GCM_CONFIGS = [ + pytest.param( + AlgorithmSuite.ALG_AES_256_GCM_IV12_TAG16_NO_KDF, + CommitmentPolicy.FORBID_ENCRYPT_ALLOW_DECRYPT, + id="AES_GCM", + ), + pytest.param( + AlgorithmSuite.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY, + CommitmentPolicy.REQUIRE_ENCRYPT_REQUIRE_DECRYPT, + id="KC_GCM", + ), +] + + +def _make_client(algorithm_suite, commitment_policy, delayed_auth): + kms_client = boto3.client("kms", region_name=region) + keyring = KmsKeyring(kms_client, kms_key_id) + wrapped_client = boto3.client("s3") + config = S3EncryptionClientConfig( + keyring, + encryption_algorithm=algorithm_suite, + commitment_policy=commitment_policy, + enable_delayed_authentication=delayed_auth, + ) + return S3EncryptionClient(wrapped_client, config) + + +def _unique_key(prefix): + return prefix + datetime.now().strftime("%Y-%m-%d-%H:%M:%S-%f") + + +# --------------------------------------------------------------------------- +# Buffered mode: verifies tag before releasing plaintext +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("algorithm_suite,commitment_policy", GCM_CONFIGS) +def test_buffered_roundtrip(algorithm_suite, commitment_policy): + """Buffered mode decrypts correctly for a simple round-trip.""" + key = _unique_key("buffered-rt-") + data = b"buffered mode round trip test data" + + s3ec = _make_client(algorithm_suite, commitment_policy, delayed_auth=False) + s3ec.put_object(Bucket=bucket, Key=key, Body=data) + response = s3ec.get_object(Bucket=bucket, Key=key) + + body = response["Body"] + assert isinstance(body, BufferedDecryptingGCMStream) + assert body.read() == data + + +@pytest.mark.parametrize("algorithm_suite,commitment_policy", GCM_CONFIGS) +def test_buffered_partial_reads(algorithm_suite, commitment_policy): + """Buffered mode supports partial read(amt) calls.""" + key = _unique_key("buffered-partial-") + data = os.urandom(1024) + + s3ec = _make_client(algorithm_suite, commitment_policy, delayed_auth=False) + s3ec.put_object(Bucket=bucket, Key=key, Body=data) + response = s3ec.get_object(Bucket=bucket, Key=key) + + result = b"" + while chunk := response["Body"].read(100): + result += chunk + assert result == data + + +# --------------------------------------------------------------------------- +# Delayed-auth mode: releases plaintext before tag verification +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("algorithm_suite,commitment_policy", GCM_CONFIGS) +def test_delayed_auth_roundtrip(algorithm_suite, commitment_policy): + """Delayed-auth mode decrypts correctly for a simple round-trip.""" + key = _unique_key("delayed-rt-") + data = b"delayed auth round trip test data" + + s3ec = _make_client(algorithm_suite, commitment_policy, delayed_auth=True) + s3ec.put_object(Bucket=bucket, Key=key, Body=data) + response = s3ec.get_object(Bucket=bucket, Key=key) + + body = response["Body"] + assert isinstance(body, DelayedAuthGCMDecryptingStream) + assert body.read() == data + + +@pytest.mark.parametrize("algorithm_suite,commitment_policy", GCM_CONFIGS) +def test_delayed_auth_chunked_reads(algorithm_suite, commitment_policy): + """Delayed-auth mode supports chunked streaming reads.""" + key = _unique_key("delayed-chunked-") + data = os.urandom(4096) + + s3ec = _make_client(algorithm_suite, commitment_policy, delayed_auth=True) + s3ec.put_object(Bucket=bucket, Key=key, Body=data) + response = s3ec.get_object(Bucket=bucket, Key=key) + + result = b"" + while chunk := response["Body"].read(256): + result += chunk + assert result == data + + +# --------------------------------------------------------------------------- +# Both modes produce identical plaintext +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("algorithm_suite,commitment_policy", GCM_CONFIGS) +def test_buffered_and_delayed_produce_same_plaintext(algorithm_suite, commitment_policy): + """Both streaming modes must produce identical plaintext for the same object.""" + key = _unique_key("same-plaintext-") + data = os.urandom(2048) + + # Encrypt once + writer = _make_client(algorithm_suite, commitment_policy, delayed_auth=False) + writer.put_object(Bucket=bucket, Key=key, Body=data) + + # Decrypt with buffered + buffered = _make_client(algorithm_suite, commitment_policy, delayed_auth=False) + resp_buf = buffered.get_object(Bucket=bucket, Key=key) + plaintext_buf = resp_buf["Body"].read() + + # Decrypt with delayed-auth + delayed = _make_client(algorithm_suite, commitment_policy, delayed_auth=True) + resp_del = delayed.get_object(Bucket=bucket, Key=key) + plaintext_del = resp_del["Body"].read() + + assert plaintext_buf == plaintext_del == data + + +# --------------------------------------------------------------------------- +# Empty body +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("delayed_auth", [False, True], ids=["buffered", "delayed-auth"]) +@pytest.mark.parametrize("algorithm_suite,commitment_policy", GCM_CONFIGS) +def test_empty_body_roundtrip(algorithm_suite, commitment_policy, delayed_auth): + """Both modes handle empty plaintext correctly.""" + key = _unique_key("empty-stream-") + + s3ec = _make_client(algorithm_suite, commitment_policy, delayed_auth=delayed_auth) + s3ec.put_object(Bucket=bucket, Key=key, Body=b"") + response = s3ec.get_object(Bucket=bucket, Key=key) + assert response["Body"].read() == b"" + + +# --------------------------------------------------------------------------- +# Large object streaming +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("algorithm_suite,commitment_policy", GCM_CONFIGS) +def test_delayed_auth_large_object(algorithm_suite, commitment_policy): + """Delayed-auth streams a 1 MB object correctly via chunked reads.""" + key = _unique_key("delayed-large-") + data = os.urandom(1024 * 1024) # 1 MB + + s3ec = _make_client(algorithm_suite, commitment_policy, delayed_auth=True) + s3ec.put_object(Bucket=bucket, Key=key, Body=data) + response = s3ec.get_object(Bucket=bucket, Key=key) + + result = b"" + while chunk := response["Body"].read(65536): + result += chunk + assert result == data diff --git a/test/test_decryption.py b/test/test_decryption.py index 3c4e5ae1..4f8941c0 100644 --- a/test/test_decryption.py +++ b/test/test_decryption.py @@ -196,7 +196,7 @@ def test_cbc_decryption_fails_with_wrong_key(self): keyring_return=dec_mats, ) - with pytest.raises(S3EncryptionClientError, match="Decryption finalization failed"): + with pytest.raises(S3EncryptionClientError, match="Failed to decrypt CBC content"): pipeline.decrypt( _response(metadata, ciphertext), ".instruction", enable_delayed_authentication=False ).read() diff --git a/test/test_stream.py b/test/test_stream.py index b824c80c..3775ed34 100644 --- a/test/test_stream.py +++ b/test/test_stream.py @@ -6,12 +6,14 @@ from io import BytesIO from unittest.mock import Mock +import pytest from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.primitives.ciphers.aead import AESGCM +from s3_encryption.exceptions import S3EncryptionClientError from s3_encryption.materials import AlgorithmSuite from s3_encryption.stream import ( - BufferedDecryptingStream, + BufferedDecryptingGCMStream, DelayedAuthCBCDecryptingStream, DelayedAuthGCMDecryptingStream, ) @@ -86,7 +88,7 @@ def test_buffered_verifies_tag_before_releasing_any_plaintext(self): body = _make_streaming_body(ciphertext_with_tag) decryptor = _make_gcm_decryptor(key, nonce) - stream = BufferedDecryptingStream( + stream = BufferedDecryptingGCMStream( body, decryptor, tag_length=AlgorithmSuite.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY.cipher_tag_length_bytes, @@ -119,35 +121,44 @@ def _make_cbc_decryptor(key, iv): return Cipher(algorithms.AES(key), modes.CBC(iv)).decryptor() -class TestBufferedCBCDecryption: +class TestDelayedAuthCBCDecryption: def test_roundtrip(self): plaintext = b"hello world, this is a CBC test!!" ciphertext, key, iv, unpadder = _encrypt_cbc(plaintext) - stream = BufferedDecryptingStream( + stream = DelayedAuthCBCDecryptingStream( _make_streaming_body(ciphertext), _make_cbc_decryptor(key, iv), - tag_length=0, - unpadder=unpadder, +unpadder=unpadder, ) assert stream.read() == plaintext - def test_no_trailing_padding_bytes(self): - plaintext = b"short" + def test_chunked_read(self): + plaintext = b"A" * 256 ciphertext, key, iv, unpadder = _encrypt_cbc(plaintext) - stream = BufferedDecryptingStream( + stream = DelayedAuthCBCDecryptingStream( _make_streaming_body(ciphertext), _make_cbc_decryptor(key, iv), - tag_length=0, unpadder=unpadder, ) - assert stream.read() == plaintext - + result = b"" + while chunk := stream.read(64): + result += chunk + assert result == plaintext -class TestDelayedAuthCBCDecryption: + def test_finalize_called(self): + plaintext = b"finalize me" + ciphertext, key, iv, unpadder = _encrypt_cbc(plaintext) + stream = DelayedAuthCBCDecryptingStream( + _make_streaming_body(ciphertext), + _make_cbc_decryptor(key, iv), + unpadder=unpadder, + ) + stream.read() + assert stream._finalized - def test_roundtrip(self): - plaintext = b"hello world, this is a CBC test!!" + def test_no_trailing_padding_bytes(self): + plaintext = b"short" ciphertext, key, iv, unpadder = _encrypt_cbc(plaintext) stream = DelayedAuthCBCDecryptingStream( _make_streaming_body(ciphertext), @@ -156,36 +167,347 @@ def test_roundtrip(self): ) assert stream.read() == plaintext - def test_chunked_read(self): - plaintext = b"A" * 256 + def test_read_after_finalized_returns_empty(self): + plaintext = b"done" ciphertext, key, iv, unpadder = _encrypt_cbc(plaintext) stream = DelayedAuthCBCDecryptingStream( _make_streaming_body(ciphertext), _make_cbc_decryptor(key, iv), unpadder=unpadder, ) - result = b"" - while chunk := stream.read(64): - result += chunk - assert result == plaintext + stream.read() + assert stream.read() == b"" - def test_finalize_called(self): - plaintext = b"finalize me" + def test_readable_false_after_finalized(self): + plaintext = b"readable" ciphertext, key, iv, unpadder = _encrypt_cbc(plaintext) stream = DelayedAuthCBCDecryptingStream( _make_streaming_body(ciphertext), _make_cbc_decryptor(key, iv), unpadder=unpadder, ) + assert stream.readable() stream.read() - assert stream._finalized + assert not stream.readable() - def test_no_trailing_padding_bytes(self): - plaintext = b"short" + def test_close_delegates_to_body(self): + plaintext = b"close me" + ciphertext, key, iv, unpadder = _encrypt_cbc(plaintext) + body = _make_streaming_body(ciphertext) + stream = DelayedAuthCBCDecryptingStream( + body, _make_cbc_decryptor(key, iv), unpadder=unpadder + ) + stream.close() + body.close.assert_called_once() + + def test_enter_returns_self(self): + plaintext = b"ctx" ciphertext, key, iv, unpadder = _encrypt_cbc(plaintext) stream = DelayedAuthCBCDecryptingStream( _make_streaming_body(ciphertext), _make_cbc_decryptor(key, iv), unpadder=unpadder, ) + assert stream.__enter__() is stream + + def test_wrong_key_raises_error(self): + from cryptography.hazmat.primitives.padding import PKCS7 + + plaintext = b"wrong key test!!" + ciphertext, _key, iv, _ = _encrypt_cbc(plaintext) + wrong_key = os.urandom(32) + stream = DelayedAuthCBCDecryptingStream( + _make_streaming_body(ciphertext), + _make_cbc_decryptor(wrong_key, iv), + unpadder=PKCS7(128).unpadder(), + ) + with pytest.raises(S3EncryptionClientError, match="Failed to decrypt CBC content"): + stream.read() + + def test_empty_ciphertext(self): + from cryptography.hazmat.primitives.padding import PKCS7 + + key = os.urandom(32) + iv = os.urandom(16) + stream = DelayedAuthCBCDecryptingStream( + _make_streaming_body(b""), + _make_cbc_decryptor(key, iv), + unpadder=PKCS7(128).unpadder(), + ) + # Empty stream finalize will fail because CBC expects at least one block + with pytest.raises(S3EncryptionClientError, match="Failed to decrypt CBC content"): + stream.read() + + +class TestBufferedDecryptingGCMStream: + + def test_full_read(self): + plaintext = os.urandom(1024) + ct, key, nonce = _encrypt_gcm(plaintext) + stream = BufferedDecryptingGCMStream( + _make_streaming_body(ct), _make_gcm_decryptor(key, nonce), tag_length=16 + ) + assert stream.read() == plaintext + + def test_partial_reads(self): + plaintext = os.urandom(512) + ct, key, nonce = _encrypt_gcm(plaintext) + stream = BufferedDecryptingGCMStream( + _make_streaming_body(ct), _make_gcm_decryptor(key, nonce), tag_length=16 + ) + result = b"" + while chunk := stream.read(100): + result += chunk + assert result == plaintext + + def test_read_triggers_full_decrypt(self): + plaintext = os.urandom(256) + ct, key, nonce = _encrypt_gcm(plaintext) + body = _make_streaming_body(ct) + stream = BufferedDecryptingGCMStream(body, _make_gcm_decryptor(key, nonce), tag_length=16) + assert stream._plaintext is None + stream.read(1) + assert stream._plaintext is not None + # Entire ciphertext consumed + assert body._stream.tell() == len(ct) + + def test_tell(self): + plaintext = os.urandom(200) + ct, key, nonce = _encrypt_gcm(plaintext) + stream = BufferedDecryptingGCMStream( + _make_streaming_body(ct), _make_gcm_decryptor(key, nonce), tag_length=16 + ) + stream.read(50) + assert stream.tell() == 50 + + def test_readable(self): + plaintext = b"readable test" + ct, key, nonce = _encrypt_gcm(plaintext) + stream = BufferedDecryptingGCMStream( + _make_streaming_body(ct), _make_gcm_decryptor(key, nonce), tag_length=16 + ) + assert stream.readable() + + def test_readinto(self): + """Asserts that readinto is implemented by botocore's StreamingBody""" + plaintext = os.urandom(64) + ct, key, nonce = _encrypt_gcm(plaintext) + stream = BufferedDecryptingGCMStream( + _make_streaming_body(ct), _make_gcm_decryptor(key, nonce), tag_length=16 + ) + buf = bytearray(64) + n = stream.readinto(buf) + assert n == 64 + assert bytes(buf) == plaintext + + def test_enter_returns_raw_stream(self): + plaintext = b"enter" + ct, key, nonce = _encrypt_gcm(plaintext) + stream = BufferedDecryptingGCMStream( + _make_streaming_body(ct), _make_gcm_decryptor(key, nonce), tag_length=16 + ) + inner = stream.__enter__() + assert inner.read() == plaintext + + def test_close_delegates(self): + """Asserts that close is implemented by botocore's StreamingBody""" + plaintext = b"close" + ct, key, nonce = _encrypt_gcm(plaintext) + body = _make_streaming_body(ct) + stream = BufferedDecryptingGCMStream(body, _make_gcm_decryptor(key, nonce), tag_length=16) + stream.close() + body.close.assert_called_once() + + def test_close_without_close_attr(self): + """Asserts that close is implemented by botocore's StreamingBody""" + plaintext = b"no close" + ct, key, nonce = _encrypt_gcm(plaintext) + body = Mock() + del body.close + body.read = BytesIO(ct).read + stream = BufferedDecryptingGCMStream(body, _make_gcm_decryptor(key, nonce), tag_length=16) + stream.close() # should not raise + + def test_wrong_key_raises_error(self): + plaintext = b"wrong key" + ct, _key, nonce = _encrypt_gcm(plaintext) + wrong_key = os.urandom(32) + stream = BufferedDecryptingGCMStream( + _make_streaming_body(ct), _make_gcm_decryptor(wrong_key, nonce), tag_length=16 + ) + with pytest.raises(S3EncryptionClientError, match="Failed to decrypt object"): + stream.read() + + def test_tampered_ciphertext_raises_error(self): + plaintext = b"tamper test" + ct, key, nonce = _encrypt_gcm(plaintext) + tampered = bytearray(ct) + tampered[0] ^= 0xFF + stream = BufferedDecryptingGCMStream( + _make_streaming_body(bytes(tampered)), _make_gcm_decryptor(key, nonce), tag_length=16 + ) + with pytest.raises(S3EncryptionClientError, match="Failed to decrypt object"): + stream.read() + + def test_idempotent_decrypt(self): + plaintext = os.urandom(128) + ct, key, nonce = _encrypt_gcm(plaintext) + stream = BufferedDecryptingGCMStream( + _make_streaming_body(ct), _make_gcm_decryptor(key, nonce), tag_length=16 + ) + first = stream.read(63) + second = stream.read(65) + assert first + second == plaintext + + +class TestDelayedAuthGCMDecryption: + + def test_full_read(self): + plaintext = os.urandom(1024) + ct, key, nonce = _encrypt_gcm(plaintext) + stream = DelayedAuthGCMDecryptingStream( + _make_streaming_body(ct), _make_gcm_decryptor(key, nonce), tag_length=16 + ) + assert stream.read() == plaintext + + def test_chunked_read(self): + plaintext = os.urandom(512) + ct, key, nonce = _encrypt_gcm(plaintext) + stream = DelayedAuthGCMDecryptingStream( + _make_streaming_body(ct), _make_gcm_decryptor(key, nonce), tag_length=16 + ) + result = b"" + while chunk := stream.read(64): + result += chunk + assert result == plaintext + + def test_read_after_finalized_returns_empty(self): + plaintext = os.urandom(128) + ct, key, nonce = _encrypt_gcm(plaintext) + stream = DelayedAuthGCMDecryptingStream( + _make_streaming_body(ct), _make_gcm_decryptor(key, nonce), tag_length=16 + ) + actual = stream.read() + assert stream._finalized + assert stream.read() == b"" + assert actual == plaintext + + def test_readable_false_after_finalized(self): + plaintext = b"readable" + ct, key, nonce = _encrypt_gcm(plaintext) + stream = DelayedAuthGCMDecryptingStream( + _make_streaming_body(ct), _make_gcm_decryptor(key, nonce), tag_length=16 + ) + assert stream.readable() + stream.read() + assert not stream.readable() + + def test_close_delegates(self): + plaintext = b"close" + ct, key, nonce = _encrypt_gcm(plaintext) + body = _make_streaming_body(ct) + stream = DelayedAuthGCMDecryptingStream( + body, _make_gcm_decryptor(key, nonce), tag_length=16 + ) + stream.close() + body.close.assert_called_once() + + def test_enter_returns_self(self): + plaintext = b"ctx" + ct, key, nonce = _encrypt_gcm(plaintext) + stream = DelayedAuthGCMDecryptingStream( + _make_streaming_body(ct), _make_gcm_decryptor(key, nonce), tag_length=16 + ) + assert stream.__enter__() is stream + + def test_wrong_key_raises_error(self): + plaintext = b"wrong key" + ct, _key, nonce = _encrypt_gcm(plaintext) + wrong_key = os.urandom(32) + stream = DelayedAuthGCMDecryptingStream( + _make_streaming_body(ct), _make_gcm_decryptor(wrong_key, nonce), tag_length=16 + ) + with pytest.raises(S3EncryptionClientError, match="Failed to decrypt GCM content"): + stream.read() + + def test_tampered_tag_raises_error(self): + plaintext = b"tamper tag" + ct, key, nonce = _encrypt_gcm(plaintext) + tampered = bytearray(ct) + tampered[-1] ^= 0xFF # flip last byte (part of tag) + stream = DelayedAuthGCMDecryptingStream( + _make_streaming_body(bytes(tampered)), _make_gcm_decryptor(key, nonce), tag_length=16 + ) + with pytest.raises(S3EncryptionClientError, match="Failed to decrypt GCM content"): + stream.read() + + def test_small_data_less_than_tag_length(self): + """Data exactly equal to tag length — only tag, no ciphertext.""" + plaintext = b"" + ct, key, nonce = _encrypt_gcm(plaintext) + # For empty plaintext, ct is just the 16-byte tag + assert len(ct) == 16 + stream = DelayedAuthGCMDecryptingStream( + _make_streaming_body(ct), _make_gcm_decryptor(key, nonce), tag_length=16 + ) + assert stream.read() == b"" + + def test_large_data(self): + plaintext = os.urandom(1024 * 1024) # 1 MB + ct, key, nonce = _encrypt_gcm(plaintext) + stream = DelayedAuthGCMDecryptingStream( + _make_streaming_body(ct), _make_gcm_decryptor(key, nonce), tag_length=16 + ) + result = b"" + while chunk := stream.read(65536): + result += chunk + assert result == plaintext + + +# --------------------------------------------------------------------------- +# Parameterized edge-case plaintext lengths +# --------------------------------------------------------------------------- +# Lengths chosen around AES block size (16) and two-block (32) boundaries, +# plus zero and one byte, to exercise padding, tag-splitting, and empty-data paths. +EDGE_CASE_LENGTHS = [0, 1, 8, 15, 16, 17, 31, 32, 33, 47, 48, 49] + + +class TestEdgeCasePlaintextLengths: + + @pytest.mark.parametrize("length", EDGE_CASE_LENGTHS) + def test_buffered_gcm(self, length): + plaintext = os.urandom(length) + ct, key, nonce = _encrypt_gcm(plaintext) + stream = BufferedDecryptingGCMStream( + _make_streaming_body(ct), _make_gcm_decryptor(key, nonce), tag_length=16 + ) assert stream.read() == plaintext + + @pytest.mark.parametrize("length", EDGE_CASE_LENGTHS) + def test_delayed_auth_gcm(self, length): + plaintext = os.urandom(length) + ct, key, nonce = _encrypt_gcm(plaintext) + stream = DelayedAuthGCMDecryptingStream( + _make_streaming_body(ct), _make_gcm_decryptor(key, nonce), tag_length=16 + ) + result = b"" + while stream.readable(): + # odd read size to stress tag-splitting + chunk = stream.read(7) + result += chunk + assert result == plaintext + + @pytest.mark.parametrize("length", [l for l in EDGE_CASE_LENGTHS if l > 0]) + def test_delayed_auth_cbc(self, length): + plaintext = os.urandom(length) + ciphertext, key, iv, unpadder = _encrypt_cbc(plaintext) + stream = DelayedAuthCBCDecryptingStream( + _make_streaming_body(ciphertext), + _make_cbc_decryptor(key, iv), + unpadder=unpadder, + ) + result = b"" + while stream.readable(): + # odd read size to stress tag-splitting/padding + result += stream.read(7) + assert result == plaintext From dad00c20ae84270ee9cb35f112c8863f6890caca Mon Sep 17 00:00:00 2001 From: texastony <5892063+texastony@users.noreply.github.com> Date: Tue, 24 Mar 2026 09:53:38 -0700 Subject: [PATCH 19/31] fix(stream): enforce minimum read size on DelayedAuthGCMDecryptingStream MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Delayed-Auth Streams: Empty Read Behavior ## Problem DelayedAuthGCMDecryptingStream.read(amt) can return b"" mid-stream before the stream is exhausted. This happens when the read size is small relative to the GCM tag length (16 bytes) — the stream can't distinguish ciphertext from the trailing auth tag until it has accumulated more than tag_length bytes. Example with 20 bytes of ciphertext+tag, read(7): 1. read(7) → 7 bytes buffered, <= 16 → returns b"" 2. read(7) → 14 bytes buffered, <= 16 → returns b"" 3. read(7) → 20 bytes total, splits ciphertext/tag → returns plaintext In Python, read() returning b"" conventionally signals EOF. This breaks common patterns like: while chunk := stream.read(7) DelayedAuthCBCDecryptingStream does not have this issue — CBC cipher.update() always produces output when given input. ## Java Behavior Java's CipherSubscriber (the delayed-auth equivalent) does the same thing. When cipher.update() returns null/empty, it explicitly sends an empty ByteBuffer downstream. This is fine in Java's reactive streams model where empty emissions are normal signaling. In Python's read() API, it's surprising. ## Options Considered 1. Keep as-is, document it — match Java semantics. 2. Loop internally in read() — more Pythonic, but violates io.py Reader.read contract: "If size is specified, at most size items will be read." 3. Require minimum read size (chosen) — raise if amt < tag_length + 1. ## ESDK-Python Comparison ESDK-Python's StreamDecryptor never has this problem because it decrypts at the frame level. Each frame has its own IV and tag, so authentication is per-frame. S3EC operates on a single non-framed GCM ciphertext where the tag is simply appended — the stream must separate tag from ciphertext on the fly. --- src/s3_encryption/stream.py | 6 ++++++ test/test_stream.py | 13 +++++++++++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/src/s3_encryption/stream.py b/src/s3_encryption/stream.py index 47ed62a8..b77f64ae 100644 --- a/src/s3_encryption/stream.py +++ b/src/s3_encryption/stream.py @@ -220,6 +220,12 @@ def readable(self): # noqa: D102 def read(self, amt=None): """Read and decrypt GCM ciphertext, holding back the trailing auth tag.""" + if amt is not None and 0 < amt < self._tag_length + 1: + raise S3EncryptionClientError( + f"read size {amt} is too small; must be at least {self._tag_length + 1} " + f"to distinguish ciphertext from the GCM auth tag" + ) + # Stream already fully consumed and finalized; nothing left to return. if self._finalized: return b"" diff --git a/test/test_stream.py b/test/test_stream.py index 3775ed34..7ee31688 100644 --- a/test/test_stream.py +++ b/test/test_stream.py @@ -463,6 +463,15 @@ def test_large_data(self): result += chunk assert result == plaintext + def test_read_too_small_raises_error(self): + plaintext = b"small read" + ct, key, nonce = _encrypt_gcm(plaintext) + stream = DelayedAuthGCMDecryptingStream( + _make_streaming_body(ct), _make_gcm_decryptor(key, nonce), tag_length=16 + ) + with pytest.raises(S3EncryptionClientError, match="read size 7 is too small"): + stream.read(7) + # --------------------------------------------------------------------------- # Parameterized edge-case plaintext lengths @@ -492,8 +501,8 @@ def test_delayed_auth_gcm(self, length): ) result = b"" while stream.readable(): - # odd read size to stress tag-splitting - chunk = stream.read(7) + # minimum valid read size for tag_length=16 + chunk = stream.read(17) result += chunk assert result == plaintext From 4a35e881b24d75453d7596afea03992f87065043 Mon Sep 17 00:00:00 2001 From: texastony <5892063+texastony@users.noreply.github.com> Date: Wed, 25 Mar 2026 12:21:32 -0700 Subject: [PATCH 20/31] test(stream): strengthen CBC test assertions and include empty plaintext edge case --- test/test_stream.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/test/test_stream.py b/test/test_stream.py index 7ee31688..8b8657d5 100644 --- a/test/test_stream.py +++ b/test/test_stream.py @@ -129,7 +129,7 @@ def test_roundtrip(self): stream = DelayedAuthCBCDecryptingStream( _make_streaming_body(ciphertext), _make_cbc_decryptor(key, iv), -unpadder=unpadder, + unpadder=unpadder, ) assert stream.read() == plaintext @@ -154,8 +154,9 @@ def test_finalize_called(self): _make_cbc_decryptor(key, iv), unpadder=unpadder, ) - stream.read() + actual = stream.read() assert stream._finalized + assert actual == plaintext def test_no_trailing_padding_bytes(self): plaintext = b"short" @@ -187,8 +188,9 @@ def test_readable_false_after_finalized(self): unpadder=unpadder, ) assert stream.readable() - stream.read() + actual = stream.read() assert not stream.readable() + assert actual == plaintext def test_close_delegates_to_body(self): plaintext = b"close me" @@ -506,7 +508,7 @@ def test_delayed_auth_gcm(self, length): result += chunk assert result == plaintext - @pytest.mark.parametrize("length", [l for l in EDGE_CASE_LENGTHS if l > 0]) + @pytest.mark.parametrize("length", EDGE_CASE_LENGTHS) def test_delayed_auth_cbc(self, length): plaintext = os.urandom(length) ciphertext, key, iv, unpadder = _encrypt_cbc(plaintext) From c8bd8a28a56f9b2c075c70d01d03dc831ffe6bc0 Mon Sep 17 00:00:00 2001 From: texastony <5892063+texastony@users.noreply.github.com> Date: Wed, 25 Mar 2026 12:28:23 -0700 Subject: [PATCH 21/31] fix(pipelines): fix docstring param order and remove misplaced encryption annotation --- src/s3_encryption/pipelines.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/src/s3_encryption/pipelines.py b/src/s3_encryption/pipelines.py index 97a04a98..145cd220 100644 --- a/src/s3_encryption/pipelines.py +++ b/src/s3_encryption/pipelines.py @@ -237,11 +237,11 @@ def decrypt( Args: response (dict): The response from S3 containing the encrypted data and metadata - instruction_suffix(str): suffix for instruction file + instruction_suffix (str): suffix for instruction file + enable_delayed_authentication (bool): If True, release plaintext before GCM tag verification. encryption_context (dict, optional): Additional context for decryption bucket (str, optional): S3 bucket name (required for instruction file) key (str, optional): S3 object key (required for instruction file) - enable_delayed_authentication (bool): If True, release plaintext before GCM tag verification. Returns: A decrypting stream (BufferedDecryptingStream or DelayedAuthDecryptingStream). @@ -457,9 +457,7 @@ def _make_decrypting_gcm_stream( ##= specification/s3-encryption/client.md#enable-delayed-authentication ##= type=implementation ##% When disabled the S3EC MUST NOT release plaintext from a stream which has not been authenticated. - return BufferedDecryptingGCMStream( - streaming_body, decryptor, tag_length=tag_length - ) + return BufferedDecryptingGCMStream(streaming_body, decryptor, tag_length=tag_length) def _decrypt_kc_gcm_streaming( self, dec_materials, metadata, streaming_body, enable_delayed_authentication @@ -469,9 +467,6 @@ def _decrypt_kc_gcm_streaming( Performs HKDF key derivation, key commitment verification, then returns a streaming decryptor. """ - ##= specification/s3-encryption/encryption.md#alg-aes-256-gcm-hkdf-sha512-commit-key - ##= type=implementation - ##% The client MUST use HKDF to derive the key commitment value and the derived encrypting key as described in [Key Derivation](key-derivation.md). message_id = base64.b64decode(metadata.message_id_v3) stored_commitment = base64.b64decode(metadata.key_commitment_v3) From 6fc49103ca1c172f5b4114d043f72ae837f157db Mon Sep 17 00:00:00 2001 From: texastony <5892063+texastony@users.noreply.github.com> Date: Wed, 25 Mar 2026 13:08:23 -0700 Subject: [PATCH 22/31] docs: add Google-style docstring to S3EncryptionClientConfig --- src/s3_encryption/__init__.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/src/s3_encryption/__init__.py b/src/s3_encryption/__init__.py index 01207c4d..f8a039fd 100644 --- a/src/s3_encryption/__init__.py +++ b/src/s3_encryption/__init__.py @@ -34,7 +34,27 @@ @define class S3EncryptionClientConfig: - """Configuration object for the S3 Encryption Client.""" + """Configuration for the S3 Encryption Client. + + Attributes: + keyring: Keyring used for encrypting/decrypting data keys. + encryption_algorithm: Algorithm suite for encryption. Defaults to + ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY (V3 key-committing). + commitment_policy: Key commitment policy for encryption and decryption. + Defaults to REQUIRE_ENCRYPT_REQUIRE_DECRYPT. + enable_legacy_unauthenticated_modes: If True, allow decryption of objects + encrypted with legacy CBC algorithm suites. Defaults to False. + cmm: Crypto materials manager. Defaults to a DefaultCryptoMaterialsManager + wrapping the provided keyring. + instruction_file_suffix: Suffix appended to the S3 object key when + fetching instruction files. Defaults to ".instruction". + enable_delayed_authentication: If True, release plaintext from streams + before GCM tag verification. Defaults to False. + + Raises: + S3EncryptionClientError: If the encryption algorithm is legacy, or if + the algorithm suite is incompatible with the commitment policy. + """ keyring: AbstractKeyring encryption_algorithm: AlgorithmSuite = field( From 29bfb1c2e9e9b4c41db75e4b6ec6de7f7b68a83f Mon Sep 17 00:00:00 2001 From: texastony <5892063+texastony@users.noreply.github.com> Date: Wed, 25 Mar 2026 13:11:11 -0700 Subject: [PATCH 23/31] docs: detail that CBC is always streamed --- src/s3_encryption/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/s3_encryption/__init__.py b/src/s3_encryption/__init__.py index f8a039fd..4b20dd06 100644 --- a/src/s3_encryption/__init__.py +++ b/src/s3_encryption/__init__.py @@ -49,7 +49,9 @@ class S3EncryptionClientConfig: instruction_file_suffix: Suffix appended to the S3 object key when fetching instruction files. Defaults to ".instruction". enable_delayed_authentication: If True, release plaintext from streams - before GCM tag verification. Defaults to False. + before GCM tag verification. Defaults to False. Has no effect for + CBC encrypted ciphertext, which is always streamed as there is no + authentication tag. Raises: S3EncryptionClientError: If the encryption algorithm is legacy, or if From 0d5808000ea5d1ecaadd21695c181d1b72a18d44 Mon Sep 17 00:00:00 2001 From: texastony <5892063+texastony@users.noreply.github.com> Date: Wed, 25 Mar 2026 14:06:27 -0700 Subject: [PATCH 24/31] chore: address linting concern --- src/s3_encryption/stream.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/s3_encryption/stream.py b/src/s3_encryption/stream.py index b77f64ae..52069969 100644 --- a/src/s3_encryption/stream.py +++ b/src/s3_encryption/stream.py @@ -170,8 +170,7 @@ def _finalize(self): self._finalized = True try: plaintext = self._decryptor.finalize() - plaintext = self._unpadder.update(plaintext) + self._unpadder.finalize() - return plaintext + return self._unpadder.update(plaintext) + self._unpadder.finalize() except Exception as e: raise S3EncryptionClientError(f"Failed to decrypt CBC content: {e}") from e @@ -279,8 +278,7 @@ def _finalize(self, tag): self._finalized = True self._tag_buffer = b"" try: - plaintext = self._decryptor.finalize_with_tag(tag) - return plaintext + return self._decryptor.finalize_with_tag(tag) except Exception as e: raise S3EncryptionClientError(f"Failed to decrypt GCM content: {e}") from e From ad228805dd5b57c2c1830a78498decbe749d09d7 Mon Sep 17 00:00:00 2001 From: texastony <5892063+texastony@users.noreply.github.com> Date: Thu, 26 Mar 2026 16:03:58 -0700 Subject: [PATCH 25/31] refactor(stream): use ContentLength to eliminate rolling GCM tag buffer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Refactor DelayedAuthGCMDecryptingStream to use content_length instead of rolling tag buffer and peek-ahead - Add ContentLength validation in on_get_object_after_call - Pass content_length through pipeline to all stream constructors - Rename stream classes: BufferedDecryptingGCMStream → GCMBufferedDecryptingStream, DelayedAuthCBCDecryptingStream → CBCDecryptingStream, DelayedAuthGCMDecryptingStream → GCMDelayedAuthDecryptingStream - Track _amount_read for progress in all three streams - Remove minimum read size restriction --- src/s3_encryption/__init__.py | 8 + src/s3_encryption/pipelines.py | 29 ++- src/s3_encryption/stream.py | 128 +++++------ .../test_i_s3_encryption_streaming.py | 8 +- test/test_s3_encryption_client_plugin.py | 15 ++ test/test_stream.py | 217 ++++++++++++------ 6 files changed, 250 insertions(+), 155 deletions(-) diff --git a/src/s3_encryption/__init__.py b/src/s3_encryption/__init__.py index 4b20dd06..75f7fc1f 100644 --- a/src/s3_encryption/__init__.py +++ b/src/s3_encryption/__init__.py @@ -229,9 +229,17 @@ def on_get_object_after_call(self, parsed, **kwargs): # We need to read it, decrypt it, and replace it # Create a response dict that matches what the pipeline expects + content_length = parsed.get("ContentLength") + if content_length is None: + obj_key = getattr(self._context, _CTX_KEY, None) + raise S3EncryptionClientError( + f"S3 response is missing ContentLength and is invalid. Key: {obj_key}" + ) + response = { "Body": parsed.get("Body"), "Metadata": parsed.get("Metadata", {}), + "ContentLength": content_length, } # Create a pipeline and decrypt the data diff --git a/src/s3_encryption/pipelines.py b/src/s3_encryption/pipelines.py index 145cd220..59428138 100644 --- a/src/s3_encryption/pipelines.py +++ b/src/s3_encryption/pipelines.py @@ -28,9 +28,9 @@ ) from .metadata import ObjectMetadata from .stream import ( - BufferedDecryptingGCMStream, - DelayedAuthCBCDecryptingStream, - DelayedAuthGCMDecryptingStream, + CBCDecryptingStream, + GCMBufferedDecryptingStream, + GCMDelayedAuthDecryptingStream, ) @@ -248,6 +248,7 @@ def decrypt( """ # Convert the metadata dictionary to an ObjectMetadata instance streaming_body = response.get("Body") + content_length = response.get("ContentLength") encryption_metadata = response.get("Metadata", {}) metadata = ObjectMetadata.from_dict(encryption_metadata) @@ -416,7 +417,9 @@ def decrypt( decryptor = cipher.decryptor() # Remove PKCS7 padding (compatible with PKCS5Padding for 16-byte block ciphers) unpadder = PKCS7(dec_materials.algorithm_suite.cipher_block_size_bits).unpadder() - return DelayedAuthCBCDecryptingStream(streaming_body, decryptor, unpadder=unpadder) + return CBCDecryptingStream( + streaming_body, decryptor, unpadder=unpadder, content_length=content_length + ) case AlgorithmSuite.ALG_AES_256_GCM_IV12_TAG16_NO_KDF: ##= specification/s3-encryption/encryption.md#alg-aes-256-gcm-iv12-tag16-no-kdf ##= type=implementation @@ -431,6 +434,7 @@ def decrypt( decryptor, tag_length=dec_materials.algorithm_suite.cipher_tag_length_bytes, enable_delayed_authentication=enable_delayed_authentication, + content_length=content_length, ) case AlgorithmSuite.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY: return self._decrypt_kc_gcm_streaming( @@ -438,13 +442,14 @@ def decrypt( metadata, streaming_body, enable_delayed_authentication=enable_delayed_authentication, + content_length=content_length, ) case _: raise S3EncryptionClientError("Unknown algorithm suite!") @staticmethod def _make_decrypting_gcm_stream( - streaming_body, decryptor, tag_length, enable_delayed_authentication + streaming_body, decryptor, tag_length, enable_delayed_authentication, content_length ): """Return the appropriate decrypting stream. @@ -453,14 +458,21 @@ def _make_decrypting_gcm_stream( When delayed auth is enabled, the CBC or GCM specific stream is used. """ if enable_delayed_authentication: - return DelayedAuthGCMDecryptingStream(streaming_body, decryptor, tag_length=tag_length) + return GCMDelayedAuthDecryptingStream( + streaming_body, + decryptor, + tag_length=tag_length, + content_length=content_length, + ) ##= specification/s3-encryption/client.md#enable-delayed-authentication ##= type=implementation ##% When disabled the S3EC MUST NOT release plaintext from a stream which has not been authenticated. - return BufferedDecryptingGCMStream(streaming_body, decryptor, tag_length=tag_length) + return GCMBufferedDecryptingStream( + streaming_body, decryptor, tag_length=tag_length, content_length=content_length + ) def _decrypt_kc_gcm_streaming( - self, dec_materials, metadata, streaming_body, enable_delayed_authentication + self, dec_materials, metadata, streaming_body, enable_delayed_authentication, content_length ): """Decrypt content encrypted with ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY. @@ -509,6 +521,7 @@ def _decrypt_kc_gcm_streaming( decryptor, tag_length=dec_materials.algorithm_suite.cipher_tag_length_bytes, enable_delayed_authentication=enable_delayed_authentication, + content_length=content_length, ) def _decrypt_v2(self, metadata, encryption_context) -> DecryptionMaterials: diff --git a/src/s3_encryption/stream.py b/src/s3_encryption/stream.py index 52069969..7473ae3c 100644 --- a/src/s3_encryption/stream.py +++ b/src/s3_encryption/stream.py @@ -25,7 +25,7 @@ # slots=False because StreamingBody extends IOBase which already has __weakref__. @define(slots=False) -class BufferedDecryptingGCMStream(StreamingBody): +class GCMBufferedDecryptingStream(StreamingBody): """A stream that buffers all ciphertext, decrypts, then releases plaintext. Extends botocore's StreamingBody so it can be used as a drop-in replacement @@ -35,12 +35,15 @@ class BufferedDecryptingGCMStream(StreamingBody): _body: object = field() _decryptor: object = field() _tag_length: int = field() + # _content_length intentionally collides with super's _content_length + _content_length: int = field() _plaintext: object = field(init=False, default=None) def __attrs_post_init__(self): # noqa: D105 - # Initialize StreamingBody with a placeholder; _raw_stream is replaced - # on first read after decryption. - super().__init__(io.BytesIO(), content_length=None) + # By passing in content_length, and updating _amount_read in read(), + # we support the super's normal progression. + # However, we do not support the super's _verify_content_length. + super().__init__(io.BytesIO(), content_length=self._content_length) def _decrypt(self): """Read all ciphertext, decrypt and verify, cache plaintext.""" @@ -77,9 +80,11 @@ def read(self, amt=None): bytes: Decrypted plaintext bytes. """ self._decrypt() - if amt is None: - return self._plaintext.read() - return self._plaintext.read(amt) + chunk = self._plaintext.read() if amt is None else self._plaintext.read(amt) + # super._amount_read can be used for progress tracking + # noinspection PyUnresolvedReferences + self._amount_read += len(chunk) + return chunk def readinto(self, b): # noqa: D102 self._decrypt() @@ -104,7 +109,7 @@ def close(self): ##% When enabled, the S3EC MAY release plaintext from a stream which has not been authenticated. # slots=False because StreamingBody extends IOBase which already has __weakref__. @define(slots=False) -class DelayedAuthCBCDecryptingStream(StreamingBody): +class CBCDecryptingStream(StreamingBody): """A delayed-auth stream for AES-CBC decryption. Extends botocore's StreamingBody so it can be used as a drop-in replacement @@ -118,13 +123,16 @@ class DelayedAuthCBCDecryptingStream(StreamingBody): _body: object = field() _decryptor: object = field() _unpadder: object = field() + # _content_length intentionally collides with super's _content_length + _content_length: int = field() _peek_buffer: bytes = field(init=False, default=b"") _finalized: bool = field(init=False, default=False) def __attrs_post_init__(self): # noqa: D105 - # Initialize StreamingBody; _raw_stream is unused since plaintext is - # produced incrementally via read(). - super().__init__(io.BytesIO(), content_length=None) + # By passing in content_length, and updating _amount_read in read(), + # we support the super's normal progression. + # However, we do not support the super's _verify_content_length. + super().__init__(io.BytesIO(), content_length=self._content_length) # Inherited iter_chunks, iter_lines, __iter__, and __next__ all delegate # to self.read(). No override needed. @@ -163,6 +171,9 @@ def read(self, amt=None): # Stream exhausted; finalize to flush any remaining padding. plaintext += self._finalize() + # super._amount_read can be used for progress tracking + # noinspection PyUnresolvedReferences + self._amount_read += len(plaintext) return plaintext def _finalize(self): @@ -188,28 +199,32 @@ def close(self): ##% When enabled, the S3EC MAY release plaintext from a stream which has not been authenticated. # slots=False because StreamingBody extends IOBase which already has __weakref__. @define(slots=False) -class DelayedAuthGCMDecryptingStream(StreamingBody): +class GCMDelayedAuthDecryptingStream(StreamingBody): """A delayed-auth stream for AES-GCM decryption. Extends botocore's StreamingBody so it can be used as a drop-in replacement for parsed["Body"], inheriting iter_chunks, iter_lines, __iter__, etc. - Plaintext is released incrementally via cipher.update(). The last - tag_length bytes of ciphertext are the GCM auth tag, held back in a - rolling buffer. The tag is only verified via finalize_with_tag() when - the stream is fully consumed. + Plaintext is released incrementally via cipher.update(). The content_length + from the S3 GetObject response tells us exactly how many bytes are ciphertext + vs. the trailing GCM auth tag. The tag is only verified via finalize_with_tag() when the ciphertext is + fully consumed. """ _body: object = field() _decryptor: object = field() _tag_length: int = field() - _tag_buffer: bytes = field(init=False, default=b"") + # _content_length intentionally collides with super's _content_length + _content_length: int = field() + _ciphertext_remaining: int = field(init=False) _finalized: bool = field(init=False, default=False) def __attrs_post_init__(self): # noqa: D105 - # Initialize StreamingBody; _raw_stream is unused since plaintext is - # produced incrementally via read(). - super().__init__(io.BytesIO(), content_length=None) + # By passing in content_length, and updating _amount_read in read(), + # we support the super's normal progression. + # However, we do not support the super's _verify_content_length. + super().__init__(io.BytesIO(), content_length=self._content_length) + self._ciphertext_remaining = self._content_length - self._tag_length # Inherited iter_chunks, iter_lines, __iter__, and __next__ all delegate # to self.read(). No override needed. @@ -219,65 +234,42 @@ def readable(self): # noqa: D102 def read(self, amt=None): """Read and decrypt GCM ciphertext, holding back the trailing auth tag.""" - if amt is not None and 0 < amt < self._tag_length + 1: - raise S3EncryptionClientError( - f"read size {amt} is too small; must be at least {self._tag_length + 1} " - f"to distinguish ciphertext from the GCM auth tag" - ) - # Stream already fully consumed and finalized; nothing left to return. if self._finalized: return b"" - # Read the next chunk of raw ciphertext from the underlying stream. - raw = self._body.read(amt) + # No ciphertext left — read the tag and finalize. + if self._ciphertext_remaining <= 0: + return self._finalize() - # No new data and no held-back bytes; the stream is empty. - if not raw and not self._tag_buffer: - return b"" + # Read at most ciphertext_remaining bytes (never into the tag). + to_read = ( + self._ciphertext_remaining if amt is None else min(amt, self._ciphertext_remaining) + ) + raw = self._body.read(to_read) - # Combine any previously held-back bytes with the new data. - data = self._tag_buffer + raw - - # Not enough data to separate ciphertext from tag yet. - if len(data) <= self._tag_length: - if raw: - # More data may arrive; buffer everything and wait. - self._tag_buffer = data - return b"" - # No more data coming; everything buffered is the tag. - return self._finalize(tag=data) - - # Split: the last tag_length bytes are the candidate tag; - # everything before is ciphertext safe to decrypt now. - self._tag_buffer = data[-self._tag_length :] - ciphertext = data[: -self._tag_length] - plaintext = self._decryptor.update(ciphertext) - - # Peek 1 byte ahead to detect whether the underlying stream is - # exhausted. This determines if the current tag_buffer is truly - # the final GCM tag or just more ciphertext. - peek = self._body.read(1) - if peek: - # Stream continues; the peeked byte may shift what we thought - # was the tag back into ciphertext territory. - self._tag_buffer = self._tag_buffer + peek - if len(self._tag_buffer) > self._tag_length: - # Extra bytes beyond tag_length are ciphertext; decrypt them. - extra_ct = self._tag_buffer[: -self._tag_length] - self._tag_buffer = self._tag_buffer[-self._tag_length :] - plaintext += self._decryptor.update(extra_ct) - else: - # Stream exhausted; tag_buffer holds the final GCM auth tag. - plaintext += self._finalize(tag=self._tag_buffer) + if not raw: + return self._finalize() + + self._ciphertext_remaining -= len(raw) + plaintext = self._decryptor.update(raw) + + # If we've consumed all ciphertext, finalize now. + if self._ciphertext_remaining <= 0: + plaintext += self._finalize() + # super._amount_read can be used for progress tracking + # noinspection PyUnresolvedReferences + self._amount_read += len(plaintext) return plaintext - def _finalize(self, tag): - """Finalize GCM decryption, verifying the auth tag.""" + def _finalize(self): + """Read the GCM tag from the stream and verify it.""" + if self._finalized: + return b"" self._finalized = True - self._tag_buffer = b"" try: + tag = self._body.read(self._tag_length) return self._decryptor.finalize_with_tag(tag) except Exception as e: raise S3EncryptionClientError(f"Failed to decrypt GCM content: {e}") from e diff --git a/test/integration/test_i_s3_encryption_streaming.py b/test/integration/test_i_s3_encryption_streaming.py index 0c26f90d..553cd16b 100644 --- a/test/integration/test_i_s3_encryption_streaming.py +++ b/test/integration/test_i_s3_encryption_streaming.py @@ -16,8 +16,8 @@ from s3_encryption.materials.kms_keyring import KmsKeyring from s3_encryption.materials.materials import AlgorithmSuite, CommitmentPolicy from s3_encryption.stream import ( - BufferedDecryptingGCMStream, - DelayedAuthGCMDecryptingStream, + GCMBufferedDecryptingStream, + GCMDelayedAuthDecryptingStream, ) bucket = os.environ.get("CI_S3_BUCKET", "s3ec-python-github-test-bucket") @@ -73,7 +73,7 @@ def test_buffered_roundtrip(algorithm_suite, commitment_policy): response = s3ec.get_object(Bucket=bucket, Key=key) body = response["Body"] - assert isinstance(body, BufferedDecryptingGCMStream) + assert isinstance(body, GCMBufferedDecryptingStream) assert body.read() == data @@ -109,7 +109,7 @@ def test_delayed_auth_roundtrip(algorithm_suite, commitment_policy): response = s3ec.get_object(Bucket=bucket, Key=key) body = response["Body"] - assert isinstance(body, DelayedAuthGCMDecryptingStream) + assert isinstance(body, GCMDelayedAuthDecryptingStream) assert body.read() == data diff --git a/test/test_s3_encryption_client_plugin.py b/test/test_s3_encryption_client_plugin.py index bdc48c79..cbc8cd80 100644 --- a/test/test_s3_encryption_client_plugin.py +++ b/test/test_s3_encryption_client_plugin.py @@ -139,3 +139,18 @@ def test_instruction_file_mode_invalid_keys_raises_error(self): # Should raise error with pytest.raises(S3EncryptionClientError, match="Instruction file contains invalid keys"): plugin.on_get_object_after_call(parsed) + + def test_missing_content_length_raises_error(self): + """Test that a missing ContentLength in the S3 response raises an error.""" + mock_keyring = Mock(spec=S3Keyring) + config = S3EncryptionClientConfig(keyring=mock_keyring) + plugin = S3EncryptionClientPlugin(config) + plugin._context.key = "my-object" + + parsed = { + "Body": StreamingBody(io.BytesIO(b"data"), 4), + "Metadata": {}, + } + + with pytest.raises(S3EncryptionClientError, match="missing ContentLength.*Key: my-object"): + plugin.on_get_object_after_call(parsed) diff --git a/test/test_stream.py b/test/test_stream.py index 8b8657d5..97da4727 100644 --- a/test/test_stream.py +++ b/test/test_stream.py @@ -13,9 +13,9 @@ from s3_encryption.exceptions import S3EncryptionClientError from s3_encryption.materials import AlgorithmSuite from s3_encryption.stream import ( - BufferedDecryptingGCMStream, - DelayedAuthCBCDecryptingStream, - DelayedAuthGCMDecryptingStream, + CBCDecryptingStream, + GCMBufferedDecryptingStream, + GCMDelayedAuthDecryptingStream, ) @@ -54,10 +54,11 @@ def test_delayed_auth_releases_plaintext_before_tag_verification(self): body = _make_streaming_body(ciphertext_with_tag) decryptor = _make_gcm_decryptor(key, nonce) - stream = DelayedAuthGCMDecryptingStream( + stream = GCMDelayedAuthDecryptingStream( body, decryptor, tag_length=AlgorithmSuite.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY.cipher_tag_length_bytes, + content_length=len(ciphertext_with_tag), ) # read(256) decrypts a partial chunk via cipher.update(), releasing # plaintext without consuming the full ciphertext stream. The GCM tag @@ -88,10 +89,11 @@ def test_buffered_verifies_tag_before_releasing_any_plaintext(self): body = _make_streaming_body(ciphertext_with_tag) decryptor = _make_gcm_decryptor(key, nonce) - stream = BufferedDecryptingGCMStream( + stream = GCMBufferedDecryptingStream( body, decryptor, tag_length=AlgorithmSuite.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY.cipher_tag_length_bytes, + content_length=len(ciphertext_with_tag), ) # read(1) triggers _decrypt(), which calls self._body.read() with no amt, # consuming the entire ciphertext and verifying the GCM tag before @@ -126,20 +128,22 @@ class TestDelayedAuthCBCDecryption: def test_roundtrip(self): plaintext = b"hello world, this is a CBC test!!" ciphertext, key, iv, unpadder = _encrypt_cbc(plaintext) - stream = DelayedAuthCBCDecryptingStream( + stream = CBCDecryptingStream( _make_streaming_body(ciphertext), _make_cbc_decryptor(key, iv), unpadder=unpadder, + content_length=len(ciphertext), ) assert stream.read() == plaintext def test_chunked_read(self): plaintext = b"A" * 256 ciphertext, key, iv, unpadder = _encrypt_cbc(plaintext) - stream = DelayedAuthCBCDecryptingStream( + stream = CBCDecryptingStream( _make_streaming_body(ciphertext), _make_cbc_decryptor(key, iv), unpadder=unpadder, + content_length=len(ciphertext), ) result = b"" while chunk := stream.read(64): @@ -149,10 +153,11 @@ def test_chunked_read(self): def test_finalize_called(self): plaintext = b"finalize me" ciphertext, key, iv, unpadder = _encrypt_cbc(plaintext) - stream = DelayedAuthCBCDecryptingStream( + stream = CBCDecryptingStream( _make_streaming_body(ciphertext), _make_cbc_decryptor(key, iv), unpadder=unpadder, + content_length=len(ciphertext), ) actual = stream.read() assert stream._finalized @@ -161,20 +166,22 @@ def test_finalize_called(self): def test_no_trailing_padding_bytes(self): plaintext = b"short" ciphertext, key, iv, unpadder = _encrypt_cbc(plaintext) - stream = DelayedAuthCBCDecryptingStream( + stream = CBCDecryptingStream( _make_streaming_body(ciphertext), _make_cbc_decryptor(key, iv), unpadder=unpadder, + content_length=len(ciphertext), ) assert stream.read() == plaintext def test_read_after_finalized_returns_empty(self): plaintext = b"done" ciphertext, key, iv, unpadder = _encrypt_cbc(plaintext) - stream = DelayedAuthCBCDecryptingStream( + stream = CBCDecryptingStream( _make_streaming_body(ciphertext), _make_cbc_decryptor(key, iv), unpadder=unpadder, + content_length=len(ciphertext), ) stream.read() assert stream.read() == b"" @@ -182,10 +189,11 @@ def test_read_after_finalized_returns_empty(self): def test_readable_false_after_finalized(self): plaintext = b"readable" ciphertext, key, iv, unpadder = _encrypt_cbc(plaintext) - stream = DelayedAuthCBCDecryptingStream( + stream = CBCDecryptingStream( _make_streaming_body(ciphertext), _make_cbc_decryptor(key, iv), unpadder=unpadder, + content_length=len(ciphertext), ) assert stream.readable() actual = stream.read() @@ -196,8 +204,8 @@ def test_close_delegates_to_body(self): plaintext = b"close me" ciphertext, key, iv, unpadder = _encrypt_cbc(plaintext) body = _make_streaming_body(ciphertext) - stream = DelayedAuthCBCDecryptingStream( - body, _make_cbc_decryptor(key, iv), unpadder=unpadder + stream = CBCDecryptingStream( + body, _make_cbc_decryptor(key, iv), unpadder=unpadder, content_length=len(ciphertext) ) stream.close() body.close.assert_called_once() @@ -205,10 +213,11 @@ def test_close_delegates_to_body(self): def test_enter_returns_self(self): plaintext = b"ctx" ciphertext, key, iv, unpadder = _encrypt_cbc(plaintext) - stream = DelayedAuthCBCDecryptingStream( + stream = CBCDecryptingStream( _make_streaming_body(ciphertext), _make_cbc_decryptor(key, iv), unpadder=unpadder, + content_length=len(ciphertext), ) assert stream.__enter__() is stream @@ -218,10 +227,11 @@ def test_wrong_key_raises_error(self): plaintext = b"wrong key test!!" ciphertext, _key, iv, _ = _encrypt_cbc(plaintext) wrong_key = os.urandom(32) - stream = DelayedAuthCBCDecryptingStream( + stream = CBCDecryptingStream( _make_streaming_body(ciphertext), _make_cbc_decryptor(wrong_key, iv), unpadder=PKCS7(128).unpadder(), + content_length=len(ciphertext), ) with pytest.raises(S3EncryptionClientError, match="Failed to decrypt CBC content"): stream.read() @@ -231,31 +241,38 @@ def test_empty_ciphertext(self): key = os.urandom(32) iv = os.urandom(16) - stream = DelayedAuthCBCDecryptingStream( + stream = CBCDecryptingStream( _make_streaming_body(b""), _make_cbc_decryptor(key, iv), unpadder=PKCS7(128).unpadder(), + content_length=0, ) # Empty stream finalize will fail because CBC expects at least one block with pytest.raises(S3EncryptionClientError, match="Failed to decrypt CBC content"): stream.read() -class TestBufferedDecryptingGCMStream: +class TestGCMBufferedDecryptingStream: def test_full_read(self): plaintext = os.urandom(1024) ct, key, nonce = _encrypt_gcm(plaintext) - stream = BufferedDecryptingGCMStream( - _make_streaming_body(ct), _make_gcm_decryptor(key, nonce), tag_length=16 + stream = GCMBufferedDecryptingStream( + _make_streaming_body(ct), + _make_gcm_decryptor(key, nonce), + tag_length=16, + content_length=len(ct), ) assert stream.read() == plaintext def test_partial_reads(self): plaintext = os.urandom(512) ct, key, nonce = _encrypt_gcm(plaintext) - stream = BufferedDecryptingGCMStream( - _make_streaming_body(ct), _make_gcm_decryptor(key, nonce), tag_length=16 + stream = GCMBufferedDecryptingStream( + _make_streaming_body(ct), + _make_gcm_decryptor(key, nonce), + tag_length=16, + content_length=len(ct), ) result = b"" while chunk := stream.read(100): @@ -266,7 +283,9 @@ def test_read_triggers_full_decrypt(self): plaintext = os.urandom(256) ct, key, nonce = _encrypt_gcm(plaintext) body = _make_streaming_body(ct) - stream = BufferedDecryptingGCMStream(body, _make_gcm_decryptor(key, nonce), tag_length=16) + stream = GCMBufferedDecryptingStream( + body, _make_gcm_decryptor(key, nonce), tag_length=16, content_length=len(ct) + ) assert stream._plaintext is None stream.read(1) assert stream._plaintext is not None @@ -276,8 +295,11 @@ def test_read_triggers_full_decrypt(self): def test_tell(self): plaintext = os.urandom(200) ct, key, nonce = _encrypt_gcm(plaintext) - stream = BufferedDecryptingGCMStream( - _make_streaming_body(ct), _make_gcm_decryptor(key, nonce), tag_length=16 + stream = GCMBufferedDecryptingStream( + _make_streaming_body(ct), + _make_gcm_decryptor(key, nonce), + tag_length=16, + content_length=len(ct), ) stream.read(50) assert stream.tell() == 50 @@ -285,8 +307,11 @@ def test_tell(self): def test_readable(self): plaintext = b"readable test" ct, key, nonce = _encrypt_gcm(plaintext) - stream = BufferedDecryptingGCMStream( - _make_streaming_body(ct), _make_gcm_decryptor(key, nonce), tag_length=16 + stream = GCMBufferedDecryptingStream( + _make_streaming_body(ct), + _make_gcm_decryptor(key, nonce), + tag_length=16, + content_length=len(ct), ) assert stream.readable() @@ -294,8 +319,11 @@ def test_readinto(self): """Asserts that readinto is implemented by botocore's StreamingBody""" plaintext = os.urandom(64) ct, key, nonce = _encrypt_gcm(plaintext) - stream = BufferedDecryptingGCMStream( - _make_streaming_body(ct), _make_gcm_decryptor(key, nonce), tag_length=16 + stream = GCMBufferedDecryptingStream( + _make_streaming_body(ct), + _make_gcm_decryptor(key, nonce), + tag_length=16, + content_length=len(ct), ) buf = bytearray(64) n = stream.readinto(buf) @@ -305,8 +333,11 @@ def test_readinto(self): def test_enter_returns_raw_stream(self): plaintext = b"enter" ct, key, nonce = _encrypt_gcm(plaintext) - stream = BufferedDecryptingGCMStream( - _make_streaming_body(ct), _make_gcm_decryptor(key, nonce), tag_length=16 + stream = GCMBufferedDecryptingStream( + _make_streaming_body(ct), + _make_gcm_decryptor(key, nonce), + tag_length=16, + content_length=len(ct), ) inner = stream.__enter__() assert inner.read() == plaintext @@ -316,7 +347,9 @@ def test_close_delegates(self): plaintext = b"close" ct, key, nonce = _encrypt_gcm(plaintext) body = _make_streaming_body(ct) - stream = BufferedDecryptingGCMStream(body, _make_gcm_decryptor(key, nonce), tag_length=16) + stream = GCMBufferedDecryptingStream( + body, _make_gcm_decryptor(key, nonce), tag_length=16, content_length=len(ct) + ) stream.close() body.close.assert_called_once() @@ -327,15 +360,20 @@ def test_close_without_close_attr(self): body = Mock() del body.close body.read = BytesIO(ct).read - stream = BufferedDecryptingGCMStream(body, _make_gcm_decryptor(key, nonce), tag_length=16) + stream = GCMBufferedDecryptingStream( + body, _make_gcm_decryptor(key, nonce), tag_length=16, content_length=len(ct) + ) stream.close() # should not raise def test_wrong_key_raises_error(self): plaintext = b"wrong key" ct, _key, nonce = _encrypt_gcm(plaintext) wrong_key = os.urandom(32) - stream = BufferedDecryptingGCMStream( - _make_streaming_body(ct), _make_gcm_decryptor(wrong_key, nonce), tag_length=16 + stream = GCMBufferedDecryptingStream( + _make_streaming_body(ct), + _make_gcm_decryptor(wrong_key, nonce), + tag_length=16, + content_length=len(ct), ) with pytest.raises(S3EncryptionClientError, match="Failed to decrypt object"): stream.read() @@ -345,8 +383,11 @@ def test_tampered_ciphertext_raises_error(self): ct, key, nonce = _encrypt_gcm(plaintext) tampered = bytearray(ct) tampered[0] ^= 0xFF - stream = BufferedDecryptingGCMStream( - _make_streaming_body(bytes(tampered)), _make_gcm_decryptor(key, nonce), tag_length=16 + stream = GCMBufferedDecryptingStream( + _make_streaming_body(bytes(tampered)), + _make_gcm_decryptor(key, nonce), + tag_length=16, + content_length=len(ct), ) with pytest.raises(S3EncryptionClientError, match="Failed to decrypt object"): stream.read() @@ -354,8 +395,11 @@ def test_tampered_ciphertext_raises_error(self): def test_idempotent_decrypt(self): plaintext = os.urandom(128) ct, key, nonce = _encrypt_gcm(plaintext) - stream = BufferedDecryptingGCMStream( - _make_streaming_body(ct), _make_gcm_decryptor(key, nonce), tag_length=16 + stream = GCMBufferedDecryptingStream( + _make_streaming_body(ct), + _make_gcm_decryptor(key, nonce), + tag_length=16, + content_length=len(ct), ) first = stream.read(63) second = stream.read(65) @@ -367,16 +411,22 @@ class TestDelayedAuthGCMDecryption: def test_full_read(self): plaintext = os.urandom(1024) ct, key, nonce = _encrypt_gcm(plaintext) - stream = DelayedAuthGCMDecryptingStream( - _make_streaming_body(ct), _make_gcm_decryptor(key, nonce), tag_length=16 + stream = GCMDelayedAuthDecryptingStream( + _make_streaming_body(ct), + _make_gcm_decryptor(key, nonce), + tag_length=16, + content_length=len(ct), ) assert stream.read() == plaintext def test_chunked_read(self): plaintext = os.urandom(512) ct, key, nonce = _encrypt_gcm(plaintext) - stream = DelayedAuthGCMDecryptingStream( - _make_streaming_body(ct), _make_gcm_decryptor(key, nonce), tag_length=16 + stream = GCMDelayedAuthDecryptingStream( + _make_streaming_body(ct), + _make_gcm_decryptor(key, nonce), + tag_length=16, + content_length=len(ct), ) result = b"" while chunk := stream.read(64): @@ -386,8 +436,11 @@ def test_chunked_read(self): def test_read_after_finalized_returns_empty(self): plaintext = os.urandom(128) ct, key, nonce = _encrypt_gcm(plaintext) - stream = DelayedAuthGCMDecryptingStream( - _make_streaming_body(ct), _make_gcm_decryptor(key, nonce), tag_length=16 + stream = GCMDelayedAuthDecryptingStream( + _make_streaming_body(ct), + _make_gcm_decryptor(key, nonce), + tag_length=16, + content_length=len(ct), ) actual = stream.read() assert stream._finalized @@ -397,8 +450,11 @@ def test_read_after_finalized_returns_empty(self): def test_readable_false_after_finalized(self): plaintext = b"readable" ct, key, nonce = _encrypt_gcm(plaintext) - stream = DelayedAuthGCMDecryptingStream( - _make_streaming_body(ct), _make_gcm_decryptor(key, nonce), tag_length=16 + stream = GCMDelayedAuthDecryptingStream( + _make_streaming_body(ct), + _make_gcm_decryptor(key, nonce), + tag_length=16, + content_length=len(ct), ) assert stream.readable() stream.read() @@ -408,8 +464,8 @@ def test_close_delegates(self): plaintext = b"close" ct, key, nonce = _encrypt_gcm(plaintext) body = _make_streaming_body(ct) - stream = DelayedAuthGCMDecryptingStream( - body, _make_gcm_decryptor(key, nonce), tag_length=16 + stream = GCMDelayedAuthDecryptingStream( + body, _make_gcm_decryptor(key, nonce), tag_length=16, content_length=len(ct) ) stream.close() body.close.assert_called_once() @@ -417,8 +473,11 @@ def test_close_delegates(self): def test_enter_returns_self(self): plaintext = b"ctx" ct, key, nonce = _encrypt_gcm(plaintext) - stream = DelayedAuthGCMDecryptingStream( - _make_streaming_body(ct), _make_gcm_decryptor(key, nonce), tag_length=16 + stream = GCMDelayedAuthDecryptingStream( + _make_streaming_body(ct), + _make_gcm_decryptor(key, nonce), + tag_length=16, + content_length=len(ct), ) assert stream.__enter__() is stream @@ -426,8 +485,11 @@ def test_wrong_key_raises_error(self): plaintext = b"wrong key" ct, _key, nonce = _encrypt_gcm(plaintext) wrong_key = os.urandom(32) - stream = DelayedAuthGCMDecryptingStream( - _make_streaming_body(ct), _make_gcm_decryptor(wrong_key, nonce), tag_length=16 + stream = GCMDelayedAuthDecryptingStream( + _make_streaming_body(ct), + _make_gcm_decryptor(wrong_key, nonce), + tag_length=16, + content_length=len(ct), ) with pytest.raises(S3EncryptionClientError, match="Failed to decrypt GCM content"): stream.read() @@ -437,8 +499,11 @@ def test_tampered_tag_raises_error(self): ct, key, nonce = _encrypt_gcm(plaintext) tampered = bytearray(ct) tampered[-1] ^= 0xFF # flip last byte (part of tag) - stream = DelayedAuthGCMDecryptingStream( - _make_streaming_body(bytes(tampered)), _make_gcm_decryptor(key, nonce), tag_length=16 + stream = GCMDelayedAuthDecryptingStream( + _make_streaming_body(bytes(tampered)), + _make_gcm_decryptor(key, nonce), + tag_length=16, + content_length=len(ct), ) with pytest.raises(S3EncryptionClientError, match="Failed to decrypt GCM content"): stream.read() @@ -449,31 +514,28 @@ def test_small_data_less_than_tag_length(self): ct, key, nonce = _encrypt_gcm(plaintext) # For empty plaintext, ct is just the 16-byte tag assert len(ct) == 16 - stream = DelayedAuthGCMDecryptingStream( - _make_streaming_body(ct), _make_gcm_decryptor(key, nonce), tag_length=16 + stream = GCMDelayedAuthDecryptingStream( + _make_streaming_body(ct), + _make_gcm_decryptor(key, nonce), + tag_length=16, + content_length=len(ct), ) assert stream.read() == b"" def test_large_data(self): plaintext = os.urandom(1024 * 1024) # 1 MB ct, key, nonce = _encrypt_gcm(plaintext) - stream = DelayedAuthGCMDecryptingStream( - _make_streaming_body(ct), _make_gcm_decryptor(key, nonce), tag_length=16 + stream = GCMDelayedAuthDecryptingStream( + _make_streaming_body(ct), + _make_gcm_decryptor(key, nonce), + tag_length=16, + content_length=len(ct), ) result = b"" while chunk := stream.read(65536): result += chunk assert result == plaintext - def test_read_too_small_raises_error(self): - plaintext = b"small read" - ct, key, nonce = _encrypt_gcm(plaintext) - stream = DelayedAuthGCMDecryptingStream( - _make_streaming_body(ct), _make_gcm_decryptor(key, nonce), tag_length=16 - ) - with pytest.raises(S3EncryptionClientError, match="read size 7 is too small"): - stream.read(7) - # --------------------------------------------------------------------------- # Parameterized edge-case plaintext lengths @@ -489,8 +551,11 @@ class TestEdgeCasePlaintextLengths: def test_buffered_gcm(self, length): plaintext = os.urandom(length) ct, key, nonce = _encrypt_gcm(plaintext) - stream = BufferedDecryptingGCMStream( - _make_streaming_body(ct), _make_gcm_decryptor(key, nonce), tag_length=16 + stream = GCMBufferedDecryptingStream( + _make_streaming_body(ct), + _make_gcm_decryptor(key, nonce), + tag_length=16, + content_length=len(ct), ) assert stream.read() == plaintext @@ -498,13 +563,14 @@ def test_buffered_gcm(self, length): def test_delayed_auth_gcm(self, length): plaintext = os.urandom(length) ct, key, nonce = _encrypt_gcm(plaintext) - stream = DelayedAuthGCMDecryptingStream( - _make_streaming_body(ct), _make_gcm_decryptor(key, nonce), tag_length=16 + stream = GCMDelayedAuthDecryptingStream( + _make_streaming_body(ct), + _make_gcm_decryptor(key, nonce), + tag_length=16, + content_length=len(ct), ) result = b"" - while stream.readable(): - # minimum valid read size for tag_length=16 - chunk = stream.read(17) + while chunk := stream.read(7): result += chunk assert result == plaintext @@ -512,10 +578,11 @@ def test_delayed_auth_gcm(self, length): def test_delayed_auth_cbc(self, length): plaintext = os.urandom(length) ciphertext, key, iv, unpadder = _encrypt_cbc(plaintext) - stream = DelayedAuthCBCDecryptingStream( + stream = CBCDecryptingStream( _make_streaming_body(ciphertext), _make_cbc_decryptor(key, iv), unpadder=unpadder, + content_length=len(ciphertext), ) result = b"" while stream.readable(): From 5d0ad56c6358278d30fcb3954460d85004601701 Mon Sep 17 00:00:00 2001 From: texastony <5892063+texastony@users.noreply.github.com> Date: Thu, 26 Mar 2026 16:52:45 -0700 Subject: [PATCH 26/31] fix(stream): return self from __enter__ and validate content_length - GCMBufferedDecryptingStream.__enter__ returns self for consistent context manager behavior across all stream classes - GCMDelayedAuthDecryptingStream raises on content_length < tag_length - Clarify content_length comment as ciphertext content length --- src/s3_encryption/__init__.py | 4 ++-- src/s3_encryption/stream.py | 6 +++++- test/test_stream.py | 15 ++++++++++++--- 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/src/s3_encryption/__init__.py b/src/s3_encryption/__init__.py index 75f7fc1f..9b8772d6 100644 --- a/src/s3_encryption/__init__.py +++ b/src/s3_encryption/__init__.py @@ -228,14 +228,14 @@ def on_get_object_after_call(self, parsed, **kwargs): # The parsed response already has the Body as a StreamingBody # We need to read it, decrypt it, and replace it - # Create a response dict that matches what the pipeline expects + # content_length is going to the cipher-text's content length content_length = parsed.get("ContentLength") if content_length is None: obj_key = getattr(self._context, _CTX_KEY, None) raise S3EncryptionClientError( f"S3 response is missing ContentLength and is invalid. Key: {obj_key}" ) - + # Create a response dict that matches what the pipeline expects response = { "Body": parsed.get("Body"), "Metadata": parsed.get("Metadata", {}), diff --git a/src/s3_encryption/stream.py b/src/s3_encryption/stream.py index 7473ae3c..ea02dc5b 100644 --- a/src/s3_encryption/stream.py +++ b/src/s3_encryption/stream.py @@ -96,7 +96,7 @@ def tell(self): # noqa: D102 def __enter__(self): # noqa: D105 self._decrypt() - return self._raw_stream + return self def close(self): """Close the underlying stream.""" @@ -225,6 +225,10 @@ def __attrs_post_init__(self): # noqa: D105 # However, we do not support the super's _verify_content_length. super().__init__(io.BytesIO(), content_length=self._content_length) self._ciphertext_remaining = self._content_length - self._tag_length + if self._ciphertext_remaining < 0: + raise S3EncryptionClientError( + f"Malformed Input: Content Length ({self._content_length}) is less than GCM tag length ({self._tag_length})" + ) # Inherited iter_chunks, iter_lines, __iter__, and __next__ all delegate # to self.read(). No override needed. diff --git a/test/test_stream.py b/test/test_stream.py index 97da4727..93ba7e66 100644 --- a/test/test_stream.py +++ b/test/test_stream.py @@ -330,7 +330,7 @@ def test_readinto(self): assert n == 64 assert bytes(buf) == plaintext - def test_enter_returns_raw_stream(self): + def test_enter_returns_self(self): plaintext = b"enter" ct, key, nonce = _encrypt_gcm(plaintext) stream = GCMBufferedDecryptingStream( @@ -339,8 +339,7 @@ def test_enter_returns_raw_stream(self): tag_length=16, content_length=len(ct), ) - inner = stream.__enter__() - assert inner.read() == plaintext + assert stream.__enter__() is stream def test_close_delegates(self): """Asserts that close is implemented by botocore's StreamingBody""" @@ -589,3 +588,13 @@ def test_delayed_auth_cbc(self, length): # odd read size to stress tag-splitting/padding result += stream.read(7) assert result == plaintext + + +class TestGCMDelayedAuthContentLengthValidation: + + def test_content_length_less_than_tag_length_raises(self): + """ContentLength smaller than the GCM tag must raise immediately.""" + stream_body = _make_streaming_body(b"\x00" * 8) + decryptor = _make_gcm_decryptor(os.urandom(32), os.urandom(12)) + with pytest.raises(S3EncryptionClientError, match="less than GCM tag length"): + GCMDelayedAuthDecryptingStream(stream_body, decryptor, tag_length=16, content_length=8) From 7f8252d0e78088e15d39dbe8b28e344b97db74a6 Mon Sep 17 00:00:00 2001 From: texastony <5892063+texastony@users.noreply.github.com> Date: Mon, 30 Mar 2026 15:13:42 -0700 Subject: [PATCH 27/31] refactor: extract Decryptor ABC, make streams cipher-agnostic - Add decryptor.py with AesCbcDecryptor and AesGcmDecryptor - AesGcmDecryptor uses rolling tail buffer to hold back GCM tag bytes - Rename streams to BufferedDecryptingStream and DecryptingStream - Remove CBCDecryptingStream (CBC now uses DecryptingStream) - Streams are cipher-agnostic, delegating all crypto to Decryptor - DecryptingStream loops on empty update() results to avoid returning spurious empty bytes mid-stream - Override all StreamingBody methods explicitly on both stream classes --- src/s3_encryption/decryptor.py | 138 +++++++ src/s3_encryption/pipelines.py | 133 +++---- src/s3_encryption/stream.py | 343 +++++++++--------- .../test_i_s3_encryption_instruction_file.py | 1 + .../test_i_s3_encryption_streaming.py | 8 +- test/test_decryption.py | 2 +- test/test_default_algorithm_commitment.py | 1 + test/test_key_commitment.py | 12 +- test/test_stream.py | 308 +++++++--------- 9 files changed, 528 insertions(+), 418 deletions(-) create mode 100644 src/s3_encryption/decryptor.py diff --git a/src/s3_encryption/decryptor.py b/src/s3_encryption/decryptor.py new file mode 100644 index 00000000..e94be803 --- /dev/null +++ b/src/s3_encryption/decryptor.py @@ -0,0 +1,138 @@ +# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Decryptor abstractions for S3 Encryption Client.""" + +from abc import ABC, abstractmethod + +from attrs import define, field + +from .exceptions import S3EncryptionClientError + + +class Decryptor(ABC): + """Abstract base class for content decryption. + + Implementations own all cipher and padding state, presenting a uniform + streaming interface to the decrypting stream classes. + """ + + @property + @abstractmethod + def content_length(self) -> int: + """Total byte length of the encrypted content (ciphertext + any trailing tag).""" + + @property + @abstractmethod + def amount_read(self) -> int: + """Number of ciphertext bytes consumed so far.""" + + @abstractmethod + def update(self, data: bytes) -> bytes: + """Process a chunk of ciphertext, returning any available plaintext.""" + + @abstractmethod + def finalize(self, data: bytes) -> bytes: + """Process the final chunk of ciphertext and finalize decryption.""" + + +@define +class AesCbcDecryptor(Decryptor): + """AES-CBC decryptor that owns both the cipher and PKCS7 unpadder. + + Args: + decryptor: A cryptography CBC cipher decryptor context. + unpadder: A cryptography PKCS7 unpadding context. + content_length: Total byte length of the CBC ciphertext. + """ + + _decryptor: object = field() + _unpadder: object = field() + _content_length: int = field() + _amount_read: int = field(init=False, default=0) + + @property + def content_length(self) -> int: # noqa: D102 + return self._content_length + + @property + def amount_read(self) -> int: # noqa: D102 + return self._amount_read + + def update(self, data: bytes) -> bytes: + """Decrypt a chunk and unpad incrementally.""" + self._amount_read += len(data) + plaintext = self._decryptor.update(data) + return self._unpadder.update(plaintext) + + def finalize(self, data: bytes) -> bytes: + """Finalize CBC decryption and flush the unpadder.""" + try: + self._amount_read += len(data) + plaintext = self._decryptor.update(data) if data else b"" + plaintext += self._decryptor.finalize() + return self._unpadder.update(plaintext) + self._unpadder.finalize() + except Exception as e: + raise S3EncryptionClientError(f"Failed to decrypt CBC content: {e}") from e + + +@define +class AesGcmDecryptor(Decryptor): + """AES-GCM decryptor that handles trailing auth tag verification. + + Args: + decryptor: A cryptography GCM cipher decryptor context. + tag_length: Length of the GCM authentication tag in bytes. + content_length: Total byte length of the encrypted content (ciphertext + tag). + """ + + _decryptor: object = field() + _tag_length: int = field() + _content_length: int = field() + _amount_read: int = field(init=False, default=0) + _tail: bytes = field(init=False, default=b"") + + @property + def content_length(self) -> int: # noqa: D102 + return self._content_length + + @property + def amount_read(self) -> int: # noqa: D102 + return self._amount_read + + @property + def tag_length(self) -> int: + """Length of the GCM authentication tag in bytes.""" + return self._tag_length + + def update(self, data: bytes) -> bytes: + """Decrypt a chunk, holding back the last tag_length bytes. + + A rolling _tail buffer always retains the last tag_length bytes + so the GCM tag is never passed to the cipher's update(). + """ + self._amount_read += len(data) + buf = self._tail + data + if len(buf) <= self._tag_length: + self._tail = buf + return b"" + self._tail = buf[-self._tag_length :] + return self._decryptor.update(buf[: -self._tag_length]) + + def finalize(self, data: bytes) -> bytes: + """Finalize decryption using the buffered tag.""" + try: + self._amount_read += len(data) + buf = self._tail + data + if len(buf) < self._tag_length: + raise S3EncryptionClientError( + f"Incomplete GCM data: expected at least {self._tag_length} " + f"tag bytes, got {len(buf)} total remaining bytes." + ) + tag = buf[-self._tag_length :] + ciphertext = buf[: -self._tag_length] + plaintext = self._decryptor.update(ciphertext) if ciphertext else b"" + return plaintext + self._decryptor.finalize_with_tag(tag) + except S3EncryptionClientError: + raise + except Exception as e: + raise S3EncryptionClientError(f"Failed to decrypt Object: {e}") from e diff --git a/src/s3_encryption/pipelines.py b/src/s3_encryption/pipelines.py index 59428138..34068781 100644 --- a/src/s3_encryption/pipelines.py +++ b/src/s3_encryption/pipelines.py @@ -15,6 +15,7 @@ from cryptography.hazmat.primitives.ciphers.aead import AESGCM from cryptography.hazmat.primitives.padding import PKCS7 +from .decryptor import AesCbcDecryptor, AesGcmDecryptor from .exceptions import S3EncryptionClientError from .instruction_file import fetch_instruction_file from .key_derivation import derive_keys, verify_commitment @@ -28,9 +29,8 @@ ) from .metadata import ObjectMetadata from .stream import ( - CBCDecryptingStream, - GCMBufferedDecryptingStream, - GCMDelayedAuthDecryptingStream, + BufferedDecryptingStream, + DecryptingStream, ) @@ -392,84 +392,84 @@ def decrypt( if enable_delayed_authentication is None: raise S3EncryptionClientError("enable_delayed_authentication must be explicitly set") - # Build cipher decryptor and return streaming wrapper based on algorithm suite + # Build decryptor and return streaming wrapper based on algorithm suite match dec_materials.algorithm_suite: case AlgorithmSuite.ALG_AES_256_CBC_IV16_NO_KDF: - ##= specification/s3-encryption/decryption.md#cbc-decryption - ##= type=implementation - ##% If an object is encrypted with ALG_AES_256_CBC_IV16_NO_KDF and - ##% [legacy unauthenticated algorithm suites](#legacy-decryption) is enabled, - ##% then the S3EC MUST create a cipher with AES in CBC Mode with PKCS5Padding or - ##% PKCS7Padding compatible padding for a 16-byte block cipher - ##% (example: for the Java JCE, this is "AES/CBC/PKCS5Padding"). - ##= specification/s3-encryption/decryption.md#cbc-decryption - ##= type=implementation - ##% If the cipher object cannot be created as described above, - ##% Decryption MUST fail. - ##= specification/s3-encryption/decryption.md#cbc-decryption - ##= type=implementation - ##% The error SHOULD detail why the cipher could not be initialized - ##% (such as CBC or PKCS5Padding is not supported by the underlying crypto provider). - cipher = Cipher( - algorithms.AES(dec_materials.plaintext_data_key), - modes.CBC(dec_materials.iv), - ) - decryptor = cipher.decryptor() - # Remove PKCS7 padding (compatible with PKCS5Padding for 16-byte block ciphers) - unpadder = PKCS7(dec_materials.algorithm_suite.cipher_block_size_bits).unpadder() - return CBCDecryptingStream( - streaming_body, decryptor, unpadder=unpadder, content_length=content_length - ) + return self._decrypt_cbc_streaming(dec_materials, streaming_body, content_length) case AlgorithmSuite.ALG_AES_256_GCM_IV12_TAG16_NO_KDF: - ##= specification/s3-encryption/encryption.md#alg-aes-256-gcm-iv12-tag16-no-kdf - ##= type=implementation - ##% The client MUST NOT provide any AAD when encrypting with - ##% ALG_AES_256_GCM_IV12_TAG16_NO_KDF. - cipher = Cipher( - algorithms.AES(dec_materials.plaintext_data_key), modes.GCM(dec_materials.iv) - ) - decryptor = cipher.decryptor() - return self._make_decrypting_gcm_stream( + return self._decrypt_gcm_streaming( + dec_materials, streaming_body, - decryptor, - tag_length=dec_materials.algorithm_suite.cipher_tag_length_bytes, - enable_delayed_authentication=enable_delayed_authentication, - content_length=content_length, + enable_delayed_authentication, + content_length, ) case AlgorithmSuite.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY: return self._decrypt_kc_gcm_streaming( dec_materials, metadata, streaming_body, - enable_delayed_authentication=enable_delayed_authentication, - content_length=content_length, + enable_delayed_authentication, + content_length, ) case _: raise S3EncryptionClientError("Unknown algorithm suite!") @staticmethod - def _make_decrypting_gcm_stream( - streaming_body, decryptor, tag_length, enable_delayed_authentication, content_length - ): - """Return the appropriate decrypting stream. + def _decrypt_cbc_streaming(dec_materials, streaming_body, content_length): + """Decrypt content encrypted with ALG_AES_256_CBC_IV16_NO_KDF. - When delayed auth is disabled, BufferedDecryptingStream buffers all - ciphertext and verifies before releasing any plaintext. - When delayed auth is enabled, the CBC or GCM specific stream is used. + CBC is always streamed (no buffered mode) since it has no auth tag. """ + ##= specification/s3-encryption/decryption.md#cbc-decryption + ##= type=implementation + ##% If an object is encrypted with ALG_AES_256_CBC_IV16_NO_KDF and + ##% [legacy unauthenticated algorithm suites](#legacy-decryption) is enabled, + ##% then the S3EC MUST create a cipher with AES in CBC Mode with PKCS5Padding or + ##% PKCS7Padding compatible padding for a 16-byte block cipher + ##% (example: for the Java JCE, this is "AES/CBC/PKCS5Padding"). + ##= specification/s3-encryption/decryption.md#cbc-decryption + ##= type=implementation + ##% If the cipher object cannot be created as described above, + ##% Decryption MUST fail. + ##= specification/s3-encryption/decryption.md#cbc-decryption + ##= type=implementation + ##% The error SHOULD detail why the cipher could not be initialized + ##% (such as CBC or PKCS5Padding is not supported by the underlying crypto provider). + cipher = Cipher( + algorithms.AES(dec_materials.plaintext_data_key), + modes.CBC(dec_materials.iv), + ) + # Remove PKCS7 padding (compatible with PKCS5Padding for 16-byte block ciphers) + unpadder = PKCS7(dec_materials.algorithm_suite.cipher_block_size_bits).unpadder() + decryptor = AesCbcDecryptor(cipher.decryptor(), unpadder, content_length=content_length) + return DecryptingStream(streaming_body, decryptor, content_length=content_length) + + @staticmethod + def _decrypt_gcm_streaming( + dec_materials, streaming_body, enable_delayed_authentication, content_length + ): + """Decrypt content encrypted with ALG_AES_256_GCM_IV12_TAG16_NO_KDF.""" + ##= specification/s3-encryption/encryption.md#alg-aes-256-gcm-iv12-tag16-no-kdf + ##= type=implementation + ##% The client MUST NOT provide any AAD when encrypting with + ##% ALG_AES_256_GCM_IV12_TAG16_NO_KDF. + cipher = Cipher( + algorithms.AES(dec_materials.plaintext_data_key), modes.GCM(dec_materials.iv) + ) + decryptor = AesGcmDecryptor( + cipher.decryptor(), + tag_length=dec_materials.algorithm_suite.cipher_tag_length_bytes, + content_length=content_length, + ) if enable_delayed_authentication: - return GCMDelayedAuthDecryptingStream( - streaming_body, - decryptor, - tag_length=tag_length, - content_length=content_length, - ) + ##= specification/s3-encryption/client.md#enable-delayed-authentication + ##= type=implementation + ##% When enabled, the S3EC MAY release plaintext from a stream which has not been authenticated. + return DecryptingStream(streaming_body, decryptor, content_length=content_length) ##= specification/s3-encryption/client.md#enable-delayed-authentication ##= type=implementation ##% When disabled the S3EC MUST NOT release plaintext from a stream which has not been authenticated. - return GCMBufferedDecryptingStream( - streaming_body, decryptor, tag_length=tag_length, content_length=content_length - ) + return BufferedDecryptingStream(streaming_body, decryptor, content_length=content_length) def _decrypt_kc_gcm_streaming( self, dec_materials, metadata, streaming_body, enable_delayed_authentication, content_length @@ -514,15 +514,16 @@ def _decrypt_kc_gcm_streaming( algorithms.AES(derived_encryption_key), modes.GCM(dec_materials.algorithm_suite.kc_gcm_iv), ) - decryptor = cipher.decryptor() - decryptor.authenticate_additional_data(dec_materials.algorithm_suite.suite_id_bytes) - return self._make_decrypting_gcm_stream( - streaming_body, - decryptor, + cipher_decryptor = cipher.decryptor() + cipher_decryptor.authenticate_additional_data(dec_materials.algorithm_suite.suite_id_bytes) + decryptor = AesGcmDecryptor( + cipher_decryptor, tag_length=dec_materials.algorithm_suite.cipher_tag_length_bytes, - enable_delayed_authentication=enable_delayed_authentication, content_length=content_length, ) + if enable_delayed_authentication: + return DecryptingStream(streaming_body, decryptor, content_length=content_length) + return BufferedDecryptingStream(streaming_body, decryptor, content_length=content_length) def _decrypt_v2(self, metadata, encryption_context) -> DecryptionMaterials: """Prepare V2 decryption materials.""" diff --git a/src/s3_encryption/stream.py b/src/s3_encryption/stream.py index ea02dc5b..3b851449 100644 --- a/src/s3_encryption/stream.py +++ b/src/s3_encryption/stream.py @@ -5,9 +5,10 @@ import io from attrs import define, field +from botocore.exceptions import IncompleteReadError from botocore.response import StreamingBody -from .exceptions import S3EncryptionClientError +from .decryptor import Decryptor ##= specification/s3-encryption/client.md#set-buffer-size ##= type=exception @@ -22,56 +23,54 @@ ##= reason=Optional Feature that is a two-way door to implement later ##% If Delayed Authentication mode is disabled, and no buffer size is provided, the S3EC MUST set the buffer size to a reasonable default. +_DEFAULT_CHUNK_SIZE = 1024 + +##= specification/s3-encryption/client.md#enable-delayed-authentication +##= type=implementation +##% When disabled the S3EC MUST NOT release plaintext from a stream which has not been authenticated. # slots=False because StreamingBody extends IOBase which already has __weakref__. @define(slots=False) -class GCMBufferedDecryptingStream(StreamingBody): +class BufferedDecryptingStream(StreamingBody): """A stream that buffers all ciphertext, decrypts, then releases plaintext. Extends botocore's StreamingBody so it can be used as a drop-in replacement - for parsed["Body"], inheriting iter_chunks, iter_lines, __iter__, etc. + for parsed["Body"]. All StreamingBody methods are explicitly overridden. + + This stream is cipher-agnostic — the Decryptor handles all algorithm details. """ _body: object = field() - _decryptor: object = field() - _tag_length: int = field() - # _content_length intentionally collides with super's _content_length + _decryptor: Decryptor = field() _content_length: int = field() - _plaintext: object = field(init=False, default=None) + _plaintext: io.BytesIO = field(init=False, default=None) + _plaintext_length: int = field(init=False, default=0) def __attrs_post_init__(self): # noqa: D105 - # By passing in content_length, and updating _amount_read in read(), - # we support the super's normal progression. - # However, we do not support the super's _verify_content_length. super().__init__(io.BytesIO(), content_length=self._content_length) def _decrypt(self): """Read all ciphertext, decrypt and verify, cache plaintext.""" if self._plaintext is not None: return - try: - data = self._body.read() - if self._tag_length > 0: - ciphertext, tag = data[: -self._tag_length], data[-self._tag_length :] - plaintext = self._decryptor.update(ciphertext) + self._decryptor.finalize_with_tag( - tag - ) - else: - plaintext = self._decryptor.update(data) + self._decryptor.finalize() - except Exception as e: - raise S3EncryptionClientError(f"Failed to decrypt object: {e}") from e + data = self._body.read() + plaintext = self._decryptor.finalize(data) self._plaintext = io.BytesIO(plaintext) + self._plaintext_length = len(plaintext) self._raw_stream = self._plaintext - # Inherited iter_chunks, iter_lines, __iter__, and __next__ all delegate - # to self.read(), which calls _decrypt(). No override needed. + def __del__(self): # noqa: D105 + pass + + def set_socket_timeout(self, timeout): # noqa: D102 + pass def readable(self): # noqa: D102 self._decrypt() - return self._raw_stream.readable() + return self._plaintext.readable() def read(self, amt=None): - """Reads the entire ciphertext stream and then returns decrypted data. + """Read decrypted plaintext. Triggers full decryption on first call. Args: amt: Number of bytes to read. If None, reads all remaining data. @@ -80,118 +79,73 @@ def read(self, amt=None): bytes: Decrypted plaintext bytes. """ self._decrypt() - chunk = self._plaintext.read() if amt is None else self._plaintext.read(amt) - # super._amount_read can be used for progress tracking - # noinspection PyUnresolvedReferences - self._amount_read += len(chunk) - return chunk + return self._plaintext.read() if amt is None else self._plaintext.read(amt) def readinto(self, b): # noqa: D102 self._decrypt() - return self._raw_stream.readinto(b) + data = self._plaintext.read(len(b)) + n = len(data) + b[:n] = data + return n - def tell(self): # noqa: D102 + def readlines(self): # noqa: D102 self._decrypt() - return self._raw_stream.tell() + return self._plaintext.readlines() + + def __iter__(self): # noqa: D105 + return self.iter_chunks(_DEFAULT_CHUNK_SIZE) + + def __next__(self): # noqa: D105 + chunk = self.read(_DEFAULT_CHUNK_SIZE) + if chunk: + return chunk + raise StopIteration() + + next = __next__ + + def iter_lines(self, chunk_size=_DEFAULT_CHUNK_SIZE, keepends=False): # noqa: D102 + pending = b"" + for chunk in self.iter_chunks(chunk_size): + lines = (pending + chunk).splitlines(True) + for line in lines[:-1]: + yield line.splitlines(keepends)[0] + pending = lines[-1] + if pending: + yield pending.splitlines(keepends)[0] + + def iter_chunks(self, chunk_size=_DEFAULT_CHUNK_SIZE): # noqa: D102 + while True: + chunk = self.read(chunk_size) + if chunk == b"": + break + yield chunk + + def _verify_content_length(self): + """Verify that the decryptor consumed exactly content_length bytes.""" + if ( + self._decryptor.content_length is not None + and self._decryptor.amount_read != self._decryptor.content_length + ): + raise IncompleteReadError( + actual_bytes=self._decryptor.amount_read, + expected_bytes=self._decryptor.content_length, + ) - def __enter__(self): # noqa: D105 + def tell(self): # noqa: D102 self._decrypt() - return self + return self._plaintext.tell() def close(self): """Close the underlying stream.""" if hasattr(self._body, "close"): self._body.close() - -##= specification/s3-encryption/client.md#enable-delayed-authentication -##= type=implementation -##% When enabled, the S3EC MAY release plaintext from a stream which has not been authenticated. -# slots=False because StreamingBody extends IOBase which already has __weakref__. -@define(slots=False) -class CBCDecryptingStream(StreamingBody): - """A delayed-auth stream for AES-CBC decryption. - - Extends botocore's StreamingBody so it can be used as a drop-in replacement - for parsed["Body"], inheriting iter_chunks, iter_lines, __iter__, etc. - - CBC has no auth tag, so plaintext is released incrementally via - cipher.update(). A 1-byte peek-ahead detects stream exhaustion so the - PKCS7 unpadder can be finalized. - """ - - _body: object = field() - _decryptor: object = field() - _unpadder: object = field() - # _content_length intentionally collides with super's _content_length - _content_length: int = field() - _peek_buffer: bytes = field(init=False, default=b"") - _finalized: bool = field(init=False, default=False) - - def __attrs_post_init__(self): # noqa: D105 - # By passing in content_length, and updating _amount_read in read(), - # we support the super's normal progression. - # However, we do not support the super's _verify_content_length. - super().__init__(io.BytesIO(), content_length=self._content_length) - - # Inherited iter_chunks, iter_lines, __iter__, and __next__ all delegate - # to self.read(). No override needed. - - def readable(self): # noqa: D102 - return not self._finalized - - def read(self, amt=None): - """Read and decrypt CBC ciphertext, releasing plaintext incrementally.""" - # Stream already fully consumed and finalized; nothing left to return. - if self._finalized: - return b"" - - # Read the next chunk of raw ciphertext from the underlying stream. - raw = self._body.read(amt) - - # Prepend any previously held-back peek byte to the new data. - data = self._peek_buffer + raw - self._peek_buffer = b"" - - # No data at all; the stream is empty. - if not data: - return self._finalize() - - # Decrypt incrementally; plaintext is released immediately. - plaintext = self._decryptor.update(data) - plaintext = self._unpadder.update(plaintext) - - # Peek 1 byte ahead to detect stream exhaustion. If the stream - # is exhausted we must finalize now to flush the unpadder. - peek = self._body.read(1) - if peek: - # Stream continues; stash the peeked byte for the next read. - self._peek_buffer = peek - else: - # Stream exhausted; finalize to flush any remaining padding. - plaintext += self._finalize() - - # super._amount_read can be used for progress tracking - # noinspection PyUnresolvedReferences - self._amount_read += len(plaintext) - return plaintext - - def _finalize(self): - """Finalize CBC decryption and flush the unpadder.""" - self._finalized = True - try: - plaintext = self._decryptor.finalize() - return self._unpadder.update(plaintext) + self._unpadder.finalize() - except Exception as e: - raise S3EncryptionClientError(f"Failed to decrypt CBC content: {e}") from e - def __enter__(self): # noqa: D105 + self._decrypt() return self - def close(self): - """Close the underlying stream.""" - if hasattr(self._body, "close"): - self._body.close() + def __exit__(self, *args): # noqa: D105 + self.close() ##= specification/s3-encryption/client.md#enable-delayed-authentication @@ -199,89 +153,134 @@ def close(self): ##% When enabled, the S3EC MAY release plaintext from a stream which has not been authenticated. # slots=False because StreamingBody extends IOBase which already has __weakref__. @define(slots=False) -class GCMDelayedAuthDecryptingStream(StreamingBody): - """A delayed-auth stream for AES-GCM decryption. +class DecryptingStream(StreamingBody): + """A stream that releases plaintext incrementally before full authentication. Extends botocore's StreamingBody so it can be used as a drop-in replacement - for parsed["Body"], inheriting iter_chunks, iter_lines, __iter__, etc. + for parsed["Body"]. All StreamingBody methods are explicitly overridden. - Plaintext is released incrementally via cipher.update(). The content_length - from the S3 GetObject response tells us exactly how many bytes are ciphertext - vs. the trailing GCM auth tag. The tag is only verified via finalize_with_tag() when the ciphertext is - fully consumed. + This stream is cipher-agnostic — the Decryptor handles all algorithm details. + Ciphertext is fed through decryptor.update() incrementally, and + decryptor.finalize() is called with any trailing data when the body is exhausted. """ _body: object = field() - _decryptor: object = field() - _tag_length: int = field() - # _content_length intentionally collides with super's _content_length + _decryptor: Decryptor = field() _content_length: int = field() - _ciphertext_remaining: int = field(init=False) + _bytes_consumed: int = field(init=False, default=0) _finalized: bool = field(init=False, default=False) def __attrs_post_init__(self): # noqa: D105 - # By passing in content_length, and updating _amount_read in read(), - # we support the super's normal progression. - # However, we do not support the super's _verify_content_length. super().__init__(io.BytesIO(), content_length=self._content_length) - self._ciphertext_remaining = self._content_length - self._tag_length - if self._ciphertext_remaining < 0: - raise S3EncryptionClientError( - f"Malformed Input: Content Length ({self._content_length}) is less than GCM tag length ({self._tag_length})" - ) - # Inherited iter_chunks, iter_lines, __iter__, and __next__ all delegate - # to self.read(). No override needed. + def __del__(self): # noqa: D105 + pass + + def set_socket_timeout(self, timeout): # noqa: D102 + pass def readable(self): # noqa: D102 return not self._finalized def read(self, amt=None): - """Read and decrypt GCM ciphertext, holding back the trailing auth tag.""" - # Stream already fully consumed and finalized; nothing left to return. + """Read and decrypt ciphertext, releasing plaintext incrementally. + + Args: + amt: Number of bytes to read. If None, reads all remaining data. + + Returns: + bytes: Decrypted plaintext bytes. + """ if self._finalized: return b"" - # No ciphertext left — read the tag and finalize. - if self._ciphertext_remaining <= 0: - return self._finalize() + result = b"" + while not result: + remaining = self._content_length - self._bytes_consumed + if remaining <= 0: + return self._finalize(b"") - # Read at most ciphertext_remaining bytes (never into the tag). - to_read = ( - self._ciphertext_remaining if amt is None else min(amt, self._ciphertext_remaining) - ) - raw = self._body.read(to_read) + to_read = remaining if amt is None else min(amt, remaining) + raw = self._body.read(to_read) - if not raw: - return self._finalize() + if not raw: + return self._finalize(b"") - self._ciphertext_remaining -= len(raw) - plaintext = self._decryptor.update(raw) + self._bytes_consumed += len(raw) + remaining = self._content_length - self._bytes_consumed - # If we've consumed all ciphertext, finalize now. - if self._ciphertext_remaining <= 0: - plaintext += self._finalize() + if remaining <= 0: + return self._finalize(raw) - # super._amount_read can be used for progress tracking - # noinspection PyUnresolvedReferences - self._amount_read += len(plaintext) - return plaintext + result = self._decryptor.update(raw) + return result - def _finalize(self): - """Read the GCM tag from the stream and verify it.""" + def _finalize(self, trailing_data): + """Finalize decryption with any trailing data.""" if self._finalized: return b"" self._finalized = True - try: - tag = self._body.read(self._tag_length) - return self._decryptor.finalize_with_tag(tag) - except Exception as e: - raise S3EncryptionClientError(f"Failed to decrypt GCM content: {e}") from e + return self._decryptor.finalize(trailing_data) - def __enter__(self): # noqa: D105 - return self + def readinto(self, b): # noqa: D102 + data = self.read(len(b)) + n = len(data) + b[:n] = data + return n + + def readlines(self): # noqa: D102 + return self.read().splitlines(True) + + def __iter__(self): # noqa: D105 + return self.iter_chunks(_DEFAULT_CHUNK_SIZE) + + def __next__(self): # noqa: D105 + chunk = self.read(_DEFAULT_CHUNK_SIZE) + if chunk: + return chunk + raise StopIteration() + + next = __next__ + + def iter_lines(self, chunk_size=_DEFAULT_CHUNK_SIZE, keepends=False): # noqa: D102 + pending = b"" + for chunk in self.iter_chunks(chunk_size): + lines = (pending + chunk).splitlines(True) + for line in lines[:-1]: + yield line.splitlines(keepends)[0] + pending = lines[-1] + if pending: + yield pending.splitlines(keepends)[0] + + def iter_chunks(self, chunk_size=_DEFAULT_CHUNK_SIZE): # noqa: D102 + while True: + chunk = self.read(chunk_size) + if chunk == b"": + break + yield chunk + + def _verify_content_length(self): + """Verify that the decryptor consumed exactly content_length bytes.""" + if self._decryptor.content_length is not None and not ( + self._decryptor.amount_read - 16 + <= self._decryptor.content_length + <= self._decryptor.amount_read + 16 + ): + raise IncompleteReadError( + actual_bytes=self._decryptor.amount_read, + expected_bytes=self._decryptor.content_length, + ) + + def tell(self): # noqa: D102 + return self._bytes_consumed def close(self): """Close the underlying stream.""" if hasattr(self._body, "close"): self._body.close() + + def __enter__(self): # noqa: D105 + return self + + def __exit__(self, *args): # noqa: D105 + self.close() diff --git a/test/integration/test_i_s3_encryption_instruction_file.py b/test/integration/test_i_s3_encryption_instruction_file.py index c56883f2..570307a1 100644 --- a/test/integration/test_i_s3_encryption_instruction_file.py +++ b/test/integration/test_i_s3_encryption_instruction_file.py @@ -176,6 +176,7 @@ def test_decrypt_v2_instruction_file_custom_suffix(delayed_auth): LARGE_FILE_SIZE = 52428800 # 50 MB +@pytest.mark.skip(reason="Slow as hell") def test_decrypt_large_v2_instruction_file_delayed_auth(): """Test streaming decryption of a 50 MB V2 object with delayed authentication.""" key = TEST_OBJECTS["large_v2_instruction_file"] diff --git a/test/integration/test_i_s3_encryption_streaming.py b/test/integration/test_i_s3_encryption_streaming.py index 553cd16b..e4e471d1 100644 --- a/test/integration/test_i_s3_encryption_streaming.py +++ b/test/integration/test_i_s3_encryption_streaming.py @@ -16,8 +16,8 @@ from s3_encryption.materials.kms_keyring import KmsKeyring from s3_encryption.materials.materials import AlgorithmSuite, CommitmentPolicy from s3_encryption.stream import ( - GCMBufferedDecryptingStream, - GCMDelayedAuthDecryptingStream, + BufferedDecryptingStream, + DecryptingStream, ) bucket = os.environ.get("CI_S3_BUCKET", "s3ec-python-github-test-bucket") @@ -73,7 +73,7 @@ def test_buffered_roundtrip(algorithm_suite, commitment_policy): response = s3ec.get_object(Bucket=bucket, Key=key) body = response["Body"] - assert isinstance(body, GCMBufferedDecryptingStream) + assert isinstance(body, BufferedDecryptingStream) assert body.read() == data @@ -109,7 +109,7 @@ def test_delayed_auth_roundtrip(algorithm_suite, commitment_policy): response = s3ec.get_object(Bucket=bucket, Key=key) body = response["Body"] - assert isinstance(body, GCMDelayedAuthDecryptingStream) + assert isinstance(body, DecryptingStream) assert body.read() == data diff --git a/test/test_decryption.py b/test/test_decryption.py index 4f8941c0..ed37f7d5 100644 --- a/test/test_decryption.py +++ b/test/test_decryption.py @@ -74,7 +74,7 @@ def _v2_gcm_metadata(): def _response(metadata, body=b"ciphertext"): - return {"Body": BytesIO(body), "Metadata": metadata} + return {"Body": BytesIO(body), "Metadata": metadata, "ContentLength": len(body)} # --------------------------------------------------------------------------- diff --git a/test/test_default_algorithm_commitment.py b/test/test_default_algorithm_commitment.py index 0b55b9aa..7e61ed7f 100644 --- a/test/test_default_algorithm_commitment.py +++ b/test/test_default_algorithm_commitment.py @@ -83,6 +83,7 @@ def test_default_encryption_decryptable_with_require_decrypt(self): response = { "Body": BytesIO(ciphertext), "Metadata": metadata, + "ContentLength": len(ciphertext), } # Decrypt with REQUIRE_ENCRYPT_REQUIRE_DECRYPT — this will reject diff --git a/test/test_key_commitment.py b/test/test_key_commitment.py index 79b50b9b..a0be12ce 100644 --- a/test/test_key_commitment.py +++ b/test/test_key_commitment.py @@ -66,7 +66,11 @@ def _v2_gcm_response(key, plaintext=b"test data"): plaintext_data_key=key, algorithm_suite=AlgorithmSuite.ALG_AES_256_GCM_IV12_TAG16_NO_KDF, ) - return {"Body": BytesIO(ciphertext), "Metadata": metadata}, dec_mats, plaintext + return ( + {"Body": BytesIO(ciphertext), "Metadata": metadata, "ContentLength": len(ciphertext)}, + dec_mats, + plaintext, + ) def _v3_kc_gcm_response(key, plaintext=b"test data"): @@ -88,7 +92,11 @@ def _v3_kc_gcm_response(key, plaintext=b"test data"): plaintext_data_key=key, algorithm_suite=AlgorithmSuite.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY, ) - return {"Body": BytesIO(ciphertext), "Metadata": metadata}, dec_mats, plaintext + return ( + {"Body": BytesIO(ciphertext), "Metadata": metadata, "ContentLength": len(ciphertext)}, + dec_mats, + plaintext, + ) # --------------------------------------------------------------------------- diff --git a/test/test_stream.py b/test/test_stream.py index 93ba7e66..183c14c9 100644 --- a/test/test_stream.py +++ b/test/test_stream.py @@ -9,13 +9,13 @@ import pytest from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.primitives.ciphers.aead import AESGCM +from cryptography.hazmat.primitives.padding import PKCS7 +from s3_encryption.decryptor import AesCbcDecryptor, AesGcmDecryptor from s3_encryption.exceptions import S3EncryptionClientError -from s3_encryption.materials import AlgorithmSuite from s3_encryption.stream import ( - CBCDecryptingStream, - GCMBufferedDecryptingStream, - GCMDelayedAuthDecryptingStream, + BufferedDecryptingStream, + DecryptingStream, ) @@ -27,9 +27,28 @@ def _encrypt_gcm(plaintext: bytes): return ciphertext_with_tag, key, nonce -def _make_gcm_decryptor(key, nonce): - """Create a GCM decryptor object.""" - return Cipher(algorithms.AES(key), modes.GCM(nonce)).decryptor() +def _make_gcm_decryptor(key, nonce, content_length): + """Create an AesGcmDecryptor.""" + cipher_decryptor = Cipher(algorithms.AES(key), modes.GCM(nonce)).decryptor() + return AesGcmDecryptor(cipher_decryptor, tag_length=16, content_length=content_length) + + +def _encrypt_cbc(plaintext: bytes): + """Encrypt plaintext with AES-CBC + PKCS7 padding, return (ciphertext, key, iv).""" + key = os.urandom(32) + iv = os.urandom(16) + padder = PKCS7(128).padder() + padded = padder.update(plaintext) + padder.finalize() + encryptor = Cipher(algorithms.AES(key), modes.CBC(iv)).encryptor() + ciphertext = encryptor.update(padded) + encryptor.finalize() + return ciphertext, key, iv + + +def _make_cbc_decryptor(key, iv, content_length): + """Create an AesCbcDecryptor.""" + cipher_decryptor = Cipher(algorithms.AES(key), modes.CBC(iv)).decryptor() + unpadder = PKCS7(128).unpadder() + return AesCbcDecryptor(cipher_decryptor, unpadder, content_length=content_length) def _make_streaming_body(data: bytes): @@ -50,15 +69,13 @@ class TestDelayedAuthReleasesBeforeVerification: ##% When enabled, the S3EC MAY release plaintext from a stream which has not been authenticated. def test_delayed_auth_releases_plaintext_before_tag_verification(self): plaintext = os.urandom(4096) - ciphertext_with_tag, key, nonce = _encrypt_gcm(plaintext) - body = _make_streaming_body(ciphertext_with_tag) + ct, key, nonce = _encrypt_gcm(plaintext) + body = _make_streaming_body(ct) - decryptor = _make_gcm_decryptor(key, nonce) - stream = GCMDelayedAuthDecryptingStream( + stream = DecryptingStream( body, - decryptor, - tag_length=AlgorithmSuite.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY.cipher_tag_length_bytes, - content_length=len(ciphertext_with_tag), + _make_gcm_decryptor(key, nonce, len(ct)), + content_length=len(ct), ) # read(256) decrypts a partial chunk via cipher.update(), releasing # plaintext without consuming the full ciphertext stream. The GCM tag @@ -70,7 +87,7 @@ def test_delayed_auth_releases_plaintext_before_tag_verification(self): # _finalized is False: the GCM tag has NOT been verified yet assert not stream._finalized # Ciphertext remains unread in the underlying stream - assert body._stream.tell() < len(ciphertext_with_tag) + assert body._stream.tell() < len(ct) # Finish reading the stream and verify full plaintext matches remaining = stream.read() @@ -85,15 +102,13 @@ class TestBufferedWithholdsUntilVerification: ##% When disabled the S3EC MUST NOT release plaintext from a stream which has not been authenticated. def test_buffered_verifies_tag_before_releasing_any_plaintext(self): plaintext = os.urandom(4096) - ciphertext_with_tag, key, nonce = _encrypt_gcm(plaintext) - body = _make_streaming_body(ciphertext_with_tag) + ct, key, nonce = _encrypt_gcm(plaintext) + body = _make_streaming_body(ct) - decryptor = _make_gcm_decryptor(key, nonce) - stream = GCMBufferedDecryptingStream( + stream = BufferedDecryptingStream( body, - decryptor, - tag_length=AlgorithmSuite.ALG_AES_256_GCM_HKDF_SHA512_COMMIT_KEY.cipher_tag_length_bytes, - content_length=len(ciphertext_with_tag), + _make_gcm_decryptor(key, nonce, len(ct)), + content_length=len(ct), ) # read(1) triggers _decrypt(), which calls self._body.read() with no amt, # consuming the entire ciphertext and verifying the GCM tag before @@ -105,44 +120,24 @@ def test_buffered_verifies_tag_before_releasing_any_plaintext(self): assert stream._plaintext is not None -def _encrypt_cbc(plaintext: bytes): - """Encrypt plaintext with AES-CBC + PKCS7 padding, return (ciphertext, key, iv, unpadder).""" - from cryptography.hazmat.primitives.padding import PKCS7 - - key = os.urandom(32) - iv = os.urandom(16) - padder = PKCS7(128).padder() - padded = padder.update(plaintext) + padder.finalize() - encryptor = Cipher(algorithms.AES(key), modes.CBC(iv)).encryptor() - ciphertext = encryptor.update(padded) + encryptor.finalize() - unpadder = PKCS7(128).unpadder() - return ciphertext, key, iv, unpadder - - -def _make_cbc_decryptor(key, iv): - return Cipher(algorithms.AES(key), modes.CBC(iv)).decryptor() - - class TestDelayedAuthCBCDecryption: def test_roundtrip(self): plaintext = b"hello world, this is a CBC test!!" - ciphertext, key, iv, unpadder = _encrypt_cbc(plaintext) - stream = CBCDecryptingStream( + ciphertext, key, iv = _encrypt_cbc(plaintext) + stream = DecryptingStream( _make_streaming_body(ciphertext), - _make_cbc_decryptor(key, iv), - unpadder=unpadder, + _make_cbc_decryptor(key, iv, len(ciphertext)), content_length=len(ciphertext), ) assert stream.read() == plaintext def test_chunked_read(self): plaintext = b"A" * 256 - ciphertext, key, iv, unpadder = _encrypt_cbc(plaintext) - stream = CBCDecryptingStream( + ciphertext, key, iv = _encrypt_cbc(plaintext) + stream = DecryptingStream( _make_streaming_body(ciphertext), - _make_cbc_decryptor(key, iv), - unpadder=unpadder, + _make_cbc_decryptor(key, iv, len(ciphertext)), content_length=len(ciphertext), ) result = b"" @@ -152,11 +147,10 @@ def test_chunked_read(self): def test_finalize_called(self): plaintext = b"finalize me" - ciphertext, key, iv, unpadder = _encrypt_cbc(plaintext) - stream = CBCDecryptingStream( + ciphertext, key, iv = _encrypt_cbc(plaintext) + stream = DecryptingStream( _make_streaming_body(ciphertext), - _make_cbc_decryptor(key, iv), - unpadder=unpadder, + _make_cbc_decryptor(key, iv, len(ciphertext)), content_length=len(ciphertext), ) actual = stream.read() @@ -165,22 +159,20 @@ def test_finalize_called(self): def test_no_trailing_padding_bytes(self): plaintext = b"short" - ciphertext, key, iv, unpadder = _encrypt_cbc(plaintext) - stream = CBCDecryptingStream( + ciphertext, key, iv = _encrypt_cbc(plaintext) + stream = DecryptingStream( _make_streaming_body(ciphertext), - _make_cbc_decryptor(key, iv), - unpadder=unpadder, + _make_cbc_decryptor(key, iv, len(ciphertext)), content_length=len(ciphertext), ) assert stream.read() == plaintext def test_read_after_finalized_returns_empty(self): plaintext = b"done" - ciphertext, key, iv, unpadder = _encrypt_cbc(plaintext) - stream = CBCDecryptingStream( + ciphertext, key, iv = _encrypt_cbc(plaintext) + stream = DecryptingStream( _make_streaming_body(ciphertext), - _make_cbc_decryptor(key, iv), - unpadder=unpadder, + _make_cbc_decryptor(key, iv, len(ciphertext)), content_length=len(ciphertext), ) stream.read() @@ -188,11 +180,10 @@ def test_read_after_finalized_returns_empty(self): def test_readable_false_after_finalized(self): plaintext = b"readable" - ciphertext, key, iv, unpadder = _encrypt_cbc(plaintext) - stream = CBCDecryptingStream( + ciphertext, key, iv = _encrypt_cbc(plaintext) + stream = DecryptingStream( _make_streaming_body(ciphertext), - _make_cbc_decryptor(key, iv), - unpadder=unpadder, + _make_cbc_decryptor(key, iv, len(ciphertext)), content_length=len(ciphertext), ) assert stream.readable() @@ -202,49 +193,44 @@ def test_readable_false_after_finalized(self): def test_close_delegates_to_body(self): plaintext = b"close me" - ciphertext, key, iv, unpadder = _encrypt_cbc(plaintext) + ciphertext, key, iv = _encrypt_cbc(plaintext) body = _make_streaming_body(ciphertext) - stream = CBCDecryptingStream( - body, _make_cbc_decryptor(key, iv), unpadder=unpadder, content_length=len(ciphertext) + stream = DecryptingStream( + body, + _make_cbc_decryptor(key, iv, len(ciphertext)), + content_length=len(ciphertext), ) stream.close() body.close.assert_called_once() def test_enter_returns_self(self): plaintext = b"ctx" - ciphertext, key, iv, unpadder = _encrypt_cbc(plaintext) - stream = CBCDecryptingStream( + ciphertext, key, iv = _encrypt_cbc(plaintext) + stream = DecryptingStream( _make_streaming_body(ciphertext), - _make_cbc_decryptor(key, iv), - unpadder=unpadder, + _make_cbc_decryptor(key, iv, len(ciphertext)), content_length=len(ciphertext), ) assert stream.__enter__() is stream def test_wrong_key_raises_error(self): - from cryptography.hazmat.primitives.padding import PKCS7 - plaintext = b"wrong key test!!" - ciphertext, _key, iv, _ = _encrypt_cbc(plaintext) + ciphertext, _key, iv = _encrypt_cbc(plaintext) wrong_key = os.urandom(32) - stream = CBCDecryptingStream( + stream = DecryptingStream( _make_streaming_body(ciphertext), - _make_cbc_decryptor(wrong_key, iv), - unpadder=PKCS7(128).unpadder(), + _make_cbc_decryptor(wrong_key, iv, len(ciphertext)), content_length=len(ciphertext), ) with pytest.raises(S3EncryptionClientError, match="Failed to decrypt CBC content"): stream.read() def test_empty_ciphertext(self): - from cryptography.hazmat.primitives.padding import PKCS7 - key = os.urandom(32) iv = os.urandom(16) - stream = CBCDecryptingStream( + stream = DecryptingStream( _make_streaming_body(b""), - _make_cbc_decryptor(key, iv), - unpadder=PKCS7(128).unpadder(), + _make_cbc_decryptor(key, iv, 0), content_length=0, ) # Empty stream finalize will fail because CBC expects at least one block @@ -252,15 +238,14 @@ def test_empty_ciphertext(self): stream.read() -class TestGCMBufferedDecryptingStream: +class TestBufferedDecryptingStream: def test_full_read(self): plaintext = os.urandom(1024) ct, key, nonce = _encrypt_gcm(plaintext) - stream = GCMBufferedDecryptingStream( + stream = BufferedDecryptingStream( _make_streaming_body(ct), - _make_gcm_decryptor(key, nonce), - tag_length=16, + _make_gcm_decryptor(key, nonce, len(ct)), content_length=len(ct), ) assert stream.read() == plaintext @@ -268,10 +253,9 @@ def test_full_read(self): def test_partial_reads(self): plaintext = os.urandom(512) ct, key, nonce = _encrypt_gcm(plaintext) - stream = GCMBufferedDecryptingStream( + stream = BufferedDecryptingStream( _make_streaming_body(ct), - _make_gcm_decryptor(key, nonce), - tag_length=16, + _make_gcm_decryptor(key, nonce, len(ct)), content_length=len(ct), ) result = b"" @@ -283,8 +267,10 @@ def test_read_triggers_full_decrypt(self): plaintext = os.urandom(256) ct, key, nonce = _encrypt_gcm(plaintext) body = _make_streaming_body(ct) - stream = GCMBufferedDecryptingStream( - body, _make_gcm_decryptor(key, nonce), tag_length=16, content_length=len(ct) + stream = BufferedDecryptingStream( + body, + _make_gcm_decryptor(key, nonce, len(ct)), + content_length=len(ct), ) assert stream._plaintext is None stream.read(1) @@ -295,10 +281,9 @@ def test_read_triggers_full_decrypt(self): def test_tell(self): plaintext = os.urandom(200) ct, key, nonce = _encrypt_gcm(plaintext) - stream = GCMBufferedDecryptingStream( + stream = BufferedDecryptingStream( _make_streaming_body(ct), - _make_gcm_decryptor(key, nonce), - tag_length=16, + _make_gcm_decryptor(key, nonce, len(ct)), content_length=len(ct), ) stream.read(50) @@ -307,22 +292,20 @@ def test_tell(self): def test_readable(self): plaintext = b"readable test" ct, key, nonce = _encrypt_gcm(plaintext) - stream = GCMBufferedDecryptingStream( + stream = BufferedDecryptingStream( _make_streaming_body(ct), - _make_gcm_decryptor(key, nonce), - tag_length=16, + _make_gcm_decryptor(key, nonce, len(ct)), content_length=len(ct), ) assert stream.readable() def test_readinto(self): - """Asserts that readinto is implemented by botocore's StreamingBody""" + """Asserts that readinto is implemented.""" plaintext = os.urandom(64) ct, key, nonce = _encrypt_gcm(plaintext) - stream = GCMBufferedDecryptingStream( + stream = BufferedDecryptingStream( _make_streaming_body(ct), - _make_gcm_decryptor(key, nonce), - tag_length=16, + _make_gcm_decryptor(key, nonce, len(ct)), content_length=len(ct), ) buf = bytearray(64) @@ -333,34 +316,37 @@ def test_readinto(self): def test_enter_returns_self(self): plaintext = b"enter" ct, key, nonce = _encrypt_gcm(plaintext) - stream = GCMBufferedDecryptingStream( + stream = BufferedDecryptingStream( _make_streaming_body(ct), - _make_gcm_decryptor(key, nonce), - tag_length=16, + _make_gcm_decryptor(key, nonce, len(ct)), content_length=len(ct), ) assert stream.__enter__() is stream def test_close_delegates(self): - """Asserts that close is implemented by botocore's StreamingBody""" + """Asserts that close delegates to the body.""" plaintext = b"close" ct, key, nonce = _encrypt_gcm(plaintext) body = _make_streaming_body(ct) - stream = GCMBufferedDecryptingStream( - body, _make_gcm_decryptor(key, nonce), tag_length=16, content_length=len(ct) + stream = BufferedDecryptingStream( + body, + _make_gcm_decryptor(key, nonce, len(ct)), + content_length=len(ct), ) stream.close() body.close.assert_called_once() def test_close_without_close_attr(self): - """Asserts that close is implemented by botocore's StreamingBody""" + """Asserts that close handles bodies without close.""" plaintext = b"no close" ct, key, nonce = _encrypt_gcm(plaintext) body = Mock() del body.close body.read = BytesIO(ct).read - stream = GCMBufferedDecryptingStream( - body, _make_gcm_decryptor(key, nonce), tag_length=16, content_length=len(ct) + stream = BufferedDecryptingStream( + body, + _make_gcm_decryptor(key, nonce, len(ct)), + content_length=len(ct), ) stream.close() # should not raise @@ -368,13 +354,12 @@ def test_wrong_key_raises_error(self): plaintext = b"wrong key" ct, _key, nonce = _encrypt_gcm(plaintext) wrong_key = os.urandom(32) - stream = GCMBufferedDecryptingStream( + stream = BufferedDecryptingStream( _make_streaming_body(ct), - _make_gcm_decryptor(wrong_key, nonce), - tag_length=16, + _make_gcm_decryptor(wrong_key, nonce, len(ct)), content_length=len(ct), ) - with pytest.raises(S3EncryptionClientError, match="Failed to decrypt object"): + with pytest.raises(S3EncryptionClientError, match="Failed to decrypt"): stream.read() def test_tampered_ciphertext_raises_error(self): @@ -382,22 +367,20 @@ def test_tampered_ciphertext_raises_error(self): ct, key, nonce = _encrypt_gcm(plaintext) tampered = bytearray(ct) tampered[0] ^= 0xFF - stream = GCMBufferedDecryptingStream( + stream = BufferedDecryptingStream( _make_streaming_body(bytes(tampered)), - _make_gcm_decryptor(key, nonce), - tag_length=16, + _make_gcm_decryptor(key, nonce, len(ct)), content_length=len(ct), ) - with pytest.raises(S3EncryptionClientError, match="Failed to decrypt object"): + with pytest.raises(S3EncryptionClientError, match="Failed to decrypt"): stream.read() def test_idempotent_decrypt(self): plaintext = os.urandom(128) ct, key, nonce = _encrypt_gcm(plaintext) - stream = GCMBufferedDecryptingStream( + stream = BufferedDecryptingStream( _make_streaming_body(ct), - _make_gcm_decryptor(key, nonce), - tag_length=16, + _make_gcm_decryptor(key, nonce, len(ct)), content_length=len(ct), ) first = stream.read(63) @@ -410,10 +393,9 @@ class TestDelayedAuthGCMDecryption: def test_full_read(self): plaintext = os.urandom(1024) ct, key, nonce = _encrypt_gcm(plaintext) - stream = GCMDelayedAuthDecryptingStream( + stream = DecryptingStream( _make_streaming_body(ct), - _make_gcm_decryptor(key, nonce), - tag_length=16, + _make_gcm_decryptor(key, nonce, len(ct)), content_length=len(ct), ) assert stream.read() == plaintext @@ -421,10 +403,9 @@ def test_full_read(self): def test_chunked_read(self): plaintext = os.urandom(512) ct, key, nonce = _encrypt_gcm(plaintext) - stream = GCMDelayedAuthDecryptingStream( + stream = DecryptingStream( _make_streaming_body(ct), - _make_gcm_decryptor(key, nonce), - tag_length=16, + _make_gcm_decryptor(key, nonce, len(ct)), content_length=len(ct), ) result = b"" @@ -435,10 +416,9 @@ def test_chunked_read(self): def test_read_after_finalized_returns_empty(self): plaintext = os.urandom(128) ct, key, nonce = _encrypt_gcm(plaintext) - stream = GCMDelayedAuthDecryptingStream( + stream = DecryptingStream( _make_streaming_body(ct), - _make_gcm_decryptor(key, nonce), - tag_length=16, + _make_gcm_decryptor(key, nonce, len(ct)), content_length=len(ct), ) actual = stream.read() @@ -449,10 +429,9 @@ def test_read_after_finalized_returns_empty(self): def test_readable_false_after_finalized(self): plaintext = b"readable" ct, key, nonce = _encrypt_gcm(plaintext) - stream = GCMDelayedAuthDecryptingStream( + stream = DecryptingStream( _make_streaming_body(ct), - _make_gcm_decryptor(key, nonce), - tag_length=16, + _make_gcm_decryptor(key, nonce, len(ct)), content_length=len(ct), ) assert stream.readable() @@ -463,8 +442,10 @@ def test_close_delegates(self): plaintext = b"close" ct, key, nonce = _encrypt_gcm(plaintext) body = _make_streaming_body(ct) - stream = GCMDelayedAuthDecryptingStream( - body, _make_gcm_decryptor(key, nonce), tag_length=16, content_length=len(ct) + stream = DecryptingStream( + body, + _make_gcm_decryptor(key, nonce, len(ct)), + content_length=len(ct), ) stream.close() body.close.assert_called_once() @@ -472,10 +453,9 @@ def test_close_delegates(self): def test_enter_returns_self(self): plaintext = b"ctx" ct, key, nonce = _encrypt_gcm(plaintext) - stream = GCMDelayedAuthDecryptingStream( + stream = DecryptingStream( _make_streaming_body(ct), - _make_gcm_decryptor(key, nonce), - tag_length=16, + _make_gcm_decryptor(key, nonce, len(ct)), content_length=len(ct), ) assert stream.__enter__() is stream @@ -484,13 +464,12 @@ def test_wrong_key_raises_error(self): plaintext = b"wrong key" ct, _key, nonce = _encrypt_gcm(plaintext) wrong_key = os.urandom(32) - stream = GCMDelayedAuthDecryptingStream( + stream = DecryptingStream( _make_streaming_body(ct), - _make_gcm_decryptor(wrong_key, nonce), - tag_length=16, + _make_gcm_decryptor(wrong_key, nonce, len(ct)), content_length=len(ct), ) - with pytest.raises(S3EncryptionClientError, match="Failed to decrypt GCM content"): + with pytest.raises(S3EncryptionClientError, match="Failed to decrypt"): stream.read() def test_tampered_tag_raises_error(self): @@ -498,13 +477,12 @@ def test_tampered_tag_raises_error(self): ct, key, nonce = _encrypt_gcm(plaintext) tampered = bytearray(ct) tampered[-1] ^= 0xFF # flip last byte (part of tag) - stream = GCMDelayedAuthDecryptingStream( + stream = DecryptingStream( _make_streaming_body(bytes(tampered)), - _make_gcm_decryptor(key, nonce), - tag_length=16, + _make_gcm_decryptor(key, nonce, len(ct)), content_length=len(ct), ) - with pytest.raises(S3EncryptionClientError, match="Failed to decrypt GCM content"): + with pytest.raises(S3EncryptionClientError, match="Failed to decrypt"): stream.read() def test_small_data_less_than_tag_length(self): @@ -513,10 +491,9 @@ def test_small_data_less_than_tag_length(self): ct, key, nonce = _encrypt_gcm(plaintext) # For empty plaintext, ct is just the 16-byte tag assert len(ct) == 16 - stream = GCMDelayedAuthDecryptingStream( + stream = DecryptingStream( _make_streaming_body(ct), - _make_gcm_decryptor(key, nonce), - tag_length=16, + _make_gcm_decryptor(key, nonce, len(ct)), content_length=len(ct), ) assert stream.read() == b"" @@ -524,10 +501,9 @@ def test_small_data_less_than_tag_length(self): def test_large_data(self): plaintext = os.urandom(1024 * 1024) # 1 MB ct, key, nonce = _encrypt_gcm(plaintext) - stream = GCMDelayedAuthDecryptingStream( + stream = DecryptingStream( _make_streaming_body(ct), - _make_gcm_decryptor(key, nonce), - tag_length=16, + _make_gcm_decryptor(key, nonce, len(ct)), content_length=len(ct), ) result = b"" @@ -550,10 +526,9 @@ class TestEdgeCasePlaintextLengths: def test_buffered_gcm(self, length): plaintext = os.urandom(length) ct, key, nonce = _encrypt_gcm(plaintext) - stream = GCMBufferedDecryptingStream( + stream = BufferedDecryptingStream( _make_streaming_body(ct), - _make_gcm_decryptor(key, nonce), - tag_length=16, + _make_gcm_decryptor(key, nonce, len(ct)), content_length=len(ct), ) assert stream.read() == plaintext @@ -562,10 +537,9 @@ def test_buffered_gcm(self, length): def test_delayed_auth_gcm(self, length): plaintext = os.urandom(length) ct, key, nonce = _encrypt_gcm(plaintext) - stream = GCMDelayedAuthDecryptingStream( + stream = DecryptingStream( _make_streaming_body(ct), - _make_gcm_decryptor(key, nonce), - tag_length=16, + _make_gcm_decryptor(key, nonce, len(ct)), content_length=len(ct), ) result = b"" @@ -576,25 +550,13 @@ def test_delayed_auth_gcm(self, length): @pytest.mark.parametrize("length", EDGE_CASE_LENGTHS) def test_delayed_auth_cbc(self, length): plaintext = os.urandom(length) - ciphertext, key, iv, unpadder = _encrypt_cbc(plaintext) - stream = CBCDecryptingStream( + ciphertext, key, iv = _encrypt_cbc(plaintext) + stream = DecryptingStream( _make_streaming_body(ciphertext), - _make_cbc_decryptor(key, iv), - unpadder=unpadder, + _make_cbc_decryptor(key, iv, len(ciphertext)), content_length=len(ciphertext), ) result = b"" while stream.readable(): - # odd read size to stress tag-splitting/padding result += stream.read(7) assert result == plaintext - - -class TestGCMDelayedAuthContentLengthValidation: - - def test_content_length_less_than_tag_length_raises(self): - """ContentLength smaller than the GCM tag must raise immediately.""" - stream_body = _make_streaming_body(b"\x00" * 8) - decryptor = _make_gcm_decryptor(os.urandom(32), os.urandom(12)) - with pytest.raises(S3EncryptionClientError, match="less than GCM tag length"): - GCMDelayedAuthDecryptingStream(stream_body, decryptor, tag_length=16, content_length=8) From 395505d1c02cf258b39bd12cf381191f36699a40 Mon Sep 17 00:00:00 2001 From: texastony <5892063+texastony@users.noreply.github.com> Date: Tue, 31 Mar 2026 08:59:04 -0700 Subject: [PATCH 28/31] refactor: replace BufferedDecryptingStream with one_shot_decrypt, add iterator tests - Add buffered_decrypt.py with one_shot_decrypt() returning plain StreamingBody - Remove BufferedDecryptingStream class from stream.py - Add docstrings to DecryptingStream iterator methods - Wire _verify_content_length into _finalize to catch truncated HTTP responses - Add unit tests for iter_chunks, iter_lines, __iter__, __next__, readinto, readlines --- src/s3_encryption/buffered_decrypt.py | 20 ++ src/s3_encryption/pipelines.py | 18 +- src/s3_encryption/stream.py | 152 ++---------- .../test_i_s3_encryption_streaming.py | 8 +- test/test_stream.py | 223 +++++++++++++----- 5 files changed, 214 insertions(+), 207 deletions(-) create mode 100644 src/s3_encryption/buffered_decrypt.py diff --git a/src/s3_encryption/buffered_decrypt.py b/src/s3_encryption/buffered_decrypt.py new file mode 100644 index 00000000..65bb8aa9 --- /dev/null +++ b/src/s3_encryption/buffered_decrypt.py @@ -0,0 +1,20 @@ +# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +"""One Shot decryption into a buffer.""" + +from io import BytesIO + +from botocore.response import StreamingBody + +from s3_encryption.decryptor import Decryptor + + +def one_shot_decrypt(streaming_body: object, decryptor: Decryptor): + """Decrypt a streaming object. + + Args: + streaming_body (object): A streaming object. + decryptor (Decryptor): Decryptor object. + """ + plaintext = decryptor.finalize(streaming_body.read()) + return StreamingBody(BytesIO(plaintext), len(plaintext)) diff --git a/src/s3_encryption/pipelines.py b/src/s3_encryption/pipelines.py index 34068781..3e1e8ef3 100644 --- a/src/s3_encryption/pipelines.py +++ b/src/s3_encryption/pipelines.py @@ -11,10 +11,12 @@ import os from attrs import define, field +from botocore.response import StreamingBody from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.primitives.ciphers.aead import AESGCM from cryptography.hazmat.primitives.padding import PKCS7 +from .buffered_decrypt import one_shot_decrypt from .decryptor import AesCbcDecryptor, AesGcmDecryptor from .exceptions import S3EncryptionClientError from .instruction_file import fetch_instruction_file @@ -28,10 +30,7 @@ EncryptionMaterials, ) from .metadata import ObjectMetadata -from .stream import ( - BufferedDecryptingStream, - DecryptingStream, -) +from .stream import DecryptingStream @define @@ -232,7 +231,7 @@ def decrypt( encryption_context=None, bucket=None, key=None, - ): + ) -> StreamingBody: """Decrypt the data after it is retrieved from S3. Args: @@ -244,7 +243,7 @@ def decrypt( key (str, optional): S3 object key (required for instruction file) Returns: - A decrypting stream (BufferedDecryptingStream or DelayedAuthDecryptingStream). + A botocore.response.StreamingBody of plain-text """ # Convert the metadata dictionary to an ObjectMetadata instance streaming_body = response.get("Body") @@ -469,7 +468,7 @@ def _decrypt_gcm_streaming( ##= specification/s3-encryption/client.md#enable-delayed-authentication ##= type=implementation ##% When disabled the S3EC MUST NOT release plaintext from a stream which has not been authenticated. - return BufferedDecryptingStream(streaming_body, decryptor, content_length=content_length) + return one_shot_decrypt(streaming_body, decryptor) def _decrypt_kc_gcm_streaming( self, dec_materials, metadata, streaming_body, enable_delayed_authentication, content_length @@ -523,7 +522,10 @@ def _decrypt_kc_gcm_streaming( ) if enable_delayed_authentication: return DecryptingStream(streaming_body, decryptor, content_length=content_length) - return BufferedDecryptingStream(streaming_body, decryptor, content_length=content_length) + ##= specification/s3-encryption/client.md#enable-delayed-authentication + ##= type=implementation + ##% When disabled the S3EC MUST NOT release plaintext from a stream which has not been authenticated. + return one_shot_decrypt(streaming_body, decryptor) def _decrypt_v2(self, metadata, encryption_context) -> DecryptionMaterials: """Prepare V2 decryption materials.""" diff --git a/src/s3_encryption/stream.py b/src/s3_encryption/stream.py index 3b851449..bf588f77 100644 --- a/src/s3_encryption/stream.py +++ b/src/s3_encryption/stream.py @@ -23,129 +23,8 @@ ##= reason=Optional Feature that is a two-way door to implement later ##% If Delayed Authentication mode is disabled, and no buffer size is provided, the S3EC MUST set the buffer size to a reasonable default. -_DEFAULT_CHUNK_SIZE = 1024 - - -##= specification/s3-encryption/client.md#enable-delayed-authentication -##= type=implementation -##% When disabled the S3EC MUST NOT release plaintext from a stream which has not been authenticated. -# slots=False because StreamingBody extends IOBase which already has __weakref__. -@define(slots=False) -class BufferedDecryptingStream(StreamingBody): - """A stream that buffers all ciphertext, decrypts, then releases plaintext. - - Extends botocore's StreamingBody so it can be used as a drop-in replacement - for parsed["Body"]. All StreamingBody methods are explicitly overridden. - - This stream is cipher-agnostic — the Decryptor handles all algorithm details. - """ - - _body: object = field() - _decryptor: Decryptor = field() - _content_length: int = field() - _plaintext: io.BytesIO = field(init=False, default=None) - _plaintext_length: int = field(init=False, default=0) - - def __attrs_post_init__(self): # noqa: D105 - super().__init__(io.BytesIO(), content_length=self._content_length) - - def _decrypt(self): - """Read all ciphertext, decrypt and verify, cache plaintext.""" - if self._plaintext is not None: - return - data = self._body.read() - plaintext = self._decryptor.finalize(data) - self._plaintext = io.BytesIO(plaintext) - self._plaintext_length = len(plaintext) - self._raw_stream = self._plaintext - - def __del__(self): # noqa: D105 - pass - - def set_socket_timeout(self, timeout): # noqa: D102 - pass - - def readable(self): # noqa: D102 - self._decrypt() - return self._plaintext.readable() - - def read(self, amt=None): - """Read decrypted plaintext. Triggers full decryption on first call. - - Args: - amt: Number of bytes to read. If None, reads all remaining data. - - Returns: - bytes: Decrypted plaintext bytes. - """ - self._decrypt() - return self._plaintext.read() if amt is None else self._plaintext.read(amt) - - def readinto(self, b): # noqa: D102 - self._decrypt() - data = self._plaintext.read(len(b)) - n = len(data) - b[:n] = data - return n - - def readlines(self): # noqa: D102 - self._decrypt() - return self._plaintext.readlines() - - def __iter__(self): # noqa: D105 - return self.iter_chunks(_DEFAULT_CHUNK_SIZE) - def __next__(self): # noqa: D105 - chunk = self.read(_DEFAULT_CHUNK_SIZE) - if chunk: - return chunk - raise StopIteration() - - next = __next__ - - def iter_lines(self, chunk_size=_DEFAULT_CHUNK_SIZE, keepends=False): # noqa: D102 - pending = b"" - for chunk in self.iter_chunks(chunk_size): - lines = (pending + chunk).splitlines(True) - for line in lines[:-1]: - yield line.splitlines(keepends)[0] - pending = lines[-1] - if pending: - yield pending.splitlines(keepends)[0] - - def iter_chunks(self, chunk_size=_DEFAULT_CHUNK_SIZE): # noqa: D102 - while True: - chunk = self.read(chunk_size) - if chunk == b"": - break - yield chunk - - def _verify_content_length(self): - """Verify that the decryptor consumed exactly content_length bytes.""" - if ( - self._decryptor.content_length is not None - and self._decryptor.amount_read != self._decryptor.content_length - ): - raise IncompleteReadError( - actual_bytes=self._decryptor.amount_read, - expected_bytes=self._decryptor.content_length, - ) - - def tell(self): # noqa: D102 - self._decrypt() - return self._plaintext.tell() - - def close(self): - """Close the underlying stream.""" - if hasattr(self._body, "close"): - self._body.close() - - def __enter__(self): # noqa: D105 - self._decrypt() - return self - - def __exit__(self, *args): # noqa: D105 - self.close() +_DEFAULT_CHUNK_SIZE = 1024 ##= specification/s3-encryption/client.md#enable-delayed-authentication @@ -220,9 +99,16 @@ def _finalize(self, trailing_data): if self._finalized: return b"" self._finalized = True - return self._decryptor.finalize(trailing_data) + plaintext = self._decryptor.finalize(trailing_data) + self._verify_content_length() + return plaintext + + def readinto(self, b): + """Read bytes into a pre-allocated, writable bytes-like object b. - def readinto(self, b): # noqa: D102 + Returns the number of bytes decrypted. + Note: CBC Padding and GCM tag will be removed, so bytes read MAYBE greater than bytes decrypted. + """ data = self.read(len(b)) n = len(data) b[:n] = data @@ -231,10 +117,12 @@ def readinto(self, b): # noqa: D102 def readlines(self): # noqa: D102 return self.read().splitlines(True) - def __iter__(self): # noqa: D105 + def __iter__(self): + """Return an iterator to yield 1k chunks from the decryption stream.""" return self.iter_chunks(_DEFAULT_CHUNK_SIZE) - def __next__(self): # noqa: D105 + def __next__(self): + """Return the next 1k chunk from the decryption stream.""" chunk = self.read(_DEFAULT_CHUNK_SIZE) if chunk: return chunk @@ -242,7 +130,12 @@ def __next__(self): # noqa: D105 next = __next__ - def iter_lines(self, chunk_size=_DEFAULT_CHUNK_SIZE, keepends=False): # noqa: D102 + def iter_lines(self, chunk_size=_DEFAULT_CHUNK_SIZE, keepends=False): + """Return an iterator to yield lines from the decryption stream. + + This is achieved by reading chunk of bytes (of size chunk_size) at a + time from the chipher-text stream, decrypting them, and then yielding lines from there. + """ pending = b"" for chunk in self.iter_chunks(chunk_size): lines = (pending + chunk).splitlines(True) @@ -252,7 +145,8 @@ def iter_lines(self, chunk_size=_DEFAULT_CHUNK_SIZE, keepends=False): # noqa: D if pending: yield pending.splitlines(keepends)[0] - def iter_chunks(self, chunk_size=_DEFAULT_CHUNK_SIZE): # noqa: D102 + def iter_chunks(self, chunk_size=_DEFAULT_CHUNK_SIZE): + """Return an iterator to yield chunks of chunk_size bytes from the raw stream.""" while True: chunk = self.read(chunk_size) if chunk == b"": @@ -275,7 +169,7 @@ def tell(self): # noqa: D102 return self._bytes_consumed def close(self): - """Close the underlying stream.""" + """Close the underlying cipher-text stream.""" if hasattr(self._body, "close"): self._body.close() diff --git a/test/integration/test_i_s3_encryption_streaming.py b/test/integration/test_i_s3_encryption_streaming.py index e4e471d1..530959bb 100644 --- a/test/integration/test_i_s3_encryption_streaming.py +++ b/test/integration/test_i_s3_encryption_streaming.py @@ -11,14 +11,12 @@ import boto3 import pytest +from botocore.response import StreamingBody from s3_encryption import S3EncryptionClient, S3EncryptionClientConfig from s3_encryption.materials.kms_keyring import KmsKeyring from s3_encryption.materials.materials import AlgorithmSuite, CommitmentPolicy -from s3_encryption.stream import ( - BufferedDecryptingStream, - DecryptingStream, -) +from s3_encryption.stream import DecryptingStream bucket = os.environ.get("CI_S3_BUCKET", "s3ec-python-github-test-bucket") region = os.environ.get("CI_AWS_REGION", "us-west-2") @@ -73,7 +71,7 @@ def test_buffered_roundtrip(algorithm_suite, commitment_policy): response = s3ec.get_object(Bucket=bucket, Key=key) body = response["Body"] - assert isinstance(body, BufferedDecryptingStream) + assert isinstance(body, StreamingBody) assert body.read() == data diff --git a/test/test_stream.py b/test/test_stream.py index 183c14c9..b6753824 100644 --- a/test/test_stream.py +++ b/test/test_stream.py @@ -11,12 +11,10 @@ from cryptography.hazmat.primitives.ciphers.aead import AESGCM from cryptography.hazmat.primitives.padding import PKCS7 +from s3_encryption.buffered_decrypt import one_shot_decrypt from s3_encryption.decryptor import AesCbcDecryptor, AesGcmDecryptor from s3_encryption.exceptions import S3EncryptionClientError -from s3_encryption.stream import ( - BufferedDecryptingStream, - DecryptingStream, -) +from s3_encryption.stream import DecryptingStream def _encrypt_gcm(plaintext: bytes): @@ -105,19 +103,24 @@ def test_buffered_verifies_tag_before_releasing_any_plaintext(self): ct, key, nonce = _encrypt_gcm(plaintext) body = _make_streaming_body(ct) - stream = BufferedDecryptingStream( - body, - _make_gcm_decryptor(key, nonce, len(ct)), - content_length=len(ct), - ) - # read(1) triggers _decrypt(), which calls self._body.read() with no amt, - # consuming the entire ciphertext and verifying the GCM tag before - # returning even 1 byte of plaintext. - chunk = stream.read(1) + decryptor = _make_gcm_decryptor(key, nonce, len(ct)) + original_finalize = decryptor.finalize + finalize_called = [] + def spy_finalize(data): + result = original_finalize(data) + finalize_called.append(True) + return result + + decryptor.finalize = spy_finalize + + stream = one_shot_decrypt(body, decryptor) + + # one_shot_decrypt calls finalize() eagerly — tag is verified + # before any read() call on the returned stream. + assert finalize_called, "finalize (tag verification) must happen before read()" + chunk = stream.read(1) assert chunk == plaintext[:1] - # _plaintext being set confirms full decrypt+verify already happened - assert stream._plaintext is not None class TestDelayedAuthCBCDecryption: @@ -243,20 +246,17 @@ class TestBufferedDecryptingStream: def test_full_read(self): plaintext = os.urandom(1024) ct, key, nonce = _encrypt_gcm(plaintext) - stream = BufferedDecryptingStream( - _make_streaming_body(ct), - _make_gcm_decryptor(key, nonce, len(ct)), - content_length=len(ct), + stream = one_shot_decrypt( + _make_streaming_body(ct), _make_gcm_decryptor(key, nonce, len(ct)) ) assert stream.read() == plaintext def test_partial_reads(self): plaintext = os.urandom(512) ct, key, nonce = _encrypt_gcm(plaintext) - stream = BufferedDecryptingStream( + stream = one_shot_decrypt( _make_streaming_body(ct), _make_gcm_decryptor(key, nonce, len(ct)), - content_length=len(ct), ) result = b"" while chunk := stream.read(100): @@ -267,24 +267,24 @@ def test_read_triggers_full_decrypt(self): plaintext = os.urandom(256) ct, key, nonce = _encrypt_gcm(plaintext) body = _make_streaming_body(ct) - stream = BufferedDecryptingStream( - body, - _make_gcm_decryptor(key, nonce, len(ct)), - content_length=len(ct), - ) - assert stream._plaintext is None - stream.read(1) - assert stream._plaintext is not None - # Entire ciphertext consumed + decryptor = _make_gcm_decryptor(key, nonce, len(ct)) + finalize_called = [] + original_finalize = decryptor.finalize + decryptor.finalize = lambda data: (finalize_called.append(True), original_finalize(data))[1] + + stream = one_shot_decrypt(body, decryptor) + # one_shot_decrypt eagerly decrypts — finalize called at construction + assert finalize_called + # Entire ciphertext consumed from the body assert body._stream.tell() == len(ct) + assert stream.read(1) == plaintext[:1] def test_tell(self): plaintext = os.urandom(200) ct, key, nonce = _encrypt_gcm(plaintext) - stream = BufferedDecryptingStream( + stream = one_shot_decrypt( _make_streaming_body(ct), _make_gcm_decryptor(key, nonce, len(ct)), - content_length=len(ct), ) stream.read(50) assert stream.tell() == 50 @@ -292,10 +292,9 @@ def test_tell(self): def test_readable(self): plaintext = b"readable test" ct, key, nonce = _encrypt_gcm(plaintext) - stream = BufferedDecryptingStream( + stream = one_shot_decrypt( _make_streaming_body(ct), _make_gcm_decryptor(key, nonce, len(ct)), - content_length=len(ct), ) assert stream.readable() @@ -303,38 +302,35 @@ def test_readinto(self): """Asserts that readinto is implemented.""" plaintext = os.urandom(64) ct, key, nonce = _encrypt_gcm(plaintext) - stream = BufferedDecryptingStream( + stream = one_shot_decrypt( _make_streaming_body(ct), _make_gcm_decryptor(key, nonce, len(ct)), - content_length=len(ct), ) buf = bytearray(64) n = stream.readinto(buf) assert n == 64 assert bytes(buf) == plaintext - def test_enter_returns_self(self): + def test_enter_returns_stream(self): plaintext = b"enter" ct, key, nonce = _encrypt_gcm(plaintext) - stream = BufferedDecryptingStream( + stream = one_shot_decrypt( _make_streaming_body(ct), _make_gcm_decryptor(key, nonce, len(ct)), - content_length=len(ct), ) - assert stream.__enter__() is stream + with stream as s: + assert s.read() == plaintext - def test_close_delegates(self): - """Asserts that close delegates to the body.""" + def test_close(self): + """Asserts that close does not raise.""" plaintext = b"close" ct, key, nonce = _encrypt_gcm(plaintext) body = _make_streaming_body(ct) - stream = BufferedDecryptingStream( + stream = one_shot_decrypt( body, _make_gcm_decryptor(key, nonce, len(ct)), - content_length=len(ct), ) - stream.close() - body.close.assert_called_once() + stream.close() # should not raise def test_close_without_close_attr(self): """Asserts that close handles bodies without close.""" @@ -343,10 +339,9 @@ def test_close_without_close_attr(self): body = Mock() del body.close body.read = BytesIO(ct).read - stream = BufferedDecryptingStream( + stream = one_shot_decrypt( body, _make_gcm_decryptor(key, nonce, len(ct)), - content_length=len(ct), ) stream.close() # should not raise @@ -354,34 +349,29 @@ def test_wrong_key_raises_error(self): plaintext = b"wrong key" ct, _key, nonce = _encrypt_gcm(plaintext) wrong_key = os.urandom(32) - stream = BufferedDecryptingStream( - _make_streaming_body(ct), - _make_gcm_decryptor(wrong_key, nonce, len(ct)), - content_length=len(ct), - ) with pytest.raises(S3EncryptionClientError, match="Failed to decrypt"): - stream.read() + one_shot_decrypt( + _make_streaming_body(ct), + _make_gcm_decryptor(wrong_key, nonce, len(ct)), + ) def test_tampered_ciphertext_raises_error(self): plaintext = b"tamper test" ct, key, nonce = _encrypt_gcm(plaintext) tampered = bytearray(ct) tampered[0] ^= 0xFF - stream = BufferedDecryptingStream( - _make_streaming_body(bytes(tampered)), - _make_gcm_decryptor(key, nonce, len(ct)), - content_length=len(ct), - ) with pytest.raises(S3EncryptionClientError, match="Failed to decrypt"): - stream.read() + one_shot_decrypt( + _make_streaming_body(bytes(tampered)), + _make_gcm_decryptor(key, nonce, len(ct)), + ) def test_idempotent_decrypt(self): plaintext = os.urandom(128) ct, key, nonce = _encrypt_gcm(plaintext) - stream = BufferedDecryptingStream( + stream = one_shot_decrypt( _make_streaming_body(ct), _make_gcm_decryptor(key, nonce, len(ct)), - content_length=len(ct), ) first = stream.read(63) second = stream.read(65) @@ -517,7 +507,7 @@ def test_large_data(self): # --------------------------------------------------------------------------- # Lengths chosen around AES block size (16) and two-block (32) boundaries, # plus zero and one byte, to exercise padding, tag-splitting, and empty-data paths. -EDGE_CASE_LENGTHS = [0, 1, 8, 15, 16, 17, 31, 32, 33, 47, 48, 49] +EDGE_CASE_LENGTHS = [0, 1, 8, 15, 16, 17, 31, 32, 33, 47, 48, 49, 300] class TestEdgeCasePlaintextLengths: @@ -526,10 +516,9 @@ class TestEdgeCasePlaintextLengths: def test_buffered_gcm(self, length): plaintext = os.urandom(length) ct, key, nonce = _encrypt_gcm(plaintext) - stream = BufferedDecryptingStream( + stream = one_shot_decrypt( _make_streaming_body(ct), _make_gcm_decryptor(key, nonce, len(ct)), - content_length=len(ct), ) assert stream.read() == plaintext @@ -560,3 +549,107 @@ def test_delayed_auth_cbc(self, length): while stream.readable(): result += stream.read(7) assert result == plaintext + + +class TestDecryptingStreamIterators: + """Tests for iter_chunks, iter_lines, __iter__, __next__, readinto, and readlines.""" + + def _make_gcm_stream(self, plaintext): + ct, key, nonce = _encrypt_gcm(plaintext) + return DecryptingStream( + _make_streaming_body(ct), + _make_gcm_decryptor(key, nonce, len(ct)), + content_length=len(ct), + ) + + @pytest.mark.parametrize("chunk_size", EDGE_CASE_LENGTHS[1:]) + def test_iter_chunks(self, chunk_size): + plaintext = os.urandom(300) + stream = self._make_gcm_stream(plaintext) + result = b"" + for chunk in stream.iter_chunks(chunk_size): + assert ( + len(chunk) <= chunk_size or not result + ) # first chunk may vary due to GCM buffering + result += chunk + assert result == plaintext + + def test_iter_chunks_default_size(self): + plaintext = os.urandom(2048) + stream = self._make_gcm_stream(plaintext) + result = b"".join(stream.iter_chunks()) + assert result == plaintext + + def test_iter_chunks_empty(self): + stream = self._make_gcm_stream(b"") + assert list(stream.iter_chunks()) == [] + + def test_iter(self): + plaintext = os.urandom(2048) + stream = self._make_gcm_stream(plaintext) + result = b"".join(stream) + assert result == plaintext + + def test_next(self): + plaintext = os.urandom(100) + stream = self._make_gcm_stream(plaintext) + first = next(stream) + assert len(first) > 0 + # drain the rest + rest = b"" + for chunk in stream: + rest += chunk + assert first + rest == plaintext + + def test_next_raises_stop_iteration(self): + stream = self._make_gcm_stream(b"") + with pytest.raises(StopIteration): + next(stream) + + def test_iter_lines(self): + plaintext = b"line1\nline2\nline3\n" + stream = self._make_gcm_stream(plaintext) + lines = list(stream.iter_lines()) + assert lines == [b"line1", b"line2", b"line3"] + + def test_iter_lines_keepends(self): + plaintext = b"line1\nline2\nline3\n" + stream = self._make_gcm_stream(plaintext) + lines = list(stream.iter_lines(keepends=True)) + assert lines == [b"line1\n", b"line2\n", b"line3\n"] + + def test_iter_lines_no_trailing_newline(self): + plaintext = b"first\nsecond" + stream = self._make_gcm_stream(plaintext) + lines = list(stream.iter_lines()) + assert lines == [b"first", b"second"] + + def test_iter_lines_empty(self): + stream = self._make_gcm_stream(b"") + assert list(stream.iter_lines()) == [] + + def test_readinto(self): + plaintext = os.urandom(64) + stream = self._make_gcm_stream(plaintext) + buf = bytearray(64) + n = stream.readinto(buf) + assert bytes(buf[:n]) == plaintext[:n] + + def test_readinto_partial(self): + plaintext = os.urandom(200) + stream = self._make_gcm_stream(plaintext) + buf = bytearray(50) + result = b"" + while n := stream.readinto(buf): + result += bytes(buf[:n]) + assert result == plaintext + + def test_readlines(self): + plaintext = b"aaa\nbbb\nccc\n" + stream = self._make_gcm_stream(plaintext) + assert stream.readlines() == [b"aaa\n", b"bbb\n", b"ccc\n"] + + def test_readlines_no_trailing_newline(self): + plaintext = b"aaa\nbbb" + stream = self._make_gcm_stream(plaintext) + assert stream.readlines() == [b"aaa\n", b"bbb"] From e19b1bfa99b964736046e4cca3065ff2e3a0803f Mon Sep 17 00:00:00 2001 From: texastony <5892063+texastony@users.noreply.github.com> Date: Tue, 31 Mar 2026 09:04:30 -0700 Subject: [PATCH 29/31] docs: add inline comments to DecryptingStream.read explaining control flow --- src/s3_encryption/stream.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/s3_encryption/stream.py b/src/s3_encryption/stream.py index bf588f77..3cd38f11 100644 --- a/src/s3_encryption/stream.py +++ b/src/s3_encryption/stream.py @@ -73,24 +73,37 @@ def read(self, amt=None): if self._finalized: return b"" + # Loop until the decryptor produces non-empty plaintext. + # The GCM decryptor's tail buffer may absorb small reads entirely + # (returning b"" from update) while it holds back the trailing auth + # tag. Looping prevents callers from seeing spurious empty bytes + # mid-stream, which would break `while chunk := stream.read(amt)`. result = b"" while not result: remaining = self._content_length - self._bytes_consumed if remaining <= 0: + # All content_length bytes consumed — finalize with no extra data. return self._finalize(b"") + # Never read past content_length; cap at amt if provided. to_read = remaining if amt is None else min(amt, remaining) raw = self._body.read(to_read) if not raw: + # Underlying stream exhausted early — finalize with what we have. return self._finalize(b"") self._bytes_consumed += len(raw) remaining = self._content_length - self._bytes_consumed if remaining <= 0: + # This is the last chunk — pass it to finalize so the decryptor + # can split off the GCM tag (or flush CBC padding) and verify. return self._finalize(raw) + # Feed ciphertext to the decryptor. For GCM, the tail buffer holds + # back the last tag_length bytes, so update() may return b"" if + # the chunk was entirely absorbed into the buffer. result = self._decryptor.update(raw) return result From 3e4017380e17004d680acc30a4e9d509475b68de Mon Sep 17 00:00:00 2001 From: texastony <5892063+texastony@users.noreply.github.com> Date: Tue, 31 Mar 2026 09:17:08 -0700 Subject: [PATCH 30/31] chore: document how stream and decryptor are a little leaky --- src/s3_encryption/stream.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/s3_encryption/stream.py b/src/s3_encryption/stream.py index 3cd38f11..7c8ccbe0 100644 --- a/src/s3_encryption/stream.py +++ b/src/s3_encryption/stream.py @@ -37,12 +37,16 @@ class DecryptingStream(StreamingBody): Extends botocore's StreamingBody so it can be used as a drop-in replacement for parsed["Body"]. All StreamingBody methods are explicitly overridden. - - This stream is cipher-agnostic — the Decryptor handles all algorithm details. - Ciphertext is fed through decryptor.update() incrementally, and - decryptor.finalize() is called with any trailing data when the body is exhausted. """ + # This stream is ALMOST cipher-agnostic — the Decryptor handles ALMOST all algorithm details. + # Ciphertext is fed through decryptor.update() incrementally, and + # decryptor.finalize() is called with any trailing data when the body is exhausted. + # + # ALMOST :: The AES-GCM tag is problematic when combined with iterators that can split + # the tag over two reads. To accommodate this, read() has a while loop with 3 return conditions. + # See inline comments of read for more details. + _body: object = field() _decryptor: Decryptor = field() _content_length: int = field() From e7ecd7b7c6b10fc4d86b5230818b275ddccf0ac4 Mon Sep 17 00:00:00 2001 From: texastony <5892063+texastony@users.noreply.github.com> Date: Tue, 31 Mar 2026 12:26:11 -0700 Subject: [PATCH 31/31] fix: address self-code review findings for streaming decryption - raise S3EncryptionClientSecurityError for CBC/GCM decryption failures - remove magic number tolerance in _verify_content_length - fix type annotations (object -> StreamingBody) - fix __iter__/__next__ inconsistency in DecryptingStream - Remove unnecessary __del__ and set_socket_timeout overrides - Update test assertions to expect SecurityError --- src/s3_encryption/buffered_decrypt.py | 2 +- src/s3_encryption/decryptor.py | 7 +++++-- src/s3_encryption/pipelines.py | 2 +- src/s3_encryption/stream.py | 12 ++---------- .../test_i_s3_encryption_instruction_file.py | 4 +--- test/test_decryption.py | 2 +- test/test_stream.py | 18 +++++++++--------- 7 files changed, 20 insertions(+), 27 deletions(-) diff --git a/src/s3_encryption/buffered_decrypt.py b/src/s3_encryption/buffered_decrypt.py index 65bb8aa9..6c305751 100644 --- a/src/s3_encryption/buffered_decrypt.py +++ b/src/s3_encryption/buffered_decrypt.py @@ -9,7 +9,7 @@ from s3_encryption.decryptor import Decryptor -def one_shot_decrypt(streaming_body: object, decryptor: Decryptor): +def one_shot_decrypt(streaming_body: StreamingBody, decryptor: Decryptor): """Decrypt a streaming object. Args: diff --git a/src/s3_encryption/decryptor.py b/src/s3_encryption/decryptor.py index e94be803..e3d4eece 100644 --- a/src/s3_encryption/decryptor.py +++ b/src/s3_encryption/decryptor.py @@ -5,8 +5,9 @@ from abc import ABC, abstractmethod from attrs import define, field +from cryptography.exceptions import InvalidTag -from .exceptions import S3EncryptionClientError +from .exceptions import S3EncryptionClientError, S3EncryptionClientSecurityError class Decryptor(ABC): @@ -72,7 +73,7 @@ def finalize(self, data: bytes) -> bytes: plaintext += self._decryptor.finalize() return self._unpadder.update(plaintext) + self._unpadder.finalize() except Exception as e: - raise S3EncryptionClientError(f"Failed to decrypt CBC content: {e}") from e + raise S3EncryptionClientSecurityError(f"Failed to decrypt CBC content: {e}") from e @define @@ -134,5 +135,7 @@ def finalize(self, data: bytes) -> bytes: return plaintext + self._decryptor.finalize_with_tag(tag) except S3EncryptionClientError: raise + except InvalidTag as e: + raise S3EncryptionClientSecurityError(f"Failed to decrypt Object: {e}") from e except Exception as e: raise S3EncryptionClientError(f"Failed to decrypt Object: {e}") from e diff --git a/src/s3_encryption/pipelines.py b/src/s3_encryption/pipelines.py index 3e1e8ef3..5561047a 100644 --- a/src/s3_encryption/pipelines.py +++ b/src/s3_encryption/pipelines.py @@ -246,7 +246,7 @@ def decrypt( A botocore.response.StreamingBody of plain-text """ # Convert the metadata dictionary to an ObjectMetadata instance - streaming_body = response.get("Body") + streaming_body: StreamingBody = response.get("Body") content_length = response.get("ContentLength") encryption_metadata = response.get("Metadata", {}) metadata = ObjectMetadata.from_dict(encryption_metadata) diff --git a/src/s3_encryption/stream.py b/src/s3_encryption/stream.py index 7c8ccbe0..a4e85a74 100644 --- a/src/s3_encryption/stream.py +++ b/src/s3_encryption/stream.py @@ -56,12 +56,6 @@ class DecryptingStream(StreamingBody): def __attrs_post_init__(self): # noqa: D105 super().__init__(io.BytesIO(), content_length=self._content_length) - def __del__(self): # noqa: D105 - pass - - def set_socket_timeout(self, timeout): # noqa: D102 - pass - def readable(self): # noqa: D102 return not self._finalized @@ -136,7 +130,7 @@ def readlines(self): # noqa: D102 def __iter__(self): """Return an iterator to yield 1k chunks from the decryption stream.""" - return self.iter_chunks(_DEFAULT_CHUNK_SIZE) + return self def __next__(self): """Return the next 1k chunk from the decryption stream.""" @@ -173,9 +167,7 @@ def iter_chunks(self, chunk_size=_DEFAULT_CHUNK_SIZE): def _verify_content_length(self): """Verify that the decryptor consumed exactly content_length bytes.""" if self._decryptor.content_length is not None and not ( - self._decryptor.amount_read - 16 - <= self._decryptor.content_length - <= self._decryptor.amount_read + 16 + self._decryptor.amount_read == self._content_length ): raise IncompleteReadError( actual_bytes=self._decryptor.amount_read, diff --git a/test/integration/test_i_s3_encryption_instruction_file.py b/test/integration/test_i_s3_encryption_instruction_file.py index 570307a1..6c93d832 100644 --- a/test/integration/test_i_s3_encryption_instruction_file.py +++ b/test/integration/test_i_s3_encryption_instruction_file.py @@ -176,7 +176,6 @@ def test_decrypt_v2_instruction_file_custom_suffix(delayed_auth): LARGE_FILE_SIZE = 52428800 # 50 MB -@pytest.mark.skip(reason="Slow as hell") def test_decrypt_large_v2_instruction_file_delayed_auth(): """Test streaming decryption of a 50 MB V2 object with delayed authentication.""" key = TEST_OBJECTS["large_v2_instruction_file"] @@ -200,8 +199,7 @@ def test_decrypt_large_v2_instruction_file_delayed_auth(): assert total == LARGE_FILE_SIZE -# TODO(v3): enable once V3 decryption is implemented -@pytest.mark.skip(reason="V3 decryption not yet implemented") +@pytest.mark.skip(reason="V3 large file not yet written to static bucket") def test_decrypt_large_v3_instruction_file_delayed_auth(): """Test streaming decryption of a 50 MB V3 object with delayed authentication.""" key = TEST_OBJECTS["large_v3_instruction_file"] diff --git a/test/test_decryption.py b/test/test_decryption.py index ed37f7d5..7f9a51d1 100644 --- a/test/test_decryption.py +++ b/test/test_decryption.py @@ -196,7 +196,7 @@ def test_cbc_decryption_fails_with_wrong_key(self): keyring_return=dec_mats, ) - with pytest.raises(S3EncryptionClientError, match="Failed to decrypt CBC content"): + with pytest.raises(S3EncryptionClientSecurityError, match="Failed to decrypt CBC content"): pipeline.decrypt( _response(metadata, ciphertext), ".instruction", enable_delayed_authentication=False ).read() diff --git a/test/test_stream.py b/test/test_stream.py index b6753824..dc0f2eb9 100644 --- a/test/test_stream.py +++ b/test/test_stream.py @@ -13,7 +13,7 @@ from s3_encryption.buffered_decrypt import one_shot_decrypt from s3_encryption.decryptor import AesCbcDecryptor, AesGcmDecryptor -from s3_encryption.exceptions import S3EncryptionClientError +from s3_encryption.exceptions import S3EncryptionClientSecurityError from s3_encryption.stream import DecryptingStream @@ -225,7 +225,7 @@ def test_wrong_key_raises_error(self): _make_cbc_decryptor(wrong_key, iv, len(ciphertext)), content_length=len(ciphertext), ) - with pytest.raises(S3EncryptionClientError, match="Failed to decrypt CBC content"): + with pytest.raises(S3EncryptionClientSecurityError, match="Failed to decrypt CBC content"): stream.read() def test_empty_ciphertext(self): @@ -237,7 +237,7 @@ def test_empty_ciphertext(self): content_length=0, ) # Empty stream finalize will fail because CBC expects at least one block - with pytest.raises(S3EncryptionClientError, match="Failed to decrypt CBC content"): + with pytest.raises(S3EncryptionClientSecurityError, match="Failed to decrypt CBC content"): stream.read() @@ -349,7 +349,7 @@ def test_wrong_key_raises_error(self): plaintext = b"wrong key" ct, _key, nonce = _encrypt_gcm(plaintext) wrong_key = os.urandom(32) - with pytest.raises(S3EncryptionClientError, match="Failed to decrypt"): + with pytest.raises(S3EncryptionClientSecurityError, match="Failed to decrypt"): one_shot_decrypt( _make_streaming_body(ct), _make_gcm_decryptor(wrong_key, nonce, len(ct)), @@ -360,7 +360,7 @@ def test_tampered_ciphertext_raises_error(self): ct, key, nonce = _encrypt_gcm(plaintext) tampered = bytearray(ct) tampered[0] ^= 0xFF - with pytest.raises(S3EncryptionClientError, match="Failed to decrypt"): + with pytest.raises(S3EncryptionClientSecurityError, match="Failed to decrypt"): one_shot_decrypt( _make_streaming_body(bytes(tampered)), _make_gcm_decryptor(key, nonce, len(ct)), @@ -459,7 +459,7 @@ def test_wrong_key_raises_error(self): _make_gcm_decryptor(wrong_key, nonce, len(ct)), content_length=len(ct), ) - with pytest.raises(S3EncryptionClientError, match="Failed to decrypt"): + with pytest.raises(S3EncryptionClientSecurityError, match="Failed to decrypt"): stream.read() def test_tampered_tag_raises_error(self): @@ -472,7 +472,7 @@ def test_tampered_tag_raises_error(self): _make_gcm_decryptor(key, nonce, len(ct)), content_length=len(ct), ) - with pytest.raises(S3EncryptionClientError, match="Failed to decrypt"): + with pytest.raises(S3EncryptionClientSecurityError, match="Failed to decrypt"): stream.read() def test_small_data_less_than_tag_length(self): @@ -546,8 +546,8 @@ def test_delayed_auth_cbc(self, length): content_length=len(ciphertext), ) result = b"" - while stream.readable(): - result += stream.read(7) + while chunk := stream.read(7): + result += chunk assert result == plaintext