diff --git a/src/s3_encryption/__init__.py b/src/s3_encryption/__init__.py index a8ab239b..6ee409ad 100644 --- a/src/s3_encryption/__init__.py +++ b/src/s3_encryption/__init__.py @@ -9,6 +9,7 @@ from botocore.exceptions import ClientError from botocore.response import StreamingBody +from ._utils import safe_get_dict from .exceptions import S3EncryptionClientError from .instruction_file import parse_instruction_file from .instruction_file_config import InstructionFileConfig @@ -198,7 +199,7 @@ def on_put_object_before_call(self, params, **kwargs): params["body"] = encrypted_data - headers = params.get("headers", {}) + headers = safe_get_dict(params, "headers") # Add encryption metadata to headers if encryption_metadata: @@ -244,7 +245,7 @@ def on_get_object_after_call(self, parsed, **kwargs): # Create a response dict that matches what the pipeline expects response = { "Body": parsed.get("Body"), - "Metadata": parsed.get("Metadata", {}), + "Metadata": safe_get_dict(parsed, "Metadata"), "ContentLength": content_length, } @@ -286,8 +287,7 @@ def process_instruction_file(self, parsed): ) # In plaintext mode, parse instruction file and append to metadata - # Metadata may be present but None, so `or {}` handles that case - existing_metadata = parsed.get("Metadata", {}) or {} + existing_metadata = safe_get_dict(parsed, "Metadata") instruction_data = body.read() instruction_metadata = parse_instruction_file(instruction_data, instruction_key) diff --git a/src/s3_encryption/_utils.py b/src/s3_encryption/_utils.py new file mode 100644 index 00000000..4997b973 --- /dev/null +++ b/src/s3_encryption/_utils.py @@ -0,0 +1,12 @@ +# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Internal utility helpers for the S3 Encryption Client.""" + + +def safe_get_dict(source: dict, key: str) -> dict: + """Get a dict value from *source*, defaulting to {} if missing or None. + + This avoids the common pitfall where ``d.get(key, {})`` returns None + when the key exists but its value is explicitly None. + """ + return source.get(key, {}) or {} diff --git a/src/s3_encryption/instruction_file.py b/src/s3_encryption/instruction_file.py index 60305d17..61f9b167 100644 --- a/src/s3_encryption/instruction_file.py +++ b/src/s3_encryption/instruction_file.py @@ -11,6 +11,7 @@ from botocore.exceptions import ClientError +from ._utils import safe_get_dict from .exceptions import S3EncryptionClientError from .metadata import VALID_S3EC_METADATA_KEYS @@ -109,7 +110,7 @@ def fetch_instruction_file(s3_client, bucket: str, key: str) -> dict[str, Any]: s3_client._s3ec_plugin_context.instruction_file_mode = False # In plaintext mode, the event handler places parsed metadata in Metadata field - metadata = response.get("Metadata", {}) + metadata = safe_get_dict(response, "Metadata") # Verify metadata is not empty if not metadata: diff --git a/src/s3_encryption/materials/crypto_materials_manager.py b/src/s3_encryption/materials/crypto_materials_manager.py index 82eab454..6a7dd3e8 100644 --- a/src/s3_encryption/materials/crypto_materials_manager.py +++ b/src/s3_encryption/materials/crypto_materials_manager.py @@ -10,6 +10,7 @@ from attrs import define +from .._utils import safe_get_dict from .keyring import AbstractKeyring from .materials import DecryptionMaterials, EncryptionMaterials @@ -74,7 +75,7 @@ def get_encryption_materials(self, enc_mats_request): # Convert dictionary to EncryptionMaterials if needed if isinstance(enc_mats_request, dict): materials = EncryptionMaterials( - encryption_context=enc_mats_request.get("encryption_context", {}) + encryption_context=safe_get_dict(enc_mats_request, "encryption_context") ) else: materials = enc_mats_request diff --git a/src/s3_encryption/materials/materials.py b/src/s3_encryption/materials/materials.py index 80f682f0..4f91330f 100644 --- a/src/s3_encryption/materials/materials.py +++ b/src/s3_encryption/materials/materials.py @@ -12,6 +12,7 @@ from attrs import define, field +from .._utils import safe_get_dict from .encrypted_data_key import EncryptedDataKey @@ -232,7 +233,7 @@ def from_dict(cls, materials_dict: dict[str, Any]) -> "EncryptionMaterials": EncryptionMaterials: A new instance with fields populated from the dictionary """ return cls( - encryption_context=materials_dict.get("encryption_context", {}), + encryption_context=safe_get_dict(materials_dict, "encryption_context"), encrypted_data_key=materials_dict.get("encrypted_data_key"), plaintext_data_key=materials_dict.get("plaintext_data_key"), ) @@ -292,9 +293,9 @@ def from_dict(cls, materials_dict: dict[str, Any]) -> "DecryptionMaterials": return cls( iv=materials_dict.get("iv"), encrypted_data_keys=materials_dict.get("encrypted_data_keys", []), - encryption_context_stored=materials_dict.get("encryption_context_stored", {}), - encryption_context_from_request=materials_dict.get( - "encryption_context_from_request", {} + encryption_context_stored=safe_get_dict(materials_dict, "encryption_context_stored"), + encryption_context_from_request=safe_get_dict( + materials_dict, "encryption_context_from_request" ), plaintext_data_key=materials_dict.get("plaintext_data_key"), ) diff --git a/src/s3_encryption/pipelines.py b/src/s3_encryption/pipelines.py index 15255173..2b1fe061 100644 --- a/src/s3_encryption/pipelines.py +++ b/src/s3_encryption/pipelines.py @@ -16,6 +16,7 @@ from cryptography.hazmat.primitives.ciphers.aead import AESGCM from cryptography.hazmat.primitives.padding import PKCS7 +from ._utils import safe_get_dict from .buffered_decrypt import one_shot_decrypt from .decryptor import AesCbcDecryptor, AesGcmDecryptor from .exceptions import S3EncryptionClientError @@ -250,7 +251,7 @@ def decrypt( # Convert the metadata dictionary to an ObjectMetadata instance streaming_body: StreamingBody = response.get("Body") content_length = response.get("ContentLength") - encryption_metadata = response.get("Metadata", {}) + encryption_metadata = safe_get_dict(response, "Metadata") metadata = ObjectMetadata.from_dict(encryption_metadata) # Use empty dict if encryption_context is None diff --git a/test/test_decryption_materials.py b/test/test_decryption_materials.py index c160b509..6dd51df6 100644 --- a/test/test_decryption_materials.py +++ b/test/test_decryption_materials.py @@ -85,3 +85,20 @@ def test_to_dict(self): assert materials_dict["encryption_context_stored"] == {"key1": "value1"} assert materials_dict["encryption_context_from_request"] == {"key2": "value2"} assert materials_dict["plaintext_data_key"] == b"plaintext-data-key" + + def test_from_dict_with_none_encryption_contexts(self): + """DecryptionMaterials.from_dict should handle None encryption contexts.""" + materials_dict = { + "encryption_context_stored": None, + "encryption_context_from_request": None, + } + materials = DecryptionMaterials.from_dict(materials_dict) + assert materials.encryption_context_stored == {} + assert materials.encryption_context_from_request == {} + + def test_from_dict_with_missing_encryption_contexts(self): + """DecryptionMaterials.from_dict should default to {} when context keys are missing.""" + materials_dict = {} + materials = DecryptionMaterials.from_dict(materials_dict) + assert materials.encryption_context_stored == {} + assert materials.encryption_context_from_request == {} diff --git a/test/test_encryption_materials.py b/test/test_encryption_materials.py index 54d80146..943a3c13 100644 --- a/test/test_encryption_materials.py +++ b/test/test_encryption_materials.py @@ -53,3 +53,19 @@ def test_to_dict(self): assert materials_dict["encryption_context"] == {"key1": "value1"} assert materials_dict["encrypted_data_key"] == edk assert materials_dict["plaintext_data_key"] == b"plaintext-data-key" + + def test_from_dict_with_none_encryption_context(self): + """EncryptionMaterials.from_dict should handle None encryption_context.""" + materials_dict = { + "encryption_context": None, + "encrypted_data_key": None, + "plaintext_data_key": None, + } + materials = EncryptionMaterials.from_dict(materials_dict) + assert materials.encryption_context == {} + + def test_from_dict_with_missing_encryption_context(self): + """EncryptionMaterials.from_dict should default to {} when key is missing.""" + materials_dict = {} + materials = EncryptionMaterials.from_dict(materials_dict) + assert materials.encryption_context == {} diff --git a/test/test_encryption_materials_integration.py b/test/test_encryption_materials_integration.py index e9e59023..a02343a1 100644 --- a/test/test_encryption_materials_integration.py +++ b/test/test_encryption_materials_integration.py @@ -90,3 +90,19 @@ def test_cmm_get_encryption_materials_with_materials(self): assert result.encryption_context == {"key1": "value1"} assert result.encrypted_data_key is not None assert result.plaintext_data_key is not None + + def test_cmm_get_encryption_materials_with_none_encryption_context(self): + """DefaultCryptoMaterialsManager handles None encryption_context in dict request.""" + keyring = MagicMock() + keyring.on_encrypt.return_value = EncryptionMaterials( + encryption_context={}, + plaintext_data_key=b"key", + ) + cmm = DefaultCryptoMaterialsManager(keyring=keyring) + + # Pass a dict with None encryption_context — should not raise TypeError + cmm.get_encryption_materials({"encryption_context": None}) + + # Keyring should receive empty dict, not None + call_args = keyring.on_encrypt.call_args[0][0] + assert call_args.encryption_context == {} diff --git a/test/test_exceptions.py b/test/test_exceptions.py index f93e3d9d..4fe46bcc 100644 --- a/test/test_exceptions.py +++ b/test/test_exceptions.py @@ -49,3 +49,24 @@ def test_inherits_from_botocore_error(self): def test_can_be_caught_as_botocore_error(self): with pytest.raises(BotoCoreError): raise S3EncryptionClientSecurityError("test security error") + + +from s3_encryption._utils import safe_get_dict + + +class TestSafeGetDict: + def test_returns_value_when_present(self): + assert safe_get_dict({"key": {"a": 1}}, "key") == {"a": 1} + + def test_returns_empty_dict_when_key_missing(self): + assert safe_get_dict({}, "key") == {} + + def test_returns_empty_dict_when_value_is_none(self): + assert safe_get_dict({"key": None}, "key") == {} + + def test_returns_empty_dict_for_empty_value(self): + assert safe_get_dict({"key": {}}, "key") == {} + + def test_preserves_non_empty_dict(self): + data = {"x": "y", "z": "w"} + assert safe_get_dict({"meta": data}, "meta") == data diff --git a/test/test_pipelines.py b/test/test_pipelines.py index edd9ba8d..e06542f0 100644 --- a/test/test_pipelines.py +++ b/test/test_pipelines.py @@ -469,3 +469,48 @@ def test_decrypt_rejects_exclusive_key_collision(self): with pytest.raises(S3EncryptionClientError, match="multiple format versions"): pipeline.decrypt(mock_response, ".instruction", enable_delayed_authentication=False) + + +class TestGetEncryptedObjectPipelineNoneMetadata: + """Tests that None Metadata in response is handled gracefully.""" + + def test_decrypt_with_none_metadata(self): + """Pipeline should not raise TypeError when Metadata is None.""" + mock_cmm = Mock() + pipeline = GetEncryptedObjectPipeline( + cmm=mock_cmm, + commitment_policy=CommitmentPolicy.FORBID_ENCRYPT_ALLOW_DECRYPT, + ) + + response = { + "Body": BytesIO(b"test"), + "ContentLength": 4, + "Metadata": None, + } + + with pytest.raises(S3EncryptionClientError): + pipeline.decrypt( + response, + instruction_suffix=".instruction", + enable_delayed_authentication=False, + ) + + def test_decrypt_with_missing_metadata(self): + """Pipeline should not raise TypeError when Metadata key is absent.""" + mock_cmm = Mock() + pipeline = GetEncryptedObjectPipeline( + cmm=mock_cmm, + commitment_policy=CommitmentPolicy.FORBID_ENCRYPT_ALLOW_DECRYPT, + ) + + response = { + "Body": BytesIO(b"test"), + "ContentLength": 4, + } + + with pytest.raises(S3EncryptionClientError): + pipeline.decrypt( + response, + instruction_suffix=".instruction", + enable_delayed_authentication=False, + )