diff --git a/Makefile b/Makefile index 814ab334..7db980c3 100644 --- a/Makefile +++ b/Makefile @@ -10,14 +10,14 @@ install: # Run linting checks lint: - uv run black --check . + uv run black --check src/ test/ # Enforce ruff checks on src/ but allow test/ to fail uv run ruff check src/ uv run ruff check test/ || true # Format code with Black and Ruff format: - uv run black . + uv run black src/ test/ uv run ruff check --fix src/ test/ # Run all tests diff --git a/pyproject.toml b/pyproject.toml index b05f22aa..b7489a03 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ exclude = [".git", "__pycache__", "build", "dist"] [tool.ruff.lint] # Enable all rules by default, then configure specific rule settings below select = ["E", "F", "W", "I", "N", "D", "UP", "B", "A", "C4", "PT", "RET", "SIM", "ARG", "ERA"] +ignore = ["ARG002"] # Allow unused method arguments (e.g., **kwargs for API compatibility) [tool.ruff.lint.pydocstyle] convention = "google" diff --git a/src/s3_encryption/__init__.py b/src/s3_encryption/__init__.py index 46cdbdd1..064096bb 100644 --- a/src/s3_encryption/__init__.py +++ b/src/s3_encryption/__init__.py @@ -3,9 +3,9 @@ """Top-level S3 Encryption Client v3 for Python package.""" import io +import threading from attrs import define, field -from botocore import serialize from botocore.response import StreamingBody from .exceptions import S3EncryptionClientError @@ -16,7 +16,7 @@ from .materials.keyring import AbstractKeyring from .pipelines import GetEncryptedObjectPipeline, PutEncryptedObjectPipeline -DEFAULT_ENCODING = "utf-8" +S3_METADATA_PREFIX = "x-amz-meta-" @define @@ -31,31 +31,122 @@ def _default_cmm_for_keyring(self): return DefaultCryptoMaterialsManager(self.keyring) +class S3EncryptionClientPlugin: + """Plugin that adds encryption/decryption capabilities to a boto3 S3 client. + + This plugin uses boto3's event system to intercept put_object and get_object + calls to provide transparent encryption and decryption of S3 objects. + """ + + def __init__(self, config: S3EncryptionClientConfig): + """Initialize the plugin with encryption configuration. + + Args: + config: S3EncryptionClientConfig containing keyring and CMM + """ + self.config = config + self._context = threading.local() + + def on_put_object_before_call(self, params, **kwargs): + """Event handler for before-call.s3.PutObject. + + This handler encrypts the body after serialization but before the request is sent. + + Args: + params: Dictionary of parameters for the PutObject call (after serialization) + **kwargs: Additional event arguments + """ + # At this point, boto3 has already serialized the Body + # Extract the serialized body from the request + body = params.get("body") + if body is None: + body_bytes = b"" + elif isinstance(body, bytes): + body_bytes = body + elif hasattr(body, "read"): + # It's a file-like object (BytesIO, etc.) + # TODO(streaming): Add support for streaming encryption without reading entire body + # into memory + body_bytes = body.read() + else: + # Unexpected body type - should not happen as boto3 validates before this point + raise S3EncryptionClientError("Unexpected type of body parameter!") + + encryption_context = getattr(self._context, "encryption_context", None) + + pipeline = PutEncryptedObjectPipeline(self.config.cmm) + encrypted_data, encryption_metadata = pipeline.encrypt( + body_bytes, encryption_context=encryption_context + ) + + params["body"] = encrypted_data + + headers = params.get("headers", {}) + + # Add encryption metadata to headers + if encryption_metadata: + for key, value in encryption_metadata.items(): + # Add as S3 metadata headers + header_key = f"{S3_METADATA_PREFIX}{key}" + headers[header_key] = value + + params["headers"] = headers + + def on_get_object_after_call(self, parsed, **kwargs): + """Event handler for after-call.s3.GetObject. + + This handler decrypts the body after the response is received from S3. + + Args: + parsed: Dictionary containing the parsed response + **kwargs: Additional event arguments (includes 'params' with request parameters) + """ + # Get encryption context from thread-local storage (set by get_object wrapper) + encryption_context = getattr(self._context, "encryption_context", None) + + # The parsed response already has the Body as a StreamingBody + # We need to read it, decrypt it, and replace it + + # Create a response dict that matches what the pipeline expects + response = { + "Body": parsed.get("Body"), + "Metadata": parsed.get("Metadata", {}), + } + + # Create a pipeline and decrypt the data + pipeline = GetEncryptedObjectPipeline(self.config.cmm) + decrypted_data = pipeline.decrypt(response, encryption_context) + + # Replace body with decrypted data + stream = io.BytesIO(decrypted_data) + streaming_body = StreamingBody(stream, len(decrypted_data)) + parsed["Body"] = streaming_body + + @define class S3EncryptionClient: """Client for encrypting and decrypting S3 objects. This client wraps a boto3 S3 client and provides encryption and decryption capabilities for S3 objects using the configured keyring and crypto materials manager. + + The encryption/decryption is implemented using boto3's event system, registering + handlers for before-call and after-call events. """ wrapped_s3_client = field() config: S3EncryptionClientConfig = field() + _plugin: S3EncryptionClientPlugin = field(init=False) def __attrs_post_init__(self): - """Validate serialization encoding after initialization. + """Install the encryption plugin on the wrapped client using boto3 events.""" + # Create the plugin + object.__setattr__(self, "_plugin", S3EncryptionClientPlugin(self.config)) - Ensures boto3 serializers are using the expected default encoding. - """ - # Sanity check that boto3 serialization are ONLY using the default encoding (utf-8) - # This should always be the case, but changes in encoding would break the assumption that - # the decrypted plaintext adheres to the non-utf8 encoding scheme. So we avoid that. - for sz_name, sz in serialize.SERIALIZERS.items(): - if sz.DEFAULT_ENCODING != DEFAULT_ENCODING: - raise S3EncryptionClientError( - f"All Serializers MUST only support utf-8 encoding, but {sz_name} is using " - f"{sz.DEFAULT_ENCODING}!" - ) + # Register event handlers using boto3's event system + event_system = self.wrapped_s3_client.meta.events + event_system.register("before-call.s3.PutObject", self._plugin.on_put_object_before_call) + event_system.register("after-call.s3.GetObject", self._plugin.on_get_object_after_call) def put_object(self, **kwargs): """Encrypt and upload an object to S3. @@ -71,52 +162,28 @@ def put_object(self, **kwargs): Returns: The response from the S3 client's put_object method. + + Raises: + S3EncryptionClientError: Any problem with encryption, including if the Body parameter + has an invalid type. """ - # Extract required parameters from kwargs - bucket = kwargs.pop("Bucket") - key = kwargs.pop("Key") - body = kwargs.pop("Body", b"") # Default to empty bytes when Body is not provided + # Extract EncryptionContext if provided (not a standard S3 parameter) encryption_context = kwargs.pop("EncryptionContext", None) - # Create a pipeline for this operation - pipeline = PutEncryptedObjectPipeline(self.config.cmm) - - # The documentation for boto3 asks for bytes or a file-like object, - # but in reality, it is possible to pass strings. - # Strings will be encoded using DEFAULT_ENCODING, - # which MUST match the default encoding defined int the Serializer class in botocore. - if isinstance(body, str): - data_bytes = body.encode(DEFAULT_ENCODING) - elif isinstance(body, bytes): - data_bytes = body - elif isinstance(body, io.IOBase): - # TODO: Streaming support - raise S3EncryptionClientError( - f"Body parameter of type {type(body)} is not an acceptable type! " - f"Streaming operations are not yet supported." - ) - else: - raise S3EncryptionClientError( - f"Body parameter of type {type(body)} is not an acceptable type! " - f"Use bytes or a file-like object." - ) - - # Now encrypt the bytes/file-like IOBase object - encrypted_data, encryption_metadata = pipeline.encrypt( - data_bytes, encryption_context=encryption_context - ) - - # Add encryption metadata to the request parameters - params = {"Bucket": bucket, "Key": key, "Body": encrypted_data, **kwargs} - - # Add encryption metadata to the parameters - if encryption_metadata: - # Merge any existing metadata with our encryption metadata - metadata = params.get("Metadata", {}) - metadata.update(encryption_metadata) - params["Metadata"] = metadata - - return self.wrapped_s3_client.put_object(**params) + # Store encryption context in thread-local storage for the event handler + self._plugin._context.encryption_context = encryption_context + + try: + return self.wrapped_s3_client.put_object(**kwargs) + except S3EncryptionClientError: + # Re-raise our own exceptions without wrapping + raise + except Exception as e: + raise S3EncryptionClientError(f"Failed to encrypt object: {str(e)}") from e + finally: + # Clean up thread-local storage + if hasattr(self._plugin._context, "encryption_context"): + delattr(self._plugin._context, "encryption_context") def get_object(self, **kwargs): """Download and decrypt an object from S3. @@ -131,29 +198,25 @@ def get_object(self, **kwargs): Returns: The response from the S3 client's get_object method with the Body replaced with a StreamingBody containing the decrypted data. + + Raises: + S3EncryptionClientError: If decryption fails or the object is not properly encrypted. """ - # Extract encryption context if provided + # Extract EncryptionContext if provided (not a standard S3 parameter) encryption_context = kwargs.pop("EncryptionContext", None) - # Create params for the S3 client - params = {**kwargs} - - # Get the encrypted object from S3 - response = self.wrapped_s3_client.get_object(**params) - - # Create a pipeline for this operation - pipeline = GetEncryptedObjectPipeline(self.config.cmm) - - # Decrypt the data using the pipeline - decrypted_data = pipeline.decrypt( - response, encryption_context - ) # encrypted_data, encryption_metadata) - - # Create a new streaming body with the decrypted data - stream = io.BytesIO(decrypted_data) - streaming_body = StreamingBody(stream, len(decrypted_data)) - - # Update the response with the decrypted data - response["Body"] = streaming_body - - return response + # Store encryption context in thread-local storage for the event handler + self._plugin._context.encryption_context = encryption_context + + try: + return self.wrapped_s3_client.get_object(**kwargs) + except S3EncryptionClientError: + # Re-raise our own exceptions without wrapping + raise + except Exception as e: + # Wrap any unexpected errors during decryption + raise S3EncryptionClientError(f"Failed to decrypt object: {str(e)}") from e + finally: + # Clean up thread-local storage + if hasattr(self._plugin._context, "encryption_context"): + delattr(self._plugin._context, "encryption_context") diff --git a/src/s3_encryption/materials/kms_keyring.py b/src/s3_encryption/materials/kms_keyring.py index 7bc8f7bd..f8bc4997 100644 --- a/src/s3_encryption/materials/kms_keyring.py +++ b/src/s3_encryption/materials/kms_keyring.py @@ -46,7 +46,6 @@ def on_encrypt(self, enc_materials): # Call parent class validation enc_materials = super().on_encrypt(enc_materials) - # Add default encryption context encryption_context = enc_materials.encryption_context encryption_context["aws:x-amz-cek-alg"] = "AES/GCM/NoPadding" @@ -111,6 +110,7 @@ def on_decrypt(self, dec_materials, encrypted_data_keys=None): encryption_context_stored_copy = encryption_context_stored.copy() encryption_context_stored_copy.pop(KMS_V1_DEFAULT_KEY, None) encryption_context_stored_copy.pop(KMS_CONTEXT_DEFAULT_KEY, None) + if encryption_context_stored_copy != encryption_context_from_request: # TODO: modeled error raise S3EncryptionClientError( diff --git a/src/s3_encryption/pipelines.py b/src/s3_encryption/pipelines.py index 37093803..6867ed7c 100644 --- a/src/s3_encryption/pipelines.py +++ b/src/s3_encryption/pipelines.py @@ -39,9 +39,9 @@ def encrypt(self, plaintext, encryption_context=None): bytes: The encrypted data dict: Metadata about the encryption to be stored with the object """ - # Create encryption materials request with encryption context + # Create encryption materials request with encryption context copy enc_mats_request = EncryptionMaterials( - encryption_context={} if encryption_context is None else encryption_context + encryption_context={} if encryption_context is None else encryption_context.copy() ) # Get encryption materials from the crypto materials manager @@ -102,6 +102,7 @@ def decrypt(self, response, encryption_context=None): bytes: The decrypted data """ # Convert the metadata dictionary to an ObjectMetadata instance + # TODO: Stream + Buffered Decryption encrypted_data = response.get("Body").read() encryption_metadata = response.get("Metadata", {}) metadata = ObjectMetadata.from_dict(encryption_metadata) diff --git a/test/integration/test_i_s3_encryption.py b/test/integration/test_i_s3_encryption.py index 2c8ea73a..616f8da4 100644 --- a/test/integration/test_i_s3_encryption.py +++ b/test/integration/test_i_s3_encryption.py @@ -216,7 +216,7 @@ def test_binary_data_roundtrip(): key += datetime.now().strftime("%Y-%m-%d-%H:%M:%S") # Create some binary data (not valid in any particular encoding) - data = bytes([i for i in range(256)]) + data = bytes(range(256)) kms_client = boto3.client("kms", region_name=region) @@ -259,31 +259,214 @@ def test_invalid_body_types(): # Test with integer with pytest.raises(S3EncryptionClientError) as excinfo: s3ec.put_object(Bucket=bucket, Key=key, Body=42) - assert "not an acceptable type" in str(excinfo.value) + assert "Invalid type for parameter Body" in str(excinfo.value) # Test with float with pytest.raises(S3EncryptionClientError) as excinfo: s3ec.put_object(Bucket=bucket, Key=key, Body=3.14) - assert "not an acceptable type" in str(excinfo.value) + assert "Invalid type for parameter Body" in str(excinfo.value) # Test with list with pytest.raises(S3EncryptionClientError) as excinfo: s3ec.put_object(Bucket=bucket, Key=key, Body=[1, 2, 3]) - assert "not an acceptable type" in str(excinfo.value) + assert "Invalid type for parameter Body" in str(excinfo.value) # Test with dictionary with pytest.raises(S3EncryptionClientError) as excinfo: s3ec.put_object(Bucket=bucket, Key=key, Body={"key": "value"}) - assert "not an acceptable type" in str(excinfo.value) + assert "Invalid type for parameter Body" in str(excinfo.value) # Test with boolean with pytest.raises(S3EncryptionClientError) as excinfo: s3ec.put_object(Bucket=bucket, Key=key, Body=True) - assert "not an acceptable type" in str(excinfo.value) + assert "Invalid type for parameter Body" in str(excinfo.value) # Test with None (also raises an exception) with pytest.raises(S3EncryptionClientError) as excinfo: s3ec.put_object(Bucket=bucket, Key=key, Body=None) - assert "not an acceptable type" in str(excinfo.value) + assert "Invalid type for parameter Body" in str(excinfo.value) print("Success! All invalid body types correctly raised exceptions.") + + +def test_user_metadata_preservation(): + """Test that user-provided metadata is preserved during encryption.""" + key = "metadata-preservation-rt" + key += datetime.now().strftime("%Y-%m-%d-%H:%M:%S") + + data = "Test data with user metadata" + + # User metadata to include + user_metadata = { + "author": "test-user", + "version": "1.0", + "description": "Test object with custom metadata", + } + + kms_client = boto3.client("kms", region_name=region) + keyring = KmsKeyring(kms_client, kms_key_id) + wrapped_client = boto3.client("s3") + config = S3EncryptionClientConfig(keyring) + s3ec = S3EncryptionClient(wrapped_client, config) + + # Put object with user metadata + s3ec.put_object(Bucket=bucket, Key=key, Body=data, Metadata=user_metadata) + + # Get the object back + get_req = {"Bucket": bucket, "Key": key} + response = s3ec.get_object(**get_req) + + # Verify the data decrypts correctly + output = response["Body"].read().decode("utf-8") + if output != data: + print("Uh oh! Input and output don't match!") + print("Input:") + print(repr(data)) + print("Output:") + print(repr(output)) + raise RuntimeError + + # Verify user metadata is preserved + returned_metadata = response.get("Metadata", {}) + + for key_name, expected_value in user_metadata.items(): + if key_name not in returned_metadata: + print(f"Uh oh! User metadata key '{key_name}' is missing!") + print("Expected metadata:") + print(user_metadata) + print("Returned metadata:") + print(returned_metadata) + raise RuntimeError + + if returned_metadata[key_name] != expected_value: + print(f"Uh oh! User metadata value for '{key_name}' doesn't match!") + print(f"Expected: {expected_value}") + print(f"Got: {returned_metadata[key_name]}") + raise RuntimeError + + print("Success! User metadata preserved correctly during encryption/decryption.") + print(f"User metadata: {user_metadata}") + print(f"Returned metadata keys: {list(returned_metadata.keys())}") + + +def test_encryption_context_roundtrip(): + """Test that EncryptionContext is properly used during encryption and required for decryption.""" + key = "encryption-context-rt" + key += datetime.now().strftime("%Y-%m-%d-%H:%M:%S") + + data = "Test data with encryption context" + + # Encryption context to use for additional authenticated data + encryption_context = { + "department": "engineering", + "project": "s3-encryption", + "environment": "test", + } + + kms_client = boto3.client("kms", region_name=region) + keyring = KmsKeyring(kms_client, kms_key_id) + wrapped_client = boto3.client("s3") + config = S3EncryptionClientConfig(keyring) + s3ec = S3EncryptionClient(wrapped_client, config) + + # Put object with encryption context + s3ec.put_object(Bucket=bucket, Key=key, Body=data, EncryptionContext=encryption_context) + + # Get the object back WITH the same encryption context + get_req = {"Bucket": bucket, "Key": key, "EncryptionContext": encryption_context} + response = s3ec.get_object(**get_req) + + # Verify the data decrypts correctly + output = response["Body"].read().decode("utf-8") + if output != data: + print("Uh oh! Input and output don't match!") + print("Input:") + print(repr(data)) + print("Output:") + print(repr(output)) + raise RuntimeError + + print("Success! Encryption context used correctly during encryption/decryption.") + print(f"Encryption context: {encryption_context}") + + +def test_encryption_context_mismatch(): + """Test that decryption fails when EncryptionContext doesn't match.""" + key = "encryption-context-mismatch" + key += datetime.now().strftime("%Y-%m-%d-%H:%M:%S") + + data = "Test data with encryption context" + + # Original encryption context + encryption_context = {"department": "engineering", "project": "s3-encryption"} + + # Wrong encryption context for decryption + wrong_encryption_context = {"department": "marketing", "project": "s3-encryption"} + + kms_client = boto3.client("kms", region_name=region) + keyring = KmsKeyring(kms_client, kms_key_id) + wrapped_client = boto3.client("s3") + config = S3EncryptionClientConfig(keyring) + s3ec = S3EncryptionClient(wrapped_client, config) + + # Put object with encryption context + s3ec.put_object(Bucket=bucket, Key=key, Body=data, EncryptionContext=encryption_context) + + # Try to get the object back with WRONG encryption context - should fail + get_req = {"Bucket": bucket, "Key": key, "EncryptionContext": wrong_encryption_context} + + try: + s3ec.get_object(**get_req) + # If we get here, the test failed - decryption should have failed + print("Uh oh! Decryption succeeded with wrong encryption context!") + print(f"Original context: {encryption_context}") + print(f"Wrong context used: {wrong_encryption_context}") + raise RuntimeError("Expected decryption to fail with mismatched encryption context") + except S3EncryptionClientError as e: + # This is expected - decryption should fail + print("Success! Decryption correctly failed with mismatched encryption context.") + print(f"Error message: {str(e)}") + except Exception as e: + # Some other error occurred + print(f"Unexpected error type: {type(e).__name__}") + print(f"Error message: {str(e)}") + raise + + +def test_encryption_context_missing_on_decrypt(): + """Test that decryption fails when encryption context is not provided for an object encrypted with context.""" + key = "encryption-context-missing" + key += datetime.now().strftime("%Y-%m-%d-%H:%M:%S") + + data = "Test data with encryption context" + + # Encryption context used during encryption + encryption_context = {"department": "engineering", "project": "s3-encryption"} + + kms_client = boto3.client("kms", region_name=region) + keyring = KmsKeyring(kms_client, kms_key_id) + wrapped_client = boto3.client("s3") + config = S3EncryptionClientConfig(keyring) + s3ec = S3EncryptionClient(wrapped_client, config) + + # Put object with encryption context + s3ec.put_object(Bucket=bucket, Key=key, Body=data, EncryptionContext=encryption_context) + + # Try to get the object back WITHOUT encryption context - should fail + get_req = {"Bucket": bucket, "Key": key} + + try: + s3ec.get_object(**get_req) + # If we get here, the test failed - decryption should have failed + print("Uh oh! Decryption succeeded without providing required encryption context!") + print(f"Original context: {encryption_context}") + raise RuntimeError("Expected decryption to fail when encryption context not provided") + except S3EncryptionClientError as e: + # This is expected - decryption should fail + print("Success! Decryption correctly failed when encryption context was not provided.") + print(f"Error message: {str(e)}") + except Exception as e: + # Some other error occurred + print(f"Unexpected error type: {type(e).__name__}") + print(f"Error message: {str(e)}") + raise diff --git a/test/integration/test_i_s3_encryption_multithreaded.py b/test/integration/test_i_s3_encryption_multithreaded.py new file mode 100644 index 00000000..419ca7ea --- /dev/null +++ b/test/integration/test_i_s3_encryption_multithreaded.py @@ -0,0 +1,303 @@ +# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Multi-threaded integration tests for S3 Encryption Client. + +These tests verify that the thread-local storage of encryption context +is properly isolated between threads when using a single S3EncryptionClient +instance across multiple threads. +""" + +import os +from concurrent.futures import ThreadPoolExecutor, as_completed +from datetime import datetime + +import boto3 + +from s3_encryption import S3EncryptionClient, S3EncryptionClientConfig +from s3_encryption.exceptions import S3EncryptionClientError +from s3_encryption.materials.kms_keyring import KmsKeyring + +bucket = os.environ.get("CI_S3_BUCKET", "s3ec-python-github-test-bucket") +region = os.environ.get("CI_AWS_REGION", "us-west-2") +kms_key_id = os.environ.get( + "CI_KMS_KEY_ALIAS", "arn:aws:kms:us-west-2:370957321024:alias/S3EC-Python-Github-KMS-Key" +) + + +def test_multithreaded_encryption_context_isolation(): + """Test that encryption context is properly isolated between threads. + + This test creates a single S3EncryptionClient instance and uses it + from multiple threads simultaneously, each with a different encryption + context. It verifies that: + 1. Each thread can encrypt with its own encryption context + 2. Each thread can decrypt only with the correct encryption context + 3. Thread-local storage doesn't leak between threads + """ + # Create a single S3EncryptionClient instance to be shared across threads + kms_client = boto3.client("kms", region_name=region) + keyring = KmsKeyring(kms_client, kms_key_id) + wrapped_client = boto3.client("s3") + config = S3EncryptionClientConfig(keyring) + s3ec = S3EncryptionClient(wrapped_client, config) + + # Number of threads to test with + num_threads = 10 + results = {} + errors = [] + + def thread_worker(thread_id): + """Worker function for each thread.""" + try: + # Each thread has its own unique encryption context + encryption_context = { + "thread_id": str(thread_id), + "department": f"dept-{thread_id}", + "project": f"project-{thread_id}", + } + + # Unique key for this thread + key = f"multithread-test-{thread_id}-{datetime.now().strftime('%Y%m%d-%H%M%S-%f')}" + data = f"Thread {thread_id} test data with unique encryption context" + + # Encrypt with thread-specific encryption context + s3ec.put_object(Bucket=bucket, Key=key, Body=data, EncryptionContext=encryption_context) + + # Decrypt with the SAME encryption context - should succeed + response = s3ec.get_object(Bucket=bucket, Key=key, EncryptionContext=encryption_context) + decrypted_data = response["Body"].read().decode("utf-8") + + if decrypted_data != data: + return { + "thread_id": thread_id, + "success": False, + "error": f"Data mismatch: expected '{data}', got '{decrypted_data}'", + } + + # Try to decrypt with a DIFFERENT encryption context - should fail + wrong_context = { + "thread_id": str(thread_id + 1000), + "department": "wrong-dept", + "project": "wrong-project", + } + + try: + s3ec.get_object(Bucket=bucket, Key=key, EncryptionContext=wrong_context) + return { + "thread_id": thread_id, + "success": False, + "error": "Decryption succeeded with wrong encryption context!", + } + except S3EncryptionClientError: + # Expected - decryption should fail with wrong context + pass + + # Try to decrypt with NO encryption context - should also fail + try: + s3ec.get_object(Bucket=bucket, Key=key) + return { + "thread_id": thread_id, + "success": False, + "error": "Decryption succeeded without encryption context!", + } + except S3EncryptionClientError: + # Expected - decryption should fail without context + pass + + return { + "thread_id": thread_id, + "success": True, + "key": key, + "encryption_context": encryption_context, + } + + except Exception as e: + return {"thread_id": thread_id, "success": False, "error": str(e)} + + # Execute threads concurrently + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(thread_worker, i) for i in range(num_threads)] + + for future in as_completed(futures): + result = future.result() + thread_id = result["thread_id"] + results[thread_id] = result + + if not result["success"]: + errors.append(f"Thread {thread_id}: {result['error']}") + + # Verify all threads succeeded + if errors: + print("Errors occurred during multi-threaded test:") + for error in errors: + print(f" - {error}") + raise RuntimeError(f"{len(errors)} thread(s) failed") + + print(f"Success! All {num_threads} threads completed successfully.") + print("Each thread:") + print(" - Encrypted with its own unique encryption context") + print(" - Decrypted successfully with the correct context") + print(" - Failed to decrypt with wrong context (as expected)") + print(" - Failed to decrypt without context (as expected)") + + +def test_multithreaded_rapid_context_switching(): + """Test rapid switching of encryption contexts in the same thread. + + This test verifies that encryption context is properly cleaned up + between operations within the same thread. + """ + kms_client = boto3.client("kms", region_name=region) + keyring = KmsKeyring(kms_client, kms_key_id) + wrapped_client = boto3.client("s3") + config = S3EncryptionClientConfig(keyring) + s3ec = S3EncryptionClient(wrapped_client, config) + + num_iterations = 20 + errors = [] + + def rapid_context_worker(thread_id): + """Worker that rapidly switches between different encryption contexts.""" + try: + for i in range(num_iterations): + # Alternate between different encryption contexts + if i % 3 == 0: + encryption_context = {"iteration": str(i), "type": "typeA"} + elif i % 3 == 1: + encryption_context = {"iteration": str(i), "type": "typeB"} + else: + encryption_context = {"iteration": str(i), "type": "typeC"} + + key = ( + f"rapid-switch-t{thread_id}-i{i}-{datetime.now().strftime('%Y%m%d-%H%M%S-%f')}" + ) + data = f"Thread {thread_id} iteration {i}" + + # Encrypt + s3ec.put_object( + Bucket=bucket, Key=key, Body=data, EncryptionContext=encryption_context + ) + + # Decrypt with correct context + response = s3ec.get_object( + Bucket=bucket, Key=key, EncryptionContext=encryption_context + ) + decrypted_data = response["Body"].read().decode("utf-8") + + if decrypted_data != data: + return { + "thread_id": thread_id, + "iteration": i, + "success": False, + "error": f"Data mismatch at iteration {i}", + } + + return {"thread_id": thread_id, "success": True} + + except Exception as e: + return {"thread_id": thread_id, "success": False, "error": str(e)} + + # Run multiple threads doing rapid context switching + num_threads = 5 + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(rapid_context_worker, i) for i in range(num_threads)] + + for future in as_completed(futures): + result = future.result() + if not result["success"]: + errors.append( + f"Thread {result['thread_id']}: {result.get('error', 'Unknown error')}" + ) + + if errors: + print("Errors occurred during rapid context switching test:") + for error in errors: + print(f" - {error}") + raise RuntimeError(f"{len(errors)} thread(s) failed") + + print(f"Success! {num_threads} threads completed {num_iterations} iterations each.") + print("Encryption context was properly isolated across rapid context switches.") + + +def test_multithreaded_mixed_with_and_without_context(): + """Test threads using encryption context mixed with threads not using it. + + This verifies that threads without encryption context don't interfere + with threads that use encryption context. + """ + kms_client = boto3.client("kms", region_name=region) + keyring = KmsKeyring(kms_client, kms_key_id) + wrapped_client = boto3.client("s3") + config = S3EncryptionClientConfig(keyring) + s3ec = S3EncryptionClient(wrapped_client, config) + + errors = [] + + def worker_with_context(thread_id): + """Worker that uses encryption context.""" + try: + encryption_context = {"thread_id": str(thread_id), "has_context": "true"} + key = f"mixed-with-ctx-{thread_id}-{datetime.now().strftime('%Y%m%d-%H%M%S-%f')}" + data = f"Thread {thread_id} WITH context" + + s3ec.put_object(Bucket=bucket, Key=key, Body=data, EncryptionContext=encryption_context) + + response = s3ec.get_object(Bucket=bucket, Key=key, EncryptionContext=encryption_context) + decrypted_data = response["Body"].read().decode("utf-8") + + if decrypted_data != data: + return {"thread_id": thread_id, "success": False, "error": "Data mismatch"} + + return {"thread_id": thread_id, "success": True, "type": "with_context"} + + except Exception as e: + return {"thread_id": thread_id, "success": False, "error": str(e)} + + def worker_without_context(thread_id): + """Worker that does NOT use encryption context.""" + try: + key = f"mixed-no-ctx-{thread_id}-{datetime.now().strftime('%Y%m%d-%H%M%S-%f')}" + data = f"Thread {thread_id} WITHOUT context" + + # No encryption context + s3ec.put_object(Bucket=bucket, Key=key, Body=data) + + # No encryption context on decrypt either + response = s3ec.get_object(Bucket=bucket, Key=key) + decrypted_data = response["Body"].read().decode("utf-8") + + if decrypted_data != data: + return {"thread_id": thread_id, "success": False, "error": "Data mismatch"} + + return {"thread_id": thread_id, "success": True, "type": "without_context"} + + except Exception as e: + return {"thread_id": thread_id, "success": False, "error": str(e)} + + # Mix threads with and without encryption context + with ThreadPoolExecutor(max_workers=10) as executor: + futures = [] + + # Submit 5 threads with context + for i in range(5): + futures.append(executor.submit(worker_with_context, i)) + + # Submit 5 threads without context + for i in range(5, 10): + futures.append(executor.submit(worker_without_context, i)) + + for future in as_completed(futures): + result = future.result() + if not result["success"]: + errors.append( + f"Thread {result['thread_id']}: {result.get('error', 'Unknown error')}" + ) + + if errors: + print("Errors occurred during mixed context test:") + for error in errors: + print(f" - {error}") + raise RuntimeError(f"{len(errors)} thread(s) failed") + + print("Success! Mixed threads (with and without encryption context) completed successfully.") + print("Thread-local storage properly isolated context between threads.") diff --git a/test/test_decryption_materials_integration.py b/test/test_decryption_materials_integration.py index 1cfab083..35b7d9e8 100644 --- a/test/test_decryption_materials_integration.py +++ b/test/test_decryption_materials_integration.py @@ -1,7 +1,7 @@ # Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock from src.s3_encryption.materials.crypto_materials_manager import DefaultCryptoMaterialsManager from src.s3_encryption.materials.encrypted_data_key import EncryptedDataKey @@ -30,17 +30,15 @@ def test_keyring_on_decrypt(self): encryption_context_from_request={"key2": "value2"}, ) - # Mock the validation method to return the materials - with patch.object(S3Keyring, "on_decrypt", return_value=materials) as mock_on_decrypt: - # Call on_decrypt - result = keyring.on_decrypt(materials, [edk]) + # Call on_decrypt + result = keyring.on_decrypt(materials, [edk]) - # Verify the result is a DecryptionMaterials instance - assert isinstance(result, DecryptionMaterials) - assert result.iv == b"initialization-vector" - assert result.encrypted_data_keys == [edk] - assert result.encryption_context_stored == {"key1": "value1"} - assert result.encryption_context_from_request == {"key2": "value2"} + # Verify the result is a DecryptionMaterials instance + assert isinstance(result, DecryptionMaterials) + assert result.iv == b"initialization-vector" + assert result.encrypted_data_keys == [edk] + assert result.encryption_context_stored == {"key1": "value1"} + assert result.encryption_context_from_request == {"key2": "value2"} def test_keyring_on_decrypt_default_enc_ctx(self): """Test that S3Keyring.on_decrypt properly handles DecryptionMaterials.""" @@ -63,16 +61,15 @@ def test_keyring_on_decrypt_default_enc_ctx(self): ) # Mock the validation method to return the materials - with patch.object(S3Keyring, "on_decrypt", return_value=materials) as mock_on_decrypt: - # Call on_decrypt - result = keyring.on_decrypt(materials, [edk]) - - # Verify the result is a DecryptionMaterials instance - assert isinstance(result, DecryptionMaterials) - assert result.iv == b"initialization-vector" - assert result.encrypted_data_keys == [edk] - assert result.encryption_context_stored == {} - assert result.encryption_context_from_request == {} + # Call on_decrypt + result = keyring.on_decrypt(materials, [edk]) + + # Verify the result is a DecryptionMaterials instance + assert isinstance(result, DecryptionMaterials) + assert result.iv == b"initialization-vector" + assert result.encrypted_data_keys == [edk] + assert result.encryption_context_stored == {} + assert result.encryption_context_from_request == {} def test_cmm_decrypt_materials_with_dict(self): """Test that DefaultCryptoMaterialsManager.decrypt_materials properly handles dictionary input."""