Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions decart/realtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
encode_subscribe_token,
decode_subscribe_token,
)
from .messages import GenerationTickMessage
from .types import RealtimeConnectOptions, ConnectionState, AvatarOptions

__all__ = [
Expand All @@ -14,6 +15,7 @@
"SubscribeOptions",
"encode_subscribe_token",
"decode_subscribe_token",
"GenerationTickMessage",
"RealtimeConnectOptions",
"ConnectionState",
"AvatarOptions",
Expand Down
26 changes: 25 additions & 1 deletion decart/realtime/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pydantic import BaseModel

from .webrtc_manager import WebRTCManager, WebRTCConfiguration
from .messages import PromptMessage, SessionIdMessage
from .messages import PromptMessage, SessionIdMessage, GenerationTickMessage
from .subscribe import (
SubscribeClient,
SubscribeOptions,
Expand Down Expand Up @@ -73,6 +73,7 @@ def __init__(
self._is_avatar_live = is_avatar_live
self._connection_callbacks: list[Callable[[ConnectionState], None]] = []
self._error_callbacks: list[Callable[[DecartSDKError], None]] = []
self._generation_tick_callbacks: list[Callable[[GenerationTickMessage], None]] = []
self._session_id: Optional[str] = None
self._subscribe_token: Optional[str] = None
self._buffering = True
Expand Down Expand Up @@ -134,6 +135,7 @@ async def connect(
config.on_connection_state_change = client._emit_connection_change
config.on_error = lambda error: client._emit_error(WebRTCError(str(error), cause=error))
config.on_session_id = client._handle_session_id
config.on_generation_tick = client._emit_generation_tick

try:
# For avatar-live, convert and send avatar image before WebRTC connection
Expand Down Expand Up @@ -226,6 +228,8 @@ def _do_flush(self) -> None:
self._dispatch_connection_change(data) # type: ignore[arg-type]
elif event == "error":
self._dispatch_error(data) # type: ignore[arg-type]
elif event == "generation_tick":
self._dispatch_generation_tick(data) # type: ignore[arg-type]
self._buffer.clear()

def _dispatch_connection_change(self, state: ConnectionState) -> None:
Expand Down Expand Up @@ -254,6 +258,19 @@ def _emit_error(self, error: DecartSDKError) -> None:
else:
self._dispatch_error(error)

def _dispatch_generation_tick(self, message: GenerationTickMessage) -> None:
for callback in list(self._generation_tick_callbacks):
try:
callback(message)
except Exception as e:
logger.exception(f"Error in generation_tick callback: {e}")

def _emit_generation_tick(self, message: GenerationTickMessage) -> None:
if self._buffering:
self._buffer.append(("generation_tick", message))
else:
self._dispatch_generation_tick(message)

async def set(self, input: SetInput) -> None:
if input.prompt is None and input.image is None:
raise InvalidInputError("At least one of 'prompt' or 'image' must be provided")
Expand Down Expand Up @@ -350,6 +367,8 @@ def on(self, event: str, callback: Callable) -> None:
self._connection_callbacks.append(callback)
elif event == "error":
self._error_callbacks.append(callback)
elif event == "generation_tick":
self._generation_tick_callbacks.append(callback)

def off(self, event: str, callback: Callable) -> None:
if event == "connection_change":
Expand All @@ -362,3 +381,8 @@ def off(self, event: str, callback: Callable) -> None:
self._error_callbacks.remove(callback)
except ValueError:
pass
elif event == "generation_tick":
try:
self._generation_tick_callbacks.remove(callback)
except ValueError:
pass
17 changes: 17 additions & 0 deletions decart/realtime/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,21 @@ class GenerationStartedMessage(BaseModel):
type: Literal["generation_started"]


class GenerationTickMessage(BaseModel):
"""Periodic billing update during generation."""

type: Literal["generation_tick"]
seconds: int


class GenerationEndedMessage(BaseModel):
"""Server signals that generation has ended. Not exposed publicly."""

type: Literal["generation_ended"]
seconds: int
reason: str


# Discriminated union for incoming messages
IncomingMessage = Annotated[
Union[
Expand All @@ -105,6 +120,8 @@ class GenerationStartedMessage(BaseModel):
ReadyMessage,
IceRestartMessage,
GenerationStartedMessage,
GenerationTickMessage,
GenerationEndedMessage,
],
Field(discriminator="type"),
]
Expand Down
11 changes: 11 additions & 0 deletions decart/realtime/webrtc_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
ErrorMessage,
IceRestartMessage,
SessionIdMessage,
GenerationTickMessage,
OutgoingMessage,
)
from .types import ConnectionState
Expand All @@ -42,6 +43,7 @@ def __init__(
on_state_change: Optional[Callable[[ConnectionState], None]] = None,
on_error: Optional[Callable[[Exception], None]] = None,
on_session_id: Optional[Callable[[SessionIdMessage], None]] = None,
on_generation_tick: Optional[Callable[[GenerationTickMessage], None]] = None,
customize_offer: Optional[Callable] = None,
):
self._pc: Optional[RTCPeerConnection] = None
Expand All @@ -52,6 +54,7 @@ def __init__(
self._on_state_change = on_state_change
self._on_error = on_error
self._on_session_id = on_session_id
self._on_generation_tick = on_generation_tick
self._customize_offer = customize_offer
self._ws_task: Optional[asyncio.Task] = None
self._ice_candidates_queue: list[RTCIceCandidate] = []
Expand Down Expand Up @@ -264,6 +267,14 @@ async def _handle_message(self, data: dict) -> None:
self._handle_set_image_ack(message)
elif message.type == "generation_started":
await self._set_state("generating")
elif message.type == "generation_tick":
if self._on_generation_tick:
self._on_generation_tick(message)
elif message.type == "generation_ended":
# Parsed but intentionally not exposed — unreliable (won't arrive on
# client disconnect/network drop), overlaps with connection_change
# "disconnected", and insufficient_credits is already covered by error event.
logger.debug(f"Generation ended: reason={message.reason}, seconds={message.seconds}")
elif message.type == "error":
self._handle_error(message)
elif message.type == "ready":
Expand Down
4 changes: 3 additions & 1 deletion decart/realtime/webrtc_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)

from .webrtc_connection import WebRTCConnection
from .messages import OutgoingMessage, SessionIdMessage
from .messages import OutgoingMessage, SessionIdMessage, GenerationTickMessage
from .types import ConnectionState
from ..types import ModelState

Expand Down Expand Up @@ -44,6 +44,7 @@ class WebRTCConfiguration:
on_connection_state_change: Optional[Callable[[ConnectionState], None]] = None
on_error: Optional[Callable[[Exception], None]] = None
on_session_id: Optional[Callable[[SessionIdMessage], None]] = None
on_generation_tick: Optional[Callable[[GenerationTickMessage], None]] = None
initial_state: Optional[ModelState] = None
customize_offer: Optional[Callable] = None
integration: Optional[str] = None
Expand Down Expand Up @@ -208,6 +209,7 @@ def _create_connection(self) -> WebRTCConnection:
on_state_change=self._handle_connection_state_change,
on_error=self._config.on_error,
on_session_id=self._config.on_session_id,
on_generation_tick=self._config.on_generation_tick,
customize_offer=self._config.customize_offer,
)

Expand Down
Loading