diff --git a/scripts/visualize_subtasks.py b/scripts/visualize_subtasks.py index c826ec8..644e628 100644 --- a/scripts/visualize_subtasks.py +++ b/scripts/visualize_subtasks.py @@ -47,6 +47,13 @@ {"task_finished", "task_failed", "iteration_finished"} ) +# add_subtask is the only planner tool that creates a root subtask. Each call +# yields the next sequential root id (max_root + 1, per _next_task_id in +# contractor/tools/tasks.py), so chronological order of *successful* calls maps +# directly onto root task_ids "1", "2", "3", … decompose creates children +# directly (no add_subtask call), so those don't enter this counter. +_ADD_SUBTASK_TOOL_NAME = "add_subtask" + _FILENAME_SAFE = re.compile(r"[^a-zA-Z0-9._-]+") @@ -57,6 +64,15 @@ def _sanitize(name: str) -> str: # ─── Data extraction ───────────────────────────────────────────────────────── +@dataclass +class AddSubtaskAction: + """A single successful add_subtask tool call, in chronological order.""" + + title: str + description: str + session_id: str | None + + @dataclass class SubtaskRun: """One reconstructed task run with its subtask pool.""" @@ -65,7 +81,9 @@ class SubtaskRun: task_name: str task_id: Any template_key: str | None + session_id: str | None = None records: list[dict[str, Any]] = field(default_factory=list) + add_actions: list[AddSubtaskAction] = field(default_factory=list) def _load_jsonl(path: Path) -> list[dict[str, Any]]: @@ -101,16 +119,88 @@ def _extract_records(event: dict[str, Any]) -> list[dict[str, Any]]: return [] +def _event_session_id(event: dict[str, Any]) -> str | None: + sid = event.get("session_id") + if sid: + return str(sid) + # task_failed wraps the iteration's session under last_result; task_finished + # always carries session_id at the top level, so this branch is just for the + # one nested case. + nested = event.get("last_result") + if isinstance(nested, dict): + nested_sid = nested.get("session_id") + if nested_sid: + return str(nested_sid) + return None + + +def _extract_add_actions_by_session( + events: Iterable[dict[str, Any]], +) -> dict[str, list[AddSubtaskAction]]: + """Collect successful add_subtask calls, grouped by session_id. + + A call is considered successful when its tool_result event does NOT carry + ``result_error=True``. Calls without a matching result are assumed to have + succeeded (the run may have been cut short before the result was emitted). + """ + # First pass: index tool_result.result_error by tool_call_id so we can + # filter out failed add_subtask calls (e.g. task limit reached). + failed_call_ids: set[str] = set() + for event in events: + if event.get("type") != "tool_result": + continue + if event.get("tool_name") != _ADD_SUBTASK_TOOL_NAME: + continue + if not event.get("result_error"): + continue + call_id = event.get("tool_call_id") + if isinstance(call_id, str): + failed_call_ids.add(call_id) + + by_session: dict[str, list[AddSubtaskAction]] = {} + for event in events: + if event.get("type") != "tool_call": + continue + if event.get("tool_name") != _ADD_SUBTASK_TOOL_NAME: + continue + call_id = event.get("tool_call_id") + if isinstance(call_id, str) and call_id in failed_call_ids: + continue + + args = event.get("arguments") or {} + if not isinstance(args, dict): + continue + title = str(args.get("title") or "").strip() + description = str(args.get("description") or "").strip() + + session_id = event.get("session_id") + bucket_key = str(session_id) if session_id else "" + by_session.setdefault(bucket_key, []).append( + AddSubtaskAction( + title=title, + description=description, + session_id=str(session_id) if session_id else None, + ) + ) + return by_session + + def extract_runs(events: Iterable[dict[str, Any]]) -> list[SubtaskRun]: """Build one SubtaskRun per recorded task attempt with non-empty records. Prefers task_finished/task_failed over iteration_finished for the same (task_name, task_id) so we only render the final, authoritative pool. + Each run is also annotated with the add_subtask tool calls from the same + session, so subtasks added but never executed/decomposed/skipped still + show up in the graph as "new" nodes. """ + # The events iterable may be a generator; materialise so we can scan twice. + events_list = list(events) + by_key: dict[tuple[str, Any], SubtaskRun] = {} priority = {"task_finished": 2, "task_failed": 2, "iteration_finished": 1} - for event in events: + for event in events_list: etype = event.get("type") if etype not in _RECORD_EVENT_TYPES: continue @@ -127,6 +217,7 @@ def extract_runs(events: Iterable[dict[str, Any]]) -> list[SubtaskRun]: task_name=task_name, task_id=task_id, template_key=event.get("template_key"), + session_id=_event_session_id(event), records=records, ) @@ -134,6 +225,11 @@ def extract_runs(events: Iterable[dict[str, Any]]) -> list[SubtaskRun]: if existing is None or priority[etype] >= priority[existing.event_type]: by_key[key] = new_run + add_by_session = _extract_add_actions_by_session(events_list) + for run in by_key.values(): + if run.session_id and run.session_id in add_by_session: + run.add_actions = add_by_session[run.session_id] + return list(by_key.values()) @@ -156,11 +252,17 @@ def _parent_id(task_id: str) -> str | None: return task_id.rsplit(".", 1)[0] -def build_tree(records: list[dict[str, Any]]) -> tuple[dict[str, Node], list[str]]: +def build_tree( + records: list[dict[str, Any]], + add_actions: list[AddSubtaskAction] | None = None, +) -> tuple[dict[str, Node], list[str]]: """Build node table + ordered list of roots from a record pool. Records share their schema with Subtask.model_dump() plus execution - result fields; we only need task_id/title/status here. + result fields; we only need task_id/title/status here. ``add_actions`` + (chronological successful add_subtask tool calls) backfill any root + subtask that was added but never reached a terminal record — those + appear as "new" nodes. """ nodes: dict[str, Node] = {} @@ -178,6 +280,24 @@ def build_tree(records: list[dict[str, Any]]) -> tuple[dict[str, Node], list[str status=status, ) + # Backfill "new" roots from add_subtask calls. Each successful call + # claims max(existing_roots) + 1 (see _next_task_id), so the chronological + # call order maps directly onto sequential root ids starting at 1. + if add_actions: + for idx, action in enumerate(add_actions, start=1): + tid = str(idx) + if tid in nodes: + # Records already capture this root's final state; keep it but + # restore the original title if the record dropped it. + if not nodes[tid].title and action.title: + nodes[tid].title = action.title + continue + nodes[tid] = Node( + task_id=tid, + title=action.title, + status="new", + ) + # Synthesize any missing ancestors as 'unknown' so the tree is connected. for tid in list(nodes): cur = tid @@ -411,7 +531,7 @@ def main() -> None: seen: dict[str, int] = {} for run in runs: - nodes, roots = build_tree(run.records) + nodes, roots = build_tree(run.records, run.add_actions) if not nodes: continue diff --git a/tests/units/contractor_tests/test_visualize_subtasks.py b/tests/units/contractor_tests/test_visualize_subtasks.py index 3022a9d..493eaed 100644 --- a/tests/units/contractor_tests/test_visualize_subtasks.py +++ b/tests/units/contractor_tests/test_visualize_subtasks.py @@ -115,6 +115,105 @@ def test_build_tree_sorts_children_numerically(vs): assert nodes["1"].children == ["1.1", "1.2", "1.10"] +def test_extract_runs_collects_add_subtask_calls_by_session(vs): + events = [ + { + "type": "tool_call", + "tool_name": "add_subtask", + "session_id": "sess-A", + "tool_call_id": "c1", + "arguments": {"title": "first", "description": "do thing 1"}, + }, + { + "type": "tool_call", + "tool_name": "add_subtask", + "session_id": "sess-A", + "tool_call_id": "c2", + "arguments": {"title": "second", "description": "do thing 2"}, + }, + { + "type": "tool_call", + "tool_name": "add_subtask", + "session_id": "sess-other", + "tool_call_id": "c3", + "arguments": {"title": "stray", "description": "unrelated"}, + }, + { + "type": "task_finished", + "task_name": "build", + "task_id": 0, + "session_id": "sess-A", + "records": [{"task_id": "1", "title": "first", "status": "done"}], + }, + ] + runs = vs.extract_runs(events) + assert len(runs) == 1 + assert [a.title for a in runs[0].add_actions] == ["first", "second"] + + +def test_extract_runs_drops_add_subtask_calls_that_errored(vs): + events = [ + { + "type": "tool_call", + "tool_name": "add_subtask", + "session_id": "sess", + "tool_call_id": "c1", + "arguments": {"title": "ok", "description": "ok"}, + }, + { + "type": "tool_call", + "tool_name": "add_subtask", + "session_id": "sess", + "tool_call_id": "c2", + "arguments": {"title": "limit-hit", "description": "boom"}, + }, + { + "type": "tool_result", + "tool_name": "add_subtask", + "session_id": "sess", + "tool_call_id": "c2", + "result_error": True, + }, + { + "type": "task_finished", + "task_name": "x", + "task_id": 0, + "session_id": "sess", + "records": [{"task_id": "1", "title": "ok", "status": "done"}], + }, + ] + runs = vs.extract_runs(events) + assert [a.title for a in runs[0].add_actions] == ["ok"] + + +def test_build_tree_backfills_new_subtask_from_add_action(vs): + records = [{"task_id": "1", "title": "executed", "status": "done"}] + adds = [ + vs.AddSubtaskAction(title="executed", description="", session_id="s"), + vs.AddSubtaskAction(title="never ran", description="", session_id="s"), + ] + nodes, roots = vs.build_tree(records, adds) + assert roots == ["1", "2"] + assert nodes["2"].status == "new" + assert nodes["2"].title == "never ran" + # Existing record's status must not be overwritten. + assert nodes["1"].status == "done" + + +def test_build_tree_ignores_add_actions_when_record_count_matches(vs): + records = [ + {"task_id": "1", "title": "a", "status": "done"}, + {"task_id": "2", "title": "b", "status": "done"}, + ] + adds = [ + vs.AddSubtaskAction(title="a", description="", session_id="s"), + vs.AddSubtaskAction(title="b", description="", session_id="s"), + ] + nodes, roots = vs.build_tree(records, adds) + assert roots == ["1", "2"] + assert all(n.status == "done" for n in nodes.values()) + + def test_render_tree_writes_png(vs, tmp_path: Path): records = [ {"task_id": "1", "title": "do thing", "status": "done"},