Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions sdk/adrian/ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,12 @@ def schedule_connect(self, loop: asyncio.AbstractEventLoop) -> None:
if self._connect_task is None or self._connect_task.done():
self._connect_task = loop.create_task(self.connect())

def _ensure_connect_task(self) -> None:
"""Start the initial/reconnect task if none is currently running."""
if self._connect_task is None or self._connect_task.done():
loop = asyncio.get_running_loop()
self._connect_task = loop.create_task(self.connect())

async def connect(self) -> None:
"""Establish the WebSocket with exponential-backoff retry.

Expand Down Expand Up @@ -575,6 +581,8 @@ async def _send_frame(self, frame: pb.ClientFrame) -> None:

if not self._connected.is_set() or self._replaying:
self._buffer_frame(frame_bytes)
if not self._replaying:
self._ensure_connect_task()
reason = "disconnected" if not self._connected.is_set() else "replaying"
logger.info(
"buffered for replay (session_id=%s, kind=%s, "
Expand All @@ -591,6 +599,8 @@ async def _send_frame(self, frame: pb.ClientFrame) -> None:

if ws is None:
self._buffer_frame(frame_bytes)
if not self._connected.is_set():
self._ensure_connect_task()

return

Expand Down Expand Up @@ -872,10 +882,7 @@ async def _handle_disconnect(self, reason: str) -> None:
if self._closing:
return

loop = asyncio.get_running_loop()

if self._connect_task is None or self._connect_task.done():
self._connect_task = loop.create_task(self.connect())
self._ensure_connect_task()

async def _fire_on_disconnect(self, reason: str) -> None:
"""Invoke the on_disconnect callback, catching any exception."""
Expand Down
35 changes: 35 additions & 0 deletions sdk/tests/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import asyncio
import os
from collections.abc import Iterator
from pathlib import Path
Expand All @@ -10,6 +11,7 @@
import adrian
import pytest
from adrian.config import AdrianConfig, get_config, is_initialized
from adrian.proto import event_pb2 as pb
from langchain_core.callbacks.manager import CallbackManager
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.runnables.base import Runnable
Expand Down Expand Up @@ -66,6 +68,39 @@ def test_creates_jsonl_file(self, tmp_path: Path) -> None:

assert log.exists()

async def test_sync_init_first_async_send_starts_connect_task(self) -> None:
"""First async send should start connect when init() ran without a loop."""
adrian.init(
auto_instrument=False,
api_key="k",
ws_url="ws://127.0.0.1:9999/ws",
)

ws = adrian._ws_client
assert ws is not None
assert ws._connect_task is None

frame = pb.ClientFrame()
event = frame.paired_batch.events.add()
event.event_id = "evt-1"
event.invocation_id = "inv-1"
event.session_id = "sess-1"
event.pair_type = pb.PAIR_TYPE_TOOL
event.tool.tool_name = "demo"

connect_calls: list[int] = []

async def _fake_connect() -> None:
connect_calls.append(1)

with patch.object(ws, "connect", _fake_connect):
await ws._send_frame(frame) # pyright: ignore[reportPrivateUsage]
await asyncio.sleep(0)

assert connect_calls == [1]
assert ws._connect_task is not None
assert len(ws._replay_buffer) == 1


class TestShutdown:
"""Tests for adrian.shutdown()."""
Expand Down