Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions decart/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
try:
from .realtime import (
RealtimeClient,
SetInput,
RealtimeConnectOptions,
ConnectionState,
AvatarOptions,
Expand All @@ -38,6 +39,7 @@
except ImportError:
REALTIME_AVAILABLE = False
RealtimeClient = None # type: ignore
SetInput = None # type: ignore
RealtimeConnectOptions = None # type: ignore
ConnectionState = None # type: ignore
AvatarOptions = None # type: ignore
Expand Down Expand Up @@ -76,6 +78,7 @@
__all__.extend(
[
"RealtimeClient",
"SetInput",
"RealtimeConnectOptions",
"ConnectionState",
"AvatarOptions",
Expand Down
3 changes: 2 additions & 1 deletion decart/realtime/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from .client import RealtimeClient
from .client import RealtimeClient, SetInput
from .types import RealtimeConnectOptions, ConnectionState, AvatarOptions

__all__ = [
"RealtimeClient",
"SetInput",
"RealtimeConnectOptions",
"ConnectionState",
"AvatarOptions",
Expand Down
98 changes: 67 additions & 31 deletions decart/realtime/client.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,61 @@
from typing import Callable, Optional
from typing import Callable, Optional, Union
import asyncio
import base64
import logging
import uuid
from pathlib import Path
from urllib.parse import urlparse
import aiohttp
from aiortc import MediaStreamTrack
from pydantic import BaseModel

from .webrtc_manager import WebRTCManager, WebRTCConfiguration
from .messages import PromptMessage, SetAvatarImageMessage
from .messages import PromptMessage
from .types import ConnectionState, RealtimeConnectOptions
from ..types import FileInput
from ..errors import DecartSDKError, InvalidInputError, WebRTCError
from ..process.request import file_input_to_bytes

logger = logging.getLogger(__name__)

PROMPT_TIMEOUT_S = 15.0
UPDATE_TIMEOUT_S = 30.0


class SetInput(BaseModel):
prompt: Optional[str] = None
enhance: bool = True
image: Optional[Union[bytes, str]] = None


async def _image_to_base64(
image: Union[bytes, str],
http_session: aiohttp.ClientSession,
) -> str:
if isinstance(image, bytes):
return base64.b64encode(image).decode("utf-8")

if isinstance(image, str):
parsed = urlparse(image)

if parsed.scheme == "data":
return image.split(",", 1)[1]

if parsed.scheme in ("http", "https"):
async with http_session.get(image) as resp:
resp.raise_for_status()
data = await resp.read()
return base64.b64encode(data).decode("utf-8")

if Path(image).exists():
image_bytes, _ = await file_input_to_bytes(image, http_session)
return base64.b64encode(image_bytes).decode("utf-8")

return image

image_bytes, _ = await file_input_to_bytes(image, http_session)
return base64.b64encode(image_bytes).decode("utf-8")


class RealtimeClient:
def __init__(
Expand Down Expand Up @@ -124,6 +165,28 @@ def _emit_error(self, error: DecartSDKError) -> None:
except Exception as e:
logger.exception(f"Error in error callback: {e}")

async def set(self, input: SetInput) -> None:
if input.prompt is None and input.image is None:
raise InvalidInputError("At least one of 'prompt' or 'image' must be provided")

if input.prompt is not None and not input.prompt.strip():
raise InvalidInputError("Prompt cannot be empty")

image_base64: Optional[str] = None
if input.image is not None:
if not self._http_session:
raise InvalidInputError("HTTP session not available")
image_base64 = await _image_to_base64(input.image, self._http_session)

await self._manager.set_image(
image_base64,
{
"prompt": input.prompt,
"enhance": input.enhance,
"timeout": UPDATE_TIMEOUT_S,
},
)

async def set_prompt(self, prompt: str, enrich: bool = True) -> None:
if not prompt or not prompt.strip():
raise InvalidInputError("Prompt cannot be empty")
Expand All @@ -136,7 +199,7 @@ async def set_prompt(self, prompt: str, enrich: bool = True) -> None:
)

try:
await asyncio.wait_for(event.wait(), timeout=15.0)
await asyncio.wait_for(event.wait(), timeout=PROMPT_TIMEOUT_S)
except asyncio.TimeoutError:
raise DecartSDKError("Prompt acknowledgment timed out")

Expand All @@ -146,43 +209,16 @@ async def set_prompt(self, prompt: str, enrich: bool = True) -> None:
self._manager.unregister_prompt_wait(prompt)

async def set_image(self, image: FileInput) -> None:
"""Set or update the avatar image.

Only available for avatar-live model.

Args:
image: The image to set. Can be bytes, Path, URL string, or file-like object.

Raises:
InvalidInputError: If not using avatar-live model or image is invalid.
DecartSDKError: If the server fails to acknowledge the image.
"""
if not self._is_avatar_live:
raise InvalidInputError("set_image() is only available for avatar-live model")

if not self._http_session:
raise InvalidInputError("HTTP session not available")

# Convert image to base64
image_bytes, _ = await file_input_to_bytes(image, self._http_session)
image_base64 = base64.b64encode(image_bytes).decode("utf-8")

event, result = self._manager.register_image_set_wait()

try:
await self._manager.send_message(
SetAvatarImageMessage(type="set_image", image_data=image_base64)
)

try:
await asyncio.wait_for(event.wait(), timeout=15.0)
except asyncio.TimeoutError:
raise DecartSDKError("Image set acknowledgment timed out")

if not result["success"]:
raise DecartSDKError(result.get("error") or "Failed to set avatar image")
finally:
self._manager.unregister_image_set_wait()
await self._manager.set_image(image_base64)

def is_connected(self) -> bool:
return self._manager.is_connected()
Expand Down
6 changes: 4 additions & 2 deletions decart/realtime/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,9 @@ class SetAvatarImageMessage(BaseModel):
"""Set avatar image message."""

type: Literal["set_image"]
image_data: str # Base64-encoded image
image_data: Optional[str] = None
prompt: Optional[str] = None
enhance_prompt: Optional[bool] = None


# Outgoing message union (no discriminator needed - we know what we're sending)
Expand Down Expand Up @@ -161,4 +163,4 @@ def message_to_json(message: OutgoingMessage) -> str:
Returns:
JSON string
"""
return message.model_dump_json()
return message.model_dump_json(exclude_none=True)
4 changes: 2 additions & 2 deletions decart/realtime/webrtc_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ async def connect(
self,
url: str,
local_track: Optional[MediaStreamTrack],
timeout: float = 30,
timeout: float,
integration: Optional[str] = None,
is_avatar_live: bool = False,
avatar_image_base64: Optional[str] = None,
Expand Down Expand Up @@ -107,7 +107,7 @@ async def connect(
self._on_error(e)
raise WebRTCError(str(e), cause=e)

async def _send_avatar_image_and_wait(self, image_base64: str, timeout: float = 15.0) -> None:
async def _send_avatar_image_and_wait(self, image_base64: str, timeout: float = 30.0) -> None:
"""Send avatar image and wait for acknowledgment."""
event, result = self.register_image_set_wait()

Expand Down
40 changes: 40 additions & 0 deletions decart/realtime/webrtc_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,11 @@ async def connect(
initial_prompt: Optional[dict] = None,
) -> bool:
try:
timeout = 60 * 5 # 5 minutes
await self._connection.connect(
url=self._config.webrtc_url,
local_track=local_track,
timeout=timeout,
integration=self._config.integration,
is_avatar_live=self._config.is_avatar_live,
avatar_image_base64=avatar_image_base64,
Expand All @@ -83,6 +85,44 @@ def _create_connection(self) -> WebRTCConnection:
customize_offer=self._config.customize_offer,
)

async def set_image(
self,
image_base64: Optional[str],
options: Optional[dict] = None,
) -> None:
from .messages import SetAvatarImageMessage

opts = options or {}
timeout = opts.get("timeout", 30.0)

event, result = self._connection.register_image_set_wait()

try:
message = SetAvatarImageMessage(
type="set_image",
image_data=image_base64,
)
if opts.get("prompt") is not None:
message.prompt = opts["prompt"]
if opts.get("enhance") is not None:
message.enhance_prompt = opts["enhance"]

await self._connection.send(message)

try:
await asyncio.wait_for(event.wait(), timeout=timeout)
except asyncio.TimeoutError:
from ..errors import DecartSDKError

raise DecartSDKError("Image send timed out")

if not result["success"]:
from ..errors import DecartSDKError

raise DecartSDKError(result.get("error") or "Failed to set image")
finally:
self._connection.unregister_image_set_wait()

async def send_message(self, message: OutgoingMessage) -> None:
await self._connection.send(message)

Expand Down
Loading
Loading