diff --git a/agent_debugger_sdk/core/context/trace_context.py b/agent_debugger_sdk/core/context/trace_context.py index d742fb1..be6fd32 100644 --- a/agent_debugger_sdk/core/context/trace_context.py +++ b/agent_debugger_sdk/core/context/trace_context.py @@ -13,6 +13,7 @@ from agent_debugger_sdk.core.emitter import EventBufferLike, EventEmitter from agent_debugger_sdk.core.events import ( Checkpoint, + DriftDetectedEvent, EventType, Session, SessionStatus, @@ -188,6 +189,7 @@ async def restore( ctx._restored_state = restored_state ctx.replayed_events: list[dict[str, Any]] = [] ctx._drift_detector = None + ctx._drift_decision_index: int = 0 ctx._hook_errors: list[Exception] = [] ctx._restored_target: Any = None @@ -274,9 +276,13 @@ async def restore( if e.get("importance", 1.0) >= importance_threshold ] - # Seed drift detector with post-checkpoint events as baseline + # Seed drift detector with post-checkpoint decision events as baseline. + # Filtering to decisions aligns the decision-local index used in + # record_decision with the correct original event at each position. if ctx._drift_detector is not None: - ctx._drift_detector.original_events = post_events.copy() + ctx._drift_detector.original_events = [ + e for e in post_events if e.get("event_type") == "decision" + ] # Replay each event, honouring cancellation for event in post_events: @@ -571,6 +577,45 @@ def _check_entered(self) -> None: "TraceContext has not been entered. Use 'async with TraceContext(...) as ctx:' to enter the context." ) + async def record_decision( + self, + reasoning: str, + confidence: float, + chosen_action: str, + evidence: list[dict[str, Any]] | None = None, + **kwargs: Any, + ) -> str: + event_id = await super().record_decision( + reasoning=reasoning, + confidence=confidence, + chosen_action=chosen_action, + evidence=evidence, + **kwargs, + ) + drift_detector = getattr(self, "_drift_detector", None) + if drift_detector is not None: + index = getattr(self, "_drift_decision_index", 0) + self._drift_decision_index = index + 1 + clamped_confidence = max(0.0, min(1.0, confidence)) + new_event_dict = { + "event_type": "decision", + "data": {"chosen_action": chosen_action, "confidence": clamped_confidence}, + } + drift = drift_detector.compare(new_event_dict, index) + if drift is not None: + drift_event = DriftDetectedEvent( + session_id=self.session_id, + parent_id=self.get_current_parent(), + description=drift.description, + original_value=str(drift.original_value), + restored_value=str(drift.restored_value), + drift_event_type=drift.event_type, + drift_index=drift.index, + severity=drift.severity.value, + ) + await self._emit_event(drift_event) + return event_id + async def _emit_event(self, event: TraceEvent) -> None: """Emit an event through the shared event emitter.""" await self._emitter.emit(event) diff --git a/agent_debugger_sdk/core/events/__init__.py b/agent_debugger_sdk/core/events/__init__.py index 632d80d..24c04b9 100644 --- a/agent_debugger_sdk/core/events/__init__.py +++ b/agent_debugger_sdk/core/events/__init__.py @@ -18,6 +18,7 @@ ) from agent_debugger_sdk.core.events.checkpoint import Checkpoint from agent_debugger_sdk.core.events.decisions import DecisionEvent +from agent_debugger_sdk.core.events.drift import DriftDetectedEvent from agent_debugger_sdk.core.events.errors import ErrorEvent from agent_debugger_sdk.core.events.llm import LLMRequestEvent, LLMResponseEvent from agent_debugger_sdk.core.events.registry import ( @@ -54,6 +55,7 @@ EventType.BEHAVIOR_ALERT: BehaviorAlertEvent, EventType.ERROR: ErrorEvent, EventType.REPAIR_ATTEMPT: RepairAttemptEvent, + EventType.DRIFT: DriftDetectedEvent, } ) @@ -81,6 +83,7 @@ "ErrorEvent", "RepairAttemptEvent", "RepairOutcome", + "DriftDetectedEvent", "Session", "Checkpoint", "EVENT_TYPE_REGISTRY", diff --git a/agent_debugger_sdk/core/events/base.py b/agent_debugger_sdk/core/events/base.py index c8a816c..90291dc 100644 --- a/agent_debugger_sdk/core/events/base.py +++ b/agent_debugger_sdk/core/events/base.py @@ -54,6 +54,7 @@ class EventType(StrEnum): AGENT_TURN = "agent_turn" BEHAVIOR_ALERT = "behavior_alert" REPAIR_ATTEMPT = "repair_attempt" + DRIFT = "drift" class SessionStatus(StrEnum): diff --git a/agent_debugger_sdk/core/events/drift.py b/agent_debugger_sdk/core/events/drift.py new file mode 100644 index 0000000..44efc2a --- /dev/null +++ b/agent_debugger_sdk/core/events/drift.py @@ -0,0 +1,30 @@ +"""Drift detection event emitted when replay diverges from original execution.""" + +from dataclasses import dataclass + +from .base import EventType, TraceEvent + +__all__ = ["DriftDetectedEvent"] + + +@dataclass(kw_only=True) +class DriftDetectedEvent(TraceEvent): + """Event emitted when a replayed decision diverges from the original execution. + + Attributes: + event_type: Always EventType.DRIFT + description: Human-readable summary of what drifted + original_value: The value from the original execution + restored_value: The value from the restored/replayed execution + drift_event_type: The event type where drift was detected (e.g. "decision") + drift_index: Position in the original event sequence where drift occurred + severity: "warning" or "critical" + """ + + event_type: EventType = EventType.DRIFT + description: str = "" + original_value: str = "" + restored_value: str = "" + drift_event_type: str = "" + drift_index: int = 0 + severity: str = "warning" diff --git a/agent_debugger_sdk/core/recorders.py b/agent_debugger_sdk/core/recorders.py index 224b1ec..768981f 100644 --- a/agent_debugger_sdk/core/recorders.py +++ b/agent_debugger_sdk/core/recorders.py @@ -98,8 +98,8 @@ async def record_decision( self, reasoning: str, confidence: float, - evidence: list[dict[str, Any]], chosen_action: str, + evidence: list[dict[str, Any]] | None = None, evidence_event_ids: list[str] | None = None, upstream_event_ids: list[str] | None = None, alternatives: list[dict[str, Any]] | None = None, @@ -114,7 +114,7 @@ async def record_decision( name=name, reasoning=reasoning, confidence=max(0.0, min(1.0, confidence)), - evidence=evidence, + evidence=evidence or [], evidence_event_ids=evidence_event_ids or [], alternatives=alternatives or [], chosen_action=chosen_action, diff --git a/tests/test_replay_depth_l3.py b/tests/test_replay_depth_l3.py index 69f03ef..e71ca17 100644 --- a/tests/test_replay_depth_l3.py +++ b/tests/test_replay_depth_l3.py @@ -774,9 +774,16 @@ async def capture_event(event): "importance": 0.5, } - # Original events show different action than what will be replayed + # Original events show different action than what will be replayed. + # Timestamp must be after the checkpoint timestamp so the post-checkpoint filter passes. mock_events = [ - {"id": "evt-2", "sequence": 2, "event_type": "decision", "data": {"chosen_action": "tool_a"}}, + { + "id": "evt-2", + "sequence": 2, + "event_type": "decision", + "timestamp": "2026-03-24T12:00:01Z", + "data": {"chosen_action": "tool_a"}, + }, ] with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get: @@ -785,8 +792,8 @@ def side_effect(url, *args, **kwargs): mock_response = MagicMock() if "checkpoints" in url: mock_response.json.return_value = mock_checkpoint_data - elif "events" in url: - mock_response.json.return_value = {"events": mock_events} + elif "traces" in url: + mock_response.json.return_value = {"traces": mock_events} mock_response.raise_for_status = MagicMock() return mock_response @@ -805,8 +812,9 @@ def side_effect(url, *args, **kwargs): chosen_action="tool_b", # Different from original "tool_a" ) - # Drift event should have been emitted - drift_events = [e for e in emitted_events if getattr(e, "event_type", None) == "drift"] + # Drift event should have been emitted into the context's event store + all_events = await ctx.get_events() + drift_events = [e for e in all_events if getattr(e, "event_type", None) == "drift"] assert len(drift_events) > 0 except (TypeError, ImportError, AttributeError) as e: pytest.skip(f"Drift event emission not yet implemented: {e}")