Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 47 additions & 2 deletions agent_debugger_sdk/core/context/trace_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from agent_debugger_sdk.core.emitter import EventBufferLike, EventEmitter
from agent_debugger_sdk.core.events import (
Checkpoint,
DriftDetectedEvent,
EventType,
Session,
SessionStatus,
Expand Down Expand Up @@ -188,6 +189,7 @@ async def restore(
ctx._restored_state = restored_state
ctx.replayed_events: list[dict[str, Any]] = []
ctx._drift_detector = None
ctx._drift_decision_index: int = 0
ctx._hook_errors: list[Exception] = []
ctx._restored_target: Any = None

Expand Down Expand Up @@ -274,9 +276,13 @@ async def restore(
if e.get("importance", 1.0) >= importance_threshold
]

# Seed drift detector with post-checkpoint events as baseline
# Seed drift detector with post-checkpoint decision events as baseline.
# Filtering to decisions aligns the decision-local index used in
# record_decision with the correct original event at each position.
if ctx._drift_detector is not None:
ctx._drift_detector.original_events = post_events.copy()
ctx._drift_detector.original_events = [
e for e in post_events if e.get("event_type") == "decision"
]

# Replay each event, honouring cancellation
for event in post_events:
Expand Down Expand Up @@ -571,6 +577,45 @@ def _check_entered(self) -> None:
"TraceContext has not been entered. Use 'async with TraceContext(...) as ctx:' to enter the context."
)

async def record_decision(
self,
reasoning: str,
confidence: float,
chosen_action: str,
evidence: list[dict[str, Any]] | None = None,
**kwargs: Any,
) -> str:
event_id = await super().record_decision(
reasoning=reasoning,
confidence=confidence,
chosen_action=chosen_action,
evidence=evidence,
**kwargs,
)
drift_detector = getattr(self, "_drift_detector", None)
if drift_detector is not None:
index = getattr(self, "_drift_decision_index", 0)
self._drift_decision_index = index + 1
clamped_confidence = max(0.0, min(1.0, confidence))
new_event_dict = {
"event_type": "decision",
"data": {"chosen_action": chosen_action, "confidence": clamped_confidence},
}
drift = drift_detector.compare(new_event_dict, index)
if drift is not None:
Comment on lines +588 to +605
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 6a095be. Two changes: (1) the drift detector baseline is now seeded with only decision events (filtered from post_events), so _drift_decision_index aligns correctly with original_events positions; (2) confidence is clamped to [0.0, 1.0] before building new_event_dict to match what RecordingMixin.record_decision() persists, eliminating false confidence drift for out-of-range inputs.

drift_event = DriftDetectedEvent(
session_id=self.session_id,
parent_id=self.get_current_parent(),
description=drift.description,
original_value=str(drift.original_value),
restored_value=str(drift.restored_value),
drift_event_type=drift.event_type,
drift_index=drift.index,
severity=drift.severity.value,
)
await self._emit_event(drift_event)
return event_id

async def _emit_event(self, event: TraceEvent) -> None:
"""Emit an event through the shared event emitter."""
await self._emitter.emit(event)
3 changes: 3 additions & 0 deletions agent_debugger_sdk/core/events/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from agent_debugger_sdk.core.events.checkpoint import Checkpoint
from agent_debugger_sdk.core.events.decisions import DecisionEvent
from agent_debugger_sdk.core.events.drift import DriftDetectedEvent
from agent_debugger_sdk.core.events.errors import ErrorEvent
from agent_debugger_sdk.core.events.llm import LLMRequestEvent, LLMResponseEvent
from agent_debugger_sdk.core.events.registry import (
Expand Down Expand Up @@ -54,6 +55,7 @@
EventType.BEHAVIOR_ALERT: BehaviorAlertEvent,
EventType.ERROR: ErrorEvent,
EventType.REPAIR_ATTEMPT: RepairAttemptEvent,
EventType.DRIFT: DriftDetectedEvent,
}
)

