Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/s3_encryption/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from botocore.exceptions import ClientError
from botocore.response import StreamingBody

from ._utils import safe_get_dict
from .exceptions import S3EncryptionClientError
from .instruction_file import parse_instruction_file
from .instruction_file_config import InstructionFileConfig
Expand Down Expand Up @@ -198,7 +199,7 @@ def on_put_object_before_call(self, params, **kwargs):

params["body"] = encrypted_data

headers = params.get("headers", {})
headers = safe_get_dict(params, "headers")

# Add encryption metadata to headers
if encryption_metadata:
Expand Down Expand Up @@ -244,7 +245,7 @@ def on_get_object_after_call(self, parsed, **kwargs):
# Create a response dict that matches what the pipeline expects
response = {
"Body": parsed.get("Body"),
"Metadata": parsed.get("Metadata", {}),
"Metadata": safe_get_dict(parsed, "Metadata"),
"ContentLength": content_length,
}

Expand Down Expand Up @@ -286,8 +287,7 @@ def process_instruction_file(self, parsed):
)

# In plaintext mode, parse instruction file and append to metadata
# Metadata may be present but None, so `or {}` handles that case
existing_metadata = parsed.get("Metadata", {}) or {}
existing_metadata = safe_get_dict(parsed, "Metadata")
instruction_data = body.read()
instruction_metadata = parse_instruction_file(instruction_data, instruction_key)

Expand Down
12 changes: 12 additions & 0 deletions src/s3_encryption/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
"""Internal utility helpers for the S3 Encryption Client."""


def safe_get_dict(source: dict, key: str) -> dict:
"""Get a dict value from *source*, defaulting to {} if missing or None.

This avoids the common pitfall where ``d.get(key, {})`` returns None
when the key exists but its value is explicitly None.
"""
return source.get(key, {}) or {}
3 changes: 2 additions & 1 deletion src/s3_encryption/instruction_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from botocore.exceptions import ClientError

from ._utils import safe_get_dict
from .exceptions import S3EncryptionClientError
from .metadata import VALID_S3EC_METADATA_KEYS

Expand Down Expand Up @@ -109,7 +110,7 @@ def fetch_instruction_file(s3_client, bucket: str, key: str) -> dict[str, Any]:
s3_client._s3ec_plugin_context.instruction_file_mode = False

# In plaintext mode, the event handler places parsed metadata in Metadata field
metadata = response.get("Metadata", {})
metadata = safe_get_dict(response, "Metadata")

# Verify metadata is not empty
if not metadata:
Expand Down
3 changes: 2 additions & 1 deletion src/s3_encryption/materials/crypto_materials_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from attrs import define

from .._utils import safe_get_dict
from .keyring import AbstractKeyring
from .materials import DecryptionMaterials, EncryptionMaterials

Expand Down Expand Up @@ -74,7 +75,7 @@ def get_encryption_materials(self, enc_mats_request):
# Convert dictionary to EncryptionMaterials if needed
if isinstance(enc_mats_request, dict):
materials = EncryptionMaterials(
encryption_context=enc_mats_request.get("encryption_context", {})
encryption_context=safe_get_dict(enc_mats_request, "encryption_context")
)
else:
materials = enc_mats_request
Expand Down
9 changes: 5 additions & 4 deletions src/s3_encryption/materials/materials.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from attrs import define, field

from .._utils import safe_get_dict
from .encrypted_data_key import EncryptedDataKey


