Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 75 additions & 14 deletions .map/scripts/map_step_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,29 @@ def _int(key: str) -> int:
}


def _coerce_token_int(value: object) -> int:
"""Best-effort int from a token field that may be int / float / str / None."""
if isinstance(value, bool):
return 0
if isinstance(value, (int, float)):
return int(value)
if isinstance(value, str):
try:
return int(value)
except ValueError:
return 0
return 0


def _usage_token_total(usage: Mapping[str, object]) -> int:
"""Sum of the four token fields for one usage record.

Used to pick the most complete copy of a turn when the transcript repeats a
msg_id with diverging usage (a streaming partial vs the final line).
"""
return sum(_coerce_token_int(usage.get(field, 0)) for field in _TOKEN_FIELDS)


def _iter_new_usage(
transcript_path: Path, seen_ids: set[str], start_offset: int = 0
) -> tuple[list[dict[str, object]], int]:
Expand All @@ -601,10 +624,17 @@ def _iter_new_usage(
JSONL) so a repeatedly-firing Stop/SubagentStop hook does not re-parse the
whole multi-MB file each turn. Returns ``(usages, new_offset)`` where
``new_offset`` advances only past the last COMPLETE line — a partial line
from a concurrent append is left for the next call. ``msg_id`` dedup against
``seen_ids`` is kept as a safety net (e.g. if the file is rotated and the
offset resets). Entries with an empty msg_id or malformed JSON are skipped;
a missing/unreadable transcript returns ``([], start_offset)``.
from a concurrent append is left for the next call.

A single assistant turn is written to the transcript as SEVERAL JSONL lines
(one per content / tool_use block) that all share the same ``message.id``
and the same cumulative ``usage``. Results are deduped by msg_id WITHIN this
read window — keeping the copy with the most total tokens — so a turn's
usage is logged exactly once; without it est_cost roughly doubles. The
persisted ``seen_ids`` is the cross-call safety net (e.g. if the file is
rotated and the offset resets, or a turn straddles two windows). Entries
with an empty msg_id or malformed JSON are skipped; a missing/unreadable
transcript returns ``([], start_offset)``.
"""
path = Path(transcript_path)
try:
Expand All @@ -629,7 +659,8 @@ def _iter_new_usage(
complete = chunk[: last_newline + 1]
new_offset = offset + len(complete)

out: list[dict[str, object]] = []
by_mid: dict[str, dict[str, object]] = {}
order: list[str] = []
for raw in complete.decode("utf-8", errors="replace").splitlines():
raw = raw.strip()
if not raw:
Expand All @@ -644,8 +675,14 @@ def _iter_new_usage(
mid = str(usage["msg_id"])
if not mid or mid in seen_ids:
continue
out.append(usage)
return out, new_offset
prev = by_mid.get(mid)
if prev is None:
order.append(mid)
by_mid[mid] = usage
elif _usage_token_total(usage) > _usage_token_total(prev):
# Same turn repeated in this window — keep the most complete copy.
by_mid[mid] = usage
return [by_mid[mid] for mid in order], new_offset


def _token_meter_cache_path(branch_name: str) -> Path:
Expand Down Expand Up @@ -789,7 +826,11 @@ def _rebuild_token_accounting(branch: Optional[str] = None) -> dict[str, object]

Groups by subtask, agent, and phase, plus an aggregate carrying
``cache_hit_ratio`` (cache_read / (input + cache_read)) and
``est_cost_usd``. Returns the written payload.
``est_cost_usd``. Rows are deduped by msg_id (keeping the most complete
copy) before rollup, so a log written by an older runner — one assistant
turn split across several rows — still produces a correct total instead of
a doubled one. ``event_count`` is therefore the number of distinct turns.
Returns the written payload.
"""
branch_name = _sanitize_branch(branch) if branch else get_branch_name()
log_path = get_branch_dir(branch_name) / TOKEN_LOG_NAME
Expand All @@ -805,6 +846,14 @@ def _rebuild_token_accounting(branch: Optional[str] = None) -> dict[str, object]
lines = log_path.read_text(encoding="utf-8").splitlines()
except (OSError, UnicodeDecodeError):
lines = []
# One assistant turn can occupy several token_log rows (Claude Code
# writes one JSONL line per content/tool_use block, all sharing a
# msg_id). Logs written before the write-time dedup landed still hold
# those repeats, so collapse by msg_id here too — keep the row with the
# most total tokens (the figure the API bills) — and stay correct.
deduped: dict[str, dict[str, object]] = {}
order: list[str] = []
anon = 0
for raw in lines:
raw = raw.strip()
if not raw:
Expand All @@ -815,14 +864,26 @@ def _rebuild_token_accounting(branch: Optional[str] = None) -> dict[str, object]
continue
if not isinstance(row, dict):
continue
mid = str(row.get("msg_id") or "")
if not mid:
key = f"__anon_{anon}"
anon += 1
else:
key = mid
prev = deduped.get(key)
if prev is None:
order.append(key)
deduped[key] = row
elif _usage_token_total(row) > _usage_token_total(prev):
deduped[key] = row

for key in order:
row = deduped[key]
event_count += 1
model = str(row.get("model") or "")
usage: dict[str, int] = {}
for field in _TOKEN_FIELDS:
try:
usage[field] = int(row.get(field, 0) or 0)
except (TypeError, ValueError):
usage[field] = 0
usage: dict[str, int] = {
field: _coerce_token_int(row.get(field, 0)) for field in _TOKEN_FIELDS
}
row_cost = _token_cost(usage, model)
total_cost += row_cost
for dim_key, dim in (
Expand Down
Loading
Loading