diff --git a/decart/realtime/__init__.py b/decart/realtime/__init__.py index 7488eb6..cd81dc6 100644 --- a/decart/realtime/__init__.py +++ b/decart/realtime/__init__.py @@ -5,6 +5,7 @@ encode_subscribe_token, decode_subscribe_token, ) +from .messages import GenerationTickMessage from .types import RealtimeConnectOptions, ConnectionState, AvatarOptions __all__ = [ @@ -14,6 +15,7 @@ "SubscribeOptions", "encode_subscribe_token", "decode_subscribe_token", + "GenerationTickMessage", "RealtimeConnectOptions", "ConnectionState", "AvatarOptions", diff --git a/decart/realtime/client.py b/decart/realtime/client.py index 3f338a8..4ce97a0 100644 --- a/decart/realtime/client.py +++ b/decart/realtime/client.py @@ -9,7 +9,7 @@ from pydantic import BaseModel from .webrtc_manager import WebRTCManager, WebRTCConfiguration -from .messages import PromptMessage, SessionIdMessage +from .messages import PromptMessage, SessionIdMessage, GenerationTickMessage from .subscribe import ( SubscribeClient, SubscribeOptions, @@ -73,6 +73,7 @@ def __init__( self._is_avatar_live = is_avatar_live self._connection_callbacks: list[Callable[[ConnectionState], None]] = [] self._error_callbacks: list[Callable[[DecartSDKError], None]] = [] + self._generation_tick_callbacks: list[Callable[[GenerationTickMessage], None]] = [] self._session_id: Optional[str] = None self._subscribe_token: Optional[str] = None self._buffering = True @@ -134,6 +135,7 @@ async def connect( config.on_connection_state_change = client._emit_connection_change config.on_error = lambda error: client._emit_error(WebRTCError(str(error), cause=error)) config.on_session_id = client._handle_session_id + config.on_generation_tick = client._emit_generation_tick try: # For avatar-live, convert and send avatar image before WebRTC connection @@ -226,6 +228,8 @@ def _do_flush(self) -> None: self._dispatch_connection_change(data) # type: ignore[arg-type] elif event == "error": self._dispatch_error(data) # type: ignore[arg-type] + elif event == "generation_tick": + self._dispatch_generation_tick(data) # type: ignore[arg-type] self._buffer.clear() def _dispatch_connection_change(self, state: ConnectionState) -> None: @@ -254,6 +258,19 @@ def _emit_error(self, error: DecartSDKError) -> None: else: self._dispatch_error(error) + def _dispatch_generation_tick(self, message: GenerationTickMessage) -> None: + for callback in list(self._generation_tick_callbacks): + try: + callback(message) + except Exception as e: + logger.exception(f"Error in generation_tick callback: {e}") + + def _emit_generation_tick(self, message: GenerationTickMessage) -> None: + if self._buffering: + self._buffer.append(("generation_tick", message)) + else: + self._dispatch_generation_tick(message) + async def set(self, input: SetInput) -> None: if input.prompt is None and input.image is None: raise InvalidInputError("At least one of 'prompt' or 'image' must be provided") @@ -350,6 +367,8 @@ def on(self, event: str, callback: Callable) -> None: self._connection_callbacks.append(callback) elif event == "error": self._error_callbacks.append(callback) + elif event == "generation_tick": + self._generation_tick_callbacks.append(callback) def off(self, event: str, callback: Callable) -> None: if event == "connection_change": @@ -362,3 +381,8 @@ def off(self, event: str, callback: Callable) -> None: self._error_callbacks.remove(callback) except ValueError: pass + elif event == "generation_tick": + try: + self._generation_tick_callbacks.remove(callback) + except ValueError: + pass diff --git a/decart/realtime/messages.py b/decart/realtime/messages.py index f3c3dc6..19a2612 100644 --- a/decart/realtime/messages.py +++ b/decart/realtime/messages.py @@ -93,6 +93,21 @@ class GenerationStartedMessage(BaseModel): type: Literal["generation_started"] +class GenerationTickMessage(BaseModel): + """Periodic billing update during generation.""" + + type: Literal["generation_tick"] + seconds: int + + +class GenerationEndedMessage(BaseModel): + """Server signals that generation has ended. Not exposed publicly.""" + + type: Literal["generation_ended"] + seconds: int + reason: str + + # Discriminated union for incoming messages IncomingMessage = Annotated[ Union[ @@ -105,6 +120,8 @@ class GenerationStartedMessage(BaseModel): ReadyMessage, IceRestartMessage, GenerationStartedMessage, + GenerationTickMessage, + GenerationEndedMessage, ], Field(discriminator="type"), ] diff --git a/decart/realtime/webrtc_connection.py b/decart/realtime/webrtc_connection.py index a7ebd34..ed09afd 100644 --- a/decart/realtime/webrtc_connection.py +++ b/decart/realtime/webrtc_connection.py @@ -28,6 +28,7 @@ ErrorMessage, IceRestartMessage, SessionIdMessage, + GenerationTickMessage, OutgoingMessage, ) from .types import ConnectionState @@ -42,6 +43,7 @@ def __init__( on_state_change: Optional[Callable[[ConnectionState], None]] = None, on_error: Optional[Callable[[Exception], None]] = None, on_session_id: Optional[Callable[[SessionIdMessage], None]] = None, + on_generation_tick: Optional[Callable[[GenerationTickMessage], None]] = None, customize_offer: Optional[Callable] = None, ): self._pc: Optional[RTCPeerConnection] = None @@ -52,6 +54,7 @@ def __init__( self._on_state_change = on_state_change self._on_error = on_error self._on_session_id = on_session_id + self._on_generation_tick = on_generation_tick self._customize_offer = customize_offer self._ws_task: Optional[asyncio.Task] = None self._ice_candidates_queue: list[RTCIceCandidate] = [] @@ -264,6 +267,14 @@ async def _handle_message(self, data: dict) -> None: self._handle_set_image_ack(message) elif message.type == "generation_started": await self._set_state("generating") + elif message.type == "generation_tick": + if self._on_generation_tick: + self._on_generation_tick(message) + elif message.type == "generation_ended": + # Parsed but intentionally not exposed — unreliable (won't arrive on + # client disconnect/network drop), overlaps with connection_change + # "disconnected", and insufficient_credits is already covered by error event. + logger.debug(f"Generation ended: reason={message.reason}, seconds={message.seconds}") elif message.type == "error": self._handle_error(message) elif message.type == "ready": diff --git a/decart/realtime/webrtc_manager.py b/decart/realtime/webrtc_manager.py index ad71b0b..1ce56cd 100644 --- a/decart/realtime/webrtc_manager.py +++ b/decart/realtime/webrtc_manager.py @@ -12,7 +12,7 @@ ) from .webrtc_connection import WebRTCConnection -from .messages import OutgoingMessage, SessionIdMessage +from .messages import OutgoingMessage, SessionIdMessage, GenerationTickMessage from .types import ConnectionState from ..types import ModelState @@ -44,6 +44,7 @@ class WebRTCConfiguration: on_connection_state_change: Optional[Callable[[ConnectionState], None]] = None on_error: Optional[Callable[[Exception], None]] = None on_session_id: Optional[Callable[[SessionIdMessage], None]] = None + on_generation_tick: Optional[Callable[[GenerationTickMessage], None]] = None initial_state: Optional[ModelState] = None customize_offer: Optional[Callable] = None integration: Optional[str] = None @@ -208,6 +209,7 @@ def _create_connection(self) -> WebRTCConnection: on_state_change=self._handle_connection_state_change, on_error=self._config.on_error, on_session_id=self._config.on_session_id, + on_generation_tick=self._config.on_generation_tick, customize_offer=self._config.customize_offer, )