Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ install:

# Run linting checks
lint:
uv run black --check .
uv run black --check src/ test/
# Enforce ruff checks on src/ but allow test/ to fail
uv run ruff check src/
uv run ruff check test/ || true

# Format code with Black and Ruff
format:
uv run black .
uv run black src/ test/
uv run ruff check --fix src/ test/

# Run all tests
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ exclude = [".git", "__pycache__", "build", "dist"]
[tool.ruff.lint]
# Enable all rules by default, then configure specific rule settings below
select = ["E", "F", "W", "I", "N", "D", "UP", "B", "A", "C4", "PT", "RET", "SIM", "ARG", "ERA"]
ignore = ["ARG002"] # Allow unused method arguments (e.g., **kwargs for API compatibility)

[tool.ruff.lint.pydocstyle]
convention = "google"
Expand Down
223 changes: 143 additions & 80 deletions src/s3_encryption/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
"""Top-level S3 Encryption Client v3 for Python package."""

import io
import threading

from attrs import define, field
from botocore import serialize
from botocore.response import StreamingBody

from .exceptions import S3EncryptionClientError
Expand All @@ -16,7 +16,7 @@
from .materials.keyring import AbstractKeyring
from .pipelines import GetEncryptedObjectPipeline, PutEncryptedObjectPipeline

DEFAULT_ENCODING = "utf-8"
S3_METADATA_PREFIX = "x-amz-meta-"


@define
Expand All @@ -31,31 +31,122 @@ def _default_cmm_for_keyring(self):
return DefaultCryptoMaterialsManager(self.keyring)


class S3EncryptionClientPlugin:
"""Plugin that adds encryption/decryption capabilities to a boto3 S3 client.

This plugin uses boto3's event system to intercept put_object and get_object
calls to provide transparent encryption and decryption of S3 objects.
"""

def __init__(self, config: S3EncryptionClientConfig):
"""Initialize the plugin with encryption configuration.

Args:
config: S3EncryptionClientConfig containing keyring and CMM
"""
self.config = config
self._context = threading.local()

def on_put_object_before_call(self, params, **kwargs):
"""Event handler for before-call.s3.PutObject.

This handler encrypts the body after serialization but before the request is sent.

Args:
params: Dictionary of parameters for the PutObject call (after serialization)
**kwargs: Additional event arguments
"""
# At this point, boto3 has already serialized the Body
# Extract the serialized body from the request
body = params.get("body")
if body is None:
body_bytes = b""
elif isinstance(body, bytes):
body_bytes = body
elif hasattr(body, "read"):
# It's a file-like object (BytesIO, etc.)
# TODO(streaming): Add support for streaming encryption without reading entire body
# into memory
body_bytes = body.read()
else:
# Unexpected body type - should not happen as boto3 validates before this point
raise S3EncryptionClientError("Unexpected type of body parameter!")

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: any chance the request ID is in scope here?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this answers you're question:

If you're asking if we have access to the x-amz-request-id or x-amz-id-2 headers, those come from the service as a part of the response. Since the request hasn't been sent yet at this point in code execution, so those values don't exist yet. You'd need to wait until we have sent the request and the service has responded.

If you only need it for client side tracking, you can delay this to the request-created event and take a look at the amz-sdk-invocation-id header. If you're looking to add any sort of logic around the request id that's connected with the server's request id, you'll need to hook into after-call or later.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was asking to improve the error message.
I frequent customer ask is for request IDs to be in error messages,
but since the request ID is not in scope,
let it roll.


encryption_context = getattr(self._context, "encryption_context", None)

pipeline = PutEncryptedObjectPipeline(self.config.cmm)
encrypted_data, encryption_metadata = pipeline.encrypt(
body_bytes, encryption_context=encryption_context
)

params["body"] = encrypted_data

headers = params.get("headers", {})

# Add encryption metadata to headers
if encryption_metadata:
for key, value in encryption_metadata.items():
# Add as S3 metadata headers
header_key = f"{S3_METADATA_PREFIX}{key}"
headers[header_key] = value

params["headers"] = headers

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand enough of the business logic to know if this was intentional or not, but you may wind up overwriting existing headers here

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm... is the ask to check for header key collision?
That is a reasonable ask, and it might prevent double wrapping of the pipeline.
i.e: a user manages to wrap a encryption client with a encryption client.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not an ask - it's me not understanding. If you want to override it, it's great as is :)

If you're concerned with users setting custom headers, you can always just update the above to

if header_key not in headers:
     headers[header_key] = value

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The headers set by the encryption client are reserved and should not be set by the customer. If the customer does this, those values must be overwritten. This is not explicitly specified, but this is how Java behaves.


