Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions decart/_user_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""User-Agent header construction for SDK requests."""

from typing import Optional
from ._version import __version__


def build_user_agent(integration: Optional[str] = None) -> str:
"""
Builds the User-Agent string for the SDK.

Format: decart-python-sdk/{version} lang/py {integration?}

Args:
integration: Optional integration identifier (e.g., "langchain/0.1.0")

Returns:
Complete User-Agent string

Examples:
>>> build_user_agent()
'decart-python-sdk/0.0.6 lang/py'

>>> build_user_agent("langchain/0.1.0")
'decart-python-sdk/0.0.6 lang/py langchain/0.1.0'
"""
parts = [f"decart-python-sdk/{__version__}", "lang/py"]

if integration:
parts.append(integration)

return " ".join(parts)
9 changes: 9 additions & 0 deletions decart/_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""SDK version information."""

from importlib.metadata import version, PackageNotFoundError

try:
__version__ = version("decart")
except PackageNotFoundError:
# Development version when package is not installed
__version__ = "0.0.0-dev"
10 changes: 9 additions & 1 deletion decart/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class DecartClient:
Args:
api_key: Your Decart API key
base_url: API base URL (defaults to production)
integration: Optional integration identifier (e.g., "langchain/0.1.0")

Example:
```python
Expand All @@ -32,7 +33,12 @@ class DecartClient:
```
"""

def __init__(self, api_key: str, base_url: str = "https://api.decart.ai") -> None:
def __init__(
self,
api_key: str,
base_url: str = "https://api.decart.ai",
integration: Optional[str] = None,
) -> None:
if not api_key or not api_key.strip():
raise InvalidAPIKeyError()

Expand All @@ -41,6 +47,7 @@ def __init__(self, api_key: str, base_url: str = "https://api.decart.ai") -> Non

self.api_key = api_key
self.base_url = base_url
self.integration = integration
self._session: Optional[aiohttp.ClientSession] = None

async def _get_session(self) -> aiohttp.ClientSession:
Expand Down Expand Up @@ -117,6 +124,7 @@ async def process(self, options: dict[str, Any]) -> bytes:
model=model,
inputs=processed_inputs,
cancel_token=cancel_token,
integration=self.integration,
)

return response
7 changes: 6 additions & 1 deletion decart/process/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ..types import FileInput
from ..models import ModelDefinition
from ..errors import InvalidInputError, ProcessingError
from .._user_agent import build_user_agent


async def file_input_to_bytes(
Expand Down Expand Up @@ -82,6 +83,7 @@ async def send_request(
model: ModelDefinition,
inputs: dict[str, Any],
cancel_token: Optional[asyncio.Event] = None,
integration: Optional[str] = None,
) -> bytes:
form_data = aiohttp.FormData()

Expand All @@ -98,7 +100,10 @@ async def send_request(
async def make_request() -> bytes:
async with session.post(
endpoint,
headers={"X-API-KEY": api_key},
headers={
"X-API-KEY": api_key,
"User-Agent": build_user_agent(integration),
},
data=form_data,
) as response:
if not response.ok:
Expand Down
4 changes: 3 additions & 1 deletion decart/realtime/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable
from typing import Callable, Optional
import logging
import uuid
from aiortc import MediaStreamTrack
Expand All @@ -25,6 +25,7 @@ async def connect(
api_key: str,
local_track: MediaStreamTrack,
options: RealtimeConnectOptions,
integration: Optional[str] = None,
) -> "RealtimeClient":
session_id = str(uuid.uuid4())
ws_url = f"{base_url}{options.model.url_path}"
Expand All @@ -40,6 +41,7 @@ async def connect(
on_error=None,
initial_state=options.initial_state,
customize_offer=options.customize_offer,
integration=integration,
)

manager = WebRTCManager(config)
Expand Down
15 changes: 14 additions & 1 deletion decart/realtime/webrtc_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import logging
from typing import Optional, Callable
from urllib.parse import quote
import aiohttp
from aiortc import (
RTCPeerConnection,
Expand All @@ -13,6 +14,7 @@
)

from ..errors import WebRTCError
from .._user_agent import build_user_agent
from .messages import (
parse_incoming_message,
message_to_json,
Expand Down Expand Up @@ -46,12 +48,23 @@ def __init__(
self._ws_task: Optional[asyncio.Task] = None
self._ice_candidates_queue: list[RTCIceCandidate] = []

async def connect(self, url: str, local_track: MediaStreamTrack, timeout: float = 30) -> None:
async def connect(
self,
url: str,
local_track: MediaStreamTrack,
timeout: float = 30,
integration: Optional[str] = None,
) -> None:
try:
await self._set_state("connecting")

ws_url = url.replace("https://", "wss://").replace("http://", "ws://")

# Add user agent as query parameter (browsers don't support WS headers)
user_agent = build_user_agent(integration)
separator = "&" if "?" in ws_url else "?"
ws_url = f"{ws_url}{separator}user_agent={quote(user_agent)}"

self._session = aiohttp.ClientSession()
self._ws = await self._session.ws_connect(ws_url)

Expand Down
2 changes: 2 additions & 0 deletions decart/realtime/webrtc_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class WebRTCConfiguration:
on_error: Optional[Callable[[Exception], None]] = None
initial_state: Optional[ModelState] = None
customize_offer: Optional[Callable] = None
integration: Optional[str] = None


def _is_retryable_error(exception: Exception) -> bool:
Expand All @@ -55,6 +56,7 @@ async def connect(self, local_track: MediaStreamTrack) -> bool:
await self._connection.connect(
url=self._config.webrtc_url,
local_track=local_track,
integration=self._config.integration,
)
return True
except Exception as e:
Expand Down
74 changes: 74 additions & 0 deletions tests/test_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,77 @@ async def test_process_with_cancellation() -> None:
"cancel_token": cancel_token,
}
)


@pytest.mark.asyncio
async def test_process_includes_user_agent_header() -> None:
"""Test that User-Agent header is included in requests."""
client = DecartClient(api_key="test-key")

with patch("aiohttp.ClientSession") as mock_session_cls:
mock_response = MagicMock()
mock_response.ok = True
mock_response.read = AsyncMock(return_value=b"fake video data")

mock_session = MagicMock()
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
mock_session.__aexit__ = AsyncMock(return_value=None)
mock_session.post = MagicMock()
mock_session.post.return_value.__aenter__ = AsyncMock(return_value=mock_response)
mock_session.post.return_value.__aexit__ = AsyncMock(return_value=None)

mock_session_cls.return_value = mock_session

await client.process(
{
"model": models.video("lucy-pro-t2v"),
"prompt": "Test prompt",
}
)

# Verify post was called with User-Agent header
mock_session.post.assert_called_once()
call_kwargs = mock_session.post.call_args[1]
headers = call_kwargs.get("headers", {})

assert "User-Agent" in headers
assert headers["User-Agent"].startswith("decart-python-sdk/")
assert "lang/py" in headers["User-Agent"]


@pytest.mark.asyncio
async def test_process_includes_integration_in_user_agent() -> None:
"""Test that integration parameter is included in User-Agent header."""
client = DecartClient(api_key="test-key", integration="langchain/0.1.0")

with patch("aiohttp.ClientSession") as mock_session_cls:
mock_response = MagicMock()
mock_response.ok = True
mock_response.read = AsyncMock(return_value=b"fake video data")

mock_session = MagicMock()
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
mock_session.__aexit__ = AsyncMock(return_value=None)
mock_session.post = MagicMock()
mock_session.post.return_value.__aenter__ = AsyncMock(return_value=mock_response)
mock_session.post.return_value.__aexit__ = AsyncMock(return_value=None)

mock_session_cls.return_value = mock_session

await client.process(
{
"model": models.video("lucy-pro-t2v"),
"prompt": "Test prompt",
}
)

# Verify post was called with User-Agent header including integration
mock_session.post.assert_called_once()
call_kwargs = mock_session.post.call_args[1]
headers = call_kwargs.get("headers", {})

assert "User-Agent" in headers
assert headers["User-Agent"].startswith("decart-python-sdk/")
assert "lang/py" in headers["User-Agent"]
assert "langchain/0.1.0" in headers["User-Agent"]
assert headers["User-Agent"].endswith(" langchain/0.1.0")
27 changes: 27 additions & 0 deletions tests/test_user_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""Tests for User-Agent header construction."""

from decart._user_agent import build_user_agent
from decart._version import __version__


def test_build_user_agent_without_integration():
"""Test User-Agent without integration parameter."""
user_agent = build_user_agent()

assert user_agent == f"decart-python-sdk/{__version__} lang/py"
assert user_agent.startswith("decart-python-sdk/")
assert "lang/py" in user_agent


def test_build_user_agent_with_integration():
"""Test User-Agent with integration parameter."""
user_agent = build_user_agent("langchain/0.1.0")

expected = f"decart-python-sdk/{__version__} lang/py langchain/0.1.0"
assert user_agent == expected

parts = user_agent.split(" ")
assert len(parts) == 3
assert parts[0].startswith("decart-python-sdk/")
assert parts[1] == "lang/py"
assert parts[2] == "langchain/0.1.0"
Loading