Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 1 addition & 9 deletions decart/realtime/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from aiortc import MediaStreamTrack

from .webrtc_manager import WebRTCManager, WebRTCConfiguration
from .messages import PromptMessage, SwitchCameraMessage
from .messages import PromptMessage
from .types import ConnectionState, RealtimeConnectOptions
from ..errors import DecartSDKError, InvalidInputError, WebRTCError

Expand Down Expand Up @@ -59,8 +59,6 @@ async def connect(
options.initial_state.prompt.text,
enrich=options.initial_state.prompt.enrich,
)
if options.initial_state.mirror is not None:
await client.set_mirror(options.initial_state.mirror)
except Exception as e:
raise WebRTCError(str(e), cause=e)

Expand All @@ -85,12 +83,6 @@ async def set_prompt(self, prompt: str, enrich: bool = True) -> None:
raise InvalidInputError("Prompt cannot be empty")
await self._manager.send_message(PromptMessage(type="prompt", prompt=prompt))

async def set_mirror(self, enabled: bool) -> None:
rotate_y = 2 if enabled else 0
await self._manager.send_message(
SwitchCameraMessage(type="switch_camera", rotateY=rotate_y)
)

def is_connected(self) -> bool:
return self._manager.is_connected()

Expand Down
9 changes: 1 addition & 8 deletions decart/realtime/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,8 @@ class PromptMessage(BaseModel):
prompt: str


class SwitchCameraMessage(BaseModel):
"""Switch camera/mirror message."""

type: Literal["switch_camera"]
rotateY: int


# Outgoing message union (no discriminator needed - we know what we're sending)
OutgoingMessage = Union[OfferMessage, IceCandidateMessage, PromptMessage, SwitchCameraMessage]
OutgoingMessage = Union[OfferMessage, IceCandidateMessage, PromptMessage]


def parse_incoming_message(data: dict) -> IncomingMessage:
Expand Down
1 change: 0 additions & 1 deletion decart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ class Prompt(BaseModel):

class ModelState(BaseModel):
prompt: Optional[Prompt] = None
mirror: bool = Field(default=False)


class MotionTrajectoryInput(BaseModel):
Expand Down
4 changes: 1 addition & 3 deletions examples/realtime_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,7 @@ def on_error(error):
options=RealtimeConnectOptions(
model=model,
on_remote_stream=on_remote_stream,
initial_state=ModelState(
prompt=Prompt(text="Lego World", enrich=True), mirror=False
),
initial_state=ModelState(prompt=Prompt(text="Lego World", enrich=True)),
),
)

Expand Down
4 changes: 1 addition & 3 deletions examples/realtime_synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,7 @@ def on_error(error):
options=RealtimeConnectOptions(
model=model,
on_remote_stream=on_remote_stream,
initial_state=ModelState(
prompt=Prompt(text="Anime style", enrich=True), mirror=False
),
initial_state=ModelState(prompt=Prompt(text="Anime style", enrich=True)),
),
)

Expand Down
35 changes: 1 addition & 34 deletions tests/test_realtime_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ async def test_realtime_client_creation_with_mock():
options=RealtimeConnectOptions(
model=models.realtime("mirage"),
on_remote_stream=lambda t: None,
initial_state=ModelState(prompt=Prompt(text="Test", enrich=True), mirror=False),
initial_state=ModelState(prompt=Prompt(text="Test", enrich=True)),
),
)

Expand Down Expand Up @@ -107,39 +107,6 @@ async def test_realtime_set_prompt_with_mock():
assert call_args.prompt == "New prompt"


@pytest.mark.asyncio
async def test_realtime_set_mirror_with_mock():
"""Test set_mirror with mocked WebRTC"""
client = DecartClient(api_key="test-key")

with patch("decart.realtime.client.WebRTCManager") as mock_manager_class:
mock_manager = AsyncMock()
mock_manager.connect = AsyncMock(return_value=True)
mock_manager.send_message = AsyncMock()
mock_manager_class.return_value = mock_manager

mock_track = MagicMock()

from decart.realtime.types import RealtimeConnectOptions

realtime_client = await RealtimeClient.connect(
base_url=client.base_url,
api_key=client.api_key,
local_track=mock_track,
options=RealtimeConnectOptions(
model=models.realtime("mirage"),
on_remote_stream=lambda t: None,
),
)

await realtime_client.set_mirror(True)

mock_manager.send_message.assert_called_once()
call_args = mock_manager.send_message.call_args[0][0]
assert call_args.type == "switch_camera"
assert call_args.rotateY == 2


@pytest.mark.asyncio
async def test_realtime_events():
"""Test event handling"""
Expand Down
Loading