Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions decart/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@
from .realtime import (
RealtimeClient,
SetInput,
SubscribeClient,
SubscribeOptions,
encode_subscribe_token,
decode_subscribe_token,
RealtimeConnectOptions,
ConnectionState,
AvatarOptions,
Expand All @@ -40,6 +44,10 @@
REALTIME_AVAILABLE = False
RealtimeClient = None # type: ignore
SetInput = None # type: ignore
SubscribeClient = None # type: ignore
SubscribeOptions = None # type: ignore
encode_subscribe_token = None # type: ignore
decode_subscribe_token = None # type: ignore
RealtimeConnectOptions = None # type: ignore
ConnectionState = None # type: ignore
AvatarOptions = None # type: ignore
Expand Down Expand Up @@ -79,6 +87,10 @@
[
"RealtimeClient",
"SetInput",
"SubscribeClient",
"SubscribeOptions",
"encode_subscribe_token",
"decode_subscribe_token",
"RealtimeConnectOptions",
"ConnectionState",
"AvatarOptions",
Expand Down
10 changes: 10 additions & 0 deletions decart/realtime/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
from .client import RealtimeClient, SetInput
from .subscribe import (
SubscribeClient,
SubscribeOptions,
encode_subscribe_token,
decode_subscribe_token,
)
from .types import RealtimeConnectOptions, ConnectionState, AvatarOptions

__all__ = [
"RealtimeClient",
"SetInput",
"SubscribeClient",
"SubscribeOptions",
"encode_subscribe_token",
"decode_subscribe_token",
"RealtimeConnectOptions",
"ConnectionState",
"AvatarOptions",
Expand Down
172 changes: 143 additions & 29 deletions decart/realtime/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,20 @@
import asyncio
import base64
import logging
import uuid
from pathlib import Path
from urllib.parse import urlparse
from urllib.parse import urlparse, quote
import aiohttp
from aiortc import MediaStreamTrack
from pydantic import BaseModel

from .webrtc_manager import WebRTCManager, WebRTCConfiguration
from .messages import PromptMessage
from .messages import PromptMessage, SessionIdMessage
from .subscribe import (
SubscribeClient,
SubscribeOptions,
encode_subscribe_token,
decode_subscribe_token,
)
from .types import ConnectionState, RealtimeConnectOptions
from ..types import FileInput
from ..errors import DecartSDKError, InvalidInputError, WebRTCError
Expand Down Expand Up @@ -51,26 +56,41 @@ async def _image_to_base64(
image_bytes, _ = await file_input_to_bytes(image, http_session)
return base64.b64encode(image_bytes).decode("utf-8")

return image

image_bytes, _ = await file_input_to_bytes(image, http_session)
return base64.b64encode(image_bytes).decode("utf-8")
raise InvalidInputError(
"Invalid image input: string is not a data URI, URL, or valid file path"
)


class RealtimeClient:
def __init__(
self,
manager: WebRTCManager,
session_id: str,
http_session: Optional[aiohttp.ClientSession] = None,
is_avatar_live: bool = False,
):
self._manager = manager
self.session_id = session_id
self._http_session = http_session
self._is_avatar_live = is_avatar_live
self._connection_callbacks: list[Callable[[ConnectionState], None]] = []
self._error_callbacks: list[Callable[[DecartSDKError], None]] = []
self._session_id: Optional[str] = None
self._subscribe_token: Optional[str] = None
self._buffering = True
self._buffer: list[tuple[str, object]] = []

@property
def session_id(self) -> Optional[str]:
return self._session_id

@property
def subscribe_token(self) -> Optional[str]:
return self._subscribe_token

def _handle_session_id(self, msg: SessionIdMessage) -> None:
self._session_id = msg.session_id
self._subscribe_token = encode_subscribe_token(
msg.session_id, msg.server_ip, msg.server_port
)

@classmethod
async def connect(
Expand All @@ -81,20 +101,20 @@ async def connect(
options: RealtimeConnectOptions,
integration: Optional[str] = None,
) -> "RealtimeClient":
session_id = str(uuid.uuid4())
ws_url = f"{base_url}{options.model.url_path}"
ws_url += f"?api_key={api_key}&model={options.model.name}"
ws_url += f"?api_key={quote(api_key)}&model={quote(options.model.name)}"

is_avatar_live = options.model.name == "avatar-live"

config = WebRTCConfiguration(
webrtc_url=ws_url,
api_key=api_key,
session_id=session_id,
session_id="",
fps=options.model.fps,
on_remote_stream=options.on_remote_stream,
on_connection_state_change=None,
on_error=None,
on_session_id=None,
initial_state=options.initial_state,
customize_offer=options.customize_offer,
integration=integration,
Expand All @@ -107,13 +127,13 @@ async def connect(
manager = WebRTCManager(config)
client = cls(
manager=manager,
session_id=session_id,
http_session=http_session,
is_avatar_live=is_avatar_live,
)

config.on_connection_state_change = client._emit_connection_change
config.on_error = lambda error: client._emit_error(WebRTCError(str(error), cause=error))
config.on_session_id = client._handle_session_id

try:
# For avatar-live, convert and send avatar image before WebRTC connection
Expand Down Expand Up @@ -143,28 +163,97 @@ async def connect(
if options.initial_state.prompt:
await client.set_prompt(
options.initial_state.prompt.text,
enrich=options.initial_state.prompt.enrich,
enhance=options.initial_state.prompt.enhance,
)
except Exception as e:
await manager.cleanup()
await http_session.close()
raise WebRTCError(str(e), cause=e)

client._flush()
return client

def _emit_connection_change(self, state: ConnectionState) -> None:
for callback in self._connection_callbacks:
@classmethod
async def subscribe(
cls,
base_url: str,
api_key: str,
options: SubscribeOptions,
integration: Optional[str] = None,
) -> SubscribeClient:
token_data = decode_subscribe_token(options.token)
subscribe_url = (
f"{base_url}/subscribe/{quote(token_data.sid)}"
f"?IP={quote(token_data.ip)}"
f"&port={quote(str(token_data.port))}"
f"&api_key={quote(api_key)}"
)

config = WebRTCConfiguration(
webrtc_url=subscribe_url,
api_key=api_key,
session_id=token_data.sid,
fps=0,
on_remote_stream=options.on_remote_stream,
on_connection_state_change=None,
on_error=None,
integration=integration,
)

manager = WebRTCManager(config)
sub_client = SubscribeClient(manager)

config.on_connection_state_change = sub_client._emit_connection_change
config.on_error = sub_client._emit_error

try:
await manager.connect(None)
except Exception as e:
await manager.cleanup()
raise WebRTCError(str(e), cause=e)

sub_client._flush()
return sub_client

def _flush(self) -> None:
# Defer to next tick so caller can register handlers before buffered events fire
asyncio.get_running_loop().call_soon(self._do_flush)

def _do_flush(self) -> None:
self._buffering = False
for event, data in self._buffer:
if event == "connection_change":
self._dispatch_connection_change(data) # type: ignore[arg-type]
elif event == "error":
self._dispatch_error(data) # type: ignore[arg-type]
self._buffer.clear()

def _dispatch_connection_change(self, state: ConnectionState) -> None:
for callback in list(self._connection_callbacks):
try:
callback(state)
except Exception as e:
logger.exception(f"Error in connection_change callback: {e}")

def _emit_error(self, error: DecartSDKError) -> None:
for callback in self._error_callbacks:
def _dispatch_error(self, error: DecartSDKError) -> None:
for callback in list(self._error_callbacks):
try:
callback(error)
except Exception as e:
logger.exception(f"Error in error callback: {e}")

def _emit_connection_change(self, state: ConnectionState) -> None:
if self._buffering:
self._buffer.append(("connection_change", state))
else:
self._dispatch_connection_change(state)

def _emit_error(self, error: DecartSDKError) -> None:
if self._buffering:
self._buffer.append(("error", error))
else:
self._dispatch_error(error)

async def set(self, input: SetInput) -> None:
if input.prompt is None and input.image is None:
raise InvalidInputError("At least one of 'prompt' or 'image' must be provided")
Expand All @@ -187,15 +276,29 @@ async def set(self, input: SetInput) -> None:
},
)

async def set_prompt(self, prompt: str, enrich: bool = True) -> None:
async def set_prompt(
self,
prompt: str,
enhance: bool = True,
enrich: Optional[bool] = None,
) -> None:
if enrich is not None:
import warnings

warnings.warn(
"set_prompt(enrich=...) is deprecated, use set_prompt(enhance=...) instead",
DeprecationWarning,
stacklevel=2,
)
enhance = enrich
if not prompt or not prompt.strip():
raise InvalidInputError("Prompt cannot be empty")

event, result = self._manager.register_prompt_wait(prompt)

try:
await self._manager.send_message(
PromptMessage(type="prompt", prompt=prompt, enhance_prompt=enrich)
PromptMessage(type="prompt", prompt=prompt, enhance_prompt=enhance)
)

try:
Expand All @@ -208,17 +311,26 @@ async def set_prompt(self, prompt: str, enrich: bool = True) -> None:
finally:
self._manager.unregister_prompt_wait(prompt)

async def set_image(self, image: FileInput) -> None:
if not self._is_avatar_live:
raise InvalidInputError("set_image() is only available for avatar-live model")

if not self._http_session:
raise InvalidInputError("HTTP session not available")
async def set_image(
self,
image: Optional[FileInput],
prompt: Optional[str] = None,
enhance: bool = True,
timeout: float = UPDATE_TIMEOUT_S,
) -> None:
image_base64: Optional[str] = None
if image is not None:
if not self._http_session:
raise InvalidInputError("HTTP session not available")
image_bytes, _ = await file_input_to_bytes(image, self._http_session)
image_base64 = base64.b64encode(image_bytes).decode("utf-8")

image_bytes, _ = await file_input_to_bytes(image, self._http_session)
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
opts: dict = {"timeout": timeout}
if prompt is not None:
opts["prompt"] = prompt
opts["enhance"] = enhance

await self._manager.set_image(image_base64)
await self._manager.set_image(image_base64, opts)

def is_connected(self) -> bool:
return self._manager.is_connected()
Expand All @@ -227,6 +339,8 @@ def get_connection_state(self) -> ConnectionState:
return self._manager.get_connection_state()

async def disconnect(self) -> None:
self._buffering = False
self._buffer.clear()
await self._manager.cleanup()
if self._http_session and not self._http_session.closed:
await self._http_session.close()
Expand Down
7 changes: 7 additions & 0 deletions decart/realtime/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ class IceRestartMessage(BaseModel):
turn_config: TurnConfig


class GenerationStartedMessage(BaseModel):
"""Server signals that generation has started."""

type: Literal["generation_started"]


# Discriminated union for incoming messages
IncomingMessage = Annotated[
Union[
Expand All @@ -98,6 +104,7 @@ class IceRestartMessage(BaseModel):
ErrorMessage,
ReadyMessage,
IceRestartMessage,
GenerationStartedMessage,
],
Field(discriminator="type"),
]
Expand Down
Loading
Loading