diff --git a/decart/models.py b/decart/models.py index 2be8c64..2e59c47 100644 --- a/decart/models.py +++ b/decart/models.py @@ -1,7 +1,7 @@ -from typing import Literal, Optional +from typing import Literal, Optional, List from pydantic import BaseModel, Field, ConfigDict from .errors import ModelNotFoundError -from .types import FileInput +from .types import FileInput, MotionTrajectoryInput RealTimeModels = Literal["mirage", "mirage_v2", "lucy_v2v_720p_rt"] @@ -12,6 +12,7 @@ "lucy-pro-i2v", "lucy-pro-v2v", "lucy-pro-flf2v", + "lucy-motion", ] ImageModels = Literal["lucy-pro-t2i", "lucy-pro-i2i"] Model = Literal[RealTimeModels, VideoModels, ImageModels] @@ -61,6 +62,13 @@ class FirstLastFrameInput(DecartBaseModel): resolution: Optional[str] = None +class ImageToMotionVideoInput(DecartBaseModel): + data: FileInput + trajectory: List[MotionTrajectoryInput] = Field(..., min_length=2, max_length=121) + seed: Optional[int] = None + resolution: Optional[str] = None + + class TextToImageInput(BaseModel): prompt: str = Field(..., min_length=1) seed: Optional[int] = None @@ -152,6 +160,14 @@ class ImageToImageInput(DecartBaseModel): height=704, input_schema=FirstLastFrameInput, ), + "lucy-motion": ModelDefinition( + name="lucy-motion", + url_path="/v1/generate/lucy-motion", + fps=25, + width=1280, + height=704, + input_schema=ImageToMotionVideoInput, + ), }, "image": { "lucy-pro-t2i": ModelDefinition( diff --git a/decart/types.py b/decart/types.py index 31bdac4..a0e78cc 100644 --- a/decart/types.py +++ b/decart/types.py @@ -21,3 +21,9 @@ class Prompt(BaseModel): class ModelState(BaseModel): prompt: Optional[Prompt] = None mirror: bool = Field(default=False) + + +class MotionTrajectoryInput(BaseModel): + frame: int = Field(..., ge=0) + x: float = Field(..., ge=0) + y: float = Field(..., ge=0) diff --git a/tests/test_process.py b/tests/test_process.py index 383e1fa..fcd6e92 100644 --- a/tests/test_process.py +++ b/tests/test_process.py @@ -2,6 +2,7 @@ import asyncio from unittest.mock import AsyncMock, patch, MagicMock from decart import DecartClient, models, DecartSDKError +from decart.types import MotionTrajectoryInput @pytest.mark.asyncio @@ -86,6 +87,60 @@ async def test_process_video_to_video() -> None: assert result == b"fake video data" +@pytest.mark.asyncio +async def test_process_image_to_motion_video() -> None: + client = DecartClient(api_key="test-key") + + with patch("aiohttp.ClientSession") as mock_session_cls: + mock_response = MagicMock() + mock_response.ok = True + mock_response.read = AsyncMock(return_value=b"fake video data") + + mock_session = MagicMock() + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=None) + mock_session.post = MagicMock() + mock_session.post.return_value.__aenter__ = AsyncMock(return_value=mock_response) + mock_session.post.return_value.__aexit__ = AsyncMock(return_value=None) + + mock_session_cls.return_value = mock_session + + result = await client.process( + { + "model": models.video("lucy-motion"), + "data": b"fake input image", + "trajectory": [ + MotionTrajectoryInput(frame=0, x=0, y=0), + MotionTrajectoryInput(frame=1, x=0.5, y=0.5), + MotionTrajectoryInput(frame=2, x=1, y=1), + MotionTrajectoryInput(frame=3, x=1.5, y=1.5), + MotionTrajectoryInput(frame=4, x=2, y=2), + ], + } + ) + + assert result == b"fake video data" + + +@pytest.mark.asyncio +async def test_process_image_to_motion_video_invalid_trajectory() -> None: + client = DecartClient(api_key="test-key") + + with pytest.raises(DecartSDKError) as exception: + await client.process( + { + "model": models.video("lucy-motion"), + "data": b"fake input image", + "trajectory": [ + MotionTrajectoryInput(frame=0, x=0, y=0), + ], + } + ) + assert "Invalid inputs for lucy-motion: 1 validation error for ImageToMotionVideoInput" in str( + exception + ) + + @pytest.mark.asyncio async def test_process_with_cancellation() -> None: client = DecartClient(api_key="test-key") diff --git a/uv.lock b/uv.lock index d84e211..a15cbf3 100644 --- a/uv.lock +++ b/uv.lock @@ -900,7 +900,7 @@ wheels = [ [[package]] name = "decart" -version = "0.0.5" +version = "0.0.7" source = { editable = "." } dependencies = [ { name = "aiofiles" },