From 439fd492c642cd8b5a5319e56dc6951e81cce4d2 Mon Sep 17 00:00:00 2001 From: Kess Plasmeier Date: Wed, 11 Feb 2026 15:02:41 -0800 Subject: [PATCH 01/10] first draft --- src/s3_encryption/__init__.py | 209 ++++++++++++++++++++------------- src/s3_encryption/pipelines.py | 1 + 2 files changed, 129 insertions(+), 81 deletions(-) diff --git a/src/s3_encryption/__init__.py b/src/s3_encryption/__init__.py index 46cdbdd1..8e670bff 100644 --- a/src/s3_encryption/__init__.py +++ b/src/s3_encryption/__init__.py @@ -5,7 +5,7 @@ import io from attrs import define, field -from botocore import serialize +from botocore.exceptions import ParamValidationError from botocore.response import StreamingBody from .exceptions import S3EncryptionClientError @@ -31,31 +31,136 @@ def _default_cmm_for_keyring(self): return DefaultCryptoMaterialsManager(self.keyring) +class S3EncryptionClientPlugin: + """Plugin that adds encryption/decryption capabilities to a boto3 S3 client. + + This plugin uses boto3's event system to intercept put_object and get_object + calls to provide transparent encryption and decryption of S3 objects. + """ + + def __init__(self, config: S3EncryptionClientConfig): + """Initialize the plugin with encryption configuration. + + Args: + config: S3EncryptionClientConfig containing keyring and CMM + """ + self.config = config + + def on_put_object_before_call(self, params, **kwargs): + """Event handler for before-call.s3.PutObject. + + This handler encrypts the body after serialization but before the request is sent. + + Args: + params: Dictionary of parameters for the PutObject call (after serialization) + **kwargs: Additional event arguments + """ + # At this point, boto3 has already serialized the Body + # Extract the serialized body from the request + body = params.get("body") + if body is None: + body_bytes = b"" + elif isinstance(body, bytes): + body_bytes = body + elif hasattr(body, "read"): + # It's a file-like object (BytesIO, etc.) + # TODO: Stream Encryption + body_bytes = body.read() + else: + body_bytes = b"" + + # Extract encryption context from headers if present + headers = params.get("headers", {}) + encryption_context = None + + # Check if EncryptionContext was passed (it would be in a custom header) + # For now, we'll handle it through metadata + + # Get metadata from headers + metadata = {} + for key, value in headers.items(): + if key.lower().startswith("x-amz-meta-"): + # Extract the metadata key (remove x-amz-meta- prefix) + meta_key = key[11:] # len("x-amz-meta-") = 11 + metadata[meta_key] = value + + # Create a pipeline and encrypt the data + pipeline = PutEncryptedObjectPipeline(self.config.cmm) + encrypted_data, encryption_metadata = pipeline.encrypt( + body_bytes, encryption_context=encryption_context + ) + + # Update the body with encrypted data + params["body"] = encrypted_data + + # Add encryption metadata to headers + if encryption_metadata: + for key, value in encryption_metadata.items(): + # Add as S3 metadata headers + header_key = f"x-amz-meta-{key}" + headers[header_key] = value + + def on_get_object_after_call(self, parsed, **kwargs): + """Event handler for after-call.s3.GetObject. + + This handler decrypts the body after the response is received from S3. + + Args: + parsed: Dictionary containing the parsed response + **kwargs: Additional event arguments (includes 'params' with request parameters) + """ + # Extract encryption context from original request params if available + request_params = kwargs.get("params", {}) + encryption_context = request_params.pop("EncryptionContext", None) + + # The parsed response already has the Body as a StreamingBody + # We need to read it, decrypt it, and replace it + + # Create a response dict that matches what the pipeline expects + response = { + "Body": parsed.get("Body"), + "Metadata": parsed.get("Metadata", {}), + } + + # Create a pipeline and decrypt the data + pipeline = GetEncryptedObjectPipeline(self.config.cmm) + decrypted_data = pipeline.decrypt(response, encryption_context) + + # Replace body with decrypted data + stream = io.BytesIO(decrypted_data) + streaming_body = StreamingBody(stream, len(decrypted_data)) + parsed["Body"] = streaming_body + + @define class S3EncryptionClient: """Client for encrypting and decrypting S3 objects. This client wraps a boto3 S3 client and provides encryption and decryption capabilities for S3 objects using the configured keyring and crypto materials manager. + + The encryption/decryption is implemented using boto3's event system, registering + handlers for before-call and after-call events. """ wrapped_s3_client = field() config: S3EncryptionClientConfig = field() + _plugin: S3EncryptionClientPlugin = field(init=False) def __attrs_post_init__(self): - """Validate serialization encoding after initialization. + """Install the encryption plugin on the wrapped client using boto3 events.""" + # Create the plugin + object.__setattr__(self, "_plugin", S3EncryptionClientPlugin(self.config)) - Ensures boto3 serializers are using the expected default encoding. - """ - # Sanity check that boto3 serialization are ONLY using the default encoding (utf-8) - # This should always be the case, but changes in encoding would break the assumption that - # the decrypted plaintext adheres to the non-utf8 encoding scheme. So we avoid that. - for sz_name, sz in serialize.SERIALIZERS.items(): - if sz.DEFAULT_ENCODING != DEFAULT_ENCODING: - raise S3EncryptionClientError( - f"All Serializers MUST only support utf-8 encoding, but {sz_name} is using " - f"{sz.DEFAULT_ENCODING}!" - ) + # Register event handlers using boto3's event system + event_system = self.wrapped_s3_client.meta.events + + # Register before-call handler for PutObject to encrypt data + # This happens after serialization, so Body is already bytes + event_system.register("before-call.s3.PutObject", self._plugin.on_put_object_before_call) + + # Register after-call handler for GetObject to decrypt data + event_system.register("after-call.s3.GetObject", self._plugin.on_get_object_after_call) def put_object(self, **kwargs): """Encrypt and upload an object to S3. @@ -71,52 +176,18 @@ def put_object(self, **kwargs): Returns: The response from the S3 client's put_object method. - """ - # Extract required parameters from kwargs - bucket = kwargs.pop("Bucket") - key = kwargs.pop("Key") - body = kwargs.pop("Body", b"") # Default to empty bytes when Body is not provided - encryption_context = kwargs.pop("EncryptionContext", None) - - # Create a pipeline for this operation - pipeline = PutEncryptedObjectPipeline(self.config.cmm) - # The documentation for boto3 asks for bytes or a file-like object, - # but in reality, it is possible to pass strings. - # Strings will be encoded using DEFAULT_ENCODING, - # which MUST match the default encoding defined int the Serializer class in botocore. - if isinstance(body, str): - data_bytes = body.encode(DEFAULT_ENCODING) - elif isinstance(body, bytes): - data_bytes = body - elif isinstance(body, io.IOBase): - # TODO: Streaming support - raise S3EncryptionClientError( - f"Body parameter of type {type(body)} is not an acceptable type! " - f"Streaming operations are not yet supported." - ) - else: + Raises: + S3EncryptionClientError: If the Body parameter has an invalid type. + """ + try: + return self.wrapped_s3_client.put_object(**kwargs) + except ParamValidationError as e: + # Wrap boto3's ParamValidationError with our custom error raise S3EncryptionClientError( - f"Body parameter of type {type(body)} is not an acceptable type! " + f"Body parameter of type {type(kwargs.get('Body'))} is not an acceptable type! " f"Use bytes or a file-like object." - ) - - # Now encrypt the bytes/file-like IOBase object - encrypted_data, encryption_metadata = pipeline.encrypt( - data_bytes, encryption_context=encryption_context - ) - - # Add encryption metadata to the request parameters - params = {"Bucket": bucket, "Key": key, "Body": encrypted_data, **kwargs} - - # Add encryption metadata to the parameters - if encryption_metadata: - # Merge any existing metadata with our encryption metadata - metadata = params.get("Metadata", {}) - metadata.update(encryption_metadata) - params["Metadata"] = metadata - - return self.wrapped_s3_client.put_object(**params) + ) from e def get_object(self, **kwargs): """Download and decrypt an object from S3. @@ -132,28 +203,4 @@ def get_object(self, **kwargs): The response from the S3 client's get_object method with the Body replaced with a StreamingBody containing the decrypted data. """ - # Extract encryption context if provided - encryption_context = kwargs.pop("EncryptionContext", None) - - # Create params for the S3 client - params = {**kwargs} - - # Get the encrypted object from S3 - response = self.wrapped_s3_client.get_object(**params) - - # Create a pipeline for this operation - pipeline = GetEncryptedObjectPipeline(self.config.cmm) - - # Decrypt the data using the pipeline - decrypted_data = pipeline.decrypt( - response, encryption_context - ) # encrypted_data, encryption_metadata) - - # Create a new streaming body with the decrypted data - stream = io.BytesIO(decrypted_data) - streaming_body = StreamingBody(stream, len(decrypted_data)) - - # Update the response with the decrypted data - response["Body"] = streaming_body - - return response + return self.wrapped_s3_client.get_object(**kwargs) diff --git a/src/s3_encryption/pipelines.py b/src/s3_encryption/pipelines.py index 37093803..4b2c0e48 100644 --- a/src/s3_encryption/pipelines.py +++ b/src/s3_encryption/pipelines.py @@ -102,6 +102,7 @@ def decrypt(self, response, encryption_context=None): bytes: The decrypted data """ # Convert the metadata dictionary to an ObjectMetadata instance + # TODO: Stream + Buffered Decryption encrypted_data = response.get("Body").read() encryption_metadata = response.get("Metadata", {}) metadata = ObjectMetadata.from_dict(encryption_metadata) From 910ab04e882cd78086882e4315a4e137e5d85821 Mon Sep 17 00:00:00 2001 From: Kess Plasmeier Date: Wed, 11 Feb 2026 16:54:39 -0800 Subject: [PATCH 02/10] second draft --- src/s3_encryption/__init__.py | 88 +++++---- src/s3_encryption/materials/kms_keyring.py | 10 +- test/integration/test_i_s3_encryption.py | 203 ++++++++++++++++++++- 3 files changed, 258 insertions(+), 43 deletions(-) diff --git a/src/s3_encryption/__init__.py b/src/s3_encryption/__init__.py index 8e670bff..52ec3e08 100644 --- a/src/s3_encryption/__init__.py +++ b/src/s3_encryption/__init__.py @@ -3,9 +3,9 @@ """Top-level S3 Encryption Client v3 for Python package.""" import io +import threading from attrs import define, field -from botocore.exceptions import ParamValidationError from botocore.response import StreamingBody from .exceptions import S3EncryptionClientError @@ -16,7 +16,7 @@ from .materials.keyring import AbstractKeyring from .pipelines import GetEncryptedObjectPipeline, PutEncryptedObjectPipeline -DEFAULT_ENCODING = "utf-8" +S3_METADATA_PREFIX = "x-amz-meta-" @define @@ -45,6 +45,7 @@ def __init__(self, config: S3EncryptionClientConfig): config: S3EncryptionClientConfig containing keyring and CMM """ self.config = config + self._context = threading.local() def on_put_object_before_call(self, params, **kwargs): """Event handler for before-call.s3.PutObject. @@ -64,42 +65,32 @@ def on_put_object_before_call(self, params, **kwargs): body_bytes = body elif hasattr(body, "read"): # It's a file-like object (BytesIO, etc.) - # TODO: Stream Encryption + # TODO(streaming): Add support for streaming encryption without reading entire body into memory body_bytes = body.read() else: - body_bytes = b"" - - # Extract encryption context from headers if present - headers = params.get("headers", {}) - encryption_context = None - - # Check if EncryptionContext was passed (it would be in a custom header) - # For now, we'll handle it through metadata + # Unexpected body type - should not happen as boto3 validates before this point + raise S3EncryptionClientError("Unexpected type of body parameter!") - # Get metadata from headers - metadata = {} - for key, value in headers.items(): - if key.lower().startswith("x-amz-meta-"): - # Extract the metadata key (remove x-amz-meta- prefix) - meta_key = key[11:] # len("x-amz-meta-") = 11 - metadata[meta_key] = value + encryption_context = getattr(self._context, "encryption_context", None) - # Create a pipeline and encrypt the data pipeline = PutEncryptedObjectPipeline(self.config.cmm) encrypted_data, encryption_metadata = pipeline.encrypt( body_bytes, encryption_context=encryption_context ) - # Update the body with encrypted data params["body"] = encrypted_data + headers = params.get("headers", {}) + # Add encryption metadata to headers if encryption_metadata: for key, value in encryption_metadata.items(): # Add as S3 metadata headers - header_key = f"x-amz-meta-{key}" + header_key = f"{S3_METADATA_PREFIX}{key}" headers[header_key] = value + params["headers"] = headers + def on_get_object_after_call(self, parsed, **kwargs): """Event handler for after-call.s3.GetObject. @@ -109,9 +100,8 @@ def on_get_object_after_call(self, parsed, **kwargs): parsed: Dictionary containing the parsed response **kwargs: Additional event arguments (includes 'params' with request parameters) """ - # Extract encryption context from original request params if available - request_params = kwargs.get("params", {}) - encryption_context = request_params.pop("EncryptionContext", None) + # Get encryption context from thread-local storage (set by get_object wrapper) + encryption_context = getattr(self._context, "encryption_context", None) # The parsed response already has the Body as a StreamingBody # We need to read it, decrypt it, and replace it @@ -154,12 +144,7 @@ def __attrs_post_init__(self): # Register event handlers using boto3's event system event_system = self.wrapped_s3_client.meta.events - - # Register before-call handler for PutObject to encrypt data - # This happens after serialization, so Body is already bytes event_system.register("before-call.s3.PutObject", self._plugin.on_put_object_before_call) - - # Register after-call handler for GetObject to decrypt data event_system.register("after-call.s3.GetObject", self._plugin.on_get_object_after_call) def put_object(self, **kwargs): @@ -178,16 +163,27 @@ def put_object(self, **kwargs): The response from the S3 client's put_object method. Raises: - S3EncryptionClientError: If the Body parameter has an invalid type. + S3EncryptionClientError: Any problem with encryption, including if the Body parameter has an invalid type. """ + # Extract EncryptionContext if provided (not a standard S3 parameter) + encryption_context = kwargs.pop("EncryptionContext", None) + + # Store encryption context in thread-local storage for the event handler + self._plugin._context.encryption_context = encryption_context + try: return self.wrapped_s3_client.put_object(**kwargs) - except ParamValidationError as e: - # Wrap boto3's ParamValidationError with our custom error + except S3EncryptionClientError: + # Re-raise our own exceptions without wrapping + raise + except Exception as e: raise S3EncryptionClientError( - f"Body parameter of type {type(kwargs.get('Body'))} is not an acceptable type! " - f"Use bytes or a file-like object." + f"Failed to encryption object: {str(e)}" ) from e + finally: + # Clean up thread-local storage + if hasattr(self._plugin._context, "encryption_context"): + delattr(self._plugin._context, "encryption_context") def get_object(self, **kwargs): """Download and decrypt an object from S3. @@ -202,5 +198,27 @@ def get_object(self, **kwargs): Returns: The response from the S3 client's get_object method with the Body replaced with a StreamingBody containing the decrypted data. + + Raises: + S3EncryptionClientError: If decryption fails or the object is not properly encrypted. """ - return self.wrapped_s3_client.get_object(**kwargs) + # Extract EncryptionContext if provided (not a standard S3 parameter) + encryption_context = kwargs.pop("EncryptionContext", None) + + # Store encryption context in thread-local storage for the event handler + self._plugin._context.encryption_context = encryption_context + + try: + return self.wrapped_s3_client.get_object(**kwargs) + except S3EncryptionClientError: + # Re-raise our own exceptions without wrapping + raise + except Exception as e: + # Wrap any unexpected errors during decryption + raise S3EncryptionClientError( + f"Failed to decrypt object: {str(e)}" + ) from e + finally: + # Clean up thread-local storage + if hasattr(self._plugin._context, "encryption_context"): + delattr(self._plugin._context, "encryption_context") diff --git a/src/s3_encryption/materials/kms_keyring.py b/src/s3_encryption/materials/kms_keyring.py index 7bc8f7bd..835816e3 100644 --- a/src/s3_encryption/materials/kms_keyring.py +++ b/src/s3_encryption/materials/kms_keyring.py @@ -46,8 +46,9 @@ def on_encrypt(self, enc_materials): # Call parent class validation enc_materials = super().on_encrypt(enc_materials) + # Copy encryption context to avoid modifying the original + encryption_context = enc_materials.encryption_context.copy() # Add default encryption context - encryption_context = enc_materials.encryption_context encryption_context["aws:x-amz-cek-alg"] = "AES/GCM/NoPadding" response = self.kms_client.generate_data_key( @@ -61,6 +62,8 @@ def on_encrypt(self, enc_materials): ) enc_materials.encrypted_data_key = encrypted_data_key enc_materials.plaintext_data_key = response["Plaintext"] + # Update enc_materials with the modified encryption context (with default added) + enc_materials.encryption_context = encryption_context return enc_materials except Exception: raise @@ -108,10 +111,13 @@ def on_decrypt(self, dec_materials, encrypted_data_keys=None): ) # The stored EC, minus default key/values, MUST match provided EC + # If no EC is provided from request (empty dict), use stored EC encryption_context_stored_copy = encryption_context_stored.copy() encryption_context_stored_copy.pop(KMS_V1_DEFAULT_KEY, None) encryption_context_stored_copy.pop(KMS_CONTEXT_DEFAULT_KEY, None) - if encryption_context_stored_copy != encryption_context_from_request: + + # Only validate if encryption context was explicitly provided in request + if encryption_context_from_request and encryption_context_stored_copy != encryption_context_from_request: # TODO: modeled error raise S3EncryptionClientError( "Provided encryption context does not match information " diff --git a/test/integration/test_i_s3_encryption.py b/test/integration/test_i_s3_encryption.py index 2c8ea73a..282b08e3 100644 --- a/test/integration/test_i_s3_encryption.py +++ b/test/integration/test_i_s3_encryption.py @@ -259,31 +259,222 @@ def test_invalid_body_types(): # Test with integer with pytest.raises(S3EncryptionClientError) as excinfo: s3ec.put_object(Bucket=bucket, Key=key, Body=42) - assert "not an acceptable type" in str(excinfo.value) + assert "Invalid type for parameter Body" in str(excinfo.value) # Test with float with pytest.raises(S3EncryptionClientError) as excinfo: s3ec.put_object(Bucket=bucket, Key=key, Body=3.14) - assert "not an acceptable type" in str(excinfo.value) + assert "Invalid type for parameter Body" in str(excinfo.value) # Test with list with pytest.raises(S3EncryptionClientError) as excinfo: s3ec.put_object(Bucket=bucket, Key=key, Body=[1, 2, 3]) - assert "not an acceptable type" in str(excinfo.value) + assert "Invalid type for parameter Body" in str(excinfo.value) # Test with dictionary with pytest.raises(S3EncryptionClientError) as excinfo: s3ec.put_object(Bucket=bucket, Key=key, Body={"key": "value"}) - assert "not an acceptable type" in str(excinfo.value) + assert "Invalid type for parameter Body" in str(excinfo.value) # Test with boolean with pytest.raises(S3EncryptionClientError) as excinfo: s3ec.put_object(Bucket=bucket, Key=key, Body=True) - assert "not an acceptable type" in str(excinfo.value) + assert "Invalid type for parameter Body" in str(excinfo.value) # Test with None (also raises an exception) with pytest.raises(S3EncryptionClientError) as excinfo: s3ec.put_object(Bucket=bucket, Key=key, Body=None) - assert "not an acceptable type" in str(excinfo.value) + assert "Invalid type for parameter Body" in str(excinfo.value) print("Success! All invalid body types correctly raised exceptions.") + + +def test_user_metadata_preservation(): + """Test that user-provided metadata is preserved during encryption.""" + key = "metadata-preservation-rt" + key += datetime.now().strftime("%Y-%m-%d-%H:%M:%S") + + data = "Test data with user metadata" + + # User metadata to include + user_metadata = { + "author": "test-user", + "version": "1.0", + "description": "Test object with custom metadata" + } + + kms_client = boto3.client("kms", region_name=region) + keyring = KmsKeyring(kms_client, kms_key_id) + wrapped_client = boto3.client("s3") + config = S3EncryptionClientConfig(keyring) + s3ec = S3EncryptionClient(wrapped_client, config) + + # Put object with user metadata + s3ec.put_object(Bucket=bucket, Key=key, Body=data, Metadata=user_metadata) + + # Get the object back + get_req = {"Bucket": bucket, "Key": key} + response = s3ec.get_object(**get_req) + + # Verify the data decrypts correctly + output = response["Body"].read().decode("utf-8") + if output != data: + print("Uh oh! Input and output don't match!") + print("Input:") + print(repr(data)) + print("Output:") + print(repr(output)) + raise RuntimeError + + # Verify user metadata is preserved + returned_metadata = response.get("Metadata", {}) + + for key_name, expected_value in user_metadata.items(): + if key_name not in returned_metadata: + print(f"Uh oh! User metadata key '{key_name}' is missing!") + print("Expected metadata:") + print(user_metadata) + print("Returned metadata:") + print(returned_metadata) + raise RuntimeError + + if returned_metadata[key_name] != expected_value: + print(f"Uh oh! User metadata value for '{key_name}' doesn't match!") + print(f"Expected: {expected_value}") + print(f"Got: {returned_metadata[key_name]}") + raise RuntimeError + + print("Success! User metadata preserved correctly during encryption/decryption.") + print(f"User metadata: {user_metadata}") + print(f"Returned metadata keys: {list(returned_metadata.keys())}") + + +def test_encryption_context_roundtrip(): + """Test that EncryptionContext is properly used during encryption and required for decryption.""" + key = "encryption-context-rt" + key += datetime.now().strftime("%Y-%m-%d-%H:%M:%S") + + data = "Test data with encryption context" + + # Encryption context to use for additional authenticated data + encryption_context = { + "department": "engineering", + "project": "s3-encryption", + "environment": "test" + } + + kms_client = boto3.client("kms", region_name=region) + keyring = KmsKeyring(kms_client, kms_key_id) + wrapped_client = boto3.client("s3") + config = S3EncryptionClientConfig(keyring) + s3ec = S3EncryptionClient(wrapped_client, config) + + # Put object with encryption context + s3ec.put_object(Bucket=bucket, Key=key, Body=data, EncryptionContext=encryption_context) + + # Get the object back WITH the same encryption context + get_req = {"Bucket": bucket, "Key": key, "EncryptionContext": encryption_context} + response = s3ec.get_object(**get_req) + + # Verify the data decrypts correctly + output = response["Body"].read().decode("utf-8") + if output != data: + print("Uh oh! Input and output don't match!") + print("Input:") + print(repr(data)) + print("Output:") + print(repr(output)) + raise RuntimeError + + print("Success! Encryption context used correctly during encryption/decryption.") + print(f"Encryption context: {encryption_context}") + + +def test_encryption_context_mismatch(): + """Test that decryption fails when EncryptionContext doesn't match.""" + key = "encryption-context-mismatch" + key += datetime.now().strftime("%Y-%m-%d-%H:%M:%S") + + data = "Test data with encryption context" + + # Original encryption context + encryption_context = { + "department": "engineering", + "project": "s3-encryption" + } + + # Wrong encryption context for decryption + wrong_encryption_context = { + "department": "marketing", + "project": "s3-encryption" + } + + kms_client = boto3.client("kms", region_name=region) + keyring = KmsKeyring(kms_client, kms_key_id) + wrapped_client = boto3.client("s3") + config = S3EncryptionClientConfig(keyring) + s3ec = S3EncryptionClient(wrapped_client, config) + + # Put object with encryption context + s3ec.put_object(Bucket=bucket, Key=key, Body=data, EncryptionContext=encryption_context) + + # Try to get the object back with WRONG encryption context - should fail + get_req = {"Bucket": bucket, "Key": key, "EncryptionContext": wrong_encryption_context} + + try: + response = s3ec.get_object(**get_req) + # If we get here, the test failed - decryption should have failed + print("Uh oh! Decryption succeeded with wrong encryption context!") + print(f"Original context: {encryption_context}") + print(f"Wrong context used: {wrong_encryption_context}") + raise RuntimeError("Expected decryption to fail with mismatched encryption context") + except S3EncryptionClientError as e: + # This is expected - decryption should fail + print("Success! Decryption correctly failed with mismatched encryption context.") + print(f"Error message: {str(e)}") + except Exception as e: + # Some other error occurred + print(f"Unexpected error type: {type(e).__name__}") + print(f"Error message: {str(e)}") + raise + + +def test_encryption_context_missing_on_decrypt(): + """Test that decryption succeeds using stored encryption context when not provided in request.""" + key = "encryption-context-missing" + key += datetime.now().strftime("%Y-%m-%d-%H:%M:%S") + + data = "Test data with encryption context" + + # Encryption context used during encryption + encryption_context = { + "department": "engineering", + "project": "s3-encryption" + } + + kms_client = boto3.client("kms", region_name=region) + keyring = KmsKeyring(kms_client, kms_key_id) + wrapped_client = boto3.client("s3") + config = S3EncryptionClientConfig(keyring) + s3ec = S3EncryptionClient(wrapped_client, config) + + # Put object with encryption context + s3ec.put_object(Bucket=bucket, Key=key, Body=data, EncryptionContext=encryption_context) + + # Get the object back WITHOUT encryption context - should succeed using stored context + get_req = {"Bucket": bucket, "Key": key} + + response = s3ec.get_object(**get_req) + + # Verify the data decrypts correctly + output = response["Body"].read().decode("utf-8") + if output != data: + print("Uh oh! Input and output don't match!") + print("Input:") + print(repr(data)) + print("Output:") + print(repr(output)) + raise RuntimeError + + print("Success! Decryption correctly used stored encryption context when not provided in request.") + print(f"Original encryption context: {encryption_context}") From f0f65a7092738fc0ece5058b7a3d80eb9f3594b6 Mon Sep 17 00:00:00 2001 From: Kess Plasmeier Date: Wed, 11 Feb 2026 17:02:40 -0800 Subject: [PATCH 03/10] format --- src/s3_encryption/__init__.py | 16 ++++++---------- src/s3_encryption/materials/kms_keyring.py | 7 +++++-- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/src/s3_encryption/__init__.py b/src/s3_encryption/__init__.py index 52ec3e08..c96df68d 100644 --- a/src/s3_encryption/__init__.py +++ b/src/s3_encryption/__init__.py @@ -167,19 +167,17 @@ def put_object(self, **kwargs): """ # Extract EncryptionContext if provided (not a standard S3 parameter) encryption_context = kwargs.pop("EncryptionContext", None) - + # Store encryption context in thread-local storage for the event handler self._plugin._context.encryption_context = encryption_context - + try: return self.wrapped_s3_client.put_object(**kwargs) except S3EncryptionClientError: # Re-raise our own exceptions without wrapping raise except Exception as e: - raise S3EncryptionClientError( - f"Failed to encryption object: {str(e)}" - ) from e + raise S3EncryptionClientError(f"Failed to encryption object: {str(e)}") from e finally: # Clean up thread-local storage if hasattr(self._plugin._context, "encryption_context"): @@ -204,10 +202,10 @@ def get_object(self, **kwargs): """ # Extract EncryptionContext if provided (not a standard S3 parameter) encryption_context = kwargs.pop("EncryptionContext", None) - + # Store encryption context in thread-local storage for the event handler self._plugin._context.encryption_context = encryption_context - + try: return self.wrapped_s3_client.get_object(**kwargs) except S3EncryptionClientError: @@ -215,9 +213,7 @@ def get_object(self, **kwargs): raise except Exception as e: # Wrap any unexpected errors during decryption - raise S3EncryptionClientError( - f"Failed to decrypt object: {str(e)}" - ) from e + raise S3EncryptionClientError(f"Failed to decrypt object: {str(e)}") from e finally: # Clean up thread-local storage if hasattr(self._plugin._context, "encryption_context"): diff --git a/src/s3_encryption/materials/kms_keyring.py b/src/s3_encryption/materials/kms_keyring.py index 835816e3..12fdd68e 100644 --- a/src/s3_encryption/materials/kms_keyring.py +++ b/src/s3_encryption/materials/kms_keyring.py @@ -115,9 +115,12 @@ def on_decrypt(self, dec_materials, encrypted_data_keys=None): encryption_context_stored_copy = encryption_context_stored.copy() encryption_context_stored_copy.pop(KMS_V1_DEFAULT_KEY, None) encryption_context_stored_copy.pop(KMS_CONTEXT_DEFAULT_KEY, None) - + # Only validate if encryption context was explicitly provided in request - if encryption_context_from_request and encryption_context_stored_copy != encryption_context_from_request: + if ( + encryption_context_from_request + and encryption_context_stored_copy != encryption_context_from_request + ): # TODO: modeled error raise S3EncryptionClientError( "Provided encryption context does not match information " From 15372744aad8b9fa53d35496e65ce4501c65d1fe Mon Sep 17 00:00:00 2001 From: Kess Plasmeier Date: Wed, 11 Feb 2026 17:04:44 -0800 Subject: [PATCH 04/10] scope format/lint --- Makefile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index 814ab334..7db980c3 100644 --- a/Makefile +++ b/Makefile @@ -10,14 +10,14 @@ install: # Run linting checks lint: - uv run black --check . + uv run black --check src/ test/ # Enforce ruff checks on src/ but allow test/ to fail uv run ruff check src/ uv run ruff check test/ || true # Format code with Black and Ruff format: - uv run black . + uv run black src/ test/ uv run ruff check --fix src/ test/ # Run all tests From 4753ca4af5f5728818b99bdb7bf4d1eed06715ef Mon Sep 17 00:00:00 2001 From: Kess Plasmeier Date: Thu, 12 Feb 2026 15:47:55 -0800 Subject: [PATCH 05/10] fix lint, EC behavior --- pyproject.toml | 1 + src/s3_encryption/__init__.py | 6 +- src/s3_encryption/materials/kms_keyring.py | 7 +- test/integration/test_i_s3_encryption.py | 75 +++++++++---------- test/test_decryption_materials_integration.py | 37 +++++---- 5 files changed, 57 insertions(+), 69 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b05f22aa..b7489a03 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ exclude = [".git", "__pycache__", "build", "dist"] [tool.ruff.lint] # Enable all rules by default, then configure specific rule settings below select = ["E", "F", "W", "I", "N", "D", "UP", "B", "A", "C4", "PT", "RET", "SIM", "ARG", "ERA"] +ignore = ["ARG002"] # Allow unused method arguments (e.g., **kwargs for API compatibility) [tool.ruff.lint.pydocstyle] convention = "google" diff --git a/src/s3_encryption/__init__.py b/src/s3_encryption/__init__.py index c96df68d..8bb1fb62 100644 --- a/src/s3_encryption/__init__.py +++ b/src/s3_encryption/__init__.py @@ -65,7 +65,8 @@ def on_put_object_before_call(self, params, **kwargs): body_bytes = body elif hasattr(body, "read"): # It's a file-like object (BytesIO, etc.) - # TODO(streaming): Add support for streaming encryption without reading entire body into memory + # TODO(streaming): Add support for streaming encryption without reading entire body + # into memory body_bytes = body.read() else: # Unexpected body type - should not happen as boto3 validates before this point @@ -163,7 +164,8 @@ def put_object(self, **kwargs): The response from the S3 client's put_object method. Raises: - S3EncryptionClientError: Any problem with encryption, including if the Body parameter has an invalid type. + S3EncryptionClientError: Any problem with encryption, including if the Body parameter + has an invalid type. """ # Extract EncryptionContext if provided (not a standard S3 parameter) encryption_context = kwargs.pop("EncryptionContext", None) diff --git a/src/s3_encryption/materials/kms_keyring.py b/src/s3_encryption/materials/kms_keyring.py index 12fdd68e..73e3b013 100644 --- a/src/s3_encryption/materials/kms_keyring.py +++ b/src/s3_encryption/materials/kms_keyring.py @@ -111,16 +111,11 @@ def on_decrypt(self, dec_materials, encrypted_data_keys=None): ) # The stored EC, minus default key/values, MUST match provided EC - # If no EC is provided from request (empty dict), use stored EC encryption_context_stored_copy = encryption_context_stored.copy() encryption_context_stored_copy.pop(KMS_V1_DEFAULT_KEY, None) encryption_context_stored_copy.pop(KMS_CONTEXT_DEFAULT_KEY, None) - # Only validate if encryption context was explicitly provided in request - if ( - encryption_context_from_request - and encryption_context_stored_copy != encryption_context_from_request - ): + if encryption_context_stored_copy != encryption_context_from_request: # TODO: modeled error raise S3EncryptionClientError( "Provided encryption context does not match information " diff --git a/test/integration/test_i_s3_encryption.py b/test/integration/test_i_s3_encryption.py index 282b08e3..4d77ba9c 100644 --- a/test/integration/test_i_s3_encryption.py +++ b/test/integration/test_i_s3_encryption.py @@ -216,7 +216,7 @@ def test_binary_data_roundtrip(): key += datetime.now().strftime("%Y-%m-%d-%H:%M:%S") # Create some binary data (not valid in any particular encoding) - data = bytes([i for i in range(256)]) + data = bytes(range(256)) kms_client = boto3.client("kms", region_name=region) @@ -295,12 +295,12 @@ def test_user_metadata_preservation(): key += datetime.now().strftime("%Y-%m-%d-%H:%M:%S") data = "Test data with user metadata" - + # User metadata to include user_metadata = { "author": "test-user", "version": "1.0", - "description": "Test object with custom metadata" + "description": "Test object with custom metadata", } kms_client = boto3.client("kms", region_name=region) @@ -308,14 +308,14 @@ def test_user_metadata_preservation(): wrapped_client = boto3.client("s3") config = S3EncryptionClientConfig(keyring) s3ec = S3EncryptionClient(wrapped_client, config) - + # Put object with user metadata s3ec.put_object(Bucket=bucket, Key=key, Body=data, Metadata=user_metadata) - + # Get the object back get_req = {"Bucket": bucket, "Key": key} response = s3ec.get_object(**get_req) - + # Verify the data decrypts correctly output = response["Body"].read().decode("utf-8") if output != data: @@ -325,10 +325,10 @@ def test_user_metadata_preservation(): print("Output:") print(repr(output)) raise RuntimeError - + # Verify user metadata is preserved returned_metadata = response.get("Metadata", {}) - + for key_name, expected_value in user_metadata.items(): if key_name not in returned_metadata: print(f"Uh oh! User metadata key '{key_name}' is missing!") @@ -337,13 +337,13 @@ def test_user_metadata_preservation(): print("Returned metadata:") print(returned_metadata) raise RuntimeError - + if returned_metadata[key_name] != expected_value: print(f"Uh oh! User metadata value for '{key_name}' doesn't match!") print(f"Expected: {expected_value}") print(f"Got: {returned_metadata[key_name]}") raise RuntimeError - + print("Success! User metadata preserved correctly during encryption/decryption.") print(f"User metadata: {user_metadata}") print(f"Returned metadata keys: {list(returned_metadata.keys())}") @@ -355,12 +355,12 @@ def test_encryption_context_roundtrip(): key += datetime.now().strftime("%Y-%m-%d-%H:%M:%S") data = "Test data with encryption context" - + # Encryption context to use for additional authenticated data encryption_context = { "department": "engineering", "project": "s3-encryption", - "environment": "test" + "environment": "test", } kms_client = boto3.client("kms", region_name=region) @@ -368,14 +368,14 @@ def test_encryption_context_roundtrip(): wrapped_client = boto3.client("s3") config = S3EncryptionClientConfig(keyring) s3ec = S3EncryptionClient(wrapped_client, config) - + # Put object with encryption context s3ec.put_object(Bucket=bucket, Key=key, Body=data, EncryptionContext=encryption_context) - + # Get the object back WITH the same encryption context get_req = {"Bucket": bucket, "Key": key, "EncryptionContext": encryption_context} response = s3ec.get_object(**get_req) - + # Verify the data decrypts correctly output = response["Body"].read().decode("utf-8") if output != data: @@ -385,7 +385,7 @@ def test_encryption_context_roundtrip(): print("Output:") print(repr(output)) raise RuntimeError - + print("Success! Encryption context used correctly during encryption/decryption.") print(f"Encryption context: {encryption_context}") @@ -396,33 +396,27 @@ def test_encryption_context_mismatch(): key += datetime.now().strftime("%Y-%m-%d-%H:%M:%S") data = "Test data with encryption context" - + # Original encryption context - encryption_context = { - "department": "engineering", - "project": "s3-encryption" - } - + encryption_context = {"department": "engineering", "project": "s3-encryption"} + # Wrong encryption context for decryption - wrong_encryption_context = { - "department": "marketing", - "project": "s3-encryption" - } + wrong_encryption_context = {"department": "marketing", "project": "s3-encryption"} kms_client = boto3.client("kms", region_name=region) keyring = KmsKeyring(kms_client, kms_key_id) wrapped_client = boto3.client("s3") config = S3EncryptionClientConfig(keyring) s3ec = S3EncryptionClient(wrapped_client, config) - + # Put object with encryption context s3ec.put_object(Bucket=bucket, Key=key, Body=data, EncryptionContext=encryption_context) - + # Try to get the object back with WRONG encryption context - should fail get_req = {"Bucket": bucket, "Key": key, "EncryptionContext": wrong_encryption_context} - + try: - response = s3ec.get_object(**get_req) + s3ec.get_object(**get_req) # If we get here, the test failed - decryption should have failed print("Uh oh! Decryption succeeded with wrong encryption context!") print(f"Original context: {encryption_context}") @@ -445,27 +439,24 @@ def test_encryption_context_missing_on_decrypt(): key += datetime.now().strftime("%Y-%m-%d-%H:%M:%S") data = "Test data with encryption context" - + # Encryption context used during encryption - encryption_context = { - "department": "engineering", - "project": "s3-encryption" - } + encryption_context = {"department": "engineering", "project": "s3-encryption"} kms_client = boto3.client("kms", region_name=region) keyring = KmsKeyring(kms_client, kms_key_id) wrapped_client = boto3.client("s3") config = S3EncryptionClientConfig(keyring) s3ec = S3EncryptionClient(wrapped_client, config) - + # Put object with encryption context s3ec.put_object(Bucket=bucket, Key=key, Body=data, EncryptionContext=encryption_context) - + # Get the object back WITHOUT encryption context - should succeed using stored context get_req = {"Bucket": bucket, "Key": key} - + response = s3ec.get_object(**get_req) - + # Verify the data decrypts correctly output = response["Body"].read().decode("utf-8") if output != data: @@ -475,6 +466,8 @@ def test_encryption_context_missing_on_decrypt(): print("Output:") print(repr(output)) raise RuntimeError - - print("Success! Decryption correctly used stored encryption context when not provided in request.") + + print( + "Success! Decryption correctly used stored encryption context when not provided in request." + ) print(f"Original encryption context: {encryption_context}") diff --git a/test/test_decryption_materials_integration.py b/test/test_decryption_materials_integration.py index 1cfab083..06729897 100644 --- a/test/test_decryption_materials_integration.py +++ b/test/test_decryption_materials_integration.py @@ -30,17 +30,15 @@ def test_keyring_on_decrypt(self): encryption_context_from_request={"key2": "value2"}, ) - # Mock the validation method to return the materials - with patch.object(S3Keyring, "on_decrypt", return_value=materials) as mock_on_decrypt: - # Call on_decrypt - result = keyring.on_decrypt(materials, [edk]) + # Call on_decrypt + result = keyring.on_decrypt(materials, [edk]) - # Verify the result is a DecryptionMaterials instance - assert isinstance(result, DecryptionMaterials) - assert result.iv == b"initialization-vector" - assert result.encrypted_data_keys == [edk] - assert result.encryption_context_stored == {"key1": "value1"} - assert result.encryption_context_from_request == {"key2": "value2"} + # Verify the result is a DecryptionMaterials instance + assert isinstance(result, DecryptionMaterials) + assert result.iv == b"initialization-vector" + assert result.encrypted_data_keys == [edk] + assert result.encryption_context_stored == {"key1": "value1"} + assert result.encryption_context_from_request == {"key2": "value2"} def test_keyring_on_decrypt_default_enc_ctx(self): """Test that S3Keyring.on_decrypt properly handles DecryptionMaterials.""" @@ -63,16 +61,15 @@ def test_keyring_on_decrypt_default_enc_ctx(self): ) # Mock the validation method to return the materials - with patch.object(S3Keyring, "on_decrypt", return_value=materials) as mock_on_decrypt: - # Call on_decrypt - result = keyring.on_decrypt(materials, [edk]) - - # Verify the result is a DecryptionMaterials instance - assert isinstance(result, DecryptionMaterials) - assert result.iv == b"initialization-vector" - assert result.encrypted_data_keys == [edk] - assert result.encryption_context_stored == {} - assert result.encryption_context_from_request == {} + # Call on_decrypt + result = keyring.on_decrypt(materials, [edk]) + + # Verify the result is a DecryptionMaterials instance + assert isinstance(result, DecryptionMaterials) + assert result.iv == b"initialization-vector" + assert result.encrypted_data_keys == [edk] + assert result.encryption_context_stored == {} + assert result.encryption_context_from_request == {} def test_cmm_decrypt_materials_with_dict(self): """Test that DefaultCryptoMaterialsManager.decrypt_materials properly handles dictionary input.""" From d2d43cbb88227e238c4d6280879ac6a10dc09265 Mon Sep 17 00:00:00 2001 From: Kess Plasmeier Date: Thu, 12 Feb 2026 15:53:33 -0800 Subject: [PATCH 06/10] fix integ test --- test/integration/test_i_s3_encryption.py | 35 ++++++++++++------------ 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/test/integration/test_i_s3_encryption.py b/test/integration/test_i_s3_encryption.py index 4d77ba9c..616f8da4 100644 --- a/test/integration/test_i_s3_encryption.py +++ b/test/integration/test_i_s3_encryption.py @@ -434,7 +434,7 @@ def test_encryption_context_mismatch(): def test_encryption_context_missing_on_decrypt(): - """Test that decryption succeeds using stored encryption context when not provided in request.""" + """Test that decryption fails when encryption context is not provided for an object encrypted with context.""" key = "encryption-context-missing" key += datetime.now().strftime("%Y-%m-%d-%H:%M:%S") @@ -452,22 +452,21 @@ def test_encryption_context_missing_on_decrypt(): # Put object with encryption context s3ec.put_object(Bucket=bucket, Key=key, Body=data, EncryptionContext=encryption_context) - # Get the object back WITHOUT encryption context - should succeed using stored context + # Try to get the object back WITHOUT encryption context - should fail get_req = {"Bucket": bucket, "Key": key} - response = s3ec.get_object(**get_req) - - # Verify the data decrypts correctly - output = response["Body"].read().decode("utf-8") - if output != data: - print("Uh oh! Input and output don't match!") - print("Input:") - print(repr(data)) - print("Output:") - print(repr(output)) - raise RuntimeError - - print( - "Success! Decryption correctly used stored encryption context when not provided in request." - ) - print(f"Original encryption context: {encryption_context}") + try: + s3ec.get_object(**get_req) + # If we get here, the test failed - decryption should have failed + print("Uh oh! Decryption succeeded without providing required encryption context!") + print(f"Original context: {encryption_context}") + raise RuntimeError("Expected decryption to fail when encryption context not provided") + except S3EncryptionClientError as e: + # This is expected - decryption should fail + print("Success! Decryption correctly failed when encryption context was not provided.") + print(f"Error message: {str(e)}") + except Exception as e: + # Some other error occurred + print(f"Unexpected error type: {type(e).__name__}") + print(f"Error message: {str(e)}") + raise From c82d3eba637d911191acb6ab12d185b303633409 Mon Sep 17 00:00:00 2001 From: Kess Plasmeier Date: Thu, 12 Feb 2026 15:59:49 -0800 Subject: [PATCH 07/10] better code --- src/s3_encryption/materials/kms_keyring.py | 6 +----- src/s3_encryption/pipelines.py | 4 ++-- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/src/s3_encryption/materials/kms_keyring.py b/src/s3_encryption/materials/kms_keyring.py index 73e3b013..f8bc4997 100644 --- a/src/s3_encryption/materials/kms_keyring.py +++ b/src/s3_encryption/materials/kms_keyring.py @@ -46,9 +46,7 @@ def on_encrypt(self, enc_materials): # Call parent class validation enc_materials = super().on_encrypt(enc_materials) - # Copy encryption context to avoid modifying the original - encryption_context = enc_materials.encryption_context.copy() - # Add default encryption context + encryption_context = enc_materials.encryption_context encryption_context["aws:x-amz-cek-alg"] = "AES/GCM/NoPadding" response = self.kms_client.generate_data_key( @@ -62,8 +60,6 @@ def on_encrypt(self, enc_materials): ) enc_materials.encrypted_data_key = encrypted_data_key enc_materials.plaintext_data_key = response["Plaintext"] - # Update enc_materials with the modified encryption context (with default added) - enc_materials.encryption_context = encryption_context return enc_materials except Exception: raise diff --git a/src/s3_encryption/pipelines.py b/src/s3_encryption/pipelines.py index 4b2c0e48..6867ed7c 100644 --- a/src/s3_encryption/pipelines.py +++ b/src/s3_encryption/pipelines.py @@ -39,9 +39,9 @@ def encrypt(self, plaintext, encryption_context=None): bytes: The encrypted data dict: Metadata about the encryption to be stored with the object """ - # Create encryption materials request with encryption context + # Create encryption materials request with encryption context copy enc_mats_request = EncryptionMaterials( - encryption_context={} if encryption_context is None else encryption_context + encryption_context={} if encryption_context is None else encryption_context.copy() ) # Get encryption materials from the crypto materials manager From 8d814878a4edb4faa3cd8b5ddb572a6dd0dd026a Mon Sep 17 00:00:00 2001 From: Kess Plasmeier Date: Fri, 13 Feb 2026 09:32:29 -0800 Subject: [PATCH 08/10] add multithreaded EncCtx integ tests --- .../test_i_s3_encryption_multithreaded.py | 305 ++++++++++++++++++ 1 file changed, 305 insertions(+) create mode 100644 test/integration/test_i_s3_encryption_multithreaded.py diff --git a/test/integration/test_i_s3_encryption_multithreaded.py b/test/integration/test_i_s3_encryption_multithreaded.py new file mode 100644 index 00000000..1365e7e7 --- /dev/null +++ b/test/integration/test_i_s3_encryption_multithreaded.py @@ -0,0 +1,305 @@ +# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Multi-threaded integration tests for S3 Encryption Client. + +These tests verify that the thread-local storage of encryption context +is properly isolated between threads when using a single S3EncryptionClient +instance across multiple threads. +""" +import os +import threading +from concurrent.futures import ThreadPoolExecutor, as_completed +from datetime import datetime + +import boto3 + +from s3_encryption import S3EncryptionClient, S3EncryptionClientConfig +from s3_encryption.exceptions import S3EncryptionClientError +from s3_encryption.materials.kms_keyring import KmsKeyring + +bucket = os.environ.get("CI_S3_BUCKET", "s3ec-python-github-test-bucket") +region = os.environ.get("CI_AWS_REGION", "us-west-2") +kms_key_id = os.environ.get( + "CI_KMS_KEY_ALIAS", "arn:aws:kms:us-west-2:370957321024:alias/S3EC-Python-Github-KMS-Key" +) + + +def test_multithreaded_encryption_context_isolation(): + """Test that encryption context is properly isolated between threads. + + This test creates a single S3EncryptionClient instance and uses it + from multiple threads simultaneously, each with a different encryption + context. It verifies that: + 1. Each thread can encrypt with its own encryption context + 2. Each thread can decrypt only with the correct encryption context + 3. Thread-local storage doesn't leak between threads + """ + # Create a single S3EncryptionClient instance to be shared across threads + kms_client = boto3.client("kms", region_name=region) + keyring = KmsKeyring(kms_client, kms_key_id) + wrapped_client = boto3.client("s3") + config = S3EncryptionClientConfig(keyring) + s3ec = S3EncryptionClient(wrapped_client, config) + + # Number of threads to test with + num_threads = 10 + results = {} + errors = [] + + def thread_worker(thread_id): + """Worker function for each thread.""" + try: + # Each thread has its own unique encryption context + encryption_context = { + "thread_id": str(thread_id), + "department": f"dept-{thread_id}", + "project": f"project-{thread_id}", + } + + # Unique key for this thread + key = f"multithread-test-{thread_id}-{datetime.now().strftime('%Y%m%d-%H%M%S-%f')}" + data = f"Thread {thread_id} test data with unique encryption context" + + # Encrypt with thread-specific encryption context + s3ec.put_object( + Bucket=bucket, Key=key, Body=data, EncryptionContext=encryption_context + ) + + # Decrypt with the SAME encryption context - should succeed + response = s3ec.get_object( + Bucket=bucket, Key=key, EncryptionContext=encryption_context + ) + decrypted_data = response["Body"].read().decode("utf-8") + + if decrypted_data != data: + return { + "thread_id": thread_id, + "success": False, + "error": f"Data mismatch: expected '{data}', got '{decrypted_data}'", + } + + # Try to decrypt with a DIFFERENT encryption context - should fail + wrong_context = { + "thread_id": str(thread_id + 1000), + "department": f"wrong-dept", + "project": f"wrong-project", + } + + try: + s3ec.get_object(Bucket=bucket, Key=key, EncryptionContext=wrong_context) + return { + "thread_id": thread_id, + "success": False, + "error": "Decryption succeeded with wrong encryption context!", + } + except S3EncryptionClientError: + # Expected - decryption should fail with wrong context + pass + + # Try to decrypt with NO encryption context - should also fail + try: + s3ec.get_object(Bucket=bucket, Key=key) + return { + "thread_id": thread_id, + "success": False, + "error": "Decryption succeeded without encryption context!", + } + except S3EncryptionClientError: + # Expected - decryption should fail without context + pass + + return { + "thread_id": thread_id, + "success": True, + "key": key, + "encryption_context": encryption_context, + } + + except Exception as e: + return {"thread_id": thread_id, "success": False, "error": str(e)} + + # Execute threads concurrently + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(thread_worker, i) for i in range(num_threads)] + + for future in as_completed(futures): + result = future.result() + thread_id = result["thread_id"] + results[thread_id] = result + + if not result["success"]: + errors.append(f"Thread {thread_id}: {result['error']}") + + # Verify all threads succeeded + if errors: + print("Errors occurred during multi-threaded test:") + for error in errors: + print(f" - {error}") + raise RuntimeError(f"{len(errors)} thread(s) failed") + + print(f"Success! All {num_threads} threads completed successfully.") + print("Each thread:") + print(" - Encrypted with its own unique encryption context") + print(" - Decrypted successfully with the correct context") + print(" - Failed to decrypt with wrong context (as expected)") + print(" - Failed to decrypt without context (as expected)") + + +def test_multithreaded_rapid_context_switching(): + """Test rapid switching of encryption contexts in the same thread. + + This test verifies that encryption context is properly cleaned up + between operations within the same thread. + """ + kms_client = boto3.client("kms", region_name=region) + keyring = KmsKeyring(kms_client, kms_key_id) + wrapped_client = boto3.client("s3") + config = S3EncryptionClientConfig(keyring) + s3ec = S3EncryptionClient(wrapped_client, config) + + num_iterations = 20 + errors = [] + + def rapid_context_worker(thread_id): + """Worker that rapidly switches between different encryption contexts.""" + try: + for i in range(num_iterations): + # Alternate between different encryption contexts + if i % 3 == 0: + encryption_context = {"iteration": str(i), "type": "typeA"} + elif i % 3 == 1: + encryption_context = {"iteration": str(i), "type": "typeB"} + else: + encryption_context = {"iteration": str(i), "type": "typeC"} + + key = f"rapid-switch-t{thread_id}-i{i}-{datetime.now().strftime('%Y%m%d-%H%M%S-%f')}" + data = f"Thread {thread_id} iteration {i}" + + # Encrypt + s3ec.put_object( + Bucket=bucket, Key=key, Body=data, EncryptionContext=encryption_context + ) + + # Decrypt with correct context + response = s3ec.get_object( + Bucket=bucket, Key=key, EncryptionContext=encryption_context + ) + decrypted_data = response["Body"].read().decode("utf-8") + + if decrypted_data != data: + return { + "thread_id": thread_id, + "iteration": i, + "success": False, + "error": f"Data mismatch at iteration {i}", + } + + return {"thread_id": thread_id, "success": True} + + except Exception as e: + return {"thread_id": thread_id, "success": False, "error": str(e)} + + # Run multiple threads doing rapid context switching + num_threads = 5 + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(rapid_context_worker, i) for i in range(num_threads)] + + for future in as_completed(futures): + result = future.result() + if not result["success"]: + errors.append(f"Thread {result['thread_id']}: {result.get('error', 'Unknown error')}") + + if errors: + print("Errors occurred during rapid context switching test:") + for error in errors: + print(f" - {error}") + raise RuntimeError(f"{len(errors)} thread(s) failed") + + print(f"Success! {num_threads} threads completed {num_iterations} iterations each.") + print("Encryption context was properly isolated across rapid context switches.") + + +def test_multithreaded_mixed_with_and_without_context(): + """Test threads using encryption context mixed with threads not using it. + + This verifies that threads without encryption context don't interfere + with threads that use encryption context. + """ + kms_client = boto3.client("kms", region_name=region) + keyring = KmsKeyring(kms_client, kms_key_id) + wrapped_client = boto3.client("s3") + config = S3EncryptionClientConfig(keyring) + s3ec = S3EncryptionClient(wrapped_client, config) + + errors = [] + + def worker_with_context(thread_id): + """Worker that uses encryption context.""" + try: + encryption_context = {"thread_id": str(thread_id), "has_context": "true"} + key = f"mixed-with-ctx-{thread_id}-{datetime.now().strftime('%Y%m%d-%H%M%S-%f')}" + data = f"Thread {thread_id} WITH context" + + s3ec.put_object( + Bucket=bucket, Key=key, Body=data, EncryptionContext=encryption_context + ) + + response = s3ec.get_object( + Bucket=bucket, Key=key, EncryptionContext=encryption_context + ) + decrypted_data = response["Body"].read().decode("utf-8") + + if decrypted_data != data: + return {"thread_id": thread_id, "success": False, "error": "Data mismatch"} + + return {"thread_id": thread_id, "success": True, "type": "with_context"} + + except Exception as e: + return {"thread_id": thread_id, "success": False, "error": str(e)} + + def worker_without_context(thread_id): + """Worker that does NOT use encryption context.""" + try: + key = f"mixed-no-ctx-{thread_id}-{datetime.now().strftime('%Y%m%d-%H%M%S-%f')}" + data = f"Thread {thread_id} WITHOUT context" + + # No encryption context + s3ec.put_object(Bucket=bucket, Key=key, Body=data) + + # No encryption context on decrypt either + response = s3ec.get_object(Bucket=bucket, Key=key) + decrypted_data = response["Body"].read().decode("utf-8") + + if decrypted_data != data: + return {"thread_id": thread_id, "success": False, "error": "Data mismatch"} + + return {"thread_id": thread_id, "success": True, "type": "without_context"} + + except Exception as e: + return {"thread_id": thread_id, "success": False, "error": str(e)} + + # Mix threads with and without encryption context + with ThreadPoolExecutor(max_workers=10) as executor: + futures = [] + + # Submit 5 threads with context + for i in range(5): + futures.append(executor.submit(worker_with_context, i)) + + # Submit 5 threads without context + for i in range(5, 10): + futures.append(executor.submit(worker_without_context, i)) + + for future in as_completed(futures): + result = future.result() + if not result["success"]: + errors.append(f"Thread {result['thread_id']}: {result.get('error', 'Unknown error')}") + + if errors: + print("Errors occurred during mixed context test:") + for error in errors: + print(f" - {error}") + raise RuntimeError(f"{len(errors)} thread(s) failed") + + print("Success! Mixed threads (with and without encryption context) completed successfully.") + print("Thread-local storage properly isolated context between threads.") From 3546929fe8e2907b67ef1d4129b04073027b2310 Mon Sep 17 00:00:00 2001 From: Kess Plasmeier Date: Fri, 13 Feb 2026 09:57:54 -0800 Subject: [PATCH 09/10] format --- .../test_i_s3_encryption_multithreaded.py | 34 +++++++++---------- test/test_decryption_materials_integration.py | 2 +- 2 files changed, 17 insertions(+), 19 deletions(-) diff --git a/test/integration/test_i_s3_encryption_multithreaded.py b/test/integration/test_i_s3_encryption_multithreaded.py index 1365e7e7..419ca7ea 100644 --- a/test/integration/test_i_s3_encryption_multithreaded.py +++ b/test/integration/test_i_s3_encryption_multithreaded.py @@ -6,8 +6,8 @@ is properly isolated between threads when using a single S3EncryptionClient instance across multiple threads. """ + import os -import threading from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime @@ -61,14 +61,10 @@ def thread_worker(thread_id): data = f"Thread {thread_id} test data with unique encryption context" # Encrypt with thread-specific encryption context - s3ec.put_object( - Bucket=bucket, Key=key, Body=data, EncryptionContext=encryption_context - ) + s3ec.put_object(Bucket=bucket, Key=key, Body=data, EncryptionContext=encryption_context) # Decrypt with the SAME encryption context - should succeed - response = s3ec.get_object( - Bucket=bucket, Key=key, EncryptionContext=encryption_context - ) + response = s3ec.get_object(Bucket=bucket, Key=key, EncryptionContext=encryption_context) decrypted_data = response["Body"].read().decode("utf-8") if decrypted_data != data: @@ -81,8 +77,8 @@ def thread_worker(thread_id): # Try to decrypt with a DIFFERENT encryption context - should fail wrong_context = { "thread_id": str(thread_id + 1000), - "department": f"wrong-dept", - "project": f"wrong-project", + "department": "wrong-dept", + "project": "wrong-project", } try: @@ -172,7 +168,9 @@ def rapid_context_worker(thread_id): else: encryption_context = {"iteration": str(i), "type": "typeC"} - key = f"rapid-switch-t{thread_id}-i{i}-{datetime.now().strftime('%Y%m%d-%H%M%S-%f')}" + key = ( + f"rapid-switch-t{thread_id}-i{i}-{datetime.now().strftime('%Y%m%d-%H%M%S-%f')}" + ) data = f"Thread {thread_id} iteration {i}" # Encrypt @@ -207,7 +205,9 @@ def rapid_context_worker(thread_id): for future in as_completed(futures): result = future.result() if not result["success"]: - errors.append(f"Thread {result['thread_id']}: {result.get('error', 'Unknown error')}") + errors.append( + f"Thread {result['thread_id']}: {result.get('error', 'Unknown error')}" + ) if errors: print("Errors occurred during rapid context switching test:") @@ -240,13 +240,9 @@ def worker_with_context(thread_id): key = f"mixed-with-ctx-{thread_id}-{datetime.now().strftime('%Y%m%d-%H%M%S-%f')}" data = f"Thread {thread_id} WITH context" - s3ec.put_object( - Bucket=bucket, Key=key, Body=data, EncryptionContext=encryption_context - ) + s3ec.put_object(Bucket=bucket, Key=key, Body=data, EncryptionContext=encryption_context) - response = s3ec.get_object( - Bucket=bucket, Key=key, EncryptionContext=encryption_context - ) + response = s3ec.get_object(Bucket=bucket, Key=key, EncryptionContext=encryption_context) decrypted_data = response["Body"].read().decode("utf-8") if decrypted_data != data: @@ -293,7 +289,9 @@ def worker_without_context(thread_id): for future in as_completed(futures): result = future.result() if not result["success"]: - errors.append(f"Thread {result['thread_id']}: {result.get('error', 'Unknown error')}") + errors.append( + f"Thread {result['thread_id']}: {result.get('error', 'Unknown error')}" + ) if errors: print("Errors occurred during mixed context test:") diff --git a/test/test_decryption_materials_integration.py b/test/test_decryption_materials_integration.py index 06729897..35b7d9e8 100644 --- a/test/test_decryption_materials_integration.py +++ b/test/test_decryption_materials_integration.py @@ -1,7 +1,7 @@ # Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock from src.s3_encryption.materials.crypto_materials_manager import DefaultCryptoMaterialsManager from src.s3_encryption.materials.encrypted_data_key import EncryptedDataKey From 21a68fcfbee7d0fd7955a5d81405ba2066a0c4c9 Mon Sep 17 00:00:00 2001 From: Kess Plasmeier Date: Mon, 16 Feb 2026 16:26:59 -0800 Subject: [PATCH 10/10] PR feedback --- src/s3_encryption/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/s3_encryption/__init__.py b/src/s3_encryption/__init__.py index 8bb1fb62..064096bb 100644 --- a/src/s3_encryption/__init__.py +++ b/src/s3_encryption/__init__.py @@ -179,7 +179,7 @@ def put_object(self, **kwargs): # Re-raise our own exceptions without wrapping raise except Exception as e: - raise S3EncryptionClientError(f"Failed to encryption object: {str(e)}") from e + raise S3EncryptionClientError(f"Failed to encrypt object: {str(e)}") from e finally: # Clean up thread-local storage if hasattr(self._plugin._context, "encryption_context"):