Skip to content
Merged
303 changes: 302 additions & 1 deletion src/s3_encryption/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""Top-level S3 Encryption Client v4 for Python package."""

import io
import os
import threading

from attrs import define, field
Expand All @@ -19,10 +20,19 @@
)
from .materials.keyring import AbstractKeyring
from .materials.materials import AlgorithmSuite, CommitmentPolicy
from .pipelines import GetEncryptedObjectPipeline, PutEncryptedObjectPipeline
from .pipelines import (
GetEncryptedObjectPipeline,
MultipartUploadPipeline,
PutEncryptedObjectPipeline,
)

S3_METADATA_PREFIX = "x-amz-meta-"

# Default multipart threshold and chunk size (same as boto3 defaults)
_DEFAULT_MULTIPART_THRESHOLD = 8 * 1024 * 1024 # 8 MB
_DEFAULT_MULTIPART_CHUNKSIZE = 8 * 1024 * 1024 # 8 MB
_MIN_MULTIPART_PART_SIZE = 5 * 1024 * 1024 # 5 MB — S3 minimum for non-final parts

# Thread-local context attribute names
_CTX_ENCRYPTION_CONTEXT = "encryption_context"
_CTX_BUCKET = "bucket"
Expand Down Expand Up @@ -341,6 +351,10 @@ class S3EncryptionClient:
wrapped_s3_client = field()
config: S3EncryptionClientConfig = field()
_plugin: S3EncryptionClientPlugin = field(init=False)
# Each upload gets its own pipeline with independent cipher state, keyed by UploadId.
# Access is protected by a lock for thread safety across all Python runtimes.
_multipart_uploads: dict = field(init=False, factory=dict)
_multipart_lock: threading.Lock = field(init=False, factory=threading.Lock)

def __attrs_post_init__(self):
"""Install the encryption plugin on the wrapped client using boto3 events."""
Expand Down Expand Up @@ -563,3 +577,290 @@ def get_object(self, **kwargs):
for attr in _GET_OBJECT_CLEANUP_ATTRS:
if hasattr(self._plugin._context, attr):
delattr(self._plugin._context, attr)

##= specification/s3-encryption/client.md#optional-api-operations
##= type=implementation
##% CreateMultipartUpload MAY be implemented by the S3EC.
def create_multipart_upload(self, **kwargs):
"""Initiate an encrypted multipart upload.

Obtains encryption materials, initializes the cipher, and calls
the underlying S3 CreateMultipartUpload. Encryption metadata is
set on the object at this point.

Args:
**kwargs: Arguments for S3 create_multipart_upload.
May include EncryptionContext.

Returns:
The response from S3 create_multipart_upload.
"""
encryption_context = kwargs.pop("EncryptionContext", None)
_validate_encryption_context(encryption_context)

pipeline = MultipartUploadPipeline(
cmm=self.config.cmm,
encryption_algorithm=self.config.encryption_algorithm,
encryption_context=encryption_context or {},
)
Comment thread
kessplas marked this conversation as resolved.

# Merge encryption metadata into user-provided Metadata
user_metadata = dict(kwargs.get("Metadata", {}))
user_metadata.update(pipeline.metadata)
kwargs["Metadata"] = user_metadata

##= specification/s3-encryption/client.md#optional-api-operations
##= type=implementation
##% If implemented, CreateMultipartUpload MUST initiate a multipart upload.
Comment on lines +612 to +614

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
##= specification/s3-encryption/client.md#optional-api-operations
##= type=implementation
##% If implemented, CreateMultipartUpload MUST initiate a multipart upload.
##= specification/s3-encryption/client.md#optional-api-operations
##% If implemented, CreateMultipartUpload MUST initiate a multipart upload.

nit, I prefer less content if possible. Implementation is the default annotation type so specifying implementation isn't necessary

try:
response = self.wrapped_s3_client.create_multipart_upload(**kwargs)
except Exception as e:
raise S3EncryptionClientError(f"Failed to create multipart upload: {e}") from e

upload_id = response["UploadId"]
with self._multipart_lock:
self._multipart_uploads[upload_id] = pipeline
return response

##= specification/s3-encryption/client.md#optional-api-operations
##= type=implementation
##% UploadPart MAY be implemented by the S3EC.
##= specification/s3-encryption/client.md#optional-api-operations
##= type=implementation
##% UploadPart MUST encrypt each part.
##= specification/s3-encryption/client.md#optional-api-operations
##= type=implementation
##% Each part MUST be encrypted in sequence.
##= specification/s3-encryption/client.md#optional-api-operations
##= type=implementation
##% Each part MUST be encrypted using the same cipher instance for each part.
def upload_part(self, **kwargs):
"""Encrypt and upload a single part of a multipart upload.

