-
Notifications
You must be signed in to change notification settings - Fork 22
perf(metrics): batch tokenization with defer-to-flush drain #350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
9ab6e83
b40f72f
82a12bc
8033c47
47c4f35
1315a73
6d227bf
0cca84a
443a923
aed6b78
700423e
9640bd7
a1b9386
34e0c48
033d724
6b70433
6361768
8321d83
f1ac948
70b39d9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,119 @@ | ||
| # Metrics Aggregator Service — Design | ||
|
|
||
| The metrics aggregator is a subprocess (`python -m | ||
| inference_endpoint.async_utils.services.metrics_aggregator`) that subscribes | ||
| to the EventRecord stream, folds per-sample events into a `MetricsRegistry`, | ||
| and publishes `MetricsSnapshot` frames over IPC PUB at a fixed cadence. At | ||
| end-of-run it atomically writes `final_snapshot.json` — the **primary** source | ||
| for `Report`; the terminal pub/sub frame is only a TUI "run finished" signal. | ||
|
|
||
| ## Lifecycle | ||
|
|
||
| ``` | ||
| INITIALIZE ──STARTED──► LIVE ──ENDED──► DRAINING ──► COMPLETE | ||
| └──► INTERRUPTED (SIGTERM) | ||
| ``` | ||
|
|
||
| The ENDED path runs inside a finalization boundary: whatever the drain does — | ||
| finish, time out, or fail — `publish_final` and the shutdown signal always | ||
| run. A tokenizer failure can degrade the snapshot (see the `n_pending_tasks` | ||
| contract) but can never hang the subprocess. SIGTERM writes a best-effort | ||
| partial snapshot tagged `INTERRUPTED`. | ||
|
|
||
| ## Token metrics pipeline | ||
|
|
||
| ISL/OSL/TPOT require tokenizer passes per completed sample; at high completion | ||
| rates a per-event dispatch model accumulates an unbounded backlog. The | ||
| pipeline batches instead: **defer-to-flush** + **process-sharded encoding**. | ||
|
|
||
| ### Defer-to-flush (`TokenBatchQueue`) | ||
|
|
||
| Triggers do no work at event time — `fire()` appends `(text, on_count)` to a | ||
| buffer, O(1), no tasks. The buffer is cleared at two points: | ||
|
|
||
| 1. **Live loop** — `start_live(interval)` flushes periodically through the | ||
| tokenizer's in-process lane: `--tokenizer-workers` threads, rayon capped | ||
| to the same width, at most `_LIVE_FLUSH_MAX_ITEMS` per flush. Never | ||
| touches the shard processes. `0` disables mid-run tokenization. Failed or | ||
| cancelled live items are **re-queued** — the drain retries them. | ||
| 2. **End-of-run** — `flush_remaining(timeout)` stops the live loop and drains | ||
| everything left through every shard, bounded by the drain budget. | ||
|
|
||
| `flush()` serializes under an asyncio lock and detaches the buffer up front. | ||
| The text and chat-template phases fail independently; a raising recorder is | ||
| logged without aborting the batch. Drain failures are terminal — items stay | ||
| counted in `pending`. `flush_remaining` never raises. | ||
|
|
||
| ### Sharded batch encoding (`BatchTokenizer`) | ||
|
|
||
| The drain fans the whole buffer out across worker **processes**, one pinned | ||
| per `CORES_PER_WORKER` (8) core block. Each worker runs the raw `tokenizers` | ||
| backend's `encode_batch_fast` (Rust, rayon); a single BPE rayon pool | ||
| saturates ~8 cores, so disjoint pinned blocks are how the whole machine is | ||
| used. Workers are spawn-context, warmed in parallel at construction (bounded | ||
| — a hung load is a startup error), and ignore SIGINT. | ||
|
|
||
| The shard pool has no knob: it auto-sizes to one shard per 8-core block of | ||
| the allowed CPU universe. There is no fallback — no fast Rust backend, or a | ||
| failed/over-budget warmup, is a startup error, because an in-process slow | ||
| path cannot keep up and would surface much later as an incomplete drain. | ||
| Platforms without an affinity API (macOS) shard unpinned; each worker caps | ||
| its rayon pool to the block size instead. | ||
|
|
||
| Chat-template items (tool calls) run on the in-process thread lane — | ||
| `apply_chat_template` is Python/Jinja; sharding buys nothing. | ||
|
|
||
| ### CPU affinity: tokenize is post-run | ||
|
|
||
| The parent pins itself to the loadgen cores and children inherit that narrow | ||
| mask. `_setup_shards` probes the full allowed universe via | ||
| `expand_to_all_online_cpus()` (cgroup/Slurm-clamped) for the block math, | ||
| **then restores the inherited mask** — the aggregator stays where the parent | ||
| placed it; only the drain-phase shard children span the machine, and they | ||
| are idle until `ENDED`. | ||
|
|
||
| ### The `n_pending_tasks` contract | ||
|
|
||
| `TokenBatchQueue.pending` (enqueued-but-not-recorded) is surfaced on every | ||
| snapshot as `n_pending_tasks`. In the final snapshot: | ||
|
|
||
| - `state == complete && n_pending_tasks == 0` — clean run, exact series. | ||
| - `state == complete && n_pending_tasks > 0` — **incomplete drain** (budget | ||
| exhausted or tokenizer failed); `Report` renders a warning. Failed items | ||
| are deliberately not removed from the count — under-reporting would | ||
| rebadge an incomplete drain as clean. | ||
|
|
||
| ### Data flow | ||
|
|
||
| ``` | ||
| COMPLETE event ─► trigger.fire ─► queue.enqueue(text, on_count) [O(1)] | ||
| │ | ||
| live loop (publish cadence) ─ flush(live) ─► in-process threads (rayon-capped) | ||
| ENDED drain (budgeted) ────── flush() ─────► chunks ─► N pinned worker procs | ||
| └─► on_count(n) ─► registry.record() | ||
| ``` | ||
|
|
||
| ## CLI | ||
|
|
||
| | Flag | Default | Purpose | | ||
| | -------------------------------- | ----------------- | --------------------------------------------------- | | ||
| | `--socket-dir` / `--socket-name` | required | EventRecord SUB socket | | ||
| | `--metrics-socket` | required | Snapshot PUB socket name | | ||
| | `--metrics-output-dir` | required | Directory for `final_snapshot.json` | | ||
| | `--publish-interval` | 0.25 | Live snapshot cadence (seconds) | | ||
| | `--drain-timeout` | required (schema) | End-of-run tokenize budget (`0` = unlimited) | | ||
| | `--tokenizer` | none | HF name or local path; unset disables token metrics | | ||
| | `--tokenizer-workers` | required (schema) | Live in-process threads (`0` = defer all to drain) | | ||
| | `--streaming` | off | Register TTFT/chunk-delta/TPOT triggers | | ||
|
|
||
| `--drain-timeout` and `--tokenizer-workers` carry no service-side defaults: | ||
| the benchmark always forwards them from `config/schema.py` | ||
| (`--metrics-drain-timeout`, `--metrics-tokenizer-workers`), the single source | ||
| of truth for their values. | ||
|
|
||
| ## References | ||
|
|
||
| - [docs/async_utils/services/DESIGN.md](../DESIGN.md) — the EventRecord | ||
| pub/sub system this service subscribes to. | ||
| - [docs/PERF_ARCHITECTURE.md](../../../PERF_ARCHITECTURE.md) — CPU pinning | ||
| for the loadgen/worker hot path. | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -33,7 +33,7 @@ | |
| from .publisher import MetricsPublisher | ||
| from .registry import MetricsRegistry | ||
| from .snapshot import MetricsSnapshotCodec | ||
| from .token_metrics import TokenizePool | ||
| from .token_metrics import BatchTokenizer, TokenBatchQueue | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
@@ -44,6 +44,7 @@ def _make_sigterm_handler( | |
| registry: MetricsRegistry, | ||
| publisher: MetricsPublisher, | ||
| table: MetricsTable, | ||
| token_queue: TokenBatchQueue | None, | ||
| shutdown_event: asyncio.Event, | ||
| ) -> tuple[Callable[[], None], set[asyncio.Task]]: | ||
| """Build the SIGTERM handler that writes the INTERRUPTED final snapshot. | ||
|
|
@@ -75,7 +76,7 @@ async def _signal_finalize() -> None: | |
| ) | ||
| await publisher.publish_final( | ||
| registry, | ||
| n_pending_tasks=table.in_flight_tasks_count, | ||
| n_pending_tasks=token_queue.pending if token_queue is not None else 0, | ||
| interrupted=True, | ||
| ) | ||
| except Exception: # noqa: BLE001 — best-effort. | ||
|
|
@@ -132,13 +133,13 @@ async def main() -> None: | |
| parser.add_argument( | ||
| "--drain-timeout", | ||
| type=float, | ||
| default=60.0, | ||
| required=True, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be required? What about a default of 0 which is what most use-cases would expect. Enforcing a limit can unnecessarily cause runs to fail for hitting the timeout? |
||
| help=( | ||
| "Wall-clock budget (seconds) to wait for in-flight async tokenize " | ||
| "tasks to finish after ENDED before the aggregator cancels them " | ||
| "and emits the final snapshot with n_pending_tasks > 0 " | ||
| "(default: 60.0; 0 = wait indefinitely). Increase for long-context " | ||
| "/ low-worker-count tokenize workloads." | ||
| "Wall-clock budget (seconds) to finish tokenizing buffered samples " | ||
| "after ENDED before the aggregator emits the final snapshot with " | ||
| "n_pending_tasks > 0 (0 = wait indefinitely; the benchmark forwards " | ||
| "the schema default, see config/schema.py). Increase for very large " | ||
| "datasets where the end-of-run tokenize batch is big." | ||
| ), | ||
| ) | ||
| parser.add_argument( | ||
|
|
@@ -162,8 +163,14 @@ async def main() -> None: | |
| parser.add_argument( | ||
| "--tokenizer-workers", | ||
| type=int, | ||
| default=2, | ||
| help="Number of tokenizer worker threads (default: 2)", | ||
| required=True, | ||
| help=( | ||
|
Comment on lines
+166
to
+167
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe a good default would be nice here (0?) - we might be overloading the user with too many required flags/options which are really tuning knobs. |
||
| "In-process tokenizer threads for live (mid-run) ISL/OSL/TPOT " | ||
| "(0 = no mid-run tokenization, everything defers to the " | ||
| "end-of-run drain; the benchmark forwards the schema default, " | ||
| "see config/schema.py). The drain always uses the auto-sized " | ||
| "sharded pool — one worker process per 8-core block." | ||
| ), | ||
| ) | ||
| parser.add_argument( | ||
| "--streaming", | ||
|
|
@@ -186,6 +193,9 @@ async def main() -> None: | |
| args = parser.parse_args() | ||
| setup_logging(level="INFO") | ||
|
|
||
| if args.tokenizer_workers < 0: | ||
| raise SystemExit("FATAL: --tokenizer-workers must be >= 0") | ||
|
|
||
| # The parent owns directory setup — `commands/benchmark/execute.py` | ||
| # creates `<report_dir>/metrics/` and validates it before launching | ||
| # this subprocess. Validate here as a fail-fast contract check so a | ||
|
|
@@ -204,15 +214,23 @@ async def main() -> None: | |
| loop = LoopManager().default_loop | ||
|
|
||
| # Using ternary operator causes errors in MyPy object type coalescing | ||
| # (coalesces to 'object' not 'AbstractContextManager[TokenizePool | None]') | ||
| pool_cm: AbstractContextManager[TokenizePool | None] | ||
| # (coalesces to 'object' not 'AbstractContextManager[BatchTokenizer | None]') | ||
| tokenizer_cm: AbstractContextManager[BatchTokenizer | None] | ||
| if args.tokenizer: | ||
| pool_cm = TokenizePool(args.tokenizer, n_workers=args.tokenizer_workers) | ||
| try: | ||
| tokenizer_cm = BatchTokenizer( | ||
| args.tokenizer, live_workers=args.tokenizer_workers | ||
| ) | ||
| except RuntimeError as exc: | ||
| # Fail-fast contract: a tokenizer environment that cannot shard | ||
| # must surface as a clear service-launch failure, not a silent | ||
| # slow path that cannot keep up with completions. | ||
| raise SystemExit(f"FATAL: {exc}") from exc | ||
| else: | ||
| pool_cm = nullcontext() | ||
| tokenizer_cm = nullcontext() | ||
|
|
||
| with ( | ||
| pool_cm as pool, | ||
| tokenizer_cm as tokenizer, | ||
| ManagedZMQContext.scoped(socket_dir=args.socket_dir) as zmq_ctx, | ||
| ): | ||
| registry = MetricsRegistry() | ||
|
|
@@ -234,7 +252,10 @@ async def main() -> None: | |
| publish_interval_s=args.publish_interval, | ||
| sig_figs=args.hdr_sig_figs, | ||
| n_histogram_buckets=args.n_histogram_buckets, | ||
| tokenize_pool=pool, | ||
| tokenizer=tokenizer, | ||
| live_flush_interval_s=( | ||
| args.publish_interval if args.tokenizer_workers > 0 else None | ||
| ), | ||
| streaming=args.streaming, | ||
| shutdown_event=shutdown_event, | ||
| drain_timeout_s=None if args.drain_timeout == 0 else args.drain_timeout, | ||
|
|
@@ -268,7 +289,8 @@ async def main() -> None: | |
| loop=loop, | ||
| registry=registry, | ||
| publisher=publisher, | ||
| table=aggregator._table, | ||
| table=aggregator.table, | ||
| token_queue=aggregator.token_queue, | ||
| shutdown_event=shutdown_event, | ||
| ) | ||
| loop.add_signal_handler(signal.SIGTERM, on_sigterm) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we ignore SIGINT, there is no way to terminate/interrupt?