diff --git a/decart/__init__.py b/decart/__init__.py index f4e7260..e6ec580 100644 --- a/decart/__init__.py +++ b/decart/__init__.py @@ -30,6 +30,10 @@ from .realtime import ( RealtimeClient, SetInput, + SubscribeClient, + SubscribeOptions, + encode_subscribe_token, + decode_subscribe_token, RealtimeConnectOptions, ConnectionState, AvatarOptions, @@ -40,6 +44,10 @@ REALTIME_AVAILABLE = False RealtimeClient = None # type: ignore SetInput = None # type: ignore + SubscribeClient = None # type: ignore + SubscribeOptions = None # type: ignore + encode_subscribe_token = None # type: ignore + decode_subscribe_token = None # type: ignore RealtimeConnectOptions = None # type: ignore ConnectionState = None # type: ignore AvatarOptions = None # type: ignore @@ -79,6 +87,10 @@ [ "RealtimeClient", "SetInput", + "SubscribeClient", + "SubscribeOptions", + "encode_subscribe_token", + "decode_subscribe_token", "RealtimeConnectOptions", "ConnectionState", "AvatarOptions", diff --git a/decart/realtime/__init__.py b/decart/realtime/__init__.py index 0a98cab..7488eb6 100644 --- a/decart/realtime/__init__.py +++ b/decart/realtime/__init__.py @@ -1,9 +1,19 @@ from .client import RealtimeClient, SetInput +from .subscribe import ( + SubscribeClient, + SubscribeOptions, + encode_subscribe_token, + decode_subscribe_token, +) from .types import RealtimeConnectOptions, ConnectionState, AvatarOptions __all__ = [ "RealtimeClient", "SetInput", + "SubscribeClient", + "SubscribeOptions", + "encode_subscribe_token", + "decode_subscribe_token", "RealtimeConnectOptions", "ConnectionState", "AvatarOptions", diff --git a/decart/realtime/client.py b/decart/realtime/client.py index 621d8d8..3f338a8 100644 --- a/decart/realtime/client.py +++ b/decart/realtime/client.py @@ -2,15 +2,20 @@ import asyncio import base64 import logging -import uuid from pathlib import Path -from urllib.parse import urlparse +from urllib.parse import urlparse, quote import aiohttp from aiortc import MediaStreamTrack from pydantic import BaseModel from .webrtc_manager import WebRTCManager, WebRTCConfiguration -from .messages import PromptMessage +from .messages import PromptMessage, SessionIdMessage +from .subscribe import ( + SubscribeClient, + SubscribeOptions, + encode_subscribe_token, + decode_subscribe_token, +) from .types import ConnectionState, RealtimeConnectOptions from ..types import FileInput from ..errors import DecartSDKError, InvalidInputError, WebRTCError @@ -51,26 +56,41 @@ async def _image_to_base64( image_bytes, _ = await file_input_to_bytes(image, http_session) return base64.b64encode(image_bytes).decode("utf-8") - return image - - image_bytes, _ = await file_input_to_bytes(image, http_session) - return base64.b64encode(image_bytes).decode("utf-8") + raise InvalidInputError( + "Invalid image input: string is not a data URI, URL, or valid file path" + ) class RealtimeClient: def __init__( self, manager: WebRTCManager, - session_id: str, http_session: Optional[aiohttp.ClientSession] = None, is_avatar_live: bool = False, ): self._manager = manager - self.session_id = session_id self._http_session = http_session self._is_avatar_live = is_avatar_live self._connection_callbacks: list[Callable[[ConnectionState], None]] = [] self._error_callbacks: list[Callable[[DecartSDKError], None]] = [] + self._session_id: Optional[str] = None + self._subscribe_token: Optional[str] = None + self._buffering = True + self._buffer: list[tuple[str, object]] = [] + + @property + def session_id(self) -> Optional[str]: + return self._session_id + + @property + def subscribe_token(self) -> Optional[str]: + return self._subscribe_token + + def _handle_session_id(self, msg: SessionIdMessage) -> None: + self._session_id = msg.session_id + self._subscribe_token = encode_subscribe_token( + msg.session_id, msg.server_ip, msg.server_port + ) @classmethod async def connect( @@ -81,20 +101,20 @@ async def connect( options: RealtimeConnectOptions, integration: Optional[str] = None, ) -> "RealtimeClient": - session_id = str(uuid.uuid4()) ws_url = f"{base_url}{options.model.url_path}" - ws_url += f"?api_key={api_key}&model={options.model.name}" + ws_url += f"?api_key={quote(api_key)}&model={quote(options.model.name)}" is_avatar_live = options.model.name == "avatar-live" config = WebRTCConfiguration( webrtc_url=ws_url, api_key=api_key, - session_id=session_id, + session_id="", fps=options.model.fps, on_remote_stream=options.on_remote_stream, on_connection_state_change=None, on_error=None, + on_session_id=None, initial_state=options.initial_state, customize_offer=options.customize_offer, integration=integration, @@ -107,13 +127,13 @@ async def connect( manager = WebRTCManager(config) client = cls( manager=manager, - session_id=session_id, http_session=http_session, is_avatar_live=is_avatar_live, ) config.on_connection_state_change = client._emit_connection_change config.on_error = lambda error: client._emit_error(WebRTCError(str(error), cause=error)) + config.on_session_id = client._handle_session_id try: # For avatar-live, convert and send avatar image before WebRTC connection @@ -143,28 +163,97 @@ async def connect( if options.initial_state.prompt: await client.set_prompt( options.initial_state.prompt.text, - enrich=options.initial_state.prompt.enrich, + enhance=options.initial_state.prompt.enhance, ) except Exception as e: + await manager.cleanup() await http_session.close() raise WebRTCError(str(e), cause=e) + client._flush() return client - def _emit_connection_change(self, state: ConnectionState) -> None: - for callback in self._connection_callbacks: + @classmethod + async def subscribe( + cls, + base_url: str, + api_key: str, + options: SubscribeOptions, + integration: Optional[str] = None, + ) -> SubscribeClient: + token_data = decode_subscribe_token(options.token) + subscribe_url = ( + f"{base_url}/subscribe/{quote(token_data.sid)}" + f"?IP={quote(token_data.ip)}" + f"&port={quote(str(token_data.port))}" + f"&api_key={quote(api_key)}" + ) + + config = WebRTCConfiguration( + webrtc_url=subscribe_url, + api_key=api_key, + session_id=token_data.sid, + fps=0, + on_remote_stream=options.on_remote_stream, + on_connection_state_change=None, + on_error=None, + integration=integration, + ) + + manager = WebRTCManager(config) + sub_client = SubscribeClient(manager) + + config.on_connection_state_change = sub_client._emit_connection_change + config.on_error = sub_client._emit_error + + try: + await manager.connect(None) + except Exception as e: + await manager.cleanup() + raise WebRTCError(str(e), cause=e) + + sub_client._flush() + return sub_client + + def _flush(self) -> None: + # Defer to next tick so caller can register handlers before buffered events fire + asyncio.get_running_loop().call_soon(self._do_flush) + + def _do_flush(self) -> None: + self._buffering = False + for event, data in self._buffer: + if event == "connection_change": + self._dispatch_connection_change(data) # type: ignore[arg-type] + elif event == "error": + self._dispatch_error(data) # type: ignore[arg-type] + self._buffer.clear() + + def _dispatch_connection_change(self, state: ConnectionState) -> None: + for callback in list(self._connection_callbacks): try: callback(state) except Exception as e: logger.exception(f"Error in connection_change callback: {e}") - def _emit_error(self, error: DecartSDKError) -> None: - for callback in self._error_callbacks: + def _dispatch_error(self, error: DecartSDKError) -> None: + for callback in list(self._error_callbacks): try: callback(error) except Exception as e: logger.exception(f"Error in error callback: {e}") + def _emit_connection_change(self, state: ConnectionState) -> None: + if self._buffering: + self._buffer.append(("connection_change", state)) + else: + self._dispatch_connection_change(state) + + def _emit_error(self, error: DecartSDKError) -> None: + if self._buffering: + self._buffer.append(("error", error)) + else: + self._dispatch_error(error) + async def set(self, input: SetInput) -> None: if input.prompt is None and input.image is None: raise InvalidInputError("At least one of 'prompt' or 'image' must be provided") @@ -187,7 +276,21 @@ async def set(self, input: SetInput) -> None: }, ) - async def set_prompt(self, prompt: str, enrich: bool = True) -> None: + async def set_prompt( + self, + prompt: str, + enhance: bool = True, + enrich: Optional[bool] = None, + ) -> None: + if enrich is not None: + import warnings + + warnings.warn( + "set_prompt(enrich=...) is deprecated, use set_prompt(enhance=...) instead", + DeprecationWarning, + stacklevel=2, + ) + enhance = enrich if not prompt or not prompt.strip(): raise InvalidInputError("Prompt cannot be empty") @@ -195,7 +298,7 @@ async def set_prompt(self, prompt: str, enrich: bool = True) -> None: try: await self._manager.send_message( - PromptMessage(type="prompt", prompt=prompt, enhance_prompt=enrich) + PromptMessage(type="prompt", prompt=prompt, enhance_prompt=enhance) ) try: @@ -208,17 +311,26 @@ async def set_prompt(self, prompt: str, enrich: bool = True) -> None: finally: self._manager.unregister_prompt_wait(prompt) - async def set_image(self, image: FileInput) -> None: - if not self._is_avatar_live: - raise InvalidInputError("set_image() is only available for avatar-live model") - - if not self._http_session: - raise InvalidInputError("HTTP session not available") + async def set_image( + self, + image: Optional[FileInput], + prompt: Optional[str] = None, + enhance: bool = True, + timeout: float = UPDATE_TIMEOUT_S, + ) -> None: + image_base64: Optional[str] = None + if image is not None: + if not self._http_session: + raise InvalidInputError("HTTP session not available") + image_bytes, _ = await file_input_to_bytes(image, self._http_session) + image_base64 = base64.b64encode(image_bytes).decode("utf-8") - image_bytes, _ = await file_input_to_bytes(image, self._http_session) - image_base64 = base64.b64encode(image_bytes).decode("utf-8") + opts: dict = {"timeout": timeout} + if prompt is not None: + opts["prompt"] = prompt + opts["enhance"] = enhance - await self._manager.set_image(image_base64) + await self._manager.set_image(image_base64, opts) def is_connected(self) -> bool: return self._manager.is_connected() @@ -227,6 +339,8 @@ def get_connection_state(self) -> ConnectionState: return self._manager.get_connection_state() async def disconnect(self) -> None: + self._buffering = False + self._buffer.clear() await self._manager.cleanup() if self._http_session and not self._http_session.closed: await self._http_session.close() diff --git a/decart/realtime/messages.py b/decart/realtime/messages.py index 1e201e6..f3c3dc6 100644 --- a/decart/realtime/messages.py +++ b/decart/realtime/messages.py @@ -87,6 +87,12 @@ class IceRestartMessage(BaseModel): turn_config: TurnConfig +class GenerationStartedMessage(BaseModel): + """Server signals that generation has started.""" + + type: Literal["generation_started"] + + # Discriminated union for incoming messages IncomingMessage = Annotated[ Union[ @@ -98,6 +104,7 @@ class IceRestartMessage(BaseModel): ErrorMessage, ReadyMessage, IceRestartMessage, + GenerationStartedMessage, ], Field(discriminator="type"), ] diff --git a/decart/realtime/subscribe.py b/decart/realtime/subscribe.py new file mode 100644 index 0000000..ae64cf7 --- /dev/null +++ b/decart/realtime/subscribe.py @@ -0,0 +1,122 @@ +from __future__ import annotations + +import asyncio +import base64 +import json +import logging +from typing import TYPE_CHECKING, Callable +from dataclasses import dataclass + +from .types import ConnectionState + +if TYPE_CHECKING: + from aiortc import MediaStreamTrack + from .webrtc_manager import WebRTCManager + +logger = logging.getLogger(__name__) + + +@dataclass +class TokenPayload: + sid: str + ip: str + port: int + + +def encode_subscribe_token(session_id: str, server_ip: str, server_port: int) -> str: + payload = json.dumps({"sid": session_id, "ip": server_ip, "port": server_port}) + return base64.urlsafe_b64encode(payload.encode()).decode() + + +def decode_subscribe_token(token: str) -> TokenPayload: + try: + raw = base64.urlsafe_b64decode(token).decode() + data = json.loads(raw) + if not data.get("sid") or not data.get("ip") or not data.get("port"): + raise ValueError("Invalid subscribe token format") + return TokenPayload(sid=data["sid"], ip=data["ip"], port=data["port"]) + except Exception: + raise ValueError("Invalid subscribe token") + + +@dataclass +class SubscribeOptions: + token: str + on_remote_stream: Callable[[MediaStreamTrack], None] + + +class SubscribeClient: + def __init__(self, manager: WebRTCManager): + self._manager = manager + self._connection_callbacks: list[Callable[[ConnectionState], None]] = [] + self._error_callbacks: list[Callable[[Exception], None]] = [] + self._buffering = True + self._buffer: list[tuple[str, object]] = [] + + def _flush(self) -> None: + # Defer to next tick so caller can register handlers before buffered events fire + asyncio.get_running_loop().call_soon(self._do_flush) + + def _do_flush(self) -> None: + self._buffering = False + for event, data in self._buffer: + if event == "connection_change": + self._dispatch_connection_change(data) # type: ignore[arg-type] + elif event == "error": + self._dispatch_error(data) # type: ignore[arg-type] + self._buffer.clear() + + def _dispatch_connection_change(self, state: ConnectionState) -> None: + for callback in list(self._connection_callbacks): + try: + callback(state) + except Exception as e: + logger.exception(f"Error in subscribe connection_change callback: {e}") + + def _dispatch_error(self, error: Exception) -> None: + for callback in list(self._error_callbacks): + try: + callback(error) + except Exception as e: + logger.exception(f"Error in subscribe error callback: {e}") + + def _emit_connection_change(self, state: ConnectionState) -> None: + if self._buffering: + self._buffer.append(("connection_change", state)) + else: + self._dispatch_connection_change(state) + + def _emit_error(self, error: Exception) -> None: + if self._buffering: + self._buffer.append(("error", error)) + else: + self._dispatch_error(error) + + def is_connected(self) -> bool: + return self._manager.is_connected() + + def get_connection_state(self) -> ConnectionState: + return self._manager.get_connection_state() + + async def disconnect(self) -> None: + self._buffering = False + self._buffer.clear() + await self._manager.cleanup() + + def on(self, event: str, callback: Callable) -> None: + if event == "connection_change": + self._connection_callbacks.append(callback) + elif event == "error": + self._error_callbacks.append(callback) + + def off(self, event: str, callback: Callable) -> None: + if event == "connection_change": + try: + self._connection_callbacks.remove(callback) + except ValueError: + pass + elif event == "error": + try: + self._error_callbacks.remove(callback) + except ValueError: + pass diff --git a/decart/realtime/types.py b/decart/realtime/types.py index 51929b6..c388ada 100644 --- a/decart/realtime/types.py +++ b/decart/realtime/types.py @@ -9,7 +9,7 @@ MediaStreamTrack = None # type: ignore -ConnectionState = Literal["connecting", "connected", "disconnected"] +ConnectionState = Literal["connecting", "connected", "generating", "disconnected", "reconnecting"] @dataclass diff --git a/decart/realtime/webrtc_connection.py b/decart/realtime/webrtc_connection.py index e561c6b..a7ebd34 100644 --- a/decart/realtime/webrtc_connection.py +++ b/decart/realtime/webrtc_connection.py @@ -27,6 +27,7 @@ SetAvatarImageMessage, ErrorMessage, IceRestartMessage, + SessionIdMessage, OutgoingMessage, ) from .types import ConnectionState @@ -40,6 +41,7 @@ def __init__( on_remote_stream: Optional[Callable[[MediaStreamTrack], None]] = None, on_state_change: Optional[Callable[[ConnectionState], None]] = None, on_error: Optional[Callable[[Exception], None]] = None, + on_session_id: Optional[Callable[[SessionIdMessage], None]] = None, customize_offer: Optional[Callable] = None, ): self._pc: Optional[RTCPeerConnection] = None @@ -49,11 +51,14 @@ def __init__( self._on_remote_stream = on_remote_stream self._on_state_change = on_state_change self._on_error = on_error + self._on_session_id = on_session_id self._customize_offer = customize_offer self._ws_task: Optional[asyncio.Task] = None self._ice_candidates_queue: list[RTCIceCandidate] = [] self._pending_prompts: dict[str, tuple[asyncio.Event, dict]] = {} self._pending_image_set: Optional[tuple[asyncio.Event, dict]] = None + self._local_track: Optional[MediaStreamTrack] = None + self._is_avatar_live: bool = False async def connect( self, @@ -66,11 +71,13 @@ async def connect( initial_prompt: Optional[dict] = None, ) -> None: try: + self._local_track = local_track + self._is_avatar_live = is_avatar_live + await self._set_state("connecting") ws_url = url.replace("https://", "wss://").replace("http://", "ws://") - # Add user agent as query parameter (browsers don't support WS headers) user_agent = build_user_agent(integration) separator = "&" if "?" in ws_url else "?" ws_url = f"{ws_url}{separator}user_agent={quote(user_agent)}" @@ -80,11 +87,9 @@ async def connect( self._ws_task = asyncio.create_task(self._receive_messages()) - # For avatar-live, send avatar image before WebRTC handshake if is_avatar_live and avatar_image_base64: await self._send_avatar_image_and_wait(avatar_image_base64) - # Send initial prompt before WebRTC handshake (if provided) if initial_prompt: await self._send_initial_prompt_and_wait(initial_prompt) @@ -94,7 +99,7 @@ async def connect( deadline = asyncio.get_event_loop().time() + timeout while asyncio.get_event_loop().time() < deadline: - if self._state == "connected": + if self._state in ("connected", "generating"): return await asyncio.sleep(0.1) @@ -189,18 +194,16 @@ async def on_connection_state_change(): await self._set_state("connected") elif self._pc.connectionState in ["failed", "closed"]: await self._set_state("disconnected") + # Keep "generating" sticky unless actually disconnected (matches JS SDK) @self._pc.on("iceconnectionstatechange") async def on_ice_connection_state_change(): logger.debug(f"ICE connection state: {self._pc.iceConnectionState}") - # For avatar-live, add recv-only video transceiver - if is_avatar_live: + if local_track is None: self._pc.addTransceiver("video", direction="recvonly") - logger.debug("Added video transceiver (recvonly) for avatar-live") - - # Add local audio track if provided - if local_track: + logger.debug("Added video transceiver (recvonly) for receive-only mode") + else: self._pc.addTrack(local_track) logger.debug("Added local track to peer connection") @@ -236,6 +239,9 @@ async def _receive_messages(self) -> None: logger.error(f"WebSocket receive error: {e}") if self._on_error: self._on_error(e) + finally: + # WS loop exited (clean close or error) — signal disconnected so manager can reconnect + await self._set_state("disconnected") async def _handle_message(self, data: dict) -> None: try: @@ -250,10 +256,14 @@ async def _handle_message(self, data: dict) -> None: await self._handle_ice_candidate(message.candidate) elif message.type == "session_id": logger.debug(f"Session ID: {message.session_id}") + if self._on_session_id: + self._on_session_id(message) elif message.type == "prompt_ack": self._handle_prompt_ack(message) elif message.type == "set_image_ack": self._handle_set_image_ack(message) + elif message.type == "generation_started": + await self._set_state("generating") elif message.type == "error": self._handle_error(message) elif message.type == "ready": @@ -320,9 +330,6 @@ async def _handle_ice_restart(self, message: IceRestartMessage) -> None: await self._setup_peer_connection_with_turn(turn_config) async def _setup_peer_connection_with_turn(self, turn_config) -> None: - """Re-setup peer connection with TURN server for ICE restart.""" - from aiortc import RTCConfiguration, RTCIceServer - ice_servers = [ RTCIceServer(urls=["stun:stun.l.google.com:19302"]), RTCIceServer( @@ -333,20 +340,33 @@ async def _setup_peer_connection_with_turn(self, turn_config) -> None: ] config = RTCConfiguration(iceServers=ice_servers) - # Close existing peer connection if self._pc: await self._pc.close() self._pc = RTCPeerConnection(configuration=config) logger.debug("Re-created peer connection with TURN server for ICE restart") - # Re-register callbacks @self._pc.on("track") def on_track(track: MediaStreamTrack): logger.debug(f"Received remote track: {track.kind}") if self._on_remote_stream: self._on_remote_stream(track) + @self._pc.on("icecandidate") + async def on_ice_candidate(candidate: RTCIceCandidate): + if candidate: + logger.debug(f"Local ICE candidate: {candidate.candidate}") + await self._send_message( + IceCandidateMessage( + type="ice-candidate", + candidate=IceCandidatePayload( + candidate=candidate.candidate, + sdpMLineIndex=candidate.sdpMLineIndex or 0, + sdpMid=candidate.sdpMid or "", + ), + ) + ) + @self._pc.on("connectionstatechange") async def on_connection_state_change(): logger.debug(f"Peer connection state: {self._pc.connectionState}") @@ -355,7 +375,13 @@ async def on_connection_state_change(): elif self._pc.connectionState in ["failed", "closed"]: await self._set_state("disconnected") - # Re-create and send offer + if self._local_track is None: + self._pc.addTransceiver("video", direction="recvonly") + logger.debug("Added video transceiver (recvonly) for receive-only ICE restart") + else: + self._pc.addTrack(self._local_track) + logger.debug("Re-added local track to peer connection for ICE restart") + await self._create_and_send_offer() def register_image_set_wait(self) -> tuple[asyncio.Event, dict]: @@ -386,6 +412,8 @@ async def _send_message(self, message: OutgoingMessage) -> None: await self._ws.send_str(msg_json) async def _set_state(self, state: ConnectionState) -> None: + if self._state == "generating" and state not in ("disconnected", "generating"): + return if self._state != state: self._state = state logger.debug(f"Connection state changed to: {state}") diff --git a/decart/realtime/webrtc_manager.py b/decart/realtime/webrtc_manager.py index 20ed852..ad71b0b 100644 --- a/decart/realtime/webrtc_manager.py +++ b/decart/realtime/webrtc_manager.py @@ -12,12 +12,27 @@ ) from .webrtc_connection import WebRTCConnection -from .messages import OutgoingMessage +from .messages import OutgoingMessage, SessionIdMessage from .types import ConnectionState from ..types import ModelState logger = logging.getLogger(__name__) +PERMANENT_ERRORS = [ + "permission denied", + "not allowed", + "invalid session", + "401", + "invalid api key", + "unauthorized", +] + +CONNECTION_TIMEOUT = 60 * 5 # 5 minutes + +RETRY_MAX_ATTEMPTS = 5 +RETRY_MIN_WAIT = 1 +RETRY_MAX_WAIT = 10 + @dataclass class WebRTCConfiguration: @@ -28,27 +43,129 @@ class WebRTCConfiguration: on_remote_stream: Callable[[MediaStreamTrack], None] on_connection_state_change: Optional[Callable[[ConnectionState], None]] = None on_error: Optional[Callable[[Exception], None]] = None + on_session_id: Optional[Callable[[SessionIdMessage], None]] = None initial_state: Optional[ModelState] = None customize_offer: Optional[Callable] = None integration: Optional[str] = None is_avatar_live: bool = False -def _is_retryable_error(exception: Exception) -> bool: - """Check if an error is retryable (not a permanent error).""" - permanent_errors = ["permission denied", "not allowed", "invalid session"] +def _is_permanent_error(exception: BaseException) -> bool: error_msg = str(exception).lower() - return not any(err in error_msg for err in permanent_errors) + return any(err in error_msg for err in PERMANENT_ERRORS) + + +def _is_retryable_error(exception: BaseException) -> bool: + if isinstance(exception, asyncio.CancelledError): + return False + return not _is_permanent_error(exception) class WebRTCManager: def __init__(self, configuration: WebRTCConfiguration): self._config = configuration - self._connection = self._create_connection() + self._connection: Optional[WebRTCConnection] = None + self._local_track: Optional[MediaStreamTrack] = None + self._subscribe_mode = False + self._manager_state: ConnectionState = "disconnected" + self._has_connected = False + self._is_reconnecting = False + self._intentional_disconnect = False + self._reconnect_generation = 0 + self._reconnect_task: Optional[asyncio.Task] = None + + def _get_connection(self) -> WebRTCConnection: + if self._connection is None: + raise RuntimeError("WebRTCManager not connected") + return self._connection + + def _emit_state(self, state: ConnectionState) -> None: + if self._manager_state != state: + self._manager_state = state + if state in ("connected", "generating"): + self._has_connected = True + if self._config.on_connection_state_change: + self._config.on_connection_state_change(state) + + def _handle_connection_state_change(self, state: ConnectionState) -> None: + if self._intentional_disconnect: + self._emit_state("disconnected") + return + + if self._is_reconnecting: + if state in ("connected", "generating"): + self._is_reconnecting = False + self._emit_state(state) + return + + # Unexpected disconnect after having been connected → trigger auto-reconnect + # _has_connected guards against triggering during initial connect (which has its own retry) + if state == "disconnected" and not self._intentional_disconnect and self._has_connected: + self._reconnect_task = asyncio.ensure_future(self._reconnect()) + return + + self._emit_state(state) + + async def _reconnect(self) -> None: + if self._is_reconnecting or self._intentional_disconnect: + return + if not self._subscribe_mode and not self._local_track: + return + + reconnect_generation = self._reconnect_generation + 1 + self._reconnect_generation = reconnect_generation + self._is_reconnecting = True + self._emit_state("reconnecting") + + try: + await self._retry_reconnect(reconnect_generation) + except asyncio.CancelledError: + # Task cancelled or intentional disconnect — don't emit error + pass + except Exception as error: + if self._intentional_disconnect or reconnect_generation != self._reconnect_generation: + return + self._emit_state("disconnected") + if self._config.on_error: + self._config.on_error( + error if isinstance(error, Exception) else Exception(str(error)) + ) + finally: + self._is_reconnecting = False + + async def _retry_reconnect(self, reconnect_generation: int) -> None: + @retry( + stop=stop_after_attempt(RETRY_MAX_ATTEMPTS), + wait=wait_exponential(multiplier=1, min=RETRY_MIN_WAIT, max=RETRY_MAX_WAIT), + retry=retry_if_exception(_is_retryable_error), + before_sleep=before_sleep_log(logger, logging.WARNING), + reraise=True, + ) + async def _attempt(): + if self._intentional_disconnect or reconnect_generation != self._reconnect_generation: + raise asyncio.CancelledError("Reconnect cancelled") + + if self._connection is not None: + await self._connection.cleanup() + conn = self._create_connection() + self._connection = conn + await conn.connect( + url=self._config.webrtc_url, + local_track=self._local_track, + timeout=CONNECTION_TIMEOUT, + integration=self._config.integration, + is_avatar_live=self._config.is_avatar_live, + ) + + if self._intentional_disconnect or reconnect_generation != self._reconnect_generation: + await conn.cleanup() + raise asyncio.CancelledError("Reconnect cancelled") + + await _attempt() @retry( - stop=stop_after_attempt(5), - wait=wait_exponential(multiplier=1, min=1, max=10), + stop=stop_after_attempt(RETRY_MAX_ATTEMPTS), + wait=wait_exponential(multiplier=1, min=RETRY_MIN_WAIT, max=RETRY_MAX_WAIT), retry=retry_if_exception(_is_retryable_error), before_sleep=before_sleep_log(logger, logging.WARNING), reraise=True, @@ -59,12 +176,20 @@ async def connect( avatar_image_base64: Optional[str] = None, initial_prompt: Optional[dict] = None, ) -> bool: + self._local_track = local_track + self._subscribe_mode = local_track is None + self._intentional_disconnect = False + self._has_connected = False + self._is_reconnecting = False + self._reconnect_generation += 1 + self._connection = self._create_connection() + self._emit_state("connecting") + try: - timeout = 60 * 5 # 5 minutes await self._connection.connect( url=self._config.webrtc_url, local_track=local_track, - timeout=timeout, + timeout=CONNECTION_TIMEOUT, integration=self._config.integration, is_avatar_live=self._config.is_avatar_live, avatar_image_base64=avatar_image_base64, @@ -74,14 +199,15 @@ async def connect( except Exception as e: logger.error(f"Connection attempt failed: {e}") await self._connection.cleanup() - self._connection = self._create_connection() + self._connection = None raise def _create_connection(self) -> WebRTCConnection: return WebRTCConnection( on_remote_stream=self._config.on_remote_stream, - on_state_change=self._config.on_connection_state_change, + on_state_change=self._handle_connection_state_change, on_error=self._config.on_error, + on_session_id=self._config.on_session_id, customize_offer=self._config.customize_offer, ) @@ -95,7 +221,8 @@ async def set_image( opts = options or {} timeout = opts.get("timeout", 30.0) - event, result = self._connection.register_image_set_wait() + conn = self._get_connection() + event, result = conn.register_image_set_wait() try: message = SetAvatarImageMessage( @@ -107,7 +234,7 @@ async def set_image( if opts.get("enhance") is not None: message.enhance_prompt = opts["enhance"] - await self._connection.send(message) + await conn.send(message) try: await asyncio.wait_for(event.wait(), timeout=timeout) @@ -121,28 +248,37 @@ async def set_image( raise DecartSDKError(result.get("error") or "Failed to set image") finally: - self._connection.unregister_image_set_wait() + conn.unregister_image_set_wait() async def send_message(self, message: OutgoingMessage) -> None: - await self._connection.send(message) + await self._get_connection().send(message) async def cleanup(self) -> None: - await self._connection.cleanup() + self._intentional_disconnect = True + self._is_reconnecting = False + self._reconnect_generation += 1 + if self._reconnect_task and not self._reconnect_task.done(): + self._reconnect_task.cancel() + if self._connection: + await self._connection.cleanup() + self._connection = None + self._local_track = None + self._emit_state("disconnected") def is_connected(self) -> bool: - return self._connection.state == "connected" + return self._manager_state in ("connected", "generating") def get_connection_state(self) -> ConnectionState: - return self._connection.state + return self._manager_state def register_prompt_wait(self, prompt: str) -> tuple[asyncio.Event, dict]: - return self._connection.register_prompt_wait(prompt) + return self._get_connection().register_prompt_wait(prompt) def unregister_prompt_wait(self, prompt: str) -> None: - self._connection.unregister_prompt_wait(prompt) + self._get_connection().unregister_prompt_wait(prompt) def register_image_set_wait(self) -> tuple[asyncio.Event, dict]: - return self._connection.register_image_set_wait() + return self._get_connection().register_image_set_wait() def unregister_image_set_wait(self) -> None: - self._connection.unregister_image_set_wait() + self._get_connection().unregister_image_set_wait() diff --git a/decart/types.py b/decart/types.py index 7f66fd8..d1bf707 100644 --- a/decart/types.py +++ b/decart/types.py @@ -15,7 +15,7 @@ def read(self) -> Union[bytes, str]: ... class Prompt(BaseModel): text: str = Field(..., min_length=1) - enrich: bool = Field(default=True) + enhance: bool = Field(default=True) class ModelState(BaseModel): diff --git a/examples/realtime_file.py b/examples/realtime_file.py index 5e79fed..f58a745 100644 --- a/examples/realtime_file.py +++ b/examples/realtime_file.py @@ -85,7 +85,7 @@ def on_error(error): options=RealtimeConnectOptions( model=model, on_remote_stream=on_remote_stream, - initial_state=ModelState(prompt=Prompt(text="Lego World", enrich=True)), + initial_state=ModelState(prompt=Prompt(text="Lego World", enhance=True)), ), ) diff --git a/examples/realtime_synthetic.py b/examples/realtime_synthetic.py index b251814..0c3f45c 100644 --- a/examples/realtime_synthetic.py +++ b/examples/realtime_synthetic.py @@ -111,7 +111,7 @@ def on_error(error): options=RealtimeConnectOptions( model=model, on_remote_stream=on_remote_stream, - initial_state=ModelState(prompt=Prompt(text="Anime style", enrich=True)), + initial_state=ModelState(prompt=Prompt(text="Anime style", enhance=True)), ), ) diff --git a/tests/test_realtime_unit.py b/tests/test_realtime_unit.py index 1c0e4b0..b1221c4 100644 --- a/tests/test_realtime_unit.py +++ b/tests/test_realtime_unit.py @@ -1,3 +1,5 @@ +import asyncio + import pytest from unittest.mock import AsyncMock, MagicMock, patch from decart import DecartClient, models @@ -86,13 +88,32 @@ async def test_realtime_client_creation_with_mock(): options=RealtimeConnectOptions( model=models.realtime("mirage"), on_remote_stream=lambda t: None, - initial_state=ModelState(prompt=Prompt(text="Test", enrich=True)), + initial_state=ModelState(prompt=Prompt(text="Test", enhance=True)), ), ) assert realtime_client is not None - assert realtime_client.session_id assert realtime_client.is_connected() + assert realtime_client.session_id is None + assert realtime_client.subscribe_token is None + + call_args = mock_manager_class.call_args + config = call_args[0][0] if call_args[0] else call_args[1]["configuration"] + assert config.on_session_id is not None, "on_session_id callback must be wired" + + from decart.realtime.messages import SessionIdMessage + + config.on_session_id( + SessionIdMessage( + type="session_id", + session_id="test-session-123", + server_ip="1.2.3.4", + server_port=8080, + ) + ) + + assert realtime_client.session_id == "test-session-123" + assert realtime_client.subscribe_token is not None @pytest.mark.asyncio @@ -146,6 +167,43 @@ async def set_event(): mock_manager.unregister_prompt_wait.assert_called_with("New prompt") +@pytest.mark.asyncio +async def test_buffered_events_delivered_after_handler_registration(): + """Events emitted during connect() must be delivered to handlers registered after connect().""" + client = DecartClient(api_key="test-key") + + with patch("decart.realtime.client.WebRTCManager") as mock_manager_class: + mock_manager = AsyncMock() + mock_manager.connect = AsyncMock(return_value=True) + mock_manager_class.return_value = mock_manager + + mock_track = MagicMock() + + from decart.realtime.types import RealtimeConnectOptions + + realtime_client = await RealtimeClient.connect( + base_url=client.base_url, + api_key=client.api_key, + local_track=mock_track, + options=RealtimeConnectOptions( + model=models.realtime("mirage"), + on_remote_stream=lambda t: None, + ), + ) + + # Simulate events that were buffered during connect + realtime_client._buffer.append(("connection_change", "connecting")) + realtime_client._buffer.append(("connection_change", "connected")) + + received: list = [] + realtime_client.on("connection_change", lambda s: received.append(s)) + + # Yield to event loop — deferred flush fires and delivers buffered events + await asyncio.sleep(0) + + assert received == ["connecting", "connected"] + + @pytest.mark.asyncio async def test_realtime_events(): """Test event handling""" @@ -182,6 +240,9 @@ def on_error(error): realtime_client.on("connection_change", on_connection_change) realtime_client.on("error", on_error) + # Yield to event loop so deferred _do_flush fires (mirrors JS setTimeout(0)) + await asyncio.sleep(0) + realtime_client._emit_connection_change("connected") assert connection_states == ["connected"] @@ -397,41 +458,123 @@ async def test_avatar_live_set_image(): @pytest.mark.asyncio -async def test_set_image_only_for_avatar_live(): - """Test that set_image raises error for non-avatar-live models""" +async def test_set_image_works_for_any_model(): + """Test that set_image works for non-avatar-live models""" client = DecartClient(api_key="test-key") with ( patch("decart.realtime.client.WebRTCManager") as mock_manager_class, + patch("decart.realtime.client.file_input_to_bytes") as mock_file_input, patch("decart.realtime.client.aiohttp.ClientSession") as mock_session_cls, ): mock_manager = AsyncMock() mock_manager.connect = AsyncMock(return_value=True) + mock_manager.set_image = AsyncMock() mock_manager_class.return_value = mock_manager + mock_file_input.return_value = (b"image data", "image/png") + mock_session = MagicMock() mock_session.closed = False + mock_session.close = AsyncMock() mock_session_cls.return_value = mock_session mock_track = MagicMock() from decart.realtime.types import RealtimeConnectOptions - from decart.errors import InvalidInputError realtime_client = await RealtimeClient.connect( base_url=client.base_url, api_key=client.api_key, local_track=mock_track, options=RealtimeConnectOptions( - model=models.realtime("mirage"), # Not avatar-live + model=models.realtime("mirage"), on_remote_stream=lambda t: None, ), ) - with pytest.raises(InvalidInputError) as exc_info: - await realtime_client.set_image(b"test image") + await realtime_client.set_image(b"test image") + mock_manager.set_image.assert_called_once() + + +@pytest.mark.asyncio +async def test_set_image_null_clears_image(): + """Test that set_image(None) sends null to clear image""" + client = DecartClient(api_key="test-key") + + with ( + patch("decart.realtime.client.WebRTCManager") as mock_manager_class, + patch("decart.realtime.client.aiohttp.ClientSession") as mock_session_cls, + ): + mock_manager = AsyncMock() + mock_manager.connect = AsyncMock(return_value=True) + mock_manager.set_image = AsyncMock() + mock_manager_class.return_value = mock_manager + + mock_session = MagicMock() + mock_session.closed = False + mock_session.close = AsyncMock() + mock_session_cls.return_value = mock_session + + mock_track = MagicMock() + + from decart.realtime.types import RealtimeConnectOptions + + realtime_client = await RealtimeClient.connect( + base_url=client.base_url, + api_key=client.api_key, + local_track=mock_track, + options=RealtimeConnectOptions( + model=models.realtime("mirage"), + on_remote_stream=lambda t: None, + ), + ) + + await realtime_client.set_image(None) + mock_manager.set_image.assert_called_once() + assert mock_manager.set_image.call_args[0][0] is None + + +@pytest.mark.asyncio +async def test_set_image_with_prompt_and_enhance(): + """Test that set_image passes prompt and enhance options""" + client = DecartClient(api_key="test-key") + + with ( + patch("decart.realtime.client.WebRTCManager") as mock_manager_class, + patch("decart.realtime.client.file_input_to_bytes") as mock_file_input, + patch("decart.realtime.client.aiohttp.ClientSession") as mock_session_cls, + ): + mock_manager = AsyncMock() + mock_manager.connect = AsyncMock(return_value=True) + mock_manager.set_image = AsyncMock() + mock_manager_class.return_value = mock_manager + + mock_file_input.return_value = (b"img", "image/png") + + mock_session = MagicMock() + mock_session.closed = False + mock_session.close = AsyncMock() + mock_session_cls.return_value = mock_session + + mock_track = MagicMock() + + from decart.realtime.types import RealtimeConnectOptions + + realtime_client = await RealtimeClient.connect( + base_url=client.base_url, + api_key=client.api_key, + local_track=mock_track, + options=RealtimeConnectOptions( + model=models.realtime("mirage"), + on_remote_stream=lambda t: None, + ), + ) - assert "avatar-live" in str(exc_info.value).lower() + await realtime_client.set_image(b"img", prompt="a dog", enhance=False) + opts = mock_manager.set_image.call_args[0][1] + assert opts["prompt"] == "a dog" + assert opts["enhance"] is False @pytest.mark.asyncio diff --git a/uv.lock b/uv.lock index 72ea5ca..f810957 100644 --- a/uv.lock +++ b/uv.lock @@ -597,7 +597,7 @@ wheels = [ [[package]] name = "decart" -version = "0.0.18" +version = "0.0.19" source = { editable = "." } dependencies = [ { name = "aiofiles" },