Skip to content

Commit f8c8a24

Browse files
feat: Implement turn sequence management and stale context discard
- Added `turn_seq` attribute to `_QueuedItem` for tracking user turns. - Introduced `discard_stale` method in `ContextInjectionQueue` to remove queued items from older turns, preventing cross-turn carryover. - Updated `SessionState` to include `user_turn_seq` and `turn_output_seen` for managing user activity and turn state. - Enhanced `WebSocketHandler` to register user activity, manage turn sequences, and handle silent turns with a watchdog mechanism. - Modified context injection methods to include `turn_seq` for accurate tracking. - Added tests for stale context discard functionality and silent turn handling. - Created a script to generate an animated GIF from system architecture SVG for documentation purposes.
1 parent 0091b63 commit f8c8a24

9 files changed

Lines changed: 1886 additions & 8444 deletions

File tree

artifacts/e2e_multiturn/report.json

Lines changed: 1443 additions & 8421 deletions
Large diffs are not rendered by default.
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import cairosvg
2+
import imageio.v2 as imageio
3+
import os
4+
import re
5+
6+
# Configuration
7+
INPUT_SVG = 'assets/system-architecture.svg'
8+
OUTPUT_GIF = 'assets/system-architecture.gif'
9+
FRAMES = 12
10+
DURATION = 0.1 # seconds per frame
11+
12+
def create_gif():
13+
print(f"Reading {INPUT_SVG}...")
14+
with open(INPUT_SVG, 'r') as f:
15+
svg_content = f.read()
16+
17+
# FIX: Remove missing marker references that cause cairosvg to crash
18+
# The SVG references #arrowGreen but doesn't define it in <defs>
19+
svg_content = re.sub(r'marker-end="url\(#[^\)]+\)"', '', svg_content)
20+
21+
images = []
22+
23+
# We will animate the 'stroke-dashoffset' to create a flowing effect on dashed lines
24+
# The patterns are roughly length 6-8. A 24-unit shift covers multiples of 3, 4, 6, 8 nicely.
25+
# Let's use 24 frames shifting by -1 each time, or 12 frames shifting by -2.
26+
27+
print("Generating frames...")
28+
for i in range(FRAMES):
29+
# Calculate offset. We shift backwards to make it look like it's flowing forward usually.
30+
offset = i * -2
31+
32+
# Inject stroke-dashoffset into elements that have stroke-dasharray
33+
# We use a regex substitution callback to append the offset
34+
def add_offset(match):
35+
full_match = match.group(0)
36+
# If dashoffset already exists, replace it; otherwise append it
37+
if 'stroke-dashoffset' in full_match:
38+
return re.sub(r'stroke-dashoffset="[^"]*"', f'stroke-dashoffset="{offset}"', full_match)
39+
else:
40+
return f'{full_match} stroke-dashoffset="{offset}"'
41+
42+
# Target lines with dasharray
43+
frame_svg = re.sub(r'stroke-dasharray="[^"]*"', add_offset, svg_content)
44+
45+
try:
46+
# Convert to PNG in memory
47+
png_data = cairosvg.svg2png(bytestring=frame_svg.encode('utf-8'))
48+
49+
# Append to image list
50+
images.append(imageio.imread(png_data))
51+
print(f" - Frame {i+1}/{FRAMES} rendered")
52+
except Exception as e:
53+
print(f"Error rendering frame {i}: {e}")
54+
break
55+
56+
if images:
57+
print(f"Saving GIF to {OUTPUT_GIF}...")
58+
imageio.mimsave(OUTPUT_GIF, images, duration=DURATION, loop=0)
59+
print("Done!")
60+
else:
61+
print("No frames generated.")
62+
63+
if __name__ == "__main__":
64+
create_gif()

server.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,7 @@ class _QueuedItem:
298298
text: str
299299
priority: int # lower = more important
300300
speak: bool
301+
turn_seq: int = 0
301302
enqueued_at: float = dc_field(default_factory=time.monotonic)
302303

303304