def on_get_object_after_call(self, parsed, **kwargs):
"""Event handler for after-call.s3.GetObject.

This handler decrypts the body after the response is received from S3.

Args:
parsed: Dictionary containing the parsed response
**kwargs: Additional event arguments (includes 'params' with request parameters)
"""
# Get encryption context from thread-local storage (set by get_object wrapper)
encryption_context = getattr(self._context, "encryption_context", None)

# The parsed response already has the Body as a StreamingBody
# We need to read it, decrypt it, and replace it

# Create a response dict that matches what the pipeline expects
response = {
"Body": parsed.get("Body"),
"Metadata": parsed.get("Metadata", {}),
}

# Create a pipeline and decrypt the data
pipeline = GetEncryptedObjectPipeline(self.config.cmm)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: To handle instruction files, we are going to need the S3EC client to be passed to the GetEncryptedObjectPipeline.

Is there an S3 Client in scope here?

Note: This does not have to be resolved for this PR, but I am asking @kessplas about this, as Kess has context on the plugin nature that I do not.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Short answer, no.

Long answer, there can be. I think the easiest way to deal with this is to make the original wrapped client an instance variable of the plugin so that the plugin can call out to get_object for the instruction file if needed. IIRC Python uses references everywhere so it should be too heavy, and the client should be thread safe enough to reuse.

decrypted_data = pipeline.decrypt(response, encryption_context)

# Replace body with decrypted data
stream = io.BytesIO(decrypted_data)
streaming_body = StreamingBody(stream, len(decrypted_data))
parsed["Body"] = streaming_body


@define
class S3EncryptionClient:
"""Client for encrypting and decrypting S3 objects.

This client wraps a boto3 S3 client and provides encryption and decryption
capabilities for S3 objects using the configured keyring and crypto materials manager.

The encryption/decryption is implemented using boto3's event system, registering
handlers for before-call and after-call events.
"""

wrapped_s3_client = field()
config: S3EncryptionClientConfig = field()
_plugin: S3EncryptionClientPlugin = field(init=False)

def __attrs_post_init__(self):
"""Validate serialization encoding after initialization.
"""Install the encryption plugin on the wrapped client using boto3 events."""
# Create the plugin
object.__setattr__(self, "_plugin", S3EncryptionClientPlugin(self.config))

Ensures boto3 serializers are using the expected default encoding.
"""
# Sanity check that boto3 serialization are ONLY using the default encoding (utf-8)
# This should always be the case, but changes in encoding would break the assumption that
# the decrypted plaintext adheres to the non-utf8 encoding scheme. So we avoid that.
for sz_name, sz in serialize.SERIALIZERS.items():
if sz.DEFAULT_ENCODING != DEFAULT_ENCODING:
raise S3EncryptionClientError(
f"All Serializers MUST only support utf-8 encoding, but {sz_name} is using "
f"{sz.DEFAULT_ENCODING}!"
)
# Register event handlers using boto3's event system
event_system = self.wrapped_s3_client.meta.events
event_system.register("before-call.s3.PutObject", self._plugin.on_put_object_before_call)
event_system.register("after-call.s3.GetObject", self._plugin.on_get_object_after_call)

def put_object(self, **kwargs):
"""Encrypt and upload an object to S3.
Expand All @@ -71,52 +162,28 @@ def put_object(self, **kwargs):

Returns:
The response from the S3 client's put_object method.

Raises:
S3EncryptionClientError: Any problem with encryption, including if the Body parameter
has an invalid type.
"""
# Extract required parameters from kwargs
bucket = kwargs.pop("Bucket")
key = kwargs.pop("Key")
body = kwargs.pop("Body", b"") # Default to empty bytes when Body is not provided
# Extract EncryptionContext if provided (not a standard S3 parameter)
encryption_context = kwargs.pop("EncryptionContext", None)

# Create a pipeline for this operation
pipeline = PutEncryptedObjectPipeline(self.config.cmm)

# The documentation for boto3 asks for bytes or a file-like object,
# but in reality, it is possible to pass strings.
# Strings will be encoded using DEFAULT_ENCODING,
# which MUST match the default encoding defined int the Serializer class in botocore.
if isinstance(body, str):
data_bytes = body.encode(DEFAULT_ENCODING)
elif isinstance(body, bytes):
data_bytes = body
elif isinstance(body, io.IOBase):
# TODO: Streaming support
raise S3EncryptionClientError(
f"Body parameter of type {type(body)} is not an acceptable type! "
f"Streaming operations are not yet supported."
)
else:
raise S3EncryptionClientError(
f"Body parameter of type {type(body)} is not an acceptable type! "
f"Use bytes or a file-like object."
)