Parts must be uploaded in sequential order (1, 2, 3, ...).
The caller MUST set ``IsLastPart=True`` on the final part so the
GCM authentication tag is appended to the ciphertext.

Args:
**kwargs: Arguments for S3 upload_part. Must include UploadId,
PartNumber, and Body. Set IsLastPart=True on the
final part.

Returns:
The response from S3 upload_part (includes ETag).
"""
Comment thread
kessplas marked this conversation as resolved.
upload_id = kwargs.get("UploadId")
with self._multipart_lock:
pipeline = self._multipart_uploads.get(upload_id)
if pipeline is None:
raise S3EncryptionClientError(
f"No multipart upload found for UploadId: {upload_id}. "
"Call create_multipart_upload first."
)

part_number = kwargs["PartNumber"]
is_last = kwargs.pop("IsLastPart", False)
body = kwargs.get("Body", b"")
if isinstance(body, str):
body = body.encode("utf-8")
elif hasattr(body, "read"):
body = body.read()

try:
ciphertext = pipeline.encrypt_part(part_number, body, is_last=is_last)
except S3EncryptionClientError:
raise
except Exception as e:
raise S3EncryptionClientError(f"Failed to encrypt part {part_number}: {e}") from e

kwargs["Body"] = ciphertext
return self.wrapped_s3_client.upload_part(**kwargs)

##= specification/s3-encryption/client.md#optional-api-operations
##= type=implementation
##% CompleteMultipartUpload MAY be implemented by the S3EC.
##% CompleteMultipartUpload MUST complete the multipart upload.
def complete_multipart_upload(self, **kwargs):
"""Complete the multipart upload.

The final part must have been uploaded with ``IsLastPart=True``
before calling this method.

Args:
**kwargs: Arguments for S3 complete_multipart_upload.
MultipartUpload.Parts must include PartNumber and ETag
for each part.

Returns:
The response from S3 complete_multipart_upload.
"""
upload_id = kwargs.get("UploadId")
with self._multipart_lock:
pipeline = self._multipart_uploads.get(upload_id)
if pipeline is None:
raise S3EncryptionClientError(f"No multipart upload found for UploadId: {upload_id}.")

if not pipeline.has_final_part_been_seen:
raise S3EncryptionClientError(
"Cannot complete multipart upload: the final part has not been uploaded. "
"Set IsLastPart=True on the last upload_part call."
)

try:
response = self.wrapped_s3_client.complete_multipart_upload(**kwargs)
except S3EncryptionClientError:
raise
except Exception as e:
raise S3EncryptionClientError(f"Failed to complete multipart upload: {e}") from e
else:
with self._multipart_lock:
self._multipart_uploads.pop(upload_id, None)
return response

##= specification/s3-encryption/client.md#optional-api-operations
##= type=implementation
##% AbortMultipartUpload MAY be implemented by the S3EC.
##% AbortMultipartUpload MUST abort the multipart upload.
def abort_multipart_upload(self, **kwargs):
"""Abort a multipart upload and clean up cipher state.

Args:
**kwargs: Arguments for S3 abort_multipart_upload.

Returns:
The response from S3 abort_multipart_upload.
"""
upload_id = kwargs.get("UploadId")
with self._multipart_lock:
self._multipart_uploads.pop(upload_id, None)
return self.wrapped_s3_client.abort_multipart_upload(**kwargs)

def upload_file(
self, filename, bucket, key, multipart_threshold=None, multipart_chunksize=None, **kwargs
):
"""Encrypt and upload a file to S3.

If the file is smaller than the threshold, uses put_object.
Otherwise, performs an encrypted multipart upload.