Expand Down Expand Up @@ -232,7 +233,7 @@ def from_dict(cls, materials_dict: dict[str, Any]) -> "EncryptionMaterials":
EncryptionMaterials: A new instance with fields populated from the dictionary
"""
return cls(
encryption_context=materials_dict.get("encryption_context", {}),
encryption_context=safe_get_dict(materials_dict, "encryption_context"),
encrypted_data_key=materials_dict.get("encrypted_data_key"),
plaintext_data_key=materials_dict.get("plaintext_data_key"),
)
Expand Down Expand Up @@ -292,9 +293,9 @@ def from_dict(cls, materials_dict: dict[str, Any]) -> "DecryptionMaterials":
return cls(
iv=materials_dict.get("iv"),
encrypted_data_keys=materials_dict.get("encrypted_data_keys", []),
encryption_context_stored=materials_dict.get("encryption_context_stored", {}),
encryption_context_from_request=materials_dict.get(
"encryption_context_from_request", {}
encryption_context_stored=safe_get_dict(materials_dict, "encryption_context_stored"),
encryption_context_from_request=safe_get_dict(
materials_dict, "encryption_context_from_request"
),
plaintext_data_key=materials_dict.get("plaintext_data_key"),
)
Expand Down
3 changes: 2 additions & 1 deletion src/s3_encryption/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from cryptography.hazmat.primitives.padding import PKCS7

from ._utils import safe_get_dict
from .buffered_decrypt import one_shot_decrypt
from .decryptor import AesCbcDecryptor, AesGcmDecryptor
from .exceptions import S3EncryptionClientError
Expand Down Expand Up @@ -250,7 +251,7 @@ def decrypt(
# Convert the metadata dictionary to an ObjectMetadata instance
streaming_body: StreamingBody = response.get("Body")
content_length = response.get("ContentLength")
encryption_metadata = response.get("Metadata", {})
encryption_metadata = safe_get_dict(response, "Metadata")
metadata = ObjectMetadata.from_dict(encryption_metadata)

# Use empty dict if encryption_context is None
Expand Down
17 changes: 17 additions & 0 deletions test/test_decryption_materials.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,20 @@ def test_to_dict(self):
assert materials_dict["encryption_context_stored"] == {"key1": "value1"}
assert materials_dict["encryption_context_from_request"] == {"key2": "value2"}
assert materials_dict["plaintext_data_key"] == b"plaintext-data-key"

def test_from_dict_with_none_encryption_contexts(self):
"""DecryptionMaterials.from_dict should handle None encryption contexts."""
materials_dict = {
"encryption_context_stored": None,
"encryption_context_from_request": None,
}
materials = DecryptionMaterials.from_dict(materials_dict)
assert materials.encryption_context_stored == {}
assert materials.encryption_context_from_request == {}

def test_from_dict_with_missing_encryption_contexts(self):
"""DecryptionMaterials.from_dict should default to {} when context keys are missing."""
materials_dict = {}
materials = DecryptionMaterials.from_dict(materials_dict)
assert materials.encryption_context_stored == {}
assert materials.encryption_context_from_request == {}
16 changes: 16 additions & 0 deletions test/test_encryption_materials.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,19 @@ def test_to_dict(self):
assert materials_dict["encryption_context"] == {"key1": "value1"}
assert materials_dict["encrypted_data_key"] == edk
assert materials_dict["plaintext_data_key"] == b"plaintext-data-key"

def test_from_dict_with_none_encryption_context(self):
"""EncryptionMaterials.from_dict should handle None encryption_context."""
materials_dict = {
"encryption_context": None,
"encrypted_data_key": None,
"plaintext_data_key": None,
}
materials = EncryptionMaterials.from_dict(materials_dict)
assert materials.encryption_context == {}

def test_from_dict_with_missing_encryption_context(self):
"""EncryptionMaterials.from_dict should default to {} when key is missing."""
materials_dict = {}
materials = EncryptionMaterials.from_dict(materials_dict)
assert materials.encryption_context == {}
16 changes: 16 additions & 0 deletions test/test_encryption_materials_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,19 @@ def test_cmm_get_encryption_materials_with_materials(self):
assert result.encryption_context == {"key1": "value1"}
assert result.encrypted_data_key is not None
assert result.plaintext_data_key is not None

def test_cmm_get_encryption_materials_with_none_encryption_context(self):
"""DefaultCryptoMaterialsManager handles None encryption_context in dict request."""
keyring = MagicMock()
keyring.on_encrypt.return_value = EncryptionMaterials(
encryption_context={},
plaintext_data_key=b"key",
)
cmm = DefaultCryptoMaterialsManager(keyring=keyring)

# Pass a dict with None encryption_context — should not raise TypeError
cmm.get_encryption_materials({"encryption_context": None})

# Keyring should receive empty dict, not None
call_args = keyring.on_encrypt.call_args[0][0]
assert call_args.encryption_context == {}
21 changes: 21 additions & 0 deletions test/test_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,24 @@ def test_inherits_from_botocore_error(self):
def test_can_be_caught_as_botocore_error(self):
with pytest.raises(BotoCoreError):
raise S3EncryptionClientSecurityError("test security error")


from s3_encryption._utils import safe_get_dict


class TestSafeGetDict:
def test_returns_value_when_present(self):
assert safe_get_dict({"key": {"a": 1}}, "key") == {"a": 1}

def test_returns_empty_dict_when_key_missing(self):
assert safe_get_dict({}, "key") == {}

def test_returns_empty_dict_when_value_is_none(self):
assert safe_get_dict({"key": None}, "key") == {}

def test_returns_empty_dict_for_empty_value(self):
assert safe_get_dict({"key": {}}, "key") == {}

def test_preserves_non_empty_dict(self):
data = {"x": "y", "z": "w"}
assert safe_get_dict({"meta": data}, "meta") == data
45 changes: 45 additions & 0 deletions test/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,3 +469,48 @@ def test_decrypt_rejects_exclusive_key_collision(self):

with pytest.raises(S3EncryptionClientError, match="multiple format versions"):
pipeline.decrypt(mock_response, ".instruction", enable_delayed_authentication=False)


class TestGetEncryptedObjectPipelineNoneMetadata:
"""Tests that None Metadata in response is handled gracefully."""

def test_decrypt_with_none_metadata(self):
"""Pipeline should not raise TypeError when Metadata is None."""
mock_cmm = Mock()
pipeline = GetEncryptedObjectPipeline(
cmm=mock_cmm,
commitment_policy=CommitmentPolicy.FORBID_ENCRYPT_ALLOW_DECRYPT,
)

response = {
"Body": BytesIO(b"test"),
"ContentLength": 4,
"Metadata": None,
}

with pytest.raises(S3EncryptionClientError):
pipeline.decrypt(
response,
instruction_suffix=".instruction",
enable_delayed_authentication=False,
)

def test_decrypt_with_missing_metadata(self):
"""Pipeline should not raise TypeError when Metadata key is absent."""
mock_cmm = Mock()
pipeline = GetEncryptedObjectPipeline(
cmm=mock_cmm,
commitment_policy=CommitmentPolicy.FORBID_ENCRYPT_ALLOW_DECRYPT,
)

response = {
"Body": BytesIO(b"test"),
"ContentLength": 4,
}

with pytest.raises(S3EncryptionClientError):
pipeline.decrypt(
response,
instruction_suffix=".instruction",
enable_delayed_authentication=False,
)
Loading