@@ -441,17 +442,54 @@ def enqueue(
441442
text: str,
442443
priority: int = 5,
443444
speak: bool = True,
445+
turn_seq: int = 0,
444446
) -> None:
445447
"""Queue a context injection. Always queues; never sends immediately."""
446448
self._queue[category] = _QueuedItem(
447-
category=category, text=text, priority=priority, speak=speak,
449+
category=category, text=text, priority=priority, speak=speak, turn_seq=turn_seq,
448450
)
449451
logger.info("Queued [%s] (priority=%d, speak=%s, queue_size=%d, state=%s)",
450452
category, priority, speak, len(self._queue), self._state.value)
451453
# Only schedule flush if IDLE and no timer pending
452454
if self._state == ModelState.IDLE and self._deferred_flush_handle is None:
453455
self._schedule_deferred_flush()
454456

457+
def discard_stale(
458+
self,
459+
*,
460+
min_turn_seq: int,
461+
categories: set[str] | None = None,
462+
) -> list[str]:
463+
"""Drop queued items that belong to older user turns.
464+
465+
This prevents silent control messages or delayed sub-agent results from
466+
leaking into a newer user request and causing cross-turn carryover.
467+
"""
468+
dropped: list[str] = []
469+
kept: dict[str, _QueuedItem] = {}
470+
for category, item in self._queue.items():
471+
should_drop = item.turn_seq < min_turn_seq
472+
if categories is not None:
473+
should_drop = should_drop and category in categories
474+
if should_drop:
475+
dropped.append(category)
476+
continue
477+
kept[category] = item
478+
479+
if not dropped:
480+
return []
481+
482+
self._queue = kept
483+
if not self._queue:
484+
self._cancel_deferred_flush()
485+
logger.info(
486+
"Discarded %d stale queued item(s) for turn %d: %s",
487+
len(dropped),
488+
min_turn_seq,
489+
", ".join(sorted(dropped)),
490+
)
491+
return dropped
492+
455493
# -- Deferred flush (batching window) ------------------------------------
456494

457495
def _schedule_deferred_flush(self, delay: float = BATCH_WINDOW_SEC):

session_state.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ class SessionState:
8383

8484
# -- User activity -------------------------------------------------------
8585
last_user_activity_at: float = field(default_factory=time.monotonic)
86+
user_turn_seq: int = 0
87+
turn_output_seen: bool = False
8688

8789
# -- Memory (Phase 4) ----------------------------------------------------
8890
memory_top3: list[str] = field(default_factory=list)
@@ -99,6 +101,7 @@ class SessionState:
99101
TRANSCRIPT_FLUSH_TIMEOUT_SEC: float = field(default=1.5, repr=False)
100102
CAMERA_GRACE_PERIOD_SEC: float = field(default=12.0, repr=False)
101103
FRAME_TO_GEMINI_INTERVAL: float = field(default=2.0, repr=False)
104+
USER_TURN_GAP_SEC: float = field(default=1.0, repr=False)
102105
SENTENCE_BOUNDARY_RE: re.Pattern = field(
103106
default_factory=lambda: re.compile(r"[。!?.!?\n]"), repr=False,
104107
)

tests/test_context_injection_queue.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,42 @@ def test_different_categories_both_kept(self, queue, lrq):
330330
assert "John detected" in text
331331

332332

333+
# ---------------------------------------------------------------------------
334+
# Tests: stale turn discard
335+
# ---------------------------------------------------------------------------
336+
337+
338+
class TestDiscardStale:
339+
def test_discards_only_matching_older_turns(self, queue, lrq):
340+
queue.set_model_speaking(True)
341+
queue.enqueue("turn_boundary", "old boundary", priority=1, speak=False, turn_seq=1)
342+
queue.enqueue("vision", "fresh vision", priority=5, speak=True, turn_seq=2)
343+
queue.enqueue("camera_toggle", "camera on", priority=4, speak=True, turn_seq=1)
344+
345+
dropped = queue.discard_stale(
346+
min_turn_seq=2,
347+
categories={"turn_boundary", "vision"},
348+
)
349+
350+
assert dropped == ["turn_boundary"]
351+
assert set(queue._queue) == {"vision", "camera_toggle"}
352+
assert len(lrq.sent) == 0
353+
354+
def test_cancels_deferred_flush_when_queue_becomes_empty(self, queue):
355+
queue.set_model_speaking(True)
356+
queue.enqueue("turn_boundary", "old boundary", priority=1, speak=False, turn_seq=1)
357+
queue._deferred_flush_handle = MagicMock()
358+
359+
dropped = queue.discard_stale(
360+
min_turn_seq=2,
361+
categories={"turn_boundary"},
362+
)
363+
364+
assert dropped == ["turn_boundary"]
365+
assert queue._queue == {}
366+
assert queue._deferred_flush_handle is None
367+
368+
333369
# ---------------------------------------------------------------------------
334370
# Tests: priority sorting in merged output
335371
# ---------------------------------------------------------------------------