Args:
filename: Path to the file to upload.
bucket: Target S3 bucket.
key: Target S3 object key.
multipart_threshold: File size threshold for multipart (default 8MB).
multipart_chunksize: Size of each part (default 8MB).
**kwargs: Additional arguments (e.g. EncryptionContext, Metadata).
"""
threshold = (
_DEFAULT_MULTIPART_THRESHOLD if multipart_threshold is None else multipart_threshold
)
chunksize = (
_DEFAULT_MULTIPART_CHUNKSIZE if multipart_chunksize is None else multipart_chunksize
)
if threshold <= 0:
raise S3EncryptionClientError("multipart_threshold must be a positive integer.")
if chunksize <= 0:
raise S3EncryptionClientError("multipart_chunksize must be a positive integer.")
if chunksize < _MIN_MULTIPART_PART_SIZE:
raise S3EncryptionClientError(
f"multipart_chunksize must be at least {_MIN_MULTIPART_PART_SIZE} bytes (5 MB). "
f"S3 requires all non-final parts to be at least 5 MB."
)
file_size = os.path.getsize(filename)
Comment thread
kessplas marked this conversation as resolved.

if file_size < threshold:
with open(filename, "rb") as f:
kwargs["Bucket"] = bucket
kwargs["Key"] = key
kwargs["Body"] = f.read()
return self.put_object(**kwargs)

return self._multipart_upload_from_readable(
open(filename, "rb"), bucket, key, chunksize, owns_readable=True, **kwargs
)

def upload_fileobj(self, fileobj, bucket, key, multipart_chunksize=None, **kwargs):
"""Encrypt and upload a file-like object to S3 via multipart upload.

The caller retains ownership of fileobj — it will not be closed
by this method.

Args:
fileobj: A file-like object with a read() method.
bucket: Target S3 bucket.
key: Target S3 object key.
multipart_chunksize: Size of each part (default 8MB).
**kwargs: Additional arguments (e.g. EncryptionContext, Metadata).
"""
chunksize = (
_DEFAULT_MULTIPART_CHUNKSIZE if multipart_chunksize is None else multipart_chunksize
)
if chunksize <= 0:
raise S3EncryptionClientError("multipart_chunksize must be a positive integer.")
if chunksize < _MIN_MULTIPART_PART_SIZE:
raise S3EncryptionClientError(
f"multipart_chunksize must be at least {_MIN_MULTIPART_PART_SIZE} bytes (5 MB). "
f"S3 requires all non-final parts to be at least 5 MB."
)
return self._multipart_upload_from_readable(
fileobj, bucket, key, chunksize, owns_readable=False, **kwargs
)

def _multipart_upload_from_readable(
self, readable, bucket, key, chunksize, *, owns_readable=False, **kwargs
):
"""Perform an encrypted multipart upload from a readable source.

Args:
readable: File-like object to read from.
bucket: Target S3 bucket.
key: Target S3 object key.
chunksize: Size of each part in bytes.
owns_readable: If True, close readable when done. If False,
the caller is responsible for closing it.
**kwargs: Additional S3 parameters forwarded to create_multipart_upload.
"""
# EncryptionContext is consumed by our pipeline, not S3
create_kwargs = {"Bucket": bucket, "Key": key}
if "EncryptionContext" in kwargs:
create_kwargs["EncryptionContext"] = kwargs.pop("EncryptionContext")
if "Metadata" in kwargs:
create_kwargs["Metadata"] = kwargs.pop("Metadata")
Comment thread
kessplas marked this conversation as resolved.
# Forward remaining kwargs (ACL, ContentType, Tagging, etc.) to create_multipart_upload
create_kwargs.update(kwargs)

create_resp = self.create_multipart_upload(**create_kwargs)
upload_id = create_resp["UploadId"]

try:
parts = []
part_number = 0
# Read ahead so we can detect the last chunk
current_chunk = readable.read(chunksize)
while current_chunk:
next_chunk = readable.read(chunksize)
part_number += 1
is_last = not next_chunk
resp = self.upload_part(
Bucket=bucket,
Key=key,
UploadId=upload_id,
PartNumber=part_number,
Body=current_chunk,
IsLastPart=is_last,
)
parts.append({"PartNumber": part_number, "ETag": resp["ETag"]})
current_chunk = next_chunk

return self.complete_multipart_upload(
Bucket=bucket,
Key=key,
UploadId=upload_id,
MultipartUpload={"Parts": parts},
)
except Exception:
self.abort_multipart_upload(Bucket=bucket, Key=key, UploadId=upload_id)
raise
Comment on lines +832 to +863
finally:
if owns_readable and hasattr(readable, "close"):
readable.close()
Comment thread
kessplas marked this conversation as resolved.
Loading
Loading