Expand Down Expand Up @@ -81,6 +83,7 @@
"ErrorEvent",
"RepairAttemptEvent",
"RepairOutcome",
"DriftDetectedEvent",
"Session",
"Checkpoint",
"EVENT_TYPE_REGISTRY",
Expand Down
1 change: 1 addition & 0 deletions agent_debugger_sdk/core/events/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class EventType(StrEnum):
AGENT_TURN = "agent_turn"
BEHAVIOR_ALERT = "behavior_alert"
REPAIR_ATTEMPT = "repair_attempt"
DRIFT = "drift"


class SessionStatus(StrEnum):
Expand Down
30 changes: 30 additions & 0 deletions agent_debugger_sdk/core/events/drift.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""Drift detection event emitted when replay diverges from original execution."""

from dataclasses import dataclass

from .base import EventType, TraceEvent

__all__ = ["DriftDetectedEvent"]


@dataclass(kw_only=True)
class DriftDetectedEvent(TraceEvent):
"""Event emitted when a replayed decision diverges from the original execution.

Attributes:
event_type: Always EventType.DRIFT
description: Human-readable summary of what drifted
original_value: The value from the original execution
restored_value: The value from the restored/replayed execution
drift_event_type: The event type where drift was detected (e.g. "decision")
drift_index: Position in the original event sequence where drift occurred
severity: "warning" or "critical"
"""

event_type: EventType = EventType.DRIFT
description: str = ""
original_value: str = ""
restored_value: str = ""
drift_event_type: str = ""
drift_index: int = 0
severity: str = "warning"
4 changes: 2 additions & 2 deletions agent_debugger_sdk/core/recorders.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ async def record_decision(
self,
reasoning: str,
confidence: float,
evidence: list[dict[str, Any]],
chosen_action: str,
evidence: list[dict[str, Any]] | None = None,
evidence_event_ids: list[str] | None = None,
upstream_event_ids: list[str] | None = None,
alternatives: list[dict[str, Any]] | None = None,
Expand All @@ -114,7 +114,7 @@ async def record_decision(
name=name,
reasoning=reasoning,
confidence=max(0.0, min(1.0, confidence)),
evidence=evidence,
evidence=evidence or [],
evidence_event_ids=evidence_event_ids or [],
alternatives=alternatives or [],
chosen_action=chosen_action,
Expand Down
20 changes: 14 additions & 6 deletions tests/test_replay_depth_l3.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,9 +774,16 @@ async def capture_event(event):
"importance": 0.5,
}

# Original events show different action than what will be replayed
# Original events show different action than what will be replayed.
# Timestamp must be after the checkpoint timestamp so the post-checkpoint filter passes.
mock_events = [
{"id": "evt-2", "sequence": 2, "event_type": "decision", "data": {"chosen_action": "tool_a"}},
{
"id": "evt-2",
"sequence": 2,
"event_type": "decision",
"timestamp": "2026-03-24T12:00:01Z",
"data": {"chosen_action": "tool_a"},
},
]

with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
Expand All @@ -785,8 +792,8 @@ def side_effect(url, *args, **kwargs):
mock_response = MagicMock()
if "checkpoints" in url:
mock_response.json.return_value = mock_checkpoint_data
elif "events" in url:
mock_response.json.return_value = {"events": mock_events}
elif "traces" in url:
mock_response.json.return_value = {"traces": mock_events}
mock_response.raise_for_status = MagicMock()
return mock_response

Expand All @@ -805,8 +812,9 @@ def side_effect(url, *args, **kwargs):
chosen_action="tool_b", # Different from original "tool_a"
)

# Drift event should have been emitted
drift_events = [e for e in emitted_events if getattr(e, "event_type", None) == "drift"]
# Drift event should have been emitted into the context's event store
all_events = await ctx.get_events()
drift_events = [e for e in all_events if getattr(e, "event_type", None) == "drift"]
assert len(drift_events) > 0
except (TypeError, ImportError, AttributeError) as e:
pytest.skip(f"Drift event emission not yet implemented: {e}")
Expand Down