# Now encrypt the bytes/file-like IOBase object
encrypted_data, encryption_metadata = pipeline.encrypt(
data_bytes, encryption_context=encryption_context
)

# Add encryption metadata to the request parameters
params = {"Bucket": bucket, "Key": key, "Body": encrypted_data, **kwargs}

# Add encryption metadata to the parameters
if encryption_metadata:
# Merge any existing metadata with our encryption metadata
metadata = params.get("Metadata", {})
metadata.update(encryption_metadata)
params["Metadata"] = metadata

return self.wrapped_s3_client.put_object(**params)
# Store encryption context in thread-local storage for the event handler
self._plugin._context.encryption_context = encryption_context

try:
return self.wrapped_s3_client.put_object(**kwargs)
except S3EncryptionClientError:
# Re-raise our own exceptions without wrapping
raise
except Exception as e:
raise S3EncryptionClientError(f"Failed to encrypt object: {str(e)}") from e
finally:
# Clean up thread-local storage
if hasattr(self._plugin._context, "encryption_context"):
delattr(self._plugin._context, "encryption_context")

def get_object(self, **kwargs):
"""Download and decrypt an object from S3.
Expand All @@ -131,29 +198,25 @@ def get_object(self, **kwargs):
Returns:
The response from the S3 client's get_object method with the Body
replaced with a StreamingBody containing the decrypted data.

Raises:
S3EncryptionClientError: If decryption fails or the object is not properly encrypted.
"""
# Extract encryption context if provided
# Extract EncryptionContext if provided (not a standard S3 parameter)
encryption_context = kwargs.pop("EncryptionContext", None)

# Create params for the S3 client
params = {**kwargs}

# Get the encrypted object from S3
response = self.wrapped_s3_client.get_object(**params)

# Create a pipeline for this operation
pipeline = GetEncryptedObjectPipeline(self.config.cmm)

# Decrypt the data using the pipeline
decrypted_data = pipeline.decrypt(
response, encryption_context
) # encrypted_data, encryption_metadata)

# Create a new streaming body with the decrypted data
stream = io.BytesIO(decrypted_data)
streaming_body = StreamingBody(stream, len(decrypted_data))

# Update the response with the decrypted data
response["Body"] = streaming_body

return response
# Store encryption context in thread-local storage for the event handler
self._plugin._context.encryption_context = encryption_context

try:
return self.wrapped_s3_client.get_object(**kwargs)
except S3EncryptionClientError:
# Re-raise our own exceptions without wrapping
raise
except Exception as e:
# Wrap any unexpected errors during decryption
raise S3EncryptionClientError(f"Failed to decrypt object: {str(e)}") from e
finally:
# Clean up thread-local storage
if hasattr(self._plugin._context, "encryption_context"):
delattr(self._plugin._context, "encryption_context")
2 changes: 1 addition & 1 deletion src/s3_encryption/materials/kms_keyring.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def on_encrypt(self, enc_materials):
# Call parent class validation
enc_materials = super().on_encrypt(enc_materials)

# Add default encryption context
encryption_context = enc_materials.encryption_context
encryption_context["aws:x-amz-cek-alg"] = "AES/GCM/NoPadding"

Expand Down Expand Up @@ -111,6 +110,7 @@ def on_decrypt(self, dec_materials, encrypted_data_keys=None):
encryption_context_stored_copy = encryption_context_stored.copy()
encryption_context_stored_copy.pop(KMS_V1_DEFAULT_KEY, None)
encryption_context_stored_copy.pop(KMS_CONTEXT_DEFAULT_KEY, None)

if encryption_context_stored_copy != encryption_context_from_request:
# TODO: modeled error
raise S3EncryptionClientError(
Expand Down
5 changes: 3 additions & 2 deletions src/s3_encryption/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ def encrypt(self, plaintext, encryption_context=None):
bytes: The encrypted data
dict: Metadata about the encryption to be stored with the object
"""
# Create encryption materials request with encryption context
# Create encryption materials request with encryption context copy
enc_mats_request = EncryptionMaterials(
encryption_context={} if encryption_context is None else encryption_context
encryption_context={} if encryption_context is None else encryption_context.copy()
)

# Get encryption materials from the crypto materials manager
Expand Down Expand Up @@ -102,6 +102,7 @@ def decrypt(self, response, encryption_context=None):
bytes: The decrypted data
"""
# Convert the metadata dictionary to an ObjectMetadata instance
# TODO: Stream + Buffered Decryption
encrypted_data = response.get("Body").read()
encryption_metadata = response.get("Metadata", {})
metadata = ObjectMetadata.from_dict(encryption_metadata)
Expand Down
Loading
Loading