tests/test_navigation.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@
1010

1111
import pytest
1212

13+
import tools.navigation as navigation_mod
1314
from tools.navigation import (
1415
NAVIGATION_FUNCTIONS,
1516
NAVIGATION_TOOL_DECLARATIONS,
17+
_get_client,
1618
_haversine_distance,
1719
_maneuver_to_description,
1820
_strip_html,
@@ -124,6 +126,21 @@ def test_rounding(self):
124126
assert result == "at 9 o'clock, 48 meters"
125127

126128

129+
class TestGoogleMapsClient:
130+
def test_get_client_uses_supported_timeout_arguments(self, monkeypatch):
131+
monkeypatch.setenv("GOOGLE_MAPS_API_KEY", "test-key")
132+
monkeypatch.setattr(navigation_mod, "_client", None)
133+
134+
with patch("tools.navigation.googlemaps.Client") as mock_client:
135+
_get_client()
136+
137+
kwargs = mock_client.call_args.kwargs
138+
assert kwargs["key"] == "test-key"
139+
assert kwargs["connect_timeout"] == 5
140+
assert kwargs["read_timeout"] == 5
141+
assert "timeout" not in kwargs
142+
143+
127144
# ---------------------------------------------------------------------------
128145
# Helper unit tests
129146
# ---------------------------------------------------------------------------

tests/test_websocket_handler.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,18 @@ async def test_activity_start_while_model_speaking_triggers_interrupt(handler, f
655655
assert session_state.is_interrupted
656656

657657

658+
def test_register_user_activity_starts_new_turn_and_discards_stale_context(handler, mock_ctx_queue, session_state):
659+
mock_ctx_queue.discard_stale = Mock()
660+
session_state.last_user_activity_at = time.monotonic() - 5
661+
662+
started_new_turn = handler._register_user_activity(explicit_turn_start=True)
663+
664+
assert started_new_turn is True
665+
assert session_state.user_turn_seq == 1
666+
mock_ctx_queue.discard_stale.assert_called_once()
667+
assert mock_ctx_queue.discard_stale.call_args.kwargs["min_turn_seq"] == 1
668+
669+
658670
@pytest.mark.websocket
659671
@pytest.mark.asyncio
660672
async def test_interrupt_debounce_rapid_interrupts(handler, fake_ws, session_state, mock_ctx_queue):
@@ -1183,6 +1195,16 @@ async def test_gemini_connection_drop_closes_queue():
11831195
assert fake_queue.closed
11841196

11851197

1198+
def test_silent_turn_reconnect_predicate(handler, session_state):
1199+
session_state.user_turn_seq = 1
1200+
session_state.turn_output_seen = False
1201+
1202+
assert handler._should_reconnect_silent_turn() is True
1203+
1204+
session_state.turn_output_seen = True
1205+
assert handler._should_reconnect_silent_turn() is False
1206+
1207+
11861208
@pytest.mark.websocket
11871209
@pytest.mark.error_scenario
11881210
@pytest.mark.asyncio

tools/navigation.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ def _get_client() -> googlemaps.Client:
4848
raise RuntimeError("GOOGLE_MAPS_API_KEY environment variable not set")
4949
_client = googlemaps.Client(
5050
key=api_key,
51-
timeout=5,
5251
connect_timeout=5,
5352
read_timeout=5,
5453
)

0 commit comments

Comments
 (0)