Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
9ab6e83
perf(metrics): batch tokenization with defer-to-flush drain
viraatc Jun 9, 2026
b40f72f
fix(metrics): never skip finalize on tokenizer drain failure; cleanup
viraatc Jun 10, 2026
82a12bc
fix(metrics): tokenizer stage uses the whole machine; restore --token…
viraatc Jun 10, 2026
8033c47
chore(metrics): use pass bodies in TokenCounter protocol stubs
viraatc Jun 10, 2026
47c4f35
docs(metrics): add metrics-aggregator design doc; refresh services ov…
viraatc Jun 10, 2026
1315a73
fix(metrics): publish live snapshots through tokenizer failures; boun…
viraatc Jun 10, 2026
6d227bf
feat(metrics): no silent tokenizer fallbacks — shard or exit cleanly
viraatc Jun 10, 2026
0cca84a
fix(metrics): shard unpinned on platforms without CPU affinity (macOS)
viraatc Jun 10, 2026
443a923
refactor(metrics): queue-owned live flush lane; drop the pre_publish …
viraatc Jun 10, 2026
aed6b78
fix(metrics): bound live flushes; align defaults; audit-driven hardening
viraatc Jun 10, 2026
700423e
fix(metrics): call-shaped awaits for cancelled tasks; pin aggregator-…
viraatc Jun 10, 2026
9640bd7
fix(metrics): requeue messages on live-cancel; shrink the tokenizer API
viraatc Jun 10, 2026
a1b9386
chore(metrics): public read-only wiring surface on the aggregator
viraatc Jun 11, 2026
34e0c48
chore(metrics): drain-timeout default back to 60s (review feedback)
viraatc Jun 11, 2026
033d724
chore(metrics): drop dead local in the live-cancel handler
viraatc Jun 11, 2026
6b70433
refactor(metrics): single-source service defaults in schema; tighten …
viraatc Jun 11, 2026
6361768
chore(metrics): restore original comments; keep only default-related …
viraatc Jun 11, 2026
8321d83
fix(metrics): raise metrics drain-timeout default to 300s
viraatc Jun 11, 2026
f1ac948
test(metrics): pass now-required aggregator args in signal-handling test
viraatc Jun 16, 2026
70b39d9
chore(deps): bump aiohttp 3.14.0 -> 3.14.1 (fixes 8 CVEs)
viraatc Jun 16, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ The aggregator is a separate process (`python -m inference_endpoint.async_utils.

- **Series storage**: each `SeriesSampler` keeps three parallel views: O(1) cheap rollups (count/total/min/max/sum_sq, exact), an HDR Histogram (cheap live percentiles), and an in-memory `array.array` of raw values (for exact percentiles in the `COMPLETE` snapshot). Hot path is `registry.record(name, value)` — no allocation, no I/O.
- **Counter API**: `registry.increment(name, delta=1)` for sample-event counters. `registry.set_counter(name, value)` only for the two duration counters (`total_duration_ns` max-of-elapsed, `tracked_duration_ns` sum-of-blocks).
- **Lifecycle**: `INITIALIZE` (constructed, awaiting first `STARTED`) → `LIVE` (run in progress, ticking every `--publish-interval` seconds) → `DRAINING` (set on `ENDED`; tick continues; bounded by `--drain-timeout` budget, default 60 s) → terminal: `COMPLETE` (clean end via `publish_final`, exact stats) **or** `INTERRUPTED` (signal-handler-triggered final via SIGTERM/SIGINT; best-effort partial stats). Drain timeout detected by consumers as `state == COMPLETE and n_pending_tasks > 0`; interrupted runs are detected as `state == INTERRUPTED` directly.
- **Lifecycle**: `INITIALIZE` (constructed, awaiting first `STARTED`) → `LIVE` (run in progress, ticking every `--publish-interval` seconds) → `DRAINING` (set on `ENDED`; tick continues; bounded by the `--drain-timeout` budget — schema default 300 s) → terminal: `COMPLETE` (clean end via `publish_final`, exact stats) **or** `INTERRUPTED` (signal-handler-triggered final via SIGTERM/SIGINT; best-effort partial stats). Drain timeout detected by consumers as `state == COMPLETE and n_pending_tasks > 0`; interrupted runs are detected as `state == INTERRUPTED` directly.
- **Final delivery is dual-path with separated concerns**: `publish_final` atomically writes `final_snapshot.json` (`tmp + fsync(file) + rename + fsync(parent_dir)`) — this is the **primary** Report source — AND emits the terminal-state snapshot over pub/sub as a TUI shutdown signal. Each path is wrapped in its own try/except so one failure cannot suppress the other. Main process consumer reads `final_snapshot.json` (via `json.loads` to dict, no Struct decode); falls back to the subscriber's `latest` live snapshot only if the file is missing (e.g. SIGKILL / OOM before the signal handler ran). The dict form is the canonical consumer contract (see `snapshot_to_dict`).
- **Histogram bucket edges are dynamic per snapshot**: log-spaced over the observed `[min, max]`. Bucket count is fixed at construction; consumers MUST re-render from the snapshot's `(lo, hi, count)` triples each frame and MUST NOT track bucket-by-index across snapshots.

Expand Down Expand Up @@ -204,7 +204,7 @@ src/inference_endpoint/
│ │ ├── publisher.py # MetricsPublisher (tick task + atomic disk fallback)
│ │ ├── subscriber.py # MetricsSnapshotSubscriber (latest + COMPLETE snapshot capture)
│ │ ├── metrics_table.py # In-flight sample rows + trigger dispatch (TTFT/TPOT/ISL/OSL)
│ │ └── token_metrics.py # TokenizePool (HF tokenizer thread pool for ISL/OSL/TPOT)
│ │ └── token_metrics.py # BatchTokenizer (live thread lane + drain-only sharded pool) + TokenBatchQueue (defer-to-flush buffer, owns the live flush loop) for ISL/OSL/TPOT
│ └── transport/ # ZMQ-based IPC transport layer
│ ├── protocol.py # Transport protocols + TransportConfig + MessageCodec[T]
│ └── zmq/ # ZMQ implementation (context, pubsub, transport, ZMQTransportConfig)
Expand Down
6 changes: 3 additions & 3 deletions docs/async_utils/services/DESIGN.md
Original file line number Diff line number Diff line change
Expand Up @@ -306,9 +306,9 @@ stateDiagram-v2

### 6.2 Metrics aggregator

- **Role**: Subscribes to EventRecords and derives real-time metrics (e.g. TTFT, sample latency, token counts). May use a tokenizer pool for token-based metrics. Shuts down on **session.ended**.
- **Outputs**: Planned is to push real time metrics to Prometheus via PushGateway. Currently, logging / writing final report to JSON is sufficient legacy behavior.
- **Process**: Run as a **subprocess**; given `--metrics-dir`, `--socket-dir`, `--socket-name`, and optional tokenizer options. Uses a dedicated event loop and `ManagedZMQContext.scoped(socket_dir=...)` so it can connect to the publisher's IPC address.
- **Role**: Subscribes to EventRecords and derives real-time metrics (e.g. TTFT, sample latency, token counts). Token metrics (ISL/OSL/TPOT) are computed by a batched tokenizer (in-process threads live; process-sharded end-of-run drain) — see [metrics_aggregator/DESIGN.md](metrics_aggregator/DESIGN.md). Shuts down on **session.ended**.
- **Outputs**: Live `MetricsSnapshot` frames over an IPC PUB socket, and an atomically written `final_snapshot.json` (the primary Report source). Planned is to push real time metrics to Prometheus via PushGateway.
- **Process**: Run as a **subprocess**; given `--metrics-output-dir`, `--socket-dir`, `--socket-name`, `--metrics-socket`, and optional tokenizer options. Uses a dedicated event loop and `ManagedZMQContext.scoped(socket_dir=...)` so it can connect to the publisher's IPC address.

---

Expand Down
119 changes: 119 additions & 0 deletions docs/async_utils/services/metrics_aggregator/DESIGN.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Metrics Aggregator Service — Design

The metrics aggregator is a subprocess (`python -m
inference_endpoint.async_utils.services.metrics_aggregator`) that subscribes
to the EventRecord stream, folds per-sample events into a `MetricsRegistry`,
and publishes `MetricsSnapshot` frames over IPC PUB at a fixed cadence. At
end-of-run it atomically writes `final_snapshot.json` — the **primary** source
for `Report`; the terminal pub/sub frame is only a TUI "run finished" signal.

## Lifecycle

```
INITIALIZE ──STARTED──► LIVE ──ENDED──► DRAINING ──► COMPLETE
└──► INTERRUPTED (SIGTERM)
```

The ENDED path runs inside a finalization boundary: whatever the drain does —
finish, time out, or fail — `publish_final` and the shutdown signal always
run. A tokenizer failure can degrade the snapshot (see the `n_pending_tasks`
contract) but can never hang the subprocess. SIGTERM writes a best-effort
partial snapshot tagged `INTERRUPTED`.

## Token metrics pipeline

ISL/OSL/TPOT require tokenizer passes per completed sample; at high completion
rates a per-event dispatch model accumulates an unbounded backlog. The
pipeline batches instead: **defer-to-flush** + **process-sharded encoding**.

### Defer-to-flush (`TokenBatchQueue`)

Triggers do no work at event time — `fire()` appends `(text, on_count)` to a
buffer, O(1), no tasks. The buffer is cleared at two points:

1. **Live loop** — `start_live(interval)` flushes periodically through the
tokenizer's in-process lane: `--tokenizer-workers` threads, rayon capped
to the same width, at most `_LIVE_FLUSH_MAX_ITEMS` per flush. Never
touches the shard processes. `0` disables mid-run tokenization. Failed or
cancelled live items are **re-queued** — the drain retries them.
2. **End-of-run** — `flush_remaining(timeout)` stops the live loop and drains
everything left through every shard, bounded by the drain budget.

`flush()` serializes under an asyncio lock and detaches the buffer up front.
The text and chat-template phases fail independently; a raising recorder is
logged without aborting the batch. Drain failures are terminal — items stay
counted in `pending`. `flush_remaining` never raises.

### Sharded batch encoding (`BatchTokenizer`)

The drain fans the whole buffer out across worker **processes**, one pinned
per `CORES_PER_WORKER` (8) core block. Each worker runs the raw `tokenizers`
backend's `encode_batch_fast` (Rust, rayon); a single BPE rayon pool
saturates ~8 cores, so disjoint pinned blocks are how the whole machine is
used. Workers are spawn-context, warmed in parallel at construction (bounded
— a hung load is a startup error), and ignore SIGINT.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we ignore SIGINT, there is no way to terminate/interrupt?


The shard pool has no knob: it auto-sizes to one shard per 8-core block of
the allowed CPU universe. There is no fallback — no fast Rust backend, or a
failed/over-budget warmup, is a startup error, because an in-process slow
path cannot keep up and would surface much later as an incomplete drain.
Platforms without an affinity API (macOS) shard unpinned; each worker caps
its rayon pool to the block size instead.

Chat-template items (tool calls) run on the in-process thread lane —
`apply_chat_template` is Python/Jinja; sharding buys nothing.

### CPU affinity: tokenize is post-run

The parent pins itself to the loadgen cores and children inherit that narrow
mask. `_setup_shards` probes the full allowed universe via
`expand_to_all_online_cpus()` (cgroup/Slurm-clamped) for the block math,
**then restores the inherited mask** — the aggregator stays where the parent
placed it; only the drain-phase shard children span the machine, and they
are idle until `ENDED`.

### The `n_pending_tasks` contract

`TokenBatchQueue.pending` (enqueued-but-not-recorded) is surfaced on every
snapshot as `n_pending_tasks`. In the final snapshot:

- `state == complete && n_pending_tasks == 0` — clean run, exact series.
- `state == complete && n_pending_tasks > 0` — **incomplete drain** (budget
exhausted or tokenizer failed); `Report` renders a warning. Failed items
are deliberately not removed from the count — under-reporting would
rebadge an incomplete drain as clean.

### Data flow

```
COMPLETE event ─► trigger.fire ─► queue.enqueue(text, on_count) [O(1)]
live loop (publish cadence) ─ flush(live) ─► in-process threads (rayon-capped)
ENDED drain (budgeted) ────── flush() ─────► chunks ─► N pinned worker procs
└─► on_count(n) ─► registry.record()
```

## CLI

| Flag | Default | Purpose |
| -------------------------------- | ----------------- | --------------------------------------------------- |
| `--socket-dir` / `--socket-name` | required | EventRecord SUB socket |
| `--metrics-socket` | required | Snapshot PUB socket name |
| `--metrics-output-dir` | required | Directory for `final_snapshot.json` |
| `--publish-interval` | 0.25 | Live snapshot cadence (seconds) |
| `--drain-timeout` | required (schema) | End-of-run tokenize budget (`0` = unlimited) |
| `--tokenizer` | none | HF name or local path; unset disables token metrics |
| `--tokenizer-workers` | required (schema) | Live in-process threads (`0` = defer all to drain) |
| `--streaming` | off | Register TTFT/chunk-delta/TPOT triggers |

`--drain-timeout` and `--tokenizer-workers` carry no service-side defaults:
the benchmark always forwards them from `config/schema.py`
(`--metrics-drain-timeout`, `--metrics-tokenizer-workers`), the single source
of truth for their values.

## References

- [docs/async_utils/services/DESIGN.md](../DESIGN.md) — the EventRecord
pub/sub system this service subscribes to.
- [docs/PERF_ARCHITECTURE.md](../../../PERF_ARCHITECTURE.md) — CPU pinning
for the loadgen/worker hot path.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ test = [
"Pympler==1.1",
"scipy==1.17.1",
# HTTP server and client for mock server fixture
"aiohttp==3.14.0",
"aiohttp==3.14.1",
# Plotting for benchmark sweep mode
"matplotlib==3.10.8",
# Property-based testing (CLI fuzz)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from .publisher import MetricsPublisher
from .registry import MetricsRegistry
from .snapshot import MetricsSnapshotCodec
from .token_metrics import TokenizePool
from .token_metrics import BatchTokenizer, TokenBatchQueue

logger = logging.getLogger(__name__)

Expand All @@ -44,6 +44,7 @@ def _make_sigterm_handler(
registry: MetricsRegistry,
publisher: MetricsPublisher,
table: MetricsTable,
token_queue: TokenBatchQueue | None,
shutdown_event: asyncio.Event,
) -> tuple[Callable[[], None], set[asyncio.Task]]:
"""Build the SIGTERM handler that writes the INTERRUPTED final snapshot.
Expand Down Expand Up @@ -75,7 +76,7 @@ async def _signal_finalize() -> None:
)
await publisher.publish_final(
registry,
n_pending_tasks=table.in_flight_tasks_count,
n_pending_tasks=token_queue.pending if token_queue is not None else 0,
interrupted=True,
)
except Exception: # noqa: BLE001 — best-effort.
Expand Down Expand Up @@ -132,13 +133,13 @@ async def main() -> None:
parser.add_argument(
"--drain-timeout",
type=float,
default=60.0,
required=True,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be required? What about a default of 0 which is what most use-cases would expect. Enforcing a limit can unnecessarily cause runs to fail for hitting the timeout?

help=(
"Wall-clock budget (seconds) to wait for in-flight async tokenize "
"tasks to finish after ENDED before the aggregator cancels them "
"and emits the final snapshot with n_pending_tasks > 0 "
"(default: 60.0; 0 = wait indefinitely). Increase for long-context "
"/ low-worker-count tokenize workloads."
"Wall-clock budget (seconds) to finish tokenizing buffered samples "
"after ENDED before the aggregator emits the final snapshot with "
"n_pending_tasks > 0 (0 = wait indefinitely; the benchmark forwards "
"the schema default, see config/schema.py). Increase for very large "
"datasets where the end-of-run tokenize batch is big."
),
)
parser.add_argument(
Expand All @@ -162,8 +163,14 @@ async def main() -> None:
parser.add_argument(
"--tokenizer-workers",
type=int,
default=2,
help="Number of tokenizer worker threads (default: 2)",
required=True,
help=(
Comment on lines +166 to +167

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a good default would be nice here (0?) - we might be overloading the user with too many required flags/options which are really tuning knobs.

"In-process tokenizer threads for live (mid-run) ISL/OSL/TPOT "
"(0 = no mid-run tokenization, everything defers to the "
"end-of-run drain; the benchmark forwards the schema default, "
"see config/schema.py). The drain always uses the auto-sized "
"sharded pool — one worker process per 8-core block."
),
)
parser.add_argument(
"--streaming",
Expand All @@ -186,6 +193,9 @@ async def main() -> None:
args = parser.parse_args()
setup_logging(level="INFO")

if args.tokenizer_workers < 0:
raise SystemExit("FATAL: --tokenizer-workers must be >= 0")

# The parent owns directory setup — `commands/benchmark/execute.py`
# creates `<report_dir>/metrics/` and validates it before launching
# this subprocess. Validate here as a fail-fast contract check so a
Expand All @@ -204,15 +214,23 @@ async def main() -> None:
loop = LoopManager().default_loop

# Using ternary operator causes errors in MyPy object type coalescing
# (coalesces to 'object' not 'AbstractContextManager[TokenizePool | None]')
pool_cm: AbstractContextManager[TokenizePool | None]
# (coalesces to 'object' not 'AbstractContextManager[BatchTokenizer | None]')
tokenizer_cm: AbstractContextManager[BatchTokenizer | None]
if args.tokenizer:
pool_cm = TokenizePool(args.tokenizer, n_workers=args.tokenizer_workers)
try:
tokenizer_cm = BatchTokenizer(
args.tokenizer, live_workers=args.tokenizer_workers
)
except RuntimeError as exc:
# Fail-fast contract: a tokenizer environment that cannot shard
# must surface as a clear service-launch failure, not a silent
# slow path that cannot keep up with completions.
raise SystemExit(f"FATAL: {exc}") from exc
else:
pool_cm = nullcontext()
tokenizer_cm = nullcontext()

with (
pool_cm as pool,
tokenizer_cm as tokenizer,
ManagedZMQContext.scoped(socket_dir=args.socket_dir) as zmq_ctx,
):
registry = MetricsRegistry()
Expand All @@ -234,7 +252,10 @@ async def main() -> None:
publish_interval_s=args.publish_interval,
sig_figs=args.hdr_sig_figs,
n_histogram_buckets=args.n_histogram_buckets,
tokenize_pool=pool,
tokenizer=tokenizer,
live_flush_interval_s=(
args.publish_interval if args.tokenizer_workers > 0 else None
),
streaming=args.streaming,
shutdown_event=shutdown_event,
drain_timeout_s=None if args.drain_timeout == 0 else args.drain_timeout,
Expand Down Expand Up @@ -268,7 +289,8 @@ async def main() -> None:
loop=loop,
registry=registry,
publisher=publisher,
table=aggregator._table,
table=aggregator.table,
token_queue=aggregator.token_queue,
shutdown_event=shutdown_event,
)
loop.add_signal_handler(signal.SIGTERM, on_sigterm)
Expand Down
Loading
Loading