diff --git a/decart/__init__.py b/decart/__init__.py index e6ec580..a9416d4 100644 --- a/decart/__init__.py +++ b/decart/__init__.py @@ -11,6 +11,8 @@ QueueStatusError, QueueResultError, TokenCreateError, + FileTooLargeError, + MAX_FILE_SIZE, ) from .models import models, ModelDefinition, VideoRestyleInput from .types import FileInput, ModelState, Prompt @@ -66,6 +68,8 @@ "QueueSubmitError", "QueueStatusError", "QueueResultError", + "FileTooLargeError", + "MAX_FILE_SIZE", "models", "ModelDefinition", "VideoRestyleInput", diff --git a/decart/client.py b/decart/client.py index 247976a..dd9a5e8 100644 --- a/decart/client.py +++ b/decart/client.py @@ -169,7 +169,7 @@ async def process(self, options: dict[str, Any]) -> bytes: inputs = {k: v for k, v in options.items() if k not in ("model", "cancel_token")} # File fields that need special handling (not validated by Pydantic) - FILE_FIELDS = {"data", "start", "end"} + FILE_FIELDS = {"data", "start", "end", "reference_image"} # Separate file inputs from regular inputs file_inputs = {k: v for k, v in inputs.items() if k in FILE_FIELDS} diff --git a/decart/errors.py b/decart/errors.py index 258a5e5..597923e 100644 --- a/decart/errors.py +++ b/decart/errors.py @@ -1,5 +1,8 @@ from typing import Any, Optional +# Maximum file size allowed for uploads (20MB) +MAX_FILE_SIZE = 20 * 1024 * 1024 + class DecartSDKError(Exception): """Base exception for all Decart SDK errors.""" @@ -88,3 +91,20 @@ class TokenCreateError(DecartSDKError): """Raised when token creation fails.""" pass + + +class FileTooLargeError(DecartSDKError): + """Raised when a file exceeds the maximum allowed size.""" + + def __init__(self, file_size: int, max_size: int, field_name: str | None = None) -> None: + file_size_mb = f"{file_size / (1024 * 1024):.1f}" + max_size_mb = f"{max_size / (1024 * 1024):.0f}" + field = f" for field '{field_name}'" if field_name else "" + super().__init__( + f"File size{field} ({file_size_mb}MB) exceeds the maximum allowed size of {max_size_mb}MB. " + f"Please reduce the file size or resolution before uploading.", + data={"file_size": file_size, "max_size": max_size, "field_name": field_name}, + ) + self.file_size = file_size + self.max_size = max_size + self.field_name = field_name diff --git a/decart/models.py b/decart/models.py index 630d1e2..14bc413 100644 --- a/decart/models.py +++ b/decart/models.py @@ -32,6 +32,7 @@ class ModelDefinition(DecartBaseModel, Generic[ModelT]): fps: int = Field(ge=1) width: int = Field(ge=1) height: int = Field(ge=1) + max_file_size: Optional[int] = None input_schema: type[BaseModel] @@ -255,6 +256,7 @@ class ImageToImageInput(DecartBaseModel): fps=25, width=1280, height=704, + max_file_size=100 * 1024 * 1024, input_schema=VideoRestyleInput, ), }, diff --git a/decart/process/request.py b/decart/process/request.py index 9f6b132..cf68ed4 100644 --- a/decart/process/request.py +++ b/decart/process/request.py @@ -5,7 +5,7 @@ from typing import Any, Optional from ..types import FileInput from ..models import ModelDefinition -from ..errors import InvalidInputError, ProcessingError +from ..errors import InvalidInputError, ProcessingError, FileTooLargeError, MAX_FILE_SIZE from .._user_agent import build_user_agent @@ -91,6 +91,9 @@ async def send_request( if value is not None: if key in ("data", "start", "end"): content, content_type = await file_input_to_bytes(value, session) + limit = model.max_file_size or MAX_FILE_SIZE + if len(content) > limit: + raise FileTooLargeError(len(content), limit, key) form_data.add_field(key, content, content_type=content_type) else: form_data.add_field(key, str(value)) diff --git a/decart/queue/client.py b/decart/queue/client.py index c06e0ec..cdba1e5 100644 --- a/decart/queue/client.py +++ b/decart/queue/client.py @@ -93,7 +93,7 @@ async def submit(self, options: dict[str, Any]) -> JobSubmitResponse: inputs = {k: v for k, v in options.items() if k not in ("model", "cancel_token")} # File fields that need special handling - FILE_FIELDS = {"data", "start", "end"} + FILE_FIELDS = {"data", "start", "end", "reference_image"} # Separate file inputs from regular inputs file_inputs = {k: v for k, v in inputs.items() if k in FILE_FIELDS} diff --git a/decart/queue/request.py b/decart/queue/request.py index c242b62..a6f4642 100644 --- a/decart/queue/request.py +++ b/decart/queue/request.py @@ -2,7 +2,13 @@ from typing import Any, Optional from ..models import ModelDefinition -from ..errors import QueueSubmitError, QueueStatusError, QueueResultError +from ..errors import ( + QueueSubmitError, + QueueStatusError, + QueueResultError, + FileTooLargeError, + MAX_FILE_SIZE, +) from .._user_agent import build_user_agent from ..process.request import file_input_to_bytes from .types import JobSubmitResponse, JobStatusResponse @@ -26,6 +32,9 @@ async def submit_job( if value is not None: if key in ("data", "start", "end", "reference_image"): content, content_type = await file_input_to_bytes(value, session) + limit = model.max_file_size or MAX_FILE_SIZE + if len(content) > limit: + raise FileTooLargeError(len(content), limit, key) form_data.add_field(key, content, content_type=content_type) else: form_data.add_field(key, str(value)) diff --git a/tests/test_process.py b/tests/test_process.py index e479378..a0ba2ad 100644 --- a/tests/test_process.py +++ b/tests/test_process.py @@ -219,3 +219,52 @@ async def test_process_includes_integration_in_user_agent() -> None: assert "lang/py" in headers["User-Agent"] assert "langchain/0.1.0" in headers["User-Agent"] assert headers["User-Agent"].endswith(" langchain/0.1.0") + + +@pytest.mark.asyncio +async def test_process_rejects_file_exceeding_20mb() -> None: + """Test that files exceeding 20MB are rejected before upload.""" + client = DecartClient(api_key="test-key") + + large_data = b"x" * (21 * 1024 * 1024) + + with pytest.raises(DecartSDKError, match="exceeds the maximum allowed size of 20MB"): + await client.process( + { + "model": models.image("lucy-pro-i2i"), + "prompt": "test", + "data": large_data, + } + ) + + +@pytest.mark.asyncio +async def test_process_accepts_file_at_20mb_limit() -> None: + """Test that files at exactly 20MB are accepted.""" + client = DecartClient(api_key="test-key") + + exact_data = b"x" * (20 * 1024 * 1024) + + with patch("aiohttp.ClientSession") as mock_session_cls: + mock_response = MagicMock() + mock_response.ok = True + mock_response.read = AsyncMock(return_value=b"fake image data") + + mock_session = MagicMock() + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=None) + mock_session.post = MagicMock() + mock_session.post.return_value.__aenter__ = AsyncMock(return_value=mock_response) + mock_session.post.return_value.__aexit__ = AsyncMock(return_value=None) + + mock_session_cls.return_value = mock_session + + result = await client.process( + { + "model": models.image("lucy-pro-i2i"), + "prompt": "test", + "data": exact_data, + } + ) + + assert result == b"fake image data" diff --git a/tests/test_queue.py b/tests/test_queue.py index 3d287ee..d9e4798 100644 --- a/tests/test_queue.py +++ b/tests/test_queue.py @@ -121,7 +121,6 @@ async def test_queue_submit_and_poll_completed() -> None: patch("decart.queue.client.get_job_content") as mock_content, patch("asyncio.sleep", new_callable=AsyncMock), ): - mock_submit.return_value = MagicMock(job_id="job-123", status="pending") mock_status.return_value = MagicMock(job_id="job-123", status="completed") mock_content.return_value = b"fake video data" @@ -147,7 +146,6 @@ async def test_queue_submit_and_poll_failed() -> None: patch("decart.queue.client.get_job_status") as mock_status, patch("asyncio.sleep", new_callable=AsyncMock), ): - mock_submit.return_value = MagicMock(job_id="job-123", status="pending") mock_status.return_value = MagicMock(job_id="job-123", status="failed") @@ -177,7 +175,6 @@ def on_status_change(job): patch("decart.queue.client.get_job_content") as mock_content, patch("asyncio.sleep", new_callable=AsyncMock), ): - mock_submit.return_value = MagicMock(job_id="job-123", status="pending") mock_status.side_effect = [ MagicMock(job_id="job-123", status="processing"), @@ -359,3 +356,59 @@ async def test_queue_restyle_rejects_enhance_prompt_with_reference_image() -> No ) assert "enhance_prompt" in str(exc_info.value).lower() + + +@pytest.mark.asyncio +async def test_queue_rejects_file_exceeding_20mb() -> None: + """Test that files exceeding 20MB are rejected before upload.""" + client = DecartClient(api_key="test-key") + + large_data = b"x" * (21 * 1024 * 1024) + + with pytest.raises(DecartSDKError, match="exceeds the maximum allowed size of 20MB"): + await client.queue.submit( + { + "model": models.video("lucy-pro-v2v"), + "prompt": "test", + "data": large_data, + } + ) + + +@pytest.mark.asyncio +async def test_queue_accepts_over_20mb_for_restyle() -> None: + """Test that lucy-restyle-v2v accepts files over 20MB (up to 100MB).""" + client = DecartClient(api_key="test-key") + + data_50mb = b"x" * (50 * 1024 * 1024) + + with patch("decart.queue.client.submit_job") as mock_submit: + mock_submit.return_value = MagicMock(job_id="job-restyle", status="pending") + + job = await client.queue.submit( + { + "model": models.video("lucy-restyle-v2v"), + "prompt": "Restyle this", + "data": data_50mb, + } + ) + + assert job.job_id == "job-restyle" + mock_submit.assert_called_once() + + +@pytest.mark.asyncio +async def test_queue_rejects_file_exceeding_100mb_for_restyle() -> None: + """Test that lucy-restyle-v2v rejects files over 100MB.""" + client = DecartClient(api_key="test-key") + + data_101mb = b"x" * (101 * 1024 * 1024) + + with pytest.raises(DecartSDKError, match="exceeds the maximum allowed size of 100MB"): + await client.queue.submit( + { + "model": models.video("lucy-restyle-v2v"), + "prompt": "Restyle this", + "data": data_101mb, + } + )