diff --git a/AGENTS.md b/AGENTS.md index 79bc5ded4..56bd7b1d1 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -115,7 +115,7 @@ The aggregator is a separate process (`python -m inference_endpoint.async_utils. - **Series storage**: each `SeriesSampler` keeps three parallel views: O(1) cheap rollups (count/total/min/max/sum_sq, exact), an HDR Histogram (cheap live percentiles), and an in-memory `array.array` of raw values (for exact percentiles in the `COMPLETE` snapshot). Hot path is `registry.record(name, value)` — no allocation, no I/O. - **Counter API**: `registry.increment(name, delta=1)` for sample-event counters. `registry.set_counter(name, value)` only for the two duration counters (`total_duration_ns` max-of-elapsed, `tracked_duration_ns` sum-of-blocks). -- **Lifecycle**: `INITIALIZE` (constructed, awaiting first `STARTED`) → `LIVE` (run in progress, ticking every `--publish-interval` seconds) → `DRAINING` (set on `ENDED`; tick continues; bounded by `--drain-timeout` budget, default 60 s) → terminal: `COMPLETE` (clean end via `publish_final`, exact stats) **or** `INTERRUPTED` (signal-handler-triggered final via SIGTERM/SIGINT; best-effort partial stats). Drain timeout detected by consumers as `state == COMPLETE and n_pending_tasks > 0`; interrupted runs are detected as `state == INTERRUPTED` directly. +- **Lifecycle**: `INITIALIZE` (constructed, awaiting first `STARTED`) → `LIVE` (run in progress, ticking every `--publish-interval` seconds) → `DRAINING` (set on `ENDED`; tick continues; bounded by the `--drain-timeout` budget — schema default 300 s) → terminal: `COMPLETE` (clean end via `publish_final`, exact stats) **or** `INTERRUPTED` (signal-handler-triggered final via SIGTERM/SIGINT; best-effort partial stats). Drain timeout detected by consumers as `state == COMPLETE and n_pending_tasks > 0`; interrupted runs are detected as `state == INTERRUPTED` directly. - **Final delivery is dual-path with separated concerns**: `publish_final` atomically writes `final_snapshot.json` (`tmp + fsync(file) + rename + fsync(parent_dir)`) — this is the **primary** Report source — AND emits the terminal-state snapshot over pub/sub as a TUI shutdown signal. Each path is wrapped in its own try/except so one failure cannot suppress the other. Main process consumer reads `final_snapshot.json` (via `json.loads` to dict, no Struct decode); falls back to the subscriber's `latest` live snapshot only if the file is missing (e.g. SIGKILL / OOM before the signal handler ran). The dict form is the canonical consumer contract (see `snapshot_to_dict`). - **Histogram bucket edges are dynamic per snapshot**: log-spaced over the observed `[min, max]`. Bucket count is fixed at construction; consumers MUST re-render from the snapshot's `(lo, hi, count)` triples each frame and MUST NOT track bucket-by-index across snapshots. @@ -204,7 +204,7 @@ src/inference_endpoint/ │ │ ├── publisher.py # MetricsPublisher (tick task + atomic disk fallback) │ │ ├── subscriber.py # MetricsSnapshotSubscriber (latest + COMPLETE snapshot capture) │ │ ├── metrics_table.py # In-flight sample rows + trigger dispatch (TTFT/TPOT/ISL/OSL) -│ │ └── token_metrics.py # TokenizePool (HF tokenizer thread pool for ISL/OSL/TPOT) +│ │ └── token_metrics.py # BatchTokenizer (live thread lane + drain-only sharded pool) + TokenBatchQueue (defer-to-flush buffer, owns the live flush loop) for ISL/OSL/TPOT │ └── transport/ # ZMQ-based IPC transport layer │ ├── protocol.py # Transport protocols + TransportConfig + MessageCodec[T] │ └── zmq/ # ZMQ implementation (context, pubsub, transport, ZMQTransportConfig) diff --git a/docs/async_utils/services/DESIGN.md b/docs/async_utils/services/DESIGN.md index a26f13783..e013b4ea1 100644 --- a/docs/async_utils/services/DESIGN.md +++ b/docs/async_utils/services/DESIGN.md @@ -306,9 +306,9 @@ stateDiagram-v2 ### 6.2 Metrics aggregator -- **Role**: Subscribes to EventRecords and derives real-time metrics (e.g. TTFT, sample latency, token counts). May use a tokenizer pool for token-based metrics. Shuts down on **session.ended**. -- **Outputs**: Planned is to push real time metrics to Prometheus via PushGateway. Currently, logging / writing final report to JSON is sufficient legacy behavior. -- **Process**: Run as a **subprocess**; given `--metrics-dir`, `--socket-dir`, `--socket-name`, and optional tokenizer options. Uses a dedicated event loop and `ManagedZMQContext.scoped(socket_dir=...)` so it can connect to the publisher's IPC address. +- **Role**: Subscribes to EventRecords and derives real-time metrics (e.g. TTFT, sample latency, token counts). Token metrics (ISL/OSL/TPOT) are computed by a batched tokenizer (in-process threads live; process-sharded end-of-run drain) — see [metrics_aggregator/DESIGN.md](metrics_aggregator/DESIGN.md). Shuts down on **session.ended**. +- **Outputs**: Live `MetricsSnapshot` frames over an IPC PUB socket, and an atomically written `final_snapshot.json` (the primary Report source). Planned is to push real time metrics to Prometheus via PushGateway. +- **Process**: Run as a **subprocess**; given `--metrics-output-dir`, `--socket-dir`, `--socket-name`, `--metrics-socket`, and optional tokenizer options. Uses a dedicated event loop and `ManagedZMQContext.scoped(socket_dir=...)` so it can connect to the publisher's IPC address. --- diff --git a/docs/async_utils/services/metrics_aggregator/DESIGN.md b/docs/async_utils/services/metrics_aggregator/DESIGN.md new file mode 100644 index 000000000..42929f5a4 --- /dev/null +++ b/docs/async_utils/services/metrics_aggregator/DESIGN.md @@ -0,0 +1,119 @@ +# Metrics Aggregator Service — Design + +The metrics aggregator is a subprocess (`python -m +inference_endpoint.async_utils.services.metrics_aggregator`) that subscribes +to the EventRecord stream, folds per-sample events into a `MetricsRegistry`, +and publishes `MetricsSnapshot` frames over IPC PUB at a fixed cadence. At +end-of-run it atomically writes `final_snapshot.json` — the **primary** source +for `Report`; the terminal pub/sub frame is only a TUI "run finished" signal. + +## Lifecycle + +``` +INITIALIZE ──STARTED──► LIVE ──ENDED──► DRAINING ──► COMPLETE + └──► INTERRUPTED (SIGTERM) +``` + +The ENDED path runs inside a finalization boundary: whatever the drain does — +finish, time out, or fail — `publish_final` and the shutdown signal always +run. A tokenizer failure can degrade the snapshot (see the `n_pending_tasks` +contract) but can never hang the subprocess. SIGTERM writes a best-effort +partial snapshot tagged `INTERRUPTED`. + +## Token metrics pipeline + +ISL/OSL/TPOT require tokenizer passes per completed sample; at high completion +rates a per-event dispatch model accumulates an unbounded backlog. The +pipeline batches instead: **defer-to-flush** + **process-sharded encoding**. + +### Defer-to-flush (`TokenBatchQueue`) + +Triggers do no work at event time — `fire()` appends `(text, on_count)` to a +buffer, O(1), no tasks. The buffer is cleared at two points: + +1. **Live loop** — `start_live(interval)` flushes periodically through the + tokenizer's in-process lane: `--tokenizer-workers` threads, rayon capped + to the same width, at most `_LIVE_FLUSH_MAX_ITEMS` per flush. Never + touches the shard processes. `0` disables mid-run tokenization. Failed or + cancelled live items are **re-queued** — the drain retries them. +2. **End-of-run** — `flush_remaining(timeout)` stops the live loop and drains + everything left through every shard, bounded by the drain budget. + +`flush()` serializes under an asyncio lock and detaches the buffer up front. +The text and chat-template phases fail independently; a raising recorder is +logged without aborting the batch. Drain failures are terminal — items stay +counted in `pending`. `flush_remaining` never raises. + +### Sharded batch encoding (`BatchTokenizer`) + +The drain fans the whole buffer out across worker **processes**, one pinned +per `CORES_PER_WORKER` (8) core block. Each worker runs the raw `tokenizers` +backend's `encode_batch_fast` (Rust, rayon); a single BPE rayon pool +saturates ~8 cores, so disjoint pinned blocks are how the whole machine is +used. Workers are spawn-context, warmed in parallel at construction (bounded +— a hung load is a startup error), and ignore SIGINT. + +The shard pool has no knob: it auto-sizes to one shard per 8-core block of +the allowed CPU universe. There is no fallback — no fast Rust backend, or a +failed/over-budget warmup, is a startup error, because an in-process slow +path cannot keep up and would surface much later as an incomplete drain. +Platforms without an affinity API (macOS) shard unpinned; each worker caps +its rayon pool to the block size instead. + +Chat-template items (tool calls) run on the in-process thread lane — +`apply_chat_template` is Python/Jinja; sharding buys nothing. + +### CPU affinity: tokenize is post-run + +The parent pins itself to the loadgen cores and children inherit that narrow +mask. `_setup_shards` probes the full allowed universe via +`expand_to_all_online_cpus()` (cgroup/Slurm-clamped) for the block math, +**then restores the inherited mask** — the aggregator stays where the parent +placed it; only the drain-phase shard children span the machine, and they +are idle until `ENDED`. + +### The `n_pending_tasks` contract + +`TokenBatchQueue.pending` (enqueued-but-not-recorded) is surfaced on every +snapshot as `n_pending_tasks`. In the final snapshot: + +- `state == complete && n_pending_tasks == 0` — clean run, exact series. +- `state == complete && n_pending_tasks > 0` — **incomplete drain** (budget + exhausted or tokenizer failed); `Report` renders a warning. Failed items + are deliberately not removed from the count — under-reporting would + rebadge an incomplete drain as clean. + +### Data flow + +``` +COMPLETE event ─► trigger.fire ─► queue.enqueue(text, on_count) [O(1)] + │ + live loop (publish cadence) ─ flush(live) ─► in-process threads (rayon-capped) + ENDED drain (budgeted) ────── flush() ─────► chunks ─► N pinned worker procs + └─► on_count(n) ─► registry.record() +``` + +## CLI + +| Flag | Default | Purpose | +| -------------------------------- | ----------------- | --------------------------------------------------- | +| `--socket-dir` / `--socket-name` | required | EventRecord SUB socket | +| `--metrics-socket` | required | Snapshot PUB socket name | +| `--metrics-output-dir` | required | Directory for `final_snapshot.json` | +| `--publish-interval` | 0.25 | Live snapshot cadence (seconds) | +| `--drain-timeout` | required (schema) | End-of-run tokenize budget (`0` = unlimited) | +| `--tokenizer` | none | HF name or local path; unset disables token metrics | +| `--tokenizer-workers` | required (schema) | Live in-process threads (`0` = defer all to drain) | +| `--streaming` | off | Register TTFT/chunk-delta/TPOT triggers | + +`--drain-timeout` and `--tokenizer-workers` carry no service-side defaults: +the benchmark always forwards them from `config/schema.py` +(`--metrics-drain-timeout`, `--metrics-tokenizer-workers`), the single source +of truth for their values. + +## References + +- [docs/async_utils/services/DESIGN.md](../DESIGN.md) — the EventRecord + pub/sub system this service subscribes to. +- [docs/PERF_ARCHITECTURE.md](../../../PERF_ARCHITECTURE.md) — CPU pinning + for the loadgen/worker hot path. diff --git a/pyproject.toml b/pyproject.toml index 0b0f67a86..4a7655021 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -112,7 +112,7 @@ test = [ "Pympler==1.1", "scipy==1.17.1", # HTTP server and client for mock server fixture - "aiohttp==3.14.0", + "aiohttp==3.14.1", # Plotting for benchmark sweep mode "matplotlib==3.10.8", # Property-based testing (CLI fuzz) diff --git a/src/inference_endpoint/async_utils/services/metrics_aggregator/__main__.py b/src/inference_endpoint/async_utils/services/metrics_aggregator/__main__.py index 2231d6dc8..70633fb6d 100644 --- a/src/inference_endpoint/async_utils/services/metrics_aggregator/__main__.py +++ b/src/inference_endpoint/async_utils/services/metrics_aggregator/__main__.py @@ -33,7 +33,7 @@ from .publisher import MetricsPublisher from .registry import MetricsRegistry from .snapshot import MetricsSnapshotCodec -from .token_metrics import TokenizePool +from .token_metrics import BatchTokenizer, TokenBatchQueue logger = logging.getLogger(__name__) @@ -44,6 +44,7 @@ def _make_sigterm_handler( registry: MetricsRegistry, publisher: MetricsPublisher, table: MetricsTable, + token_queue: TokenBatchQueue | None, shutdown_event: asyncio.Event, ) -> tuple[Callable[[], None], set[asyncio.Task]]: """Build the SIGTERM handler that writes the INTERRUPTED final snapshot. @@ -75,7 +76,7 @@ async def _signal_finalize() -> None: ) await publisher.publish_final( registry, - n_pending_tasks=table.in_flight_tasks_count, + n_pending_tasks=token_queue.pending if token_queue is not None else 0, interrupted=True, ) except Exception: # noqa: BLE001 — best-effort. @@ -132,13 +133,13 @@ async def main() -> None: parser.add_argument( "--drain-timeout", type=float, - default=60.0, + required=True, help=( - "Wall-clock budget (seconds) to wait for in-flight async tokenize " - "tasks to finish after ENDED before the aggregator cancels them " - "and emits the final snapshot with n_pending_tasks > 0 " - "(default: 60.0; 0 = wait indefinitely). Increase for long-context " - "/ low-worker-count tokenize workloads." + "Wall-clock budget (seconds) to finish tokenizing buffered samples " + "after ENDED before the aggregator emits the final snapshot with " + "n_pending_tasks > 0 (0 = wait indefinitely; the benchmark forwards " + "the schema default, see config/schema.py). Increase for very large " + "datasets where the end-of-run tokenize batch is big." ), ) parser.add_argument( @@ -162,8 +163,14 @@ async def main() -> None: parser.add_argument( "--tokenizer-workers", type=int, - default=2, - help="Number of tokenizer worker threads (default: 2)", + required=True, + help=( + "In-process tokenizer threads for live (mid-run) ISL/OSL/TPOT " + "(0 = no mid-run tokenization, everything defers to the " + "end-of-run drain; the benchmark forwards the schema default, " + "see config/schema.py). The drain always uses the auto-sized " + "sharded pool — one worker process per 8-core block." + ), ) parser.add_argument( "--streaming", @@ -186,6 +193,9 @@ async def main() -> None: args = parser.parse_args() setup_logging(level="INFO") + if args.tokenizer_workers < 0: + raise SystemExit("FATAL: --tokenizer-workers must be >= 0") + # The parent owns directory setup — `commands/benchmark/execute.py` # creates `/metrics/` and validates it before launching # this subprocess. Validate here as a fail-fast contract check so a @@ -204,15 +214,23 @@ async def main() -> None: loop = LoopManager().default_loop # Using ternary operator causes errors in MyPy object type coalescing - # (coalesces to 'object' not 'AbstractContextManager[TokenizePool | None]') - pool_cm: AbstractContextManager[TokenizePool | None] + # (coalesces to 'object' not 'AbstractContextManager[BatchTokenizer | None]') + tokenizer_cm: AbstractContextManager[BatchTokenizer | None] if args.tokenizer: - pool_cm = TokenizePool(args.tokenizer, n_workers=args.tokenizer_workers) + try: + tokenizer_cm = BatchTokenizer( + args.tokenizer, live_workers=args.tokenizer_workers + ) + except RuntimeError as exc: + # Fail-fast contract: a tokenizer environment that cannot shard + # must surface as a clear service-launch failure, not a silent + # slow path that cannot keep up with completions. + raise SystemExit(f"FATAL: {exc}") from exc else: - pool_cm = nullcontext() + tokenizer_cm = nullcontext() with ( - pool_cm as pool, + tokenizer_cm as tokenizer, ManagedZMQContext.scoped(socket_dir=args.socket_dir) as zmq_ctx, ): registry = MetricsRegistry() @@ -234,7 +252,10 @@ async def main() -> None: publish_interval_s=args.publish_interval, sig_figs=args.hdr_sig_figs, n_histogram_buckets=args.n_histogram_buckets, - tokenize_pool=pool, + tokenizer=tokenizer, + live_flush_interval_s=( + args.publish_interval if args.tokenizer_workers > 0 else None + ), streaming=args.streaming, shutdown_event=shutdown_event, drain_timeout_s=None if args.drain_timeout == 0 else args.drain_timeout, @@ -268,7 +289,8 @@ async def main() -> None: loop=loop, registry=registry, publisher=publisher, - table=aggregator._table, + table=aggregator.table, + token_queue=aggregator.token_queue, shutdown_event=shutdown_event, ) loop.add_signal_handler(signal.SIGTERM, on_sigterm) diff --git a/src/inference_endpoint/async_utils/services/metrics_aggregator/aggregator.py b/src/inference_endpoint/async_utils/services/metrics_aggregator/aggregator.py index f01c9753c..e7e8daf20 100644 --- a/src/inference_endpoint/async_utils/services/metrics_aggregator/aggregator.py +++ b/src/inference_endpoint/async_utils/services/metrics_aggregator/aggregator.py @@ -47,7 +47,7 @@ from .publisher import MetricsPublisher from .registry import MetricsRegistry from .snapshot import SessionState -from .token_metrics import TokenizePool +from .token_metrics import BatchTokenizer, TokenBatchQueue logger = logging.getLogger(__name__) @@ -96,8 +96,6 @@ class MetricCounterKey(str, Enum): _TOKEN_HDR_LOW: Final[int] = 1 _TOKEN_HDR_HIGH: Final[int] = 10_000_000 # 10M tokens -_DEFAULT_DRAIN_TIMEOUT_S: Final[float] = 60.0 - class MetricsAggregatorService(ZmqMessageSubscriber[EventRecord]): """Subscribes to EventRecords and computes per-sample metrics in real time. @@ -117,23 +115,33 @@ def __init__( publish_interval_s: float, sig_figs: int, n_histogram_buckets: int, - tokenize_pool: TokenizePool | None = None, + tokenizer: BatchTokenizer | None = None, + live_flush_interval_s: float | None = None, streaming: bool = False, shutdown_event: asyncio.Event | None = None, - drain_timeout_s: float | None = _DEFAULT_DRAIN_TIMEOUT_S, + drain_timeout_s: float | None, **kwargs, ): # drain_timeout_s is injected (not derived) because the right # value is workload-dependent: long-context tokenize-heavy runs - # need more headroom than the default 60 s, and the aggregator - # itself can't measure that ahead of time. Keeping it as an arg - # lets the __main__ CLI flag plumb the user's choice through - # without coupling this class to argparse. + # need more headroom than the schema default 300 s, and the + # aggregator itself can't measure that ahead of time. Keeping it + # as an arg lets the __main__ CLI flag plumb the user's choice + # through without coupling this class to argparse. super().__init__(EventRecordCodec(), *args, **kwargs) self._registry = registry self._publisher = publisher self._publish_interval_s = publish_interval_s - self._tokenize_pool = tokenize_pool + # Token triggers enqueue onto this queue; it is flushed by the + # queue's own live loop (start_live) and by the end-of-run drain. + # None when no tokenizer is set (token metrics disabled), in which + # case those triggers are no-ops. + self._token_queue: TokenBatchQueue | None = ( + TokenBatchQueue(tokenizer, self.loop) if tokenizer is not None else None + ) + # Cadence of the queue's live flush loop (None = no mid-run + # tokenization; everything defers to the end-of-run drain). + self._live_flush_interval_s = live_flush_interval_s self._streaming = streaming self._shutdown_event = shutdown_event self._shutdown_received = False @@ -223,21 +231,33 @@ def _register_triggers(self, streaming: bool) -> None: """ table = self._table registry = self._registry - pool = self._tokenize_pool - loop = self.loop + queue = self._token_queue # Always registered - table.add_trigger(SampleField.ISSUED_NS, IslTrigger(registry, pool, loop)) + table.add_trigger(SampleField.ISSUED_NS, IslTrigger(registry, queue)) table.add_trigger(SampleField.COMPLETE_NS, SampleLatencyTrigger(registry)) - table.add_trigger(SampleField.COMPLETE_NS, OslTrigger(registry, pool, loop)) + table.add_trigger(SampleField.COMPLETE_NS, OslTrigger(registry, queue)) # Streaming-only if streaming: table.add_trigger(SampleField.RECV_FIRST_NS, TtftTrigger(registry)) table.add_trigger(SampleField.LAST_RECV_NS, ChunkDeltaTrigger(registry)) - table.add_trigger( - SampleField.COMPLETE_NS, TpotTrigger(registry, pool, loop) - ) + table.add_trigger(SampleField.COMPLETE_NS, TpotTrigger(registry, queue)) + + @property + def table(self) -> MetricsTable: + """The per-sample metrics table (read-only; for service wiring).""" + return self._table + + @property + def token_queue(self) -> TokenBatchQueue | None: + """The token batch queue, if token metrics are enabled.""" + return self._token_queue + + @property + def pending_tokens(self) -> int: + """Enqueued tokenizations not yet recorded (the snapshot n_pending_tasks).""" + return self._token_queue.pending if self._token_queue is not None else 0 # ------------------------------------------------------------------ # Event processing @@ -311,9 +331,16 @@ async def process(self, records: list[EventRecord]) -> None: self._publish_interval_s, get_runtime_state=lambda: ( self._session_state, - table.in_flight_tasks_count, + self.pending_tokens, ), ) + if ( + self._token_queue is not None + and self._live_flush_interval_s is not None + ): + self._token_queue.start_live( + self._live_flush_interval_s + ) table.handle_session_event(record) if ev == SessionEventType.STOP_PERFORMANCE_TRACKING: registry.set_counter( @@ -367,41 +394,47 @@ async def process(self, records: list[EventRecord]) -> None: # ENDED has been observed; transition to DRAINING so any tick # that fires before publish_final reflects the new state. self._session_state = SessionState.DRAINING - logger.info("Draining %d async tasks...", table.in_flight_tasks_count) - # drain_tasks owns the timeout + cancel-and-await sequence so - # the pending count is captured BEFORE done-callbacks empty - # the in-flight set. Reading in_flight_tasks_count out here - # would always be 0 (see drain_tasks docstring). - n_pending = await table.drain_tasks(timeout=self._drain_timeout_s) - if n_pending > 0: - timeout_str = ( - f"{self._drain_timeout_s:.1f}s" - if self._drain_timeout_s is not None - else "unlimited" + logger.info("Draining %d pending tokenizations...", self.pending_tokens) + # The drain and final publish are wrapped together so the aggregator + # ALWAYS reaches _finalize (which sets the shutdown event); a + # tokenizer failure during the drain must not skip publish_final and + # leave main()'s `await shutdown_event.wait()` hanging. + n_pending = self.pending_tokens + try: + # flush_remaining tokenizes the whole buffer in one batched pass, + # bounded by the drain budget, and never raises: it returns the + # count it could not finish (timeout or failure), which becomes + # the snapshot's n_pending_tasks so Report flags an incomplete drain. + if self._token_queue is not None: + n_pending = await self._token_queue.flush_remaining( + self._drain_timeout_s + ) + if n_pending > 0: + budget = ( + f"{self._drain_timeout_s:.1f}s" + if self._drain_timeout_s is not None + else "unlimited" + ) + logger.warning( + "tokenizer drain incomplete (budget %s); %d tokenizations " + "did not complete", + budget, + n_pending, + ) + logger.info( + "Tokenizations drained (n_pending_tasks=%d at finalize)", n_pending ) - logger.warning( - "drain_tasks timed out after %s; %d async tasks " - "did not complete and were cancelled", - timeout_str, - n_pending, + registry.set_counter( + MetricCounterKey.TRACKED_DURATION_NS.value, + table.total_tracked_duration_ns, ) - logger.info( - "Async tasks drained (n_pending_tasks=%d at finalize)", n_pending - ) - registry.set_counter( - MetricCounterKey.TRACKED_DURATION_NS.value, - table.total_tracked_duration_ns, - ) - try: await self._publisher.publish_final(registry, n_pending_tasks=n_pending) finally: - # Whatever happens above, the aggregator MUST close the - # publisher and signal shutdown — otherwise the main() - # entry point's `await shutdown_event.wait()` hangs - # forever and the subprocess never exits cleanly. Each - # cleanup step is independently wrapped: a failure in - # aclose must not prevent _finalize, since _finalize is - # what sets the shutdown event. + # The aggregator MUST close the publisher and signal shutdown even + # if the drain/publish above failed — otherwise main()'s + # `await shutdown_event.wait()` hangs forever. aclose is + # independently wrapped: its failure must not prevent _finalize, + # which is what sets the shutdown event. try: await self._publisher.aclose() except Exception: # noqa: BLE001 — best-effort cleanup. diff --git a/src/inference_endpoint/async_utils/services/metrics_aggregator/metrics_table.py b/src/inference_endpoint/async_utils/services/metrics_aggregator/metrics_table.py index 46a17e92f..88d2693ee 100644 --- a/src/inference_endpoint/async_utils/services/metrics_aggregator/metrics_table.py +++ b/src/inference_endpoint/async_utils/services/metrics_aggregator/metrics_table.py @@ -17,9 +17,9 @@ from __future__ import annotations -import asyncio import logging from abc import ABC, abstractmethod +from collections.abc import Callable from dataclasses import dataclass from enum import Enum from typing import TYPE_CHECKING, Any @@ -33,7 +33,8 @@ MetricsRegistry, ) from inference_endpoint.async_utils.services.metrics_aggregator.token_metrics import ( - TokenizePool, + MessageParts, + TokenBatchQueue, ) from inference_endpoint.core.record import EventRecord @@ -146,8 +147,13 @@ def fire( ev_rec: EventRecord, row: SampleRow, pre_change: dict[str, Any], - ) -> asyncio.Task | None: - """Must be non-blocking. Return a Task if async work was scheduled.""" + ) -> None: + """Must be non-blocking. + + Sync triggers record into the registry directly. Token triggers + enqueue onto the shared ``TokenBatchQueue`` for batched tokenization + at the next flush; neither path schedules per-event tasks. + """ raise NotImplementedError() @@ -173,32 +179,30 @@ def fire(self, ev_rec, row, pre_change): baseline = pre_change.get(self._delta_start_fieldname) if baseline is not None: self.registry.record(self.metric_name, ev_rec.timestamp_ns - baseline) - return None -class AsyncTokenTrigger(EmitTrigger): - """Base for triggers that need async tokenization. +class TokenTrigger(EmitTrigger): + """Base for triggers whose metric needs tokenization. - Subclasses implement ``_extract_text()`` to pull the text to tokenize - from the event record. If text is returned, an async task is created - to tokenize and emit. Subclasses can also override ``_extract_message()`` - to return (content, reasoning, tool_calls) for chat-template–aware tokenization - when tool calls are present. Subclasses can override ``_compute_value()`` to - transform the token count before storing. + Subclasses implement ``_extract_text()`` to pull the text to tokenize from + the event record, and may override ``_extract_message()`` to return + (content, reasoning, tool_calls) for chat-template–aware tokenization when + tool calls are present. ``fire()`` does not tokenize inline — it enqueues + the work plus a recorder callback onto the shared ``TokenBatchQueue``, which + the aggregator flushes in batches. ``_compute_value()`` can transform the + token count before it is recorded. """ def __init__( self, metric_name: str, registry: MetricsRegistry, - tokenize_pool: TokenizePool | None, - loop: asyncio.AbstractEventLoop | None, + queue: TokenBatchQueue | None, requires: tuple[str, ...] = (), dtype: type = int, ): super().__init__(metric_name, registry, requires=requires, dtype=dtype) - self._pool = tokenize_pool - self._loop = loop + self._queue = queue @abstractmethod def _extract_text( @@ -209,11 +213,11 @@ def _extract_text( def _extract_message( self, ev_rec: EventRecord, row: SampleRow, pre_change: dict[str, Any] - ) -> tuple[str, str | None, tuple[dict[str, Any], ...] | None] | None: - """Return (content, reasoning, tool_calls) for message-aware tokenization, or None. + ) -> MessageParts | None: + """Return (content, reasoning, tool_calls) for message-aware tokenization. - When non-None is returned, ``token_count_message_async`` is used instead of - ``token_count_async``. Default returns None (use text path). + When non-None, the message (chat-template) path is used instead of the + plain-text path. Default returns None (use text path). """ return None @@ -223,48 +227,32 @@ def _compute_value( """Transform token count into the metric value. Default: count as-is.""" return token_count - def fire(self, ev_rec, row, pre_change): - if self._pool is None or self._loop is None: - return None + def _make_recorder( + self, ev_rec: EventRecord, pre_change: dict[str, Any] + ) -> Callable[[int], None]: + """Build the callback the queue runs once the token count is known.""" + registry, name = self.registry, self.metric_name - message_parts = self._extract_message(ev_rec, row, pre_change) - if message_parts is not None: - content, reasoning, tool_calls = message_parts - pool, loop = self._pool, self._loop - registry, name = self.registry, self.metric_name - uuid = row.sample_uuid - - async def _tokenize_message_and_emit() -> None: - try: - count = await pool.token_count_message_async( - content, reasoning, tool_calls, loop - ) - value = self._compute_value(count, ev_rec, pre_change) - if value is not None: - registry.record(name, value) - except Exception: - logger.exception("%s tokenization failed for %s", name, uuid) + def record(count: int) -> None: + value = self._compute_value(count, ev_rec, pre_change) + if value is not None: + registry.record(name, value) - return loop.create_task(_tokenize_message_and_emit()) + return record + def fire(self, ev_rec, row, pre_change): + if self._queue is None: + return + message_parts = self._extract_message(ev_rec, row, pre_change) + if message_parts is not None: + self._queue.enqueue_message( + message_parts, self._make_recorder(ev_rec, pre_change) + ) + return text = self._extract_text(ev_rec, row, pre_change) if not text: - return None - - pool, loop = self._pool, self._loop - registry, name = self.registry, self.metric_name - uuid = row.sample_uuid - - async def _tokenize_and_emit() -> None: - try: - count = await pool.token_count_async(text, loop) - value = self._compute_value(count, ev_rec, pre_change) - if value is not None: - registry.record(name, value) - except Exception: - logger.exception("%s tokenization failed for %s", name, uuid) - - return loop.create_task(_tokenize_and_emit()) + return + self._queue.enqueue_text(text, self._make_recorder(ev_rec, pre_change)) # --------------------------------------------------------------------------- @@ -309,29 +297,28 @@ def __init__(self, registry: MetricsRegistry): # --------------------------------------------------------------------------- -# Token triggers (async) +# Token triggers (batched) # --------------------------------------------------------------------------- -class IslTrigger(AsyncTokenTrigger): - """ISL from PromptData: len(token_ids) sync, or token_count(text) async.""" +class IslTrigger(TokenTrigger): + """ISL from PromptData: ``len(token_ids)`` or the tokenized prompt text.""" def __init__( self, registry: MetricsRegistry, - tokenize_pool: TokenizePool | None, - loop: asyncio.AbstractEventLoop | None, + queue: TokenBatchQueue | None, ): - super().__init__(MetricSeriesKey.ISL, registry, tokenize_pool, loop) + super().__init__(MetricSeriesKey.ISL, registry, queue) def fire(self, ev_rec, row, pre_change): # Sync fast path: any backend that pre-populates token_ids (e.g. SGLang). if isinstance(ev_rec.data, PromptData) and ev_rec.data.token_ids is not None: self.registry.record(self.metric_name, len(ev_rec.data.token_ids)) - return None - # Async path: tokenize raw text — used when token_ids are unavailable - # (e.g. OpenAI-compatible endpoints). Handled by the base class. - return super().fire(ev_rec, row, pre_change) + return + # Text path: tokenize raw prompt text — used when token_ids are + # unavailable (e.g. OpenAI-compatible endpoints). Enqueued by the base. + super().fire(ev_rec, row, pre_change) def _extract_text(self, ev_rec, row, pre_change): if isinstance(ev_rec.data, PromptData) and ev_rec.data.text is not None: @@ -339,16 +326,15 @@ def _extract_text(self, ev_rec, row, pre_change): return None -class OslTrigger(AsyncTokenTrigger): +class OslTrigger(TokenTrigger): """OSL = token_count(full output text) from COMPLETE event data.""" def __init__( self, registry: MetricsRegistry, - tokenize_pool: TokenizePool | None, - loop: asyncio.AbstractEventLoop | None, + queue: TokenBatchQueue | None, ): - super().__init__(MetricSeriesKey.OSL, registry, tokenize_pool, loop) + super().__init__(MetricSeriesKey.OSL, registry, queue) def _extract_text(self, ev_rec, row, pre_change): if isinstance(ev_rec.data, TextModelOutput): @@ -365,32 +351,24 @@ def _extract_message(self, ev_rec, row, pre_change): return None -class TpotTrigger(AsyncTokenTrigger): - """TPOT = (complete_ns - recv_first_ns) / token_count(text_after_first_chunk). +class TpotTrigger(TokenTrigger): + """TPOT = (complete_ns - recv_first_ns) / output token count. - Only registered when streaming mode is enabled. - - # NOTE(agents): This trigger tokenizes text_after_first_chunk independently - # from OslTrigger, which tokenizes the full output. This means the output is - # tokenized twice at COMPLETE time for streaming samples. This is intentional: - # OSL is always required (non-streaming and streaming), while TPOT is - # streaming-only. Keeping them as separate triggers allows conditional - # registration via the streaming flag. If tokenization throughput becomes a - # bottleneck, consider merging OSL and TPOT into a single trigger that - # tokenizes once and derives both metrics. + Streaming-only. Tokenizes the post-first-chunk output independently of + ``OslTrigger`` (full output), so streaming samples are tokenized twice — + intentional: OSL is always required, TPOT is conditional on the streaming + flag. """ def __init__( self, registry: MetricsRegistry, - tokenize_pool: TokenizePool | None, - loop: asyncio.AbstractEventLoop | None, + queue: TokenBatchQueue | None, ): super().__init__( MetricSeriesKey.TPOT_NS, registry, - tokenize_pool, - loop, + queue, requires=(SampleField.RECV_FIRST_NS,), dtype=float, ) @@ -444,7 +422,6 @@ def __init__(self, registry: MetricsRegistry) -> None: self._registry = registry self._in_flight: dict[str, SampleRow] = {} self._triggers: dict[str, list[EmitTrigger]] = {} - self._in_flight_tasks: set[asyncio.Task] = set() # Session-level state self.is_tracking: bool = False @@ -538,45 +515,6 @@ def set_field( self._update_tracked_block(row, ev_rec.timestamp_ns) self._in_flight.pop(sample_uuid, None) - # --- Task draining --- - - @property - def in_flight_tasks_count(self) -> int: - """Number of async trigger tasks currently in flight.""" - return len(self._in_flight_tasks) - - async def drain_tasks(self, *, timeout: float | None = None) -> int: - """Await in-flight async trigger tasks. - - With ``timeout``, the pending set at the timeout boundary is - cancelled and awaited; the count of those pending tasks is - returned (>0 indicates the drain timed out). Without - ``timeout``, blocks indefinitely and returns 0 on clean drain. - - The pending count must be captured BEFORE the cancel-and-await - step: each task's ``add_done_callback(_in_flight_tasks.discard)`` - empties ``_in_flight_tasks`` as cancellation propagates, so - reading ``in_flight_tasks_count`` after this method returns - would always be 0 — making a drain timeout indistinguishable - from a clean run. - """ - if not self._in_flight_tasks: - return 0 - if timeout is None: - await asyncio.gather(*self._in_flight_tasks, return_exceptions=True) - self._in_flight_tasks.clear() - return 0 - _, still_pending = await asyncio.wait( - list(self._in_flight_tasks), timeout=timeout - ) - n_pending = len(still_pending) - if still_pending: - for t in still_pending: - t.cancel() - await asyncio.gather(*still_pending, return_exceptions=True) - self._in_flight_tasks.clear() - return n_pending - # --- Internal --- def _create_row(self, sample_uuid: str) -> SampleRow: @@ -595,10 +533,7 @@ def _fire_triggers( ) -> None: for trigger in self._triggers.get(field_name, ()): pre_change = {attr: getattr(row, attr) for attr in trigger.requires} - task = trigger.fire(ev_rec, row, pre_change) - if task is not None: - self._in_flight_tasks.add(task) - task.add_done_callback(self._in_flight_tasks.discard) + trigger.fire(ev_rec, row, pre_change) def _update_tracked_block(self, row: SampleRow, complete_ns: int) -> None: """Extend the sample's tracked block duration and increment count.""" diff --git a/src/inference_endpoint/async_utils/services/metrics_aggregator/publisher.py b/src/inference_endpoint/async_utils/services/metrics_aggregator/publisher.py index d21973a3f..578e47198 100644 --- a/src/inference_endpoint/async_utils/services/metrics_aggregator/publisher.py +++ b/src/inference_endpoint/async_utils/services/metrics_aggregator/publisher.py @@ -88,7 +88,7 @@ def __init__( self._final_snapshot_path = final_snapshot_path self._tick_task: asyncio.Task | None = None self._closed = False - # publish_final is idempotent: the SIGTERM/SIGINT handler in + # publish_final is idempotent: the SIGTERM handler in # __main__.py and the aggregator's ENDED-driven path can both # call it; the second call must not re-publish or re-write. self._finalized = False @@ -107,10 +107,10 @@ def start( ``get_runtime_state`` returns ``(state, n_pending_tasks)`` for the current moment: the aggregator's session state (``LIVE`` or - ``DRAINING``) and the count of in-flight async tokenize tasks. The - callable is invoked once per tick and the values are plumbed into - the published snapshot. ``COMPLETE`` is emitted only by - ``publish_final``, never by the tick task. + ``DRAINING``) and the count of pending tokenizations. The callable is + invoked once per tick and the values are plumbed into the published + snapshot. ``COMPLETE`` is emitted only by ``publish_final``, never by + the tick task. Idempotent on the tick-task slot: a second call (e.g. from a spurious duplicate ``STARTED`` event or a buggy replay producer) @@ -159,12 +159,12 @@ async def publish_final( ) -> None: """Write the final snapshot to disk and signal pub/sub consumers. - ``n_pending_tasks`` is the count of in-flight async tokenize tasks - at finalization time. Drain timeout is detected by Report consumers - as ``state == COMPLETE and n_pending_tasks > 0``. + ``n_pending_tasks`` is the count of buffered tokenizations not yet + recorded at finalization time. An incomplete drain is detected by + Report consumers as ``state == COMPLETE and n_pending_tasks > 0``. ``interrupted=True`` is set by the signal handler in __main__.py - when SIGTERM/SIGINT triggers shutdown before ``ENDED`` arrived; + when SIGTERM triggers shutdown before ``ENDED`` arrived; the resulting snapshot is tagged ``state=INTERRUPTED`` so Report can distinguish "user killed the run mid-execution" from a clean end. Stats in an INTERRUPTED snapshot are best-effort partial @@ -190,7 +190,7 @@ async def publish_final( of the terminal state as the last message). Idempotent: only the first call writes/publishes; subsequent - calls early-return. The SIGTERM/SIGINT handler relies on this to + calls early-return. The SIGTERM handler relies on this to race safely with the ENDED-driven path. """ if self._finalized: diff --git a/src/inference_endpoint/async_utils/services/metrics_aggregator/snapshot.py b/src/inference_endpoint/async_utils/services/metrics_aggregator/snapshot.py index 95c68ab16..8046e1704 100644 --- a/src/inference_endpoint/async_utils/services/metrics_aggregator/snapshot.py +++ b/src/inference_endpoint/async_utils/services/metrics_aggregator/snapshot.py @@ -44,8 +44,8 @@ class SessionState(str, Enum): state to carry). LIVE → run in progress; tick task publishing live HDR-derived stats. DRAINING → ``SessionEventType.ENDED`` has been received; the aggregator - is awaiting the in-flight async tokenize tasks (bounded by - the ``--drain-timeout`` budget, default 60 s). Tick task + is tokenizing the buffered samples (bounded by the + ``--drain-timeout`` budget — schema default 300 s). Tick task continues at this stage, still HDR-derived; no new events will arrive. COMPLETE → terminal clean state. The ``publish_final()`` snapshot @@ -149,13 +149,14 @@ class MetricsSnapshot( ``INTERRUPTED``) mark the last snapshot of the run; for ``COMPLETE`` snapshots percentiles and histograms are exact, otherwise HDR-derived. - n_pending_tasks: Count of in-flight async tokenize tasks at snapshot - composition time. ``> 0`` during normal load (ISL/ - OSL/TPOT post-processing in flight) and during the - drain phase. **Drain timeout is detected as** - ``state == COMPLETE and n_pending_tasks > 0``: the - aggregator gave up draining; some async-only series - are missing samples that were still being tokenized. + n_pending_tasks: Count of buffered tokenizations not yet recorded at + snapshot composition time. ``> 0`` during normal + load (ISL/OSL/TPOT buffered between publish-tick + flushes) and during the drain phase. **An + incomplete drain is detected as** ``state == + COMPLETE and n_pending_tasks > 0``: the end-of-run + flush timed out or failed; the token-derived series + are missing those samples. metrics: Tagged union of ``CounterStat`` and ``SeriesStat``, ordered counters-first then series, registration order within each. diff --git a/src/inference_endpoint/async_utils/services/metrics_aggregator/token_metrics.py b/src/inference_endpoint/async_utils/services/metrics_aggregator/token_metrics.py index 3411d5061..60a75bdb6 100644 --- a/src/inference_endpoint/async_utils/services/metrics_aggregator/token_metrics.py +++ b/src/inference_endpoint/async_utils/services/metrics_aggregator/token_metrics.py @@ -13,25 +13,66 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tokenization utilities for metrics aggregation.""" +"""Tokenization for ISL/OSL/TPOT metrics. + +``BatchTokenizer`` tokenizes whole batches at once, sharded across worker +processes each pinned to a block of ``CORES_PER_WORKER`` cores (a single BPE +rayon pool is memory-bound and saturates ~8 cores). The aggregator buffers +per-sample text. The sharded pool is the drain-phase accelerator and is +auto-sized (one shard per core block); live mid-run flushes run on a small +in-process thread pool (``--tokenizer-workers``, default 2) owned by the +queue's live loop. A tokenizer without a fast (Rust) backend is a startup +error, never a silent slow path. Platforms without CPU affinity (e.g. macOS) +shard unpinned at full speed; only cache/NUMA locality is lost. +""" from __future__ import annotations import asyncio import json import logging -import threading -from concurrent.futures import ThreadPoolExecutor -from typing import TYPE_CHECKING, Any +import multiprocessing +import os +import signal +import time +from collections.abc import Callable +from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor +from typing import TYPE_CHECKING, Any, Protocol, cast import msgspec +from inference_endpoint.endpoint_client.cpu_affinity import ( + expand_to_all_online_cpus, +) from transformers import AutoTokenizer +from transformers.utils import logging as transformers_logging + +# A single rayon pool peaks at ~8 cores for BPE (memory-bound; more threads +# oversubscribe and, on multi-socket Grace, cross the NUMA boundary). Sharding +# across processes pinned to disjoint 8-core blocks is how the whole machine is +# used. Measured on GB200: ~16k texts/s at 18 blocks vs ~1.5k single-process. +CORES_PER_WORKER = 8 + +# Budget for the parallel shard warmup (spawn + transformers import + +# tokenizer load per worker). A hung load (e.g. a stuck network filesystem) +# must become a bounded startup error, not wedge service startup — and the +# error must fire before the parent's 30 s service-launch budget kills the +# subprocess, so the diagnostic wins the race. +_SHARD_WARMUP_TIMEOUT_S = 25.0 + +# Per-flush ceiling for the LIVE lane. Bounds three things at once: how long +# the queue lock is held mid-run, how much work an unstoppable in-flight +# thread encode can hold after a drain-start cancellation, and how much the +# drain re-encodes for items the cancelled flush gave back. The drain has no +# ceiling — it always takes the whole buffer. +_LIVE_FLUSH_MAX_ITEMS = 1024 # Minimal user message used to satisfy chat templates that reject assistant-only # message lists. Its token count is subtracted so only the assistant payload is # measured. _PREFIX_USER_MSG: dict[str, str] = {"role": "user", "content": ""} +logger = logging.getLogger(__name__) + def _normalize_tool_calls_for_template( tool_calls: tuple[dict[str, Any], ...] | list[dict[str, Any]], @@ -60,140 +101,328 @@ def _normalize_tool_calls_for_template( return normalized +# --------------------------------------------------------------------------- +# Process-worker entry points (module-level so ProcessPoolExecutor can pickle +# them by name). Each worker holds one raw tokenizers backend, pinned to a +# fixed core block. +# --------------------------------------------------------------------------- + +_WORKER_BACKEND: Any = None + + +def _init_worker(tokenizer_name: str, core_set: list[int]) -> None: + """Pin this worker to ``core_set``, then load the raw tokenizers backend. + + Affinity is set before the first encode so the Rust rayon pool sizes itself + to the pinned core count (num_cpus respects sched_getaffinity on Linux). + """ + # Ctrl-C sends SIGINT to the whole foreground process group; the parent + # drives worker shutdown, so a worker dying mid-drain would break the pool + # and lose the buffered tokenizations it was counting. + signal.signal(signal.SIGINT, signal.SIG_IGN) + if core_set: + # Size the rayon pool to the block explicitly: the parent process caps + # its own pool for the live lane, and spawn children inherit that env — + # without the override every shard would run at the live-lane width. + os.environ["RAYON_NUM_THREADS"] = str(len(core_set)) + try: + os.sched_setaffinity(0, set(core_set)) + except (OSError, AttributeError): + # No pinning (e.g. macOS): the rayon cap above still keeps + # unpinned shards from oversubscribing each other. + logger.debug("could not pin tokenizer worker to %s", core_set) + transformers_logging.set_verbosity_error() + tok = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True) + global _WORKER_BACKEND + _WORKER_BACKEND = getattr(tok, "backend_tokenizer", None) + if _WORKER_BACKEND is not None: + _WORKER_BACKEND.encode("warmup", add_special_tokens=False) + + +def _encode_batch_lengths(backend: Any, texts: list[str]) -> list[int]: + """Per-text token counts via the raw tokenizers backend, one rayon call.""" + encode_batch = getattr(backend, "encode_batch_fast", None) or backend.encode_batch + return [len(e.ids) for e in encode_batch(texts, add_special_tokens=False)] + + +def _worker_encode_lengths(texts: list[str]) -> list[int]: + """Per-text token counts for a shard, in one rayon-parallel call.""" + backend = _WORKER_BACKEND + if backend is None: + raise RuntimeError("tokenizer worker backend unavailable") + return _encode_batch_lengths(backend, texts) + + +def _worker_ready(_: int) -> bool: + """Warmup probe: returns once the worker's backend is loaded.""" + return _WORKER_BACKEND is not None + + +def _terminate_procs(procs: list[ProcessPoolExecutor]) -> None: + """Best-effort immediate stop: cancel queued work and SIGTERM workers. + + ``shutdown(wait=False)`` alone leaves an in-flight encode running, and the + non-daemon worker would still be joined at interpreter exit — so a drain + timeout could stall process shutdown until the chunk finished. + """ + for ex in procs: + ex.shutdown(wait=False, cancel_futures=True) + workers = getattr(ex, "_processes", None) or {} # CPython impl detail. + for p in workers.values(): + try: + p.terminate() + except Exception: # noqa: BLE001 — already-dead workers are fine. + pass + + if TYPE_CHECKING: from transformers import PreTrainedTokenizerBase -logger = logging.getLogger(__name__) +def _even_chunks(items: list[str], n: int) -> list[list[str]]: + """Split ``items`` into at most ``n`` near-equal contiguous chunks.""" + if n <= 1 or len(items) <= 1: + return [items] + size = (len(items) + n - 1) // n + return [items[i : i + size] for i in range(0, len(items), size)] -class TokenizePool: - """A pool of worker threads, each with its own HuggingFace AutoTokenizer. - Uses multi-threading (not multiprocessing) because HuggingFace tokenizers - use a Rust backend that releases the GIL during tokenization, so threads - can run tokenization in parallel without GIL contention. Multiprocessing - would add process spawn overhead and per-process tokenizer memory and - IPC latency. +class BatchTokenizer: + """Counts tokens for batches of text, sharded across pinned CPU cores. - Thread-safety notes: - - The ThreadPoolExecutor itself is thread-safe (submit/shutdown are synchronized). - - Each worker thread has its own tokenizer via thread-local storage, so there - is no shared mutable state during tokenization. - - The blocking `token_count()` method is safe to call from multiple threads - concurrently. - - In an async context, use `token_count_async` to avoid blocking the event loop. + ``count_texts_async`` tokenizes a whole list in one sharded call. The + chat-template ``token_count_message_async`` path runs on a small in-process + thread — rare (tool calls) relative to the batched OSL/ISL/TPOT flush. """ - def __init__(self, tokenizer_name: str, n_workers: int) -> None: - if n_workers < 1: - raise ValueError("n_workers must be at least 1") + def __init__( + self, + tokenizer_name: str, + *, + live_workers: int, + cores_per_worker: int = CORES_PER_WORKER, + n_workers: int = -1, + ) -> None: self._tokenizer_name = tokenizer_name - self._n_workers = n_workers - self._thread_local = threading.local() + # The live lane runs in-process: cap this process's rayon pool so a + # mid-run batched encode uses ~live_workers cores, not the whole + # machine. Must be set before the first encode initializes the pool; + # setdefault lets an operator-exported RAYON_NUM_THREADS win. + os.environ.setdefault("RAYON_NUM_THREADS", str(max(1, live_workers))) + self._live_workers = live_workers self._fallback_warned: set[str] = set() - self._executor: ThreadPoolExecutor | None = ThreadPoolExecutor( - max_workers=n_workers, - thread_name_prefix="TokenizePool", + self._tokenizer: PreTrainedTokenizerBase | None = None + self._prefix_len = 0 + self._baseline = 0 + # In-process threads: the live token-metric lane plus the + # chat-template path. + self._thread: ThreadPoolExecutor | None = ThreadPoolExecutor( + max_workers=max(1, live_workers), thread_name_prefix="tok-thread" ) - # Pre-load a tokenizer on every worker thread so the first real - # token_count call doesn't pay the AutoTokenizer.from_pretrained cost. - # Submitting n_workers tasks is guaranteed to hit every thread because - # AutoTokenizer.from_pretrained blocks long enough that no thread - # completes before all tasks are submitted. - # **IMPORTANT**: This is not a guarantee - for instance when using a mock - # object in tests for the tokenizer, the mock object *must* block in the 100ms - # range to simulate proper .from_pretrained behavior. - # It is not super impactful if a thread is not pre-initialized - it will just - # have to pay the cost of .from_pretrained on the first pool.token_count call - # for that thread. - futures = [ - self._executor.submit(self._get_thread_tokenizer) for _ in range(n_workers) - ] + self._load_tokenizer() # also computes the chat-template baseline + # Process shards for the batched text path. Empty only when + # in-process mode was explicitly requested (n_workers=0 or + # cores_per_worker<=0; ctor overrides used primarily by tests — + # production wiring passes live_workers only and shards auto-size). + self._procs: list[ProcessPoolExecutor] = [] + self._setup_shards(cores_per_worker, n_workers) + + # -- setup -------------------------------------------------------------- + + def _load_tokenizer(self) -> None: + tok = AutoTokenizer.from_pretrained( + self._tokenizer_name, trust_remote_code=True + ) + self._tokenizer = tok + # Baseline = tokens from a [user, empty-assistant] pair minus the [user] + # prefix alone, so the assistant frame is subtracted from message counts. try: - for f in futures: - f.result() - except Exception: - self._executor.shutdown(wait=False) - self._executor = None - raise - - def _get_thread_tokenizer(self) -> PreTrainedTokenizerBase: - """Return the tokenizer for the current thread, loading it if needed.""" - if getattr(self._thread_local, "tokenizer", None) is None: - self._thread_local.tokenizer = AutoTokenizer.from_pretrained( - self._tokenizer_name, trust_remote_code=True + prefix = cast( + str, + tok.apply_chat_template( + [_PREFIX_USER_MSG], tokenize=False, add_generation_prompt=False + ), ) - # Baseline = tokens contributed by a [user, empty-assistant] pair minus - # the [user] prefix alone. Some templates (Qwen3-Coder, etc.) reject - # assistant-only message lists, so a user prefix is required; we - # subtract it out so the baseline reflects only the assistant frame. - try: - tok = self._thread_local.tokenizer - prefix_rendered = tok.apply_chat_template( - [_PREFIX_USER_MSG], - tokenize=False, - add_generation_prompt=False, - ) - prefix_len = len(tok.tokenize(prefix_rendered)) - with_empty_assistant_rendered = tok.apply_chat_template( + self._prefix_len = len(tok.tokenize(prefix)) + with_assistant = cast( + str, + tok.apply_chat_template( [_PREFIX_USER_MSG, {"role": "assistant", "content": ""}], tokenize=False, add_generation_prompt=False, + ), + ) + self._baseline = len(tok.tokenize(with_assistant)) - self._prefix_len + except Exception: + self._prefix_len = 0 + self._baseline = 0 + logger.exception( + "Failed to compute chat-template baseline for %s; tool-call " + "token counts may be over-estimated", + self._tokenizer_name, + ) + + def _setup_shards(self, cores_per_worker: int, n_workers: int) -> None: + """Spawn one pinned single-worker process per core block. + + ``n_workers == 0`` explicitly selects in-process tokenization. Auto + (``< 0``) fits one shard per ``cores_per_worker`` block of this + process's affinity mask (or the online CPU count when the platform + has no affinity API — shards then run unpinned), always at least one; + an explicit count is clamped to that capacity. An environment that + cannot shard — no fast Rust backend, a warmup that fails or exceeds + its budget — raises instead of silently degrading to a slow path + that cannot keep up with completions. + """ + if cores_per_worker <= 0 or n_workers == 0: + logger.info("BatchTokenizer: in-process tokenization (explicit)") + return + if getattr(self._tokenizer, "backend_tokenizer", None) is None: + raise RuntimeError( + f"tokenizer {self._tokenizer_name!r} has no fast (Rust) " + "backend; token metrics require one to keep up with " + "completions. Use a fast tokenizer, or disable token metrics." + ) + # Probe the full allowed CPU universe (cgroup-clamped) for the shard + # block math, then restore this process's inherited mask: the + # aggregator's event loop, publisher, and live tokenizer threads stay + # exactly where the parent placed them (the loadgen mask on a pinned + # Linux run). Only the drain-phase shard processes, pinned to their + # own blocks, span the whole machine. + try: + original = os.sched_getaffinity(0) + except (OSError, AttributeError): + original = None + try: + available = sorted(expand_to_all_online_cpus()) + except Exception: # noqa: BLE001 — no affinity API (e.g. macOS). + # Shard unpinned: the OS scheduler spreads the workers; only + # cache/NUMA locality is lost. Workers cap their rayon pools to + # the block size instead (_init_worker). + available = list(range(os.cpu_count() or 1)) + logger.info("BatchTokenizer: CPU affinity unavailable; sharding unpinned") + else: + if original is not None: + try: + os.sched_setaffinity(0, original) + except OSError: + logger.warning( + "could not restore the aggregator's inherited CPU " + "mask; this process stays expanded to all CPUs" + ) + capacity = max(1, len(available) // cores_per_worker) + n = capacity if n_workers < 0 else min(n_workers, capacity) + t0 = time.perf_counter() + ctx = multiprocessing.get_context("spawn") + procs: list[ProcessPoolExecutor] = [] + try: + for i in range(n): + block = available[i * cores_per_worker : (i + 1) * cores_per_worker] + ex = ProcessPoolExecutor( + max_workers=1, + mp_context=ctx, + initializer=_init_worker, + initargs=(self._tokenizer_name, block), ) - with_empty_assistant_len = len( - tok.tokenize(with_empty_assistant_rendered) - ) - self._thread_local.prefix_len = prefix_len - self._thread_local.baseline = with_empty_assistant_len - prefix_len - except Exception: - self._thread_local.prefix_len = 0 - self._thread_local.baseline = 0 - logger.exception( - "Failed to compute chat-template baseline for %s; tool-call token counts may be over-estimated", - self._tokenizer_name, - ) - return self._thread_local.tokenizer + procs.append(ex) + # Force spawn + pin + tokenizer-load now (not on the first batch). + # Submit to every shard first so the loads run in parallel, then + # await — waiting on each before submitting the next would + # serialize P tokenizer loads and can exceed the launch budget. + # The wait is bounded: one hung load must not wedge startup. + ready = [ex.submit(_worker_ready, 0) for ex in procs] + deadline = time.monotonic() + _SHARD_WARMUP_TIMEOUT_S + for f in ready: + f.result(timeout=max(0.0, deadline - time.monotonic())) + except Exception as exc: + _terminate_procs(procs) + raise RuntimeError( + "tokenizer shard warmup failed; refusing to fall back to a " + "slow path that cannot keep up with completions. Fix the " + "environment (see the chained error)." + ) from exc + self._procs = procs + logger.info( + "BatchTokenizer: %d shards x %d cores (setup %.1fs)", + len(procs), + cores_per_worker, + time.perf_counter() - t0, + ) + + # -- batched text path -------------------------------------------------- + + def _encode_lengths_inproc(self, texts: list[str]) -> list[int]: + tok = self._tokenizer + backend = getattr(tok, "backend_tokenizer", None) + if backend is not None: + return _encode_batch_lengths(backend, texts) + return [len(tok.tokenize(t)) for t in texts] # type: ignore[union-attr] + + async def count_texts_async( + self, + texts: list[str], + loop: asyncio.AbstractEventLoop, + *, + live: bool = False, + ) -> list[int]: + """Per-text token counts for a whole batch without blocking the loop. + + ``live=True`` is the mid-run lane: it never touches the shard + processes — it runs on this process's small thread pool with a rayon + pool capped to ``live_workers`` cores. The default (drain) path fans + out across every shard; a worker-shard failure propagates and is + treated as an incomplete drain. + """ + if not texts: + return [] + if self._procs and not live: + return await self._fan_out(self._procs, texts) + if self._thread is None: + raise RuntimeError("BatchTokenizer is closed") + return await loop.run_in_executor( + self._thread, self._encode_lengths_inproc, texts + ) + + @staticmethod + async def _fan_out(procs: list[ProcessPoolExecutor], texts: list[str]) -> list[int]: + chunks = _even_chunks(texts, len(procs)) + futures = [ + asyncio.wrap_future(ex.submit(_worker_encode_lengths, chunk)) + for ex, chunk in zip(procs, chunks, strict=False) + ] + results = await asyncio.gather(*futures) + return [n for r in results for n in r] - def _token_count_worker(self, text: str) -> int: - """Worker entry: return the number of tokens in text.""" - tokenizer = self._get_thread_tokenizer() - return len(tokenizer.tokenize(text)) + # -- sync + chat-template paths (in-process thread) --------------------- - def _token_count_message_worker( + def _token_count_text(self, text: str) -> int: + return len(self._tokenizer.tokenize(text)) # type: ignore[union-attr] + + def _token_count_message( self, content: str, reasoning: str | None, tool_calls: tuple[dict[str, Any], ...] | None, ) -> int: - """Worker entry: tokenize a full assistant message using apply_chat_template. - - Falls back to whitespace-split tokenization if apply_chat_template raises - (e.g. the template does not support tool_calls or reasoning fields). - """ - tokenizer = self._get_thread_tokenizer() + tok = self._tokenizer msg: dict[str, Any] = {"role": "assistant", "content": content or ""} if reasoning: msg["reasoning_content"] = reasoning if tool_calls: msg["tool_calls"] = _normalize_tool_calls_for_template(tool_calls) try: - rendered = tokenizer.apply_chat_template( - [_PREFIX_USER_MSG, msg], - tokenize=False, - add_generation_prompt=False, + rendered = tok.apply_chat_template( # type: ignore[union-attr] + [_PREFIX_USER_MSG, msg], tokenize=False, add_generation_prompt=False ) - full = len(tokenizer.tokenize(rendered)) - prefix_len = getattr(self._thread_local, "prefix_len", 0) - baseline = getattr(self._thread_local, "baseline", 0) - return max(0, full - prefix_len - baseline) + full = len(tok.tokenize(rendered)) # type: ignore[union-attr] + return max(0, full - self._prefix_len - self._baseline) except Exception as exc: key = f"{self._tokenizer_name}:{type(exc).__name__}" if key not in self._fallback_warned: self._fallback_warned.add(key) logger.exception( "apply_chat_template failed for %s (%s); falling back to " - "whitespace tokenization. Tool-call OSL/TPOT may diverge " - "from server-side counts for this run.", + "whitespace tokenization. Tool-call OSL/TPOT may diverge.", self._tokenizer_name, type(exc).__name__, ) @@ -203,43 +432,65 @@ def _token_count_message_worker( parts = [ p for p in (content or None, reasoning or None, tool_calls_json) if p ] - fallback_text = "\n".join(parts) - return self._token_count_worker(fallback_text) - - def token_count(self, text: str) -> int: - """Return the number of tokens in the input string (blocking).""" - if self._executor is None: - raise RuntimeError("TokenizePool is closed") - future = self._executor.submit(self._token_count_worker, text) - return future.result() + return self._token_count_text("\n".join(parts)) - def token_count_message( + async def token_count_message_async( self, content: str, reasoning: str | None, tool_calls: tuple[dict[str, Any], ...] | None, + loop: asyncio.AbstractEventLoop, ) -> int: - """Return the token count for an assistant message (blocking).""" - if self._executor is None: - raise RuntimeError("TokenizePool is closed") - future = self._executor.submit( - self._token_count_message_worker, content, reasoning, tool_calls + """Chat-template message token count without blocking the loop.""" + if self._thread is None: + raise RuntimeError("BatchTokenizer is closed") + return await loop.run_in_executor( + self._thread, self._token_count_message, content, reasoning, tool_calls ) - return future.result() - async def token_count_async( - self, text: str, loop: asyncio.AbstractEventLoop - ) -> int: - """Return the number of tokens without blocking the event loop. + def close(self) -> None: + """Shut down all workers. Idempotent. - Submits directly to the TokenizePool's executor so tokenization runs - on a thread with a pre-loaded thread-local tokenizer instance. + Shards are stopped without waiting (a hung worker must not block + aggregator shutdown) and terminated so an in-flight encode cannot + stall interpreter exit after a drain timeout. """ - if self._executor is None: - raise RuntimeError("TokenizePool is closed") - return await loop.run_in_executor( - self._executor, self._token_count_worker, text - ) + _terminate_procs(self._procs) + self._procs = [] + if self._thread is not None: + self._thread.shutdown(wait=True) + self._thread = None + + def __enter__(self) -> BatchTokenizer: + return self + + def __exit__(self, exc_type: object, exc_val: object, exc_tb: object) -> None: + self.close() + + +# Type alias for the (content, reasoning, tool_calls) tuple a message trigger +# enqueues for chat-template tokenization. +MessageParts = tuple[str, str | None, tuple[dict[str, Any], ...] | None] + + +class TokenCounter(Protocol): + """The async tokenization surface ``TokenBatchQueue`` depends on. + + ``BatchTokenizer`` satisfies this structurally; tests pass lightweight + stubs. Declared as a Protocol so the queue is decoupled from the concrete + tokenizer and test doubles type-check without inheritance. + """ + + async def count_texts_async( + self, + texts: list[str], + loop: asyncio.AbstractEventLoop, + /, + *, + live: bool = False, + ) -> list[int]: + """Per-text token counts (``live=True`` = the bounded mid-run lane).""" + raise NotImplementedError async def token_count_message_async( self, @@ -247,26 +498,189 @@ async def token_count_message_async( reasoning: str | None, tool_calls: tuple[dict[str, Any], ...] | None, loop: asyncio.AbstractEventLoop, + /, ) -> int: - """Return the token count for an assistant message without blocking the event loop.""" - if self._executor is None: - raise RuntimeError("TokenizePool is closed") - return await loop.run_in_executor( - self._executor, - self._token_count_message_worker, - content, - reasoning, - tool_calls, - ) + """Chat-template token count for one assistant message.""" + raise NotImplementedError - def close(self) -> None: - """Shut down the worker pool. Idempotent.""" - if self._executor is not None: - self._executor.shutdown(wait=True) - self._executor = None - def __enter__(self) -> TokenizePool: - return self +class TokenBatchQueue: + """Buffers per-sample tokenization work and clears it in batches. - def __exit__(self, exc_type: object, exc_val: object, exc_tb: object) -> None: - self.close() + Triggers call ``enqueue_text`` / ``enqueue_message`` at event time with an + ``on_count`` callback that records the resulting metric. The queue owns + its own flush cadence: ``start_live`` begins a periodic flush through the + tokenizer's bounded live lane (so live ISL/OSL/TPOT stay current without + touching the benchmark's cores), and ``flush_remaining`` drains everything + left at end-of-run through every shard. + + ``pending`` counts enqueued-but-not-yet-recorded items; it is the + ``n_pending_tasks`` on the snapshot. A non-zero value in the final snapshot + means the end-of-run flush did not finish within the drain budget or failed. + """ + + def __init__( + self, tokenizer: TokenCounter, loop: asyncio.AbstractEventLoop + ) -> None: + self._tokenizer = tokenizer + self._loop = loop + self._text: list[tuple[str, Callable[[int], None]]] = [] + self._msg: list[tuple[MessageParts, Callable[[int], None]]] = [] + self._inflight = 0 + self._live_task: asyncio.Task | None = None + # Serializes flushes so the periodic live flush and the end-of-run + # flush never record the same item twice or race on the pending count. + self._lock = asyncio.Lock() + + def start_live(self, interval_s: float) -> None: + """Begin the periodic live flush (idempotent). + + Failures are logged once and never interrupt the loop — unflushed + items stay visible as ``pending`` and the end-of-run drain picks + them up. + """ + if self._live_task is not None: + return + self._live_task = self._loop.create_task(self._live_flush_loop(interval_s)) + + async def _live_flush_loop(self, interval_s: float) -> None: + failure_logged = False + while True: + await asyncio.sleep(interval_s) + try: + await self.flush(live=True) + except Exception: # noqa: BLE001 — keep live metrics flowing. + if not failure_logged: + failure_logged = True + logger.exception( + "live token flush failed; retrying each interval " + "(further failures logged at debug)" + ) + else: + logger.debug("live token flush failed again") + + @property + def pending(self) -> int: + """Enqueued items not yet tokenized-and-recorded.""" + return self._inflight + + def enqueue_text(self, text: str, on_count: Callable[[int], None]) -> None: + self._inflight += 1 + self._text.append((text, on_count)) + + def enqueue_message( + self, parts: MessageParts, on_count: Callable[[int], None] + ) -> None: + self._inflight += 1 + self._msg.append((parts, on_count)) + + async def flush(self, live: bool = False) -> None: + """Tokenize everything buffered so far and run each ``on_count``. + + ``live=True`` routes text batches through the tokenizer's bounded + live lane instead of the full shard pool, takes at most + ``_LIVE_FLUSH_MAX_ITEMS`` per kind (bounding lock-hold time and the + unstoppable in-flight encode a drain-start cancellation leaves + behind), and re-queues items on failure or cancellation so a mid-run + hiccup never loses samples — the end-of-run drain retries them. Drain-mode failures are terminal: the + un-recorded items stay counted in ``pending`` (``_inflight`` is + decremented only after a callback runs) and surface as an incomplete + drain, not as silently dropped samples. Items are detached from the + buffer up front so concurrent enqueues land in the next flush. + """ + async with self._lock: + if not (self._text or self._msg): + return + if live: + cap = _LIVE_FLUSH_MAX_ITEMS + text_items = self._text[:cap] + del self._text[:cap] # in-place: O(cap), not O(backlog). + msg_items = self._msg[:cap] + del self._msg[:cap] + else: + text_items, self._text = self._text, [] + msg_items, self._msg = self._msg, [] + # The text and message phases fail independently — they run on + # separate executors, so a dead text shard must not drop message + # items that would still succeed (and vice versa). The first + # failure is re-raised after both phases so callers still see it. + failure: Exception | None = None + if text_items: + try: + counts = await self._tokenizer.count_texts_async( + [t for t, _ in text_items], self._loop, live=live + ) + except asyncio.CancelledError: + if live: + self._text[:0] = text_items + self._msg[:0] = msg_items + raise + except Exception as exc: # noqa: BLE001 — isolate phases. + failure = exc + if live: + # A live hiccup must not lose samples: give the items + # back so the end-of-run drain (full pool) retries. + # Drain failures are terminal and stay pending-only. + self._text[:0] = text_items + else: + for (_, on_count), count in zip(text_items, counts, strict=True): + self._record(on_count, count) + for i, ((content, reasoning, tool_calls), on_count) in enumerate(msg_items): + try: + count = await self._tokenizer.token_count_message_async( + content, reasoning, tool_calls, self._loop + ) + except asyncio.CancelledError: + if live: + self._msg[:0] = msg_items[i:] + raise + except Exception as exc: # noqa: BLE001 — isolate items. + failure = failure or exc + if live: + self._msg.append(((content, reasoning, tool_calls), on_count)) + continue + self._record(on_count, count) + if failure is not None: + raise failure + + def _record(self, on_count: Callable[[int], None], count: int) -> None: + """Run one recorder callback; a raising recorder must not poison the + rest of the batch, and the item still counts as recorded.""" + try: + on_count(count) + except Exception: # noqa: BLE001 — per-item isolation. + logger.exception("token metric recorder failed") + finally: + self._inflight -= 1 + + async def flush_remaining(self, timeout: float | None) -> int: + """End-of-run flush, bounded by ``timeout`` seconds. + + Stops the live flush loop, then drains through the full shard pool. + Returns the number of items still un-tokenized — non-zero if the budget + was exhausted (``timeout`` reached) or tokenization failed. ``None`` + waits indefinitely. Never raises: a failure here must not stop the + aggregator from publishing the (incomplete) final snapshot. + """ + if self._live_task is not None: + self._live_task.cancel() + await asyncio.gather(self._live_task, return_exceptions=True) + self._live_task = None + if self._inflight == 0: + return 0 + try: + if timeout is None: + await self.flush() + else: + await asyncio.wait_for(self.flush(), timeout) + except TimeoutError: + logger.warning( + "tokenizer drain timed out after %.1fs; %d items not counted", + timeout, + self._inflight, + ) + except Exception: # noqa: BLE001 — drain must not block finalize. + logger.exception( + "tokenizer drain failed; %d items not counted", self._inflight + ) + return self._inflight diff --git a/src/inference_endpoint/config/schema.py b/src/inference_endpoint/config/schema.py index 9226d7f85..19447dd2b 100644 --- a/src/inference_endpoint/config/schema.py +++ b/src/inference_endpoint/config/schema.py @@ -578,18 +578,17 @@ class DrainConfig(BaseModel): alias="--metrics-drain-timeout", help=( "Wall-clock budget (seconds) for the metrics aggregator to finish " - "in-flight async tokenize tasks after the run ends before cancelling " - "them. Set to 0 to wait indefinitely. Increase for large datasets or " - "long-context workloads where ISL/OSL/TPOT tokenization lags behind " - "request throughput." + "tokenizing buffered samples after the run ends. Set to 0 to wait " + "indefinitely. Increase for very large datasets where the end-of-run " + "tokenize batch is big." ), ), ] = Field( - 60.0, + 300.0, ge=0, description=( - "Wall-clock budget (seconds) for the metrics aggregator to drain " - "in-flight tokenize tasks after ENDED (default: 60.0; 0 = unlimited)." + "Wall-clock budget (seconds) to finish tokenizing buffered samples " + "after ENDED (default: 300.0; 0 = unlimited)." ), ) metrics_tokenizer_workers: Annotated[ @@ -597,15 +596,18 @@ class DrainConfig(BaseModel): cyclopts.Parameter( alias="--metrics-tokenizer-workers", help=( - "Number of tokenizer worker threads in the metrics aggregator. " - "Increase if ISL/OSL/TPOT tokenization can't keep up with request " - "throughput (symptoms: large drain timeout warning at run end)." + "In-process tokenizer threads for live (mid-run) ISL/OSL/TPOT in " + "the metrics aggregator. 0 defers all tokenization to the " + "end-of-run drain, which always uses the auto-sized sharded pool." ), ), ] = Field( 2, - ge=1, - description="Number of tokenizer worker threads in the metrics aggregator (default: 2).", + ge=0, + description=( + "In-process tokenizer threads for live (mid-run) ISL/OSL/TPOT " + "(default: 2; 0 = defer everything to the end-of-run drain)." + ), ) diff --git a/src/inference_endpoint/config/templates/concurrency_template_full.yaml b/src/inference_endpoint/config/templates/concurrency_template_full.yaml index 38829f0f5..42c449d1d 100644 --- a/src/inference_endpoint/config/templates/concurrency_template_full.yaml +++ b/src/inference_endpoint/config/templates/concurrency_template_full.yaml @@ -79,8 +79,8 @@ settings: warmup_timeout_s: 240.0 # Warmup drain timeout in seconds (None = wait indefinitely) performance_timeout_s: 240.0 # Performance drain timeout in seconds (None = wait indefinitely) accuracy_timeout_s: null # Accuracy drain timeout in seconds (None = wait indefinitely) - metrics_drain_timeout_s: 60.0 # Wall-clock budget (seconds) for the metrics aggregator to drain in-flight tokenize tasks after ENDED (default: 60.0; 0 = unlimited). - metrics_tokenizer_workers: 2 # Number of tokenizer worker threads in the metrics aggregator (default: 2). + metrics_drain_timeout_s: 300.0 # Wall-clock budget (seconds) to finish tokenizing buffered samples after ENDED (default: 300.0; 0 = unlimited). + metrics_tokenizer_workers: 2 # In-process tokenizer threads for live (mid-run) ISL/OSL/TPOT (default: 2; 0 = defer everything to the end-of-run drain). warmup: enabled: false # Enable warmup phase before performance run n_requests: null # Warmup request count (None = full dataset once) diff --git a/src/inference_endpoint/config/templates/offline_template_full.yaml b/src/inference_endpoint/config/templates/offline_template_full.yaml index c3454d5da..b5f4f5a23 100644 --- a/src/inference_endpoint/config/templates/offline_template_full.yaml +++ b/src/inference_endpoint/config/templates/offline_template_full.yaml @@ -79,8 +79,8 @@ settings: warmup_timeout_s: 240.0 # Warmup drain timeout in seconds (None = wait indefinitely) performance_timeout_s: 240.0 # Performance drain timeout in seconds (None = wait indefinitely) accuracy_timeout_s: null # Accuracy drain timeout in seconds (None = wait indefinitely) - metrics_drain_timeout_s: 60.0 # Wall-clock budget (seconds) for the metrics aggregator to drain in-flight tokenize tasks after ENDED (default: 60.0; 0 = unlimited). - metrics_tokenizer_workers: 2 # Number of tokenizer worker threads in the metrics aggregator (default: 2). + metrics_drain_timeout_s: 300.0 # Wall-clock budget (seconds) to finish tokenizing buffered samples after ENDED (default: 300.0; 0 = unlimited). + metrics_tokenizer_workers: 2 # In-process tokenizer threads for live (mid-run) ISL/OSL/TPOT (default: 2; 0 = defer everything to the end-of-run drain). warmup: enabled: false # Enable warmup phase before performance run n_requests: null # Warmup request count (None = full dataset once) diff --git a/src/inference_endpoint/config/templates/online_template_full.yaml b/src/inference_endpoint/config/templates/online_template_full.yaml index 5bea95329..4271ff792 100644 --- a/src/inference_endpoint/config/templates/online_template_full.yaml +++ b/src/inference_endpoint/config/templates/online_template_full.yaml @@ -79,8 +79,8 @@ settings: warmup_timeout_s: 240.0 # Warmup drain timeout in seconds (None = wait indefinitely) performance_timeout_s: 240.0 # Performance drain timeout in seconds (None = wait indefinitely) accuracy_timeout_s: null # Accuracy drain timeout in seconds (None = wait indefinitely) - metrics_drain_timeout_s: 60.0 # Wall-clock budget (seconds) for the metrics aggregator to drain in-flight tokenize tasks after ENDED (default: 60.0; 0 = unlimited). - metrics_tokenizer_workers: 2 # Number of tokenizer worker threads in the metrics aggregator (default: 2). + metrics_drain_timeout_s: 300.0 # Wall-clock budget (seconds) to finish tokenizing buffered samples after ENDED (default: 300.0; 0 = unlimited). + metrics_tokenizer_workers: 2 # In-process tokenizer threads for live (mid-run) ISL/OSL/TPOT (default: 2; 0 = defer everything to the end-of-run drain). warmup: enabled: false # Enable warmup phase before performance run n_requests: null # Warmup request count (None = full dataset once) diff --git a/src/inference_endpoint/endpoint_client/cpu_affinity.py b/src/inference_endpoint/endpoint_client/cpu_affinity.py index 8972a59d9..0de6e39a4 100644 --- a/src/inference_endpoint/endpoint_client/cpu_affinity.py +++ b/src/inference_endpoint/endpoint_client/cpu_affinity.py @@ -317,6 +317,32 @@ def pin_loadgen( return None +@require_linux +def expand_to_all_online_cpus() -> set[int]: + """Reset the current process's affinity to every online CPU. + + Undoes a narrow mask inherited from a pinned parent (subprocesses spawned + after ``pin_loadgen`` inherit the loadgen mask). The kernel intersects the + request with the cgroup cpuset, so container/Slurm CPU limits still apply. + + Returns: + The effective CPU set after the reset. + + Raises: + UnsupportedPlatformError: If not running on Linux. + """ + online = _read_sysfs_cpulist(_SYSFS_CPU / "online") or set() + if online: + try: + os.sched_setaffinity(0, online) + except OSError as e: + logger.warning(f"Could not expand CPU affinity: {e}") + try: + return os.sched_getaffinity(0) + except OSError: + return online + + @require_linux def set_cpu_affinity(pid: int, cpus: set[int]) -> bool: """Set CPU affinity for a process. diff --git a/tests/integration/async_utils/services/metrics_aggregator/test_signal_handling.py b/tests/integration/async_utils/services/metrics_aggregator/test_signal_handling.py index 010536c09..62db80b04 100644 --- a/tests/integration/async_utils/services/metrics_aggregator/test_signal_handling.py +++ b/tests/integration/async_utils/services/metrics_aggregator/test_signal_handling.py @@ -64,6 +64,13 @@ def _spawn_aggregator( metrics_socket, "--metrics-output-dir", str(output_dir), + # Required by the entrypoint, but inert here: no tokenizer is + # configured (so no live tokenization) and the run is signalled + # rather than ENDED, so the drain budget is never reached. + "--drain-timeout", + "5", + "--tokenizer-workers", + "0", ], # New process group so we can signal it without disturbing the # test runner. diff --git a/tests/unit/async_utils/services/metrics_aggregator/conftest.py b/tests/unit/async_utils/services/metrics_aggregator/conftest.py index 7adbe0361..aae7a07ac 100644 --- a/tests/unit/async_utils/services/metrics_aggregator/conftest.py +++ b/tests/unit/async_utils/services/metrics_aggregator/conftest.py @@ -49,24 +49,26 @@ from inference_endpoint.core.types import TextModelOutput # --------------------------------------------------------------------------- -# Mock TokenizePool — used by tests that exercise async triggers directly. +# Mock BatchTokenizer — whitespace token counts; matches the BatchTokenizer +# surface the TokenBatchQueue calls (count_texts_async + message path). # --------------------------------------------------------------------------- -class MockTokenizePool: - """Mock TokenizePool that splits on whitespace with artificial async delay.""" +class MockBatchTokenizer: + """Mock BatchTokenizer that splits on whitespace with optional async delay.""" - def __init__(self, delay: float = 0.01) -> None: + def __init__(self, delay: float = 0.0) -> None: self._delay = delay - def token_count(self, text: str) -> int: - return len(text.split()) - - async def token_count_async( - self, text: str, _loop: asyncio.AbstractEventLoop - ) -> int: - await asyncio.sleep(self._delay) - return len(text.split()) + async def count_texts_async( + self, + texts: list[str], + _loop: asyncio.AbstractEventLoop, + live: bool = False, + ) -> list[int]: + if self._delay: + await asyncio.sleep(self._delay) + return [len(t.split()) for t in texts] async def token_count_message_async( self, @@ -77,7 +79,8 @@ async def token_count_message_async( ) -> int: import msgspec - await asyncio.sleep(self._delay) + if self._delay: + await asyncio.sleep(self._delay) tool_calls_str = ( msgspec.json.encode(list(tool_calls)).decode() if tool_calls else "" ) @@ -164,9 +167,11 @@ def make_aggregator( loop: asyncio.AbstractEventLoop, socket_name: str, *, - tokenize_pool=None, + tokenizer=None, + live_flush_interval_s: float | None = None, streaming: bool = True, shutdown_event: asyncio.Event | None = None, + drain_timeout_s: float | None = None, ) -> tuple[MetricsAggregatorService, MetricsRegistry, MagicMock]: """Construct an aggregator wired to a real SUB socket and a mocked publisher. @@ -195,8 +200,10 @@ def make_aggregator( publish_interval_s=0.25, sig_figs=3, n_histogram_buckets=10, - tokenize_pool=tokenize_pool, + tokenizer=tokenizer, + live_flush_interval_s=live_flush_interval_s, streaming=streaming, shutdown_event=shutdown_event, + drain_timeout_s=drain_timeout_s, ) return agg, registry, publisher diff --git a/tests/unit/async_utils/services/metrics_aggregator/test_aggregator.py b/tests/unit/async_utils/services/metrics_aggregator/test_aggregator.py index 9877aee5d..075e4a0d5 100644 --- a/tests/unit/async_utils/services/metrics_aggregator/test_aggregator.py +++ b/tests/unit/async_utils/services/metrics_aggregator/test_aggregator.py @@ -44,7 +44,7 @@ from inference_endpoint.core.types import ErrorData, PromptData, TextModelOutput from .conftest import ( - MockTokenizePool, + MockBatchTokenizer, make_aggregator, sample_event, session_event, @@ -312,10 +312,10 @@ async def test_chunk_deltas(self, tmp_path): async def test_non_streaming_latency_only(self, tmp_path): """Non-streaming: emits sample_latency_ns + OSL, no TTFT/chunk_delta/TPOT.""" loop = asyncio.get_event_loop() - pool = MockTokenizePool(delay=0.0) + tokenizer = MockBatchTokenizer(delay=0.0) with ManagedZMQContext.scoped(socket_dir=str(tmp_path)) as ctx: agg, registry, _ = make_aggregator( - ctx, loop, "agg_non_streaming", tokenize_pool=pool + ctx, loop, "agg_non_streaming", tokenizer=tokenizer ) try: await agg.process( @@ -332,7 +332,7 @@ async def test_non_streaming_latency_only(self, tmp_path): ), ] ) - await agg._table.drain_tasks() + await agg._token_queue.flush() # sample_latency = 3000-1000 = 2000 assert ( snapshot_series_total( @@ -380,7 +380,7 @@ async def test_chunk_delta_not_emitted_without_last_recv(self, tmp_path): # --------------------------------------------------------------------------- -# ISL (token_ids path -- sync, no tokenize_pool needed) +# ISL (token_ids path -- sync, no tokenizer needed) # --------------------------------------------------------------------------- @@ -766,7 +766,7 @@ async def test_total_vs_tracked_counters(self, tmp_path): # --------------------------------------------------------------------------- -# Async trigger tests (with mock TokenizePool and real event loop) +# Token trigger tests (with mock BatchTokenizer and real event loop) # --------------------------------------------------------------------------- @@ -776,10 +776,10 @@ class TestAsyncTriggers: async def test_isl_text_path_async(self, tmp_path): """ISL with text prompt triggers async tokenization.""" loop = asyncio.get_event_loop() - pool = MockTokenizePool(delay=0.01) + tokenizer = MockBatchTokenizer(delay=0.01) with ManagedZMQContext.scoped(socket_dir=str(tmp_path)) as ctx: agg, registry, _ = make_aggregator( - ctx, loop, "agg_isl_text_async", tokenize_pool=pool + ctx, loop, "agg_isl_text_async", tokenizer=tokenizer ) try: await agg.process( @@ -796,7 +796,7 @@ async def test_isl_text_path_async(self, tmp_path): ] ) # ISL task is in-flight; drain it - await agg._table.drain_tasks() + await agg._token_queue.flush() assert snapshot_series_total(registry, MetricSeriesKey.ISL.value) == 4 finally: agg.close() @@ -805,10 +805,10 @@ async def test_isl_text_path_async(self, tmp_path): async def test_osl_emitted_on_complete(self, tmp_path): """OSL is emitted via async tokenization when COMPLETE carries text.""" loop = asyncio.get_event_loop() - pool = MockTokenizePool(delay=0.01) + tokenizer = MockBatchTokenizer(delay=0.01) with ManagedZMQContext.scoped(socket_dir=str(tmp_path)) as ctx: agg, registry, _ = make_aggregator( - ctx, loop, "agg_osl_complete", tokenize_pool=pool + ctx, loop, "agg_osl_complete", tokenizer=tokenizer ) try: await agg.process( @@ -825,7 +825,7 @@ async def test_osl_emitted_on_complete(self, tmp_path): ), ] ) - await agg._table.drain_tasks() + await agg._token_queue.flush() # sample_latency_ns = 5000-1000 = 4000 assert ( snapshot_series_total( @@ -842,10 +842,10 @@ async def test_osl_emitted_on_complete(self, tmp_path): async def test_tpot_emitted_for_streaming(self, tmp_path): """TPOT is emitted for streaming responses using text_after_first_chunk.""" loop = asyncio.get_event_loop() - pool = MockTokenizePool(delay=0.0) + tokenizer = MockBatchTokenizer(delay=0.0) with ManagedZMQContext.scoped(socket_dir=str(tmp_path)) as ctx: agg, registry, _ = make_aggregator( - ctx, loop, "agg_tpot_streaming", tokenize_pool=pool + ctx, loop, "agg_tpot_streaming", tokenizer=tokenizer ) try: await agg.process( @@ -864,7 +864,7 @@ async def test_tpot_emitted_for_streaming(self, tmp_path): ), ] ) - await agg._table.drain_tasks() + await agg._token_queue.flush() # OSL = "hello world foo" = 3 tokens assert snapshot_series_total(registry, MetricSeriesKey.OSL.value) == 3 # tpot = (5000 - 2000) / token_count("world foo") = 3000 / 2 = 1500 @@ -878,10 +878,10 @@ async def test_tpot_emitted_for_streaming(self, tmp_path): async def test_tpot_skipped_when_single_chunk(self, tmp_path): """TPOT is not emitted when there are no tokens after the first chunk.""" loop = asyncio.get_event_loop() - pool = MockTokenizePool(delay=0.0) + tokenizer = MockBatchTokenizer(delay=0.0) with ManagedZMQContext.scoped(socket_dir=str(tmp_path)) as ctx: agg, registry, _ = make_aggregator( - ctx, loop, "agg_tpot_single_chunk", tokenize_pool=pool + ctx, loop, "agg_tpot_single_chunk", tokenizer=tokenizer ) try: await agg.process( @@ -900,7 +900,7 @@ async def test_tpot_skipped_when_single_chunk(self, tmp_path): ), ] ) - await agg._table.drain_tasks() + await agg._token_queue.flush() assert snapshot_series_total(registry, MetricSeriesKey.OSL.value) == 1 assert ( snapshot_series_count(registry, MetricSeriesKey.TPOT_NS.value) == 0 @@ -914,13 +914,13 @@ async def test_tpot_not_emitted_without_streaming_flag(self, tmp_path): registered at all — the aggregator's snapshot has no entry for them. """ loop = asyncio.get_event_loop() - pool = MockTokenizePool(delay=0.0) + tokenizer = MockBatchTokenizer(delay=0.0) with ManagedZMQContext.scoped(socket_dir=str(tmp_path)) as ctx: agg, registry, _ = make_aggregator( ctx, loop, "agg_tpot_no_streaming", - tokenize_pool=pool, + tokenizer=tokenizer, streaming=False, ) try: @@ -939,7 +939,7 @@ async def test_tpot_not_emitted_without_streaming_flag(self, tmp_path): ), ] ) - await agg._table.drain_tasks() + await agg._token_queue.flush() # sample_latency / OSL still emitted in non-streaming mode. assert ( snapshot_series_total( @@ -959,10 +959,10 @@ async def test_tpot_not_emitted_without_streaming_flag(self, tmp_path): async def test_tpot_non_streaming_output_skipped(self, tmp_path): """TPOT is not emitted for non-streaming (str) TextModelOutput.""" loop = asyncio.get_event_loop() - pool = MockTokenizePool(delay=0.0) + tokenizer = MockBatchTokenizer(delay=0.0) with ManagedZMQContext.scoped(socket_dir=str(tmp_path)) as ctx: agg, registry, _ = make_aggregator( - ctx, loop, "agg_tpot_str_output", tokenize_pool=pool + ctx, loop, "agg_tpot_str_output", tokenizer=tokenizer ) try: await agg.process( @@ -981,7 +981,7 @@ async def test_tpot_non_streaming_output_skipped(self, tmp_path): ), ] ) - await agg._table.drain_tasks() + await agg._token_queue.flush() assert snapshot_series_total(registry, MetricSeriesKey.OSL.value) == 3 assert ( snapshot_series_count(registry, MetricSeriesKey.TPOT_NS.value) == 0 @@ -990,13 +990,36 @@ async def test_tpot_non_streaming_output_skipped(self, tmp_path): agg.close() @pytest.mark.asyncio - async def test_drain_tasks_awaits_in_flight(self, tmp_path): - """drain_tasks() properly awaits all in-flight async trigger tasks.""" + async def test_started_arms_the_live_flush_loop(self, tmp_path): + """STARTED starts the queue's live loop when an interval is set.""" loop = asyncio.get_event_loop() - pool = MockTokenizePool(delay=0.05) + with ManagedZMQContext.scoped(socket_dir=str(tmp_path)) as ctx: + agg, _, _ = make_aggregator( + ctx, + loop, + "agg_live_arm", + tokenizer=MockBatchTokenizer(), + live_flush_interval_s=0.01, + ) + try: + await agg.process([session_event(SessionEventType.STARTED, ts=0)]) + assert agg._token_queue is not None + assert agg._token_queue._live_task is not None + await agg.process([session_event(SessionEventType.ENDED, ts=100)]) + assert ( + agg._token_queue._live_task is None + ), "drain must stop the live loop" + finally: + agg.close() + + @pytest.mark.asyncio + async def test_flush_records_buffered_tokenizations(self, tmp_path): + """fire() buffers tokenization; flush() tokenizes the batch and records.""" + loop = asyncio.get_event_loop() + tokenizer = MockBatchTokenizer() with ManagedZMQContext.scoped(socket_dir=str(tmp_path)) as ctx: agg, registry, _ = make_aggregator( - ctx, loop, "agg_drain_in_flight", tokenize_pool=pool + ctx, loop, "agg_flush_records", tokenizer=tokenizer ) try: await agg.process( @@ -1012,23 +1035,24 @@ async def test_drain_tasks_awaits_in_flight(self, tmp_path): ), ] ) - # Tasks are in-flight but not yet complete - assert agg._table.in_flight_tasks_count > 0 + assert agg._token_queue is not None + # Enqueued by fire(), not yet tokenized (no tick/drain flush). + assert agg._token_queue.pending > 0 - await agg._table.drain_tasks() - assert agg._table.in_flight_tasks_count == 0 + await agg._token_queue.flush() + assert agg._token_queue.pending == 0 assert snapshot_series_total(registry, MetricSeriesKey.ISL.value) == 5 finally: agg.close() @pytest.mark.asyncio - async def test_shutdown_drains_async_tasks(self, tmp_path): - """ENDED drains in-flight async tasks before finalizing.""" + async def test_shutdown_flushes_buffered_tokenizations(self, tmp_path): + """ENDED flushes buffered tokenizations before finalizing.""" loop = asyncio.get_event_loop() - pool = MockTokenizePool(delay=0.02) + tokenizer = MockBatchTokenizer(delay=0.02) with ManagedZMQContext.scoped(socket_dir=str(tmp_path)) as ctx: agg, registry, publisher = make_aggregator( - ctx, loop, "agg_shutdown_drain", tokenize_pool=pool + ctx, loop, "agg_shutdown_drain", tokenizer=tokenizer ) try: await agg.process( @@ -1045,16 +1069,56 @@ async def test_shutdown_drains_async_tasks(self, tmp_path): session_event(SessionEventType.ENDED, ts=2000), ] ) - # After ENDED, drain_tasks ran inside process() — ISL emitted. + # After ENDED, flush_remaining ran inside process() — ISL emitted. assert snapshot_series_total(registry, MetricSeriesKey.ISL.value) == 3 publisher.publish_final.assert_awaited_once() finally: agg.close() - # NOTE(agents): Trigger exception handling (logger.exception paths) is not - # exercised here. Adding a MockTokenizePool that raises on - # token_count_async would let us assert no metric is emitted, the - # aggregator does not crash, and the task set is cleaned up. + @pytest.mark.asyncio + async def test_drain_failure_reports_pending_and_finalizes(self, tmp_path): + """A tokenizer error during the ENDED drain must not skip finalize. + + flush_remaining swallows non-timeout failures and returns the stuck + count, so publish_final still runs with n_pending_tasks > 0 (incomplete + drain) instead of the error escaping process() and hanging main(). + """ + loop = asyncio.get_event_loop() + + class FailingBatchTokenizer: + async def count_texts_async(self, texts, _loop, live=False): + raise RuntimeError("tokenizer backend died") + + async def token_count_message_async(self, *args): + raise RuntimeError("tokenizer backend died") + + with ManagedZMQContext.scoped(socket_dir=str(tmp_path)) as ctx: + agg, _, publisher = make_aggregator( + ctx, loop, "agg_drain_failure", tokenizer=FailingBatchTokenizer() + ) + try: + await agg.process( + [ + session_event( + SessionEventType.START_PERFORMANCE_TRACKING, ts=0 + ), + sample_event( + SampleEventType.ISSUED, + "s1", + ts=1000, + data=PromptData(text="some text to tokenize"), + ), + ] + ) + assert agg._token_queue is not None + assert agg._token_queue.pending > 0 + await agg.process([session_event(SessionEventType.ENDED, ts=2000)]) + + publisher.publish_final.assert_awaited_once() + assert publisher.publish_final.await_args.kwargs["n_pending_tasks"] > 0 + publisher.aclose.assert_awaited_once() + finally: + agg.close() @pytest.mark.asyncio async def test_drain_timeout_reports_pending_count(self, tmp_path): @@ -1068,29 +1132,21 @@ async def test_drain_timeout_reports_pending_count(self, tmp_path): """ loop = asyncio.get_event_loop() - class BlockingTokenizePool: - async def token_count_async(self, text, _loop): + class BlockingBatchTokenizer: + async def count_texts_async(self, texts, _loop, live=False): await asyncio.sleep(10.0) # exceeds drain timeout - return 0 + return [0] * len(texts) - def token_count(self, text): + async def token_count_message_async(self, *args): + await asyncio.sleep(10.0) return 0 - def close(self): - pass - - def __enter__(self): - return self - - def __exit__(self, *args): - self.close() - with ManagedZMQContext.scoped(socket_dir=str(tmp_path)) as ctx: agg, _, publisher = make_aggregator( ctx, loop, "agg_drain_timeout", - tokenize_pool=BlockingTokenizePool(), + tokenizer=BlockingBatchTokenizer(), ) agg._drain_timeout_s = 0.05 try: @@ -1107,9 +1163,10 @@ def __exit__(self, *args): ), ] ) + assert agg._token_queue is not None assert ( - agg._table.in_flight_tasks_count > 0 - ), "precondition: ISL task must be in-flight before ENDED" + agg._token_queue.pending > 0 + ), "precondition: ISL must be buffered before ENDED" await agg.process([session_event(SessionEventType.ENDED, ts=2000)]) publisher.publish_final.assert_awaited_once() @@ -1125,10 +1182,10 @@ def __exit__(self, *args): async def test_tpot_osl_for_tool_call_complete(self, tmp_path): """OSL and TPOT use message-path tokenization when COMPLETE carries tool_calls.""" loop = asyncio.get_event_loop() - pool = MockTokenizePool(delay=0.0) + tokenizer = MockBatchTokenizer(delay=0.0) with ManagedZMQContext.scoped(socket_dir=str(tmp_path)) as ctx: agg, registry, _ = make_aggregator( - ctx, loop, "agg_tpot_osl_tool_call", tokenize_pool=pool + ctx, loop, "agg_tpot_osl_tool_call", tokenizer=tokenizer ) try: tool_call = { @@ -1151,7 +1208,7 @@ async def test_tpot_osl_for_tool_call_complete(self, tmp_path): ), ] ) - await agg._table.drain_tasks() + await agg._token_queue.flush() # OSL = token_count("ok" + tool_calls_json) = 2 assert snapshot_series_total(registry, MetricSeriesKey.OSL.value) == 2 # tpot = (5000 - 2000) / token_count(tool_calls_json) = 3000 / 1 = 3000 diff --git a/tests/unit/async_utils/services/metrics_aggregator/test_aggregator_error_handler.py b/tests/unit/async_utils/services/metrics_aggregator/test_aggregator_error_handler.py index 4e8222c48..40e6eb91b 100644 --- a/tests/unit/async_utils/services/metrics_aggregator/test_aggregator_error_handler.py +++ b/tests/unit/async_utils/services/metrics_aggregator/test_aggregator_error_handler.py @@ -82,6 +82,7 @@ def _make_aggregator( sig_figs=3, n_histogram_buckets=10, streaming=streaming, + drain_timeout_s=None, ) return agg, registry, publisher diff --git a/tests/unit/async_utils/services/metrics_aggregator/test_main_signal_handler.py b/tests/unit/async_utils/services/metrics_aggregator/test_main_signal_handler.py index 550a4863c..13fb1f40b 100644 --- a/tests/unit/async_utils/services/metrics_aggregator/test_main_signal_handler.py +++ b/tests/unit/async_utils/services/metrics_aggregator/test_main_signal_handler.py @@ -27,6 +27,7 @@ import asyncio import gc import weakref +from types import SimpleNamespace from unittest.mock import AsyncMock, MagicMock import pytest @@ -50,7 +51,7 @@ async def test_sigterm_handler_holds_strong_reference_to_finalize_task(): registry = MagicMock() table = MagicMock() table.total_tracked_duration_ns = 0 - table.in_flight_tasks_count = 0 + token_queue = SimpleNamespace(pending=0) # publish_final blocks on an event so we can observe the task # mid-execution and exercise the strong-ref contract. @@ -69,6 +70,7 @@ async def _slow_publish(*args, **kwargs): registry=registry, publisher=publisher, table=table, + token_queue=token_queue, shutdown_event=shutdown_event, ) @@ -122,7 +124,7 @@ async def test_sigterm_handler_refreshes_tracked_duration(): registry = MagicMock() table = MagicMock() table.total_tracked_duration_ns = 12345 - table.in_flight_tasks_count = 3 + token_queue = SimpleNamespace(pending=3) publisher = MagicMock() publisher.publish_final = AsyncMock() @@ -134,6 +136,7 @@ async def test_sigterm_handler_refreshes_tracked_duration(): registry=registry, publisher=publisher, table=table, + token_queue=token_queue, shutdown_event=shutdown_event, ) on_sigterm() diff --git a/tests/unit/async_utils/services/metrics_aggregator/test_metrics_table.py b/tests/unit/async_utils/services/metrics_aggregator/test_metrics_table.py index 077923ff8..4ed957a98 100644 --- a/tests/unit/async_utils/services/metrics_aggregator/test_metrics_table.py +++ b/tests/unit/async_utils/services/metrics_aggregator/test_metrics_table.py @@ -34,6 +34,9 @@ from inference_endpoint.async_utils.services.metrics_aggregator.registry import ( MetricsRegistry, ) +from inference_endpoint.async_utils.services.metrics_aggregator.token_metrics import ( + TokenBatchQueue, +) from inference_endpoint.core.record import ( EventRecord, SampleEventType, @@ -294,13 +297,13 @@ async def test_osl_with_tool_calls_uses_message_path(self): ) from inference_endpoint.core.types import TextModelOutput - from .conftest import MockTokenizePool, snapshot_series_count + from .conftest import MockBatchTokenizer, snapshot_series_count registry = MetricsRegistry() registry.register_series("osl", hdr_low=1, hdr_high=100_000) loop = asyncio.get_running_loop() - pool = MockTokenizePool(delay=0) - trigger = OslTrigger(registry, pool, loop) + queue = TokenBatchQueue(MockBatchTokenizer(), loop) + trigger = OslTrigger(registry, queue) tool_calls = ( { @@ -317,9 +320,8 @@ async def test_osl_with_tool_calls_uses_message_path(self): data=tmo, ) row = SampleRow(sample_uuid="s1") - task = trigger.fire(ev, row, {}) - assert task is not None - await task + trigger.fire(ev, row, {}) + await queue.flush() assert snapshot_series_count(registry, "osl") == 1 @@ -331,13 +333,13 @@ async def test_osl_without_tool_calls_uses_text_path(self): ) from inference_endpoint.core.types import TextModelOutput - from .conftest import MockTokenizePool, snapshot_series_count + from .conftest import MockBatchTokenizer, snapshot_series_count registry = MetricsRegistry() registry.register_series("osl", hdr_low=1, hdr_high=100_000) loop = asyncio.get_running_loop() - pool = MockTokenizePool(delay=0) - trigger = OslTrigger(registry, pool, loop) + queue = TokenBatchQueue(MockBatchTokenizer(), loop) + trigger = OslTrigger(registry, queue) tmo = TextModelOutput(output="hello world") ev = EventRecord( @@ -347,9 +349,8 @@ async def test_osl_without_tool_calls_uses_text_path(self): data=tmo, ) row = SampleRow(sample_uuid="s1") - task = trigger.fire(ev, row, {}) - assert task is not None - await task + trigger.fire(ev, row, {}) + await queue.flush() assert snapshot_series_count(registry, "osl") == 1 @@ -368,15 +369,15 @@ async def test_tpot_tool_calls_only_response(self): ) from inference_endpoint.core.types import TextModelOutput - from .conftest import MockTokenizePool, snapshot_series_count + from .conftest import MockBatchTokenizer, snapshot_series_count registry = MetricsRegistry() registry.register_series( "tpot_ns", hdr_low=1, hdr_high=100_000_000_000, dtype=float ) loop = asyncio.get_running_loop() - pool = MockTokenizePool(delay=0) - trigger = TpotTrigger(registry, pool, loop) + queue = TokenBatchQueue(MockBatchTokenizer(), loop) + trigger = TpotTrigger(registry, queue) tool_calls = ( { @@ -395,9 +396,8 @@ async def test_tpot_tool_calls_only_response(self): row = SampleRow(sample_uuid="s1") # RECV_FIRST_NS was set at t=1000 pre_change = {SampleField.RECV_FIRST_NS: 1000} - task = trigger.fire(ev, row, pre_change) - assert task is not None - await task + trigger.fire(ev, row, pre_change) + await queue.flush() assert snapshot_series_count(registry, "tpot_ns") == 1 @@ -409,15 +409,15 @@ async def test_tpot_uses_tool_call_deltas_after_first_chunk(self): ) from inference_endpoint.core.types import TextModelOutput - from .conftest import MockTokenizePool, snapshot_series_total + from .conftest import MockBatchTokenizer, snapshot_series_total registry = MetricsRegistry() registry.register_series( "tpot_ns", hdr_low=1, hdr_high=100_000_000_000, dtype=float ) loop = asyncio.get_running_loop() - pool = MockTokenizePool(delay=0) - trigger = TpotTrigger(registry, pool, loop) + queue = TokenBatchQueue(MockBatchTokenizer(), loop) + trigger = TpotTrigger(registry, queue) tool_call_chunks = ( ( @@ -442,8 +442,7 @@ async def test_tpot_uses_tool_call_deltas_after_first_chunk(self): ) row = SampleRow(sample_uuid="s1") pre_change = {SampleField.RECV_FIRST_NS: 1000} - task = trigger.fire(ev, row, pre_change) - assert task is not None - await task + trigger.fire(ev, row, pre_change) + await queue.flush() assert snapshot_series_total(registry, "tpot_ns") == pytest.approx(2000.0) diff --git a/tests/unit/async_utils/services/metrics_aggregator/test_token_metrics.py b/tests/unit/async_utils/services/metrics_aggregator/test_token_metrics.py index e25bf0022..ba0d2e3e9 100644 --- a/tests/unit/async_utils/services/metrics_aggregator/test_token_metrics.py +++ b/tests/unit/async_utils/services/metrics_aggregator/test_token_metrics.py @@ -13,27 +13,37 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for TokenizePool thread-safety and correctness.""" +"""Tests for BatchTokenizer and TokenBatchQueue.""" import asyncio import time -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import Future +from concurrent.futures.process import BrokenProcessPool from unittest.mock import patch import pytest +from inference_endpoint.async_utils.services.metrics_aggregator import ( + token_metrics as token_metrics_module, +) from inference_endpoint.async_utils.services.metrics_aggregator.token_metrics import ( - TokenizePool, + BatchTokenizer, + TokenBatchQueue, + _encode_batch_lengths, + _even_chunks, + _worker_encode_lengths, ) _MOCK_TARGET = "inference_endpoint.async_utils.services.metrics_aggregator.token_metrics.AutoTokenizer" class _FakeTokenizer: - """Deterministic tokenizer that splits on whitespace.""" + """Deterministic tokenizer that splits on whitespace. + + Has no ``backend_tokenizer``, so BatchTokenizer keeps the batch path + in-process (no subprocess shards) and counts via ``tokenize`` per text. + """ - def __init__(self, load_delay: float = 0.1): - # Simulate the blocking cost of from_pretrained so that - # pre-initialization in __init__ saturates all worker threads. + def __init__(self, load_delay: float = 0.0): time.sleep(load_delay) def tokenize(self, text: str) -> list[str]: @@ -45,65 +55,81 @@ def from_pretrained(cls, name: str, **kwargs: object) -> "_FakeTokenizer": return cls() -@pytest.mark.unit -class TestTokenizePool: - def test_token_count_returns_int(self): - with patch(_MOCK_TARGET, _FakeTokenizer): - with TokenizePool("fake", n_workers=1) as pool: - count = pool.token_count("Hello world") - assert count == 2 +class _FakeProc: + """Stands in for a ProcessPoolExecutor shard; whitespace-counts its chunk.""" - def test_multiple_workers(self): - with patch(_MOCK_TARGET, _FakeTokenizer): - with TokenizePool("fake", n_workers=4) as pool: - results = [] - for i in range(10): - results.append(pool.token_count(f"Sentence number {i}")) - assert all(isinstance(r, int) and r > 0 for r in results) + def submit(self, _fn, chunk): + fut: Future = Future() + fut.set_result([len(t.split()) for t in chunk]) + return fut - def test_concurrent_calls_thread_safe(self): - with patch(_MOCK_TARGET, _FakeTokenizer): - with TokenizePool("fake", n_workers=2) as pool: - texts = [f"word{i} word{i+1}" for i in range(20)] + def shutdown(self, wait=False, cancel_futures=False): + pass - with ThreadPoolExecutor(max_workers=8) as executor: - futures = [executor.submit(pool.token_count, t) for t in texts] - results = [f.result() for f in futures] - assert len(results) == 20 - assert all(r == 2 for r in results) +class _BrokenProc: + """A shard whose work resolves to BrokenProcessPool (worker died).""" - def test_close_is_idempotent(self): + def submit(self, _fn, _chunk): + fut: Future = Future() + fut.set_exception(BrokenProcessPool("worker died")) + return fut + + def shutdown(self, wait=False, cancel_futures=False): + pass + + +@pytest.mark.unit +class TestBatchTokenizer: + @pytest.mark.asyncio + async def test_count_texts_async(self): with patch(_MOCK_TARGET, _FakeTokenizer): - pool = TokenizePool("fake", n_workers=1) - pool.close() - pool.close() # Should not raise + loop = asyncio.get_running_loop() + with BatchTokenizer("fake", n_workers=0, live_workers=2) as tok: + counts = await tok.count_texts_async(["Hello world foo", "a"], loop) + assert counts == [3, 1] - def test_use_after_close_raises(self): + @pytest.mark.asyncio + async def test_count_texts_async_empty(self): with patch(_MOCK_TARGET, _FakeTokenizer): - pool = TokenizePool("fake", n_workers=1) - pool.close() - with pytest.raises(RuntimeError, match="closed"): - pool.token_count("hello") + loop = asyncio.get_running_loop() + with BatchTokenizer("fake", n_workers=0, live_workers=2) as tok: + assert await tok.count_texts_async([], loop) == [] - def test_n_workers_zero_raises(self): - with pytest.raises(ValueError, match="n_workers"): - TokenizePool("fake", n_workers=0) + @pytest.mark.asyncio + async def test_count_texts_async_sharded(self): + """With shards present, chunks are reassembled in original order.""" + with patch(_MOCK_TARGET, _FakeTokenizer): + loop = asyncio.get_running_loop() + with BatchTokenizer("fake", n_workers=0, live_workers=2) as tok: + tok._procs = [_FakeProc(), _FakeProc()] + counts = await tok.count_texts_async(["a", "b b", "c c c", "d"], loop) + assert counts == [1, 2, 3, 1] @pytest.mark.asyncio - async def test_token_count_async(self): + async def test_count_texts_async_shard_failure_propagates(self): + """A dead shard surfaces as an error, not a silent in-process fallback.""" with patch(_MOCK_TARGET, _FakeTokenizer): loop = asyncio.get_running_loop() - with TokenizePool("fake", n_workers=1) as pool: - count = await pool.token_count_async("Hello world foo", loop) - assert count == 3 + with BatchTokenizer("fake", n_workers=0, live_workers=2) as tok: + tok._procs = [_BrokenProc()] + with pytest.raises(BrokenProcessPool): + await tok.count_texts_async(["a b"], loop) - def test_context_manager(self): + def test_close_is_idempotent(self): + with patch(_MOCK_TARGET, _FakeTokenizer): + tok = BatchTokenizer("fake", n_workers=0, live_workers=2) + tok.close() + tok.close() # must not raise + + @pytest.mark.asyncio + async def test_use_after_close_raises(self): with patch(_MOCK_TARGET, _FakeTokenizer): - with TokenizePool("fake", n_workers=1) as pool: - assert pool.token_count("a b c") == 3 + loop = asyncio.get_running_loop() + tok = BatchTokenizer("fake", n_workers=0, live_workers=2) + tok.close() with pytest.raises(RuntimeError, match="closed"): - pool.token_count("test") + await tok.count_texts_async(["hello"], loop) class _FakeTokenizerWithTemplate(_FakeTokenizer): @@ -131,19 +157,25 @@ def apply_chat_template( @pytest.mark.unit -class TestTokenizePoolMessageTokenization: - def test_token_count_message_subtracts_baseline(self): - """token_count_message returns full_tokens - baseline.""" +class TestBatchTokenizerMessageTokenization: + @pytest.mark.asyncio + async def test_token_count_message_subtracts_baseline(self): + """token_count_message_async returns full_tokens - baseline.""" with patch(_MOCK_TARGET, _FakeTokenizerWithTemplate): - with TokenizePool("fake", n_workers=1) as pool: - # "hello world" -> 2 content words + 2 wrapper = 4; baseline = 0 + 2 = 2; net = 2 - count = pool.token_count_message("hello world", None, None) + loop = asyncio.get_running_loop() + with BatchTokenizer("fake", n_workers=0, live_workers=2) as tok: + # "hello world" -> 2 content + 2 wrapper = 4; baseline = 0, prefix = 2 + count = await tok.token_count_message_async( + "hello world", None, None, loop + ) assert count == 2 - def test_token_count_message_includes_tool_calls(self): - """token_count_message includes tool-call JSON tokens.""" + @pytest.mark.asyncio + async def test_token_count_message_includes_tool_calls(self): + """Tool-call JSON tokens are included in the count.""" with patch(_MOCK_TARGET, _FakeTokenizerWithTemplate): - with TokenizePool("fake", n_workers=1) as pool: + loop = asyncio.get_running_loop() + with BatchTokenizer("fake", n_workers=0, live_workers=2) as tok: tool_calls = ( { "id": "c1", @@ -151,11 +183,14 @@ def test_token_count_message_includes_tool_calls(self): "function": {"name": "f", "arguments": "{}"}, }, ) - count_without = pool.token_count_message("hello", None, None) - count_with = pool.token_count_message("hello", None, tool_calls) - assert count_with > count_without + without = await tok.token_count_message_async("hello", None, None, loop) + with_calls = await tok.token_count_message_async( + "hello", None, tool_calls, loop + ) + assert with_calls > without - def test_token_count_message_fallback_on_exception(self): + @pytest.mark.asyncio + async def test_token_count_message_fallback_on_exception(self): """Falls back to whitespace split when apply_chat_template raises.""" class _BadTemplateTokenizer(_FakeTokenizer): @@ -163,7 +198,8 @@ def apply_chat_template(self, *args, **kwargs): raise ValueError("template does not support tool_calls") with patch(_MOCK_TARGET, _BadTemplateTokenizer): - with TokenizePool("fake", n_workers=1) as pool: + loop = asyncio.get_running_loop() + with BatchTokenizer("fake", n_workers=0, live_workers=2) as tok: tool_calls = ( { "id": "c1", @@ -171,17 +207,494 @@ def apply_chat_template(self, *args, **kwargs): "function": {"name": "f", "arguments": "{}"}, }, ) - # Should not raise; falls back to whitespace tokenizer - count = pool.token_count_message("hello world", None, tool_calls) + # Must not raise; falls back to whitespace tokenizer. + count = await tok.token_count_message_async( + "hello world", None, tool_calls, loop + ) assert count > 0 + +class _Encoding: + def __init__(self, n: int): + self.ids = list(range(n)) + + +class _FastBackend: + """Raw-tokenizers backend stub with the fast batch entry point.""" + + def encode_batch_fast(self, texts, add_special_tokens=False): + return [_Encoding(len(t.split())) for t in texts] + + +class _SlowBackend: + """Raw-tokenizers backend stub without encode_batch_fast.""" + + def encode_batch(self, texts, add_special_tokens=False): + return [_Encoding(len(t.split())) for t in texts] + + +@pytest.mark.unit +class TestEncodeHelpers: + def test_encode_batch_lengths_prefers_fast(self): + assert _encode_batch_lengths(_FastBackend(), ["a b", "c"]) == [2, 1] + + def test_encode_batch_lengths_falls_back_to_encode_batch(self): + assert _encode_batch_lengths(_SlowBackend(), ["a b c", "d"]) == [3, 1] + + def test_worker_encode_lengths_raises_without_backend(self, monkeypatch): + monkeypatch.setattr(token_metrics_module, "_WORKER_BACKEND", None) + with pytest.raises(RuntimeError, match="backend unavailable"): + _worker_encode_lengths(["a"]) + + def test_worker_encode_lengths_uses_backend(self, monkeypatch): + monkeypatch.setattr(token_metrics_module, "_WORKER_BACKEND", _FastBackend()) + assert _worker_encode_lengths(["a b", "c d e"]) == [2, 3] + + +class _FakeTokenizerWithBackend(_FakeTokenizer): + """Fast-backend fake: lets ``_setup_shards`` proceed past the backend guard.""" + + backend_tokenizer = _FastBackend() + + +class _SpawnlessExecutor: + """Stands in for ProcessPoolExecutor: records ctor args, instant warmup.""" + + def __init__(self, max_workers, mp_context=None, initializer=None, initargs=()): + self.initargs = initargs + + def submit(self, fn, *args): + fut: Future = Future() + fut.set_result(True) + return fut + + def shutdown(self, wait=False, cancel_futures=False): + pass + + +@pytest.mark.unit +class TestSetupShardsDecisions: + """Pins the BatchTokenizer(n_workers=...) shard contract: -1 auto / N + clamped / 0 explicit in-process (auto-sized in production — the CLI's + --tokenizer-workers maps to the live thread lane, not to shards). + + An environment that cannot shard is a startup error — never a silent + in-process fallback. + """ + + def _make(self, monkeypatch, cpus, n_workers, executor=_SpawnlessExecutor): + monkeypatch.setattr(token_metrics_module, "ProcessPoolExecutor", executor) + # Patch the probe + the restore so no real affinity syscalls run. + monkeypatch.setattr( + token_metrics_module, + "expand_to_all_online_cpus", + lambda: set(range(cpus)), + ) + monkeypatch.setattr( + token_metrics_module.os, "sched_getaffinity", lambda pid: {0, 1} + ) + self.restored: list[set] = [] + monkeypatch.setattr( + token_metrics_module.os, + "sched_setaffinity", + lambda pid, mask: self.restored.append(set(mask)), + ) + with patch(_MOCK_TARGET, _FakeTokenizerWithBackend): + return BatchTokenizer("fake", n_workers=n_workers, live_workers=2) + + @pytest.mark.parametrize( + "cpus, n_workers, expected_shards", + [ + (16, -1, 2), # auto: one shard per 8-core block + (10, -1, 1), # auto: always at least one shard + (6, -1, 1), # auto: even below one full block + (48, 3, 3), # explicit count under capacity + (16, 10, 2), # explicit count clamped to capacity + (16, 1, 1), # explicit single shard honored + (16, 0, 0), # 0 = explicit in-process mode + ], + ) + def test_shard_count(self, monkeypatch, cpus, n_workers, expected_shards): + with self._make(monkeypatch, cpus, n_workers) as tok: + assert len(tok._procs) == expected_shards + + def test_blocks_are_disjoint_consecutive_core_sets(self, monkeypatch): + with self._make(monkeypatch, 16, -1) as tok: + blocks = [set(ex.initargs[1]) for ex in tok._procs] + assert blocks == [set(range(0, 8)), set(range(8, 16))] + + def test_probe_restores_the_inherited_mask(self, monkeypatch): + """The aggregator keeps the mask its parent gave it; only the probe + widens, and only the shard children pin elsewhere.""" + with self._make(monkeypatch, 16, -1): + pass + assert self.restored == [{0, 1}] + + def test_no_fast_backend_is_a_startup_error(self, monkeypatch): + monkeypatch.setattr( + token_metrics_module, "ProcessPoolExecutor", _SpawnlessExecutor + ) + with patch(_MOCK_TARGET, _FakeTokenizer): # no backend_tokenizer + with pytest.raises(RuntimeError, match="fast"): + BatchTokenizer("fake", live_workers=2) + + def test_affinity_unavailable_shards_unpinned(self, monkeypatch): + """No affinity API (e.g. macOS): shard from the CPU count, unpinned.""" + monkeypatch.setattr( + token_metrics_module, "ProcessPoolExecutor", _SpawnlessExecutor + ) + + def _unsupported(): + raise RuntimeError("affinity requires Linux") + + monkeypatch.setattr( + token_metrics_module, "expand_to_all_online_cpus", _unsupported + ) + + def _raise(pid): + raise AttributeError("no sched_getaffinity") + + monkeypatch.setattr(token_metrics_module.os, "sched_getaffinity", _raise) + monkeypatch.setattr(token_metrics_module.os, "cpu_count", lambda: 16) + with patch(_MOCK_TARGET, _FakeTokenizerWithBackend): + with BatchTokenizer("fake", live_workers=2) as tok: + assert len(tok._procs) == 2 + + def test_warmup_failure_is_a_startup_error(self, monkeypatch): + class _BrokenWarmup(_SpawnlessExecutor): + def submit(self, fn, *args): + fut: Future = Future() + fut.set_exception(RuntimeError("spawn died")) + return fut + + with pytest.raises(RuntimeError, match="warmup"): + self._make(monkeypatch, 16, -1, executor=_BrokenWarmup) + + +class _RecordingProc(_FakeProc): + """_FakeProc that records the chunks submitted to it.""" + + def __init__(self): + self.chunks = [] + + def submit(self, _fn, chunk): + self.chunks.append(list(chunk)) + return super().submit(_fn, chunk) + + +@pytest.mark.unit +class TestLiveLane: @pytest.mark.asyncio - async def test_token_count_message_async(self): - """token_count_message_async returns count without blocking event loop.""" - with patch(_MOCK_TARGET, _FakeTokenizerWithTemplate): + async def test_live_never_touches_the_shard_pool(self): + """Mid-run flushes run in-process; the shards are drain-only.""" + with patch(_MOCK_TARGET, _FakeTokenizer): loop = asyncio.get_running_loop() - with TokenizePool("fake", n_workers=1) as pool: - count = await pool.token_count_message_async( - "hello world", None, None, loop - ) - assert count == 2 + with BatchTokenizer("fake", n_workers=0, live_workers=1) as tok: + procs = [_RecordingProc(), _RecordingProc(), _RecordingProc()] + tok._procs = procs + counts = await tok.count_texts_async(["a b", "c"], loop, live=True) + assert counts == [2, 1] + assert all(p.chunks == [] for p in procs) + + @pytest.mark.asyncio + async def test_drain_uses_every_shard(self): + with patch(_MOCK_TARGET, _FakeTokenizer): + loop = asyncio.get_running_loop() + with BatchTokenizer("fake", n_workers=0, live_workers=1) as tok: + procs = [_RecordingProc(), _RecordingProc()] + tok._procs = procs + await tok.count_texts_async(["a", "b", "c", "d"], loop) + assert all(p.chunks for p in procs) + + +@pytest.mark.unit +@pytest.mark.asyncio +class TestQueueLiveLoop: + async def test_start_live_flushes_periodically(self): + loop = asyncio.get_running_loop() + queue = TokenBatchQueue(_CapturingTokenizer(), loop) + recorded: list[int] = [] + queue.enqueue_text("a b c", recorded.append) + queue.start_live(0.01) + queue.start_live(0.01) # idempotent + await asyncio.sleep(0.05) + assert recorded == [3] + assert queue.pending == 0 + await queue.flush_remaining(timeout=1.0) + + async def test_live_loop_survives_tokenizer_failure(self): + class _FailingLive(_CapturingTokenizer): + async def count_texts_async(self, texts, _loop, live=False): + if live: + raise RuntimeError("live lane boom") + return await super().count_texts_async(texts, _loop) + + loop = asyncio.get_running_loop() + queue = TokenBatchQueue(_FailingLive(), loop) + recorded: list[int] = [] + queue.enqueue_text("a b", recorded.append) + queue.start_live(0.01) + await asyncio.sleep(0.05) + assert recorded == [] + assert queue.pending == 1, "failed live flush must keep items pending" + assert queue._live_task is not None and not queue._live_task.done() + # The end-of-run drain (full pool) still recovers the items. + assert await queue.flush_remaining(timeout=1.0) == 0 + assert recorded == [2] + + async def test_flush_remaining_stops_live_loop(self): + loop = asyncio.get_running_loop() + queue = TokenBatchQueue(_CapturingTokenizer(), loop) + queue.start_live(0.01) + task = queue._live_task + await queue.flush_remaining(timeout=1.0) + assert queue._live_task is None + assert task is not None and task.cancelled() + + +@pytest.mark.unit +class TestRayonCaps: + def test_ctor_caps_rayon_to_live_workers(self, monkeypatch): + monkeypatch.delenv("RAYON_NUM_THREADS", raising=False) + with patch(_MOCK_TARGET, _FakeTokenizer): + with BatchTokenizer("fake", n_workers=0, live_workers=3): + assert token_metrics_module.os.environ["RAYON_NUM_THREADS"] == "3" + + def test_ctor_respects_operator_exported_cap(self, monkeypatch): + monkeypatch.setenv("RAYON_NUM_THREADS", "7") + with patch(_MOCK_TARGET, _FakeTokenizer): + with BatchTokenizer("fake", n_workers=0, live_workers=3): + assert token_metrics_module.os.environ["RAYON_NUM_THREADS"] == "7" + + def test_init_worker_overrides_inherited_cap_with_block_size(self, monkeypatch): + """Spawn children inherit the parent's live cap; each shard must + re-size its rayon pool to its own core block.""" + monkeypatch.setenv("RAYON_NUM_THREADS", "2") + + def _no_affinity(pid, mask): + raise AttributeError("no sched_setaffinity") + + monkeypatch.setattr(token_metrics_module.os, "sched_setaffinity", _no_affinity) + with patch(_MOCK_TARGET, _FakeTokenizer): + token_metrics_module._init_worker("fake", [0, 1, 2, 3, 4, 5, 6, 7]) + assert token_metrics_module.os.environ["RAYON_NUM_THREADS"] == "8" + + +@pytest.mark.unit +@pytest.mark.asyncio +class TestLiveFlushBounds: + async def test_live_flush_takes_at_most_the_cap(self, monkeypatch): + monkeypatch.setattr(token_metrics_module, "_LIVE_FLUSH_MAX_ITEMS", 3) + loop = asyncio.get_running_loop() + queue = TokenBatchQueue(_CapturingTokenizer(), loop) + recorded: list[int] = [] + for i in range(5): + queue.enqueue_text(f"t{i}", recorded.append) + await queue.flush(live=True) + assert len(recorded) == 3 + assert queue.pending == 2 + # The drain takes everything that remains. + assert await queue.flush_remaining(timeout=1.0) == 0 + assert len(recorded) == 5 + + async def test_live_cancellation_requeues_texts(self): + class _Hanging(_CapturingTokenizer): + async def count_texts_async(self, texts, _loop, live=False): + if live: + await asyncio.sleep(30) + return await super().count_texts_async(texts, _loop) + + loop = asyncio.get_running_loop() + queue = TokenBatchQueue(_Hanging(), loop) + recorded: list[int] = [] + queue.enqueue_text("a b", recorded.append) + task = loop.create_task(queue.flush(live=True)) + await asyncio.sleep(0.01) + task.cancel() + with pytest.raises(asyncio.CancelledError): + await asyncio.wait_for(task, timeout=1.0) + assert queue.pending == 1 + assert len(queue._text) == 1, "cancelled live flush must give items back" + assert await queue.flush_remaining(timeout=1.0) == 0 + assert recorded == [2] + + async def test_live_cancellation_requeues_messages_too(self): + """A cancel landing in the text encode must give back BOTH kinds.""" + + class _Hanging(_CapturingTokenizer): + async def count_texts_async(self, texts, _loop, live=False): + if live: + await asyncio.sleep(30) + return await super().count_texts_async(texts, _loop) + + loop = asyncio.get_running_loop() + queue = TokenBatchQueue(_Hanging(), loop) + recorded: list[int] = [] + queue.enqueue_text("a b", recorded.append) + queue.enqueue_message(("hello world", None, None), recorded.append) + task = loop.create_task(queue.flush(live=True)) + await asyncio.sleep(0.01) + task.cancel() + with pytest.raises(asyncio.CancelledError): + await asyncio.wait_for(task, timeout=1.0) + assert queue.pending == 2 + assert len(queue._text) == 1 + assert len(queue._msg) == 1, "detached messages must be re-queued" + assert await queue.flush_remaining(timeout=1.0) == 0 + assert sorted(recorded) == [2, 2] + + async def test_live_message_failure_requeues_message(self): + class _MsgFailing(_CapturingTokenizer): + async def token_count_message_async(self, *args): + raise RuntimeError("template boom") + + loop = asyncio.get_running_loop() + queue = TokenBatchQueue(_MsgFailing(), loop) + recorded: list[int] = [] + queue.enqueue_message(("hello world", None, None), recorded.append) + with pytest.raises(RuntimeError, match="template boom"): + await queue.flush(live=True) + assert queue.pending == 1 + assert len(queue._msg) == 1, "failed live message must be re-queued" + + +@pytest.mark.unit +class TestEvenChunks: + def test_splits_into_near_equal_chunks(self): + assert _even_chunks(["a", "b", "c", "d", "e"], 2) == [ + ["a", "b", "c"], + ["d", "e"], + ] + + def test_single_chunk_when_n_le_one(self): + assert _even_chunks(["a", "b"], 1) == [["a", "b"]] + + def test_single_item_input(self): + assert _even_chunks(["only"], 4) == [["only"]] + + def test_preserves_order_and_bounds_chunk_count(self): + items = [str(i) for i in range(10)] + chunks = _even_chunks(items, 3) + assert [x for c in chunks for x in c] == items + assert len(chunks) <= 3 + + +class _CapturingTokenizer: + """Minimal tokenizer stub for queue tests: whitespace counts, no procs.""" + + async def count_texts_async(self, texts, _loop, live=False): + return [len(t.split()) for t in texts] + + async def token_count_message_async(self, content, reasoning, tool_calls, _loop): + parts = [p for p in (content, reasoning) if p] + return len(" ".join(parts).split()) + (len(tool_calls) if tool_calls else 0) + + +@pytest.mark.unit +@pytest.mark.asyncio +class TestTokenBatchQueue: + async def test_flush_records_text_via_callback(self): + loop = asyncio.get_running_loop() + queue = TokenBatchQueue(_CapturingTokenizer(), loop) + recorded: list[int] = [] + queue.enqueue_text("a b c", recorded.append) + queue.enqueue_text("d e", recorded.append) + assert queue.pending == 2 + await queue.flush() + assert sorted(recorded) == [2, 3] + assert queue.pending == 0 + + async def test_flush_records_message_via_callback(self): + loop = asyncio.get_running_loop() + queue = TokenBatchQueue(_CapturingTokenizer(), loop) + recorded: list[int] = [] + queue.enqueue_message(("hello world", None, None), recorded.append) + await queue.flush() + assert recorded == [2] + + async def test_flush_empty_is_noop(self): + loop = asyncio.get_running_loop() + queue = TokenBatchQueue(_CapturingTokenizer(), loop) + await queue.flush() + assert queue.pending == 0 + + async def test_flush_remaining_clean_returns_zero(self): + loop = asyncio.get_running_loop() + queue = TokenBatchQueue(_CapturingTokenizer(), loop) + recorded: list[int] = [] + queue.enqueue_text("a b", recorded.append) + assert await queue.flush_remaining(timeout=5.0) == 0 + assert recorded == [2] + + async def test_flush_remaining_timeout_reports_pending(self): + """A tokenizer slower than the budget leaves items pending.""" + + class _BlockingTokenizer: + async def count_texts_async(self, texts, _loop, live=False): + await asyncio.sleep(10.0) + return [0] * len(texts) + + async def token_count_message_async(self, *args): + return 0 + + loop = asyncio.get_running_loop() + queue = TokenBatchQueue(_BlockingTokenizer(), loop) + recorded: list[int] = [] + queue.enqueue_text("never counted", recorded.append) + n_pending = await queue.flush_remaining(timeout=0.05) + assert n_pending == 1 + assert recorded == [] + + async def test_flush_remaining_failure_reports_pending(self): + """A tokenizer error leaves items pending and never raises.""" + + class _FailingTokenizer: + async def count_texts_async(self, texts, _loop, live=False): + raise RuntimeError("tokenizer boom") + + async def token_count_message_async(self, *args): + raise RuntimeError("tokenizer boom") + + loop = asyncio.get_running_loop() + queue = TokenBatchQueue(_FailingTokenizer(), loop) + recorded: list[int] = [] + queue.enqueue_text("x y", recorded.append) + assert await queue.flush_remaining(timeout=5.0) == 1 + assert recorded == [] + + async def test_flush_text_failure_does_not_drop_message_items(self): + """The message phase runs (and records) even when the text batch fails.""" + + class _TextFailingTokenizer: + async def count_texts_async(self, texts, _loop, live=False): + raise RuntimeError("text shard died") + + async def token_count_message_async( + self, content, reasoning, tool_calls, _loop + ): + return len(content.split()) + + loop = asyncio.get_running_loop() + queue = TokenBatchQueue(_TextFailingTokenizer(), loop) + recorded: list[int] = [] + queue.enqueue_text("never counted", recorded.append) + queue.enqueue_message(("hello world", None, None), recorded.append) + with pytest.raises(RuntimeError, match="text shard died"): + await queue.flush() + assert recorded == [2], "message item must survive the text failure" + assert queue.pending == 1, "only the text item remains pending" + + async def test_flush_recorder_failure_does_not_poison_batch(self): + """One raising on_count is logged; the rest of the batch still records.""" + loop = asyncio.get_running_loop() + queue = TokenBatchQueue(_CapturingTokenizer(), loop) + recorded: list[int] = [] + + def bad_recorder(count: int) -> None: + raise ValueError("recorder bug") + + queue.enqueue_text("a b", bad_recorder) + queue.enqueue_text("c d e", recorded.append) + await queue.flush() + assert recorded == [3] + assert queue.pending == 0, "a raising recorder still counts as recorded" diff --git a/tests/unit/commands/test_benchmark.py b/tests/unit/commands/test_benchmark.py index 1c90554fb..e47def8f0 100644 --- a/tests/unit/commands/test_benchmark.py +++ b/tests/unit/commands/test_benchmark.py @@ -489,8 +489,7 @@ def test_defaults(self): assert cfg.warmup_timeout_s == 240.0 assert cfg.performance_timeout_s == 240.0 assert cfg.accuracy_timeout_s is None - assert cfg.metrics_drain_timeout_s == 60.0 - assert cfg.metrics_tokenizer_workers == 2 + assert cfg.metrics_drain_timeout_s == 300.0 @pytest.mark.unit @pytest.mark.parametrize( @@ -512,11 +511,6 @@ def test_metrics_drain_timeout_negative_rejected(self): with pytest.raises(ValidationError): DrainConfig(metrics_drain_timeout_s=-1.0) - @pytest.mark.unit - def test_metrics_tokenizer_workers_must_be_at_least_one(self): - with pytest.raises(ValidationError): - DrainConfig(metrics_tokenizer_workers=0) - @pytest.mark.unit def test_extra_fields_rejected(self): with pytest.raises(ValidationError): @@ -538,7 +532,6 @@ def test_yaml_roundtrip(self, tmp_path): performance_timeout_s: 30.0 accuracy_timeout_s: null metrics_drain_timeout_s: 300.0 - metrics_tokenizer_workers: 8 """ config_file = tmp_path / "drain.yaml" config_file.write_text(yaml_content) @@ -548,7 +541,6 @@ def test_yaml_roundtrip(self, tmp_path): assert drain.performance_timeout_s == 30.0 assert drain.accuracy_timeout_s is None assert drain.metrics_drain_timeout_s == 300.0 - assert drain.metrics_tokenizer_workers == 8 class TestAggregatorArgs: @@ -641,17 +633,13 @@ async def _capture_launch(service_configs, *, timeout): @pytest.mark.unit @pytest.mark.asyncio - @pytest.mark.parametrize("workers, expected_flag", [(4, "4"), (8, "8"), (2, "2")]) - async def test_tokenizer_workers_forwarded_to_aggregator_args( - self, tmp_path, workers, expected_flag - ): - config = OfflineConfig( - **_OFFLINE_KWARGS, - settings=OfflineSettings( - drain=DrainConfig(metrics_tokenizer_workers=workers) - ), - ) + async def test_tokenizer_and_workers_forwarded_from_schema(self, tmp_path): + """The benchmark forwards --tokenizer and --tokenizer-workers; the + workers value comes from the schema default + (drain.metrics_tokenizer_workers), the single source of truth.""" + config = OfflineConfig(**_OFFLINE_KWARGS, settings=OfflineSettings()) ctx = self._make_ctx(config, tmp_path) + ctx.tokenizer_name = "gpt2" captured: list = [] @@ -689,9 +677,11 @@ async def _capture_launch(service_configs, *, timeout): aggregator_cfg = next(c for c in captured if "metrics_aggregator" in c.module) args = aggregator_cfg.args - assert "--tokenizer-workers" in args + idx = args.index("--tokenizer") + assert args[idx + 1] == "gpt2" idx = args.index("--tokenizer-workers") - assert args[idx + 1] == expected_flag + expected = str(config.settings.drain.metrics_tokenizer_workers) + assert args[idx + 1] == expected class TestBuildPhases: diff --git a/tests/unit/endpoint_client/test_cpu_affinity.py b/tests/unit/endpoint_client/test_cpu_affinity.py index 52ef724e2..7d100be9d 100644 --- a/tests/unit/endpoint_client/test_cpu_affinity.py +++ b/tests/unit/endpoint_client/test_cpu_affinity.py @@ -6,6 +6,7 @@ from inference_endpoint.endpoint_client.cpu_affinity import ( AffinityPlan, compute_affinity_plan, + expand_to_all_online_cpus, get_all_online_cpus, pin_loadgen, set_cpu_affinity, @@ -146,3 +147,30 @@ def test_all_methods_fail_returns_empty( """Test that empty set is returned when all methods fail.""" cpus = get_all_online_cpus() assert cpus == set() + + +class TestExpandToAllOnlineCpus: + @patch("os.sched_getaffinity") + @patch("os.sched_setaffinity") + @patch("pathlib.Path.read_text") + def test_expands_inherited_mask_to_online(self, mock_read, mock_set, mock_get): + """The full sysfs online set is requested; the effective mask is returned.""" + mock_read.return_value = "0-7\n" + mock_get.return_value = {0, 1, 2, 3, 4, 5, 6, 7} + + cpus = expand_to_all_online_cpus() + + mock_set.assert_called_once_with(0, {0, 1, 2, 3, 4, 5, 6, 7}) + assert cpus == {0, 1, 2, 3, 4, 5, 6, 7} + + @patch("os.sched_getaffinity") + @patch("os.sched_setaffinity", side_effect=OSError("cpuset denies")) + @patch("pathlib.Path.read_text") + def test_setaffinity_failure_returns_current_mask( + self, mock_read, mock_set, mock_get + ): + """A denied expansion is non-fatal: the current mask is reported.""" + mock_read.return_value = "0-7\n" + mock_get.return_value = {0, 1} + + assert expand_to_all_online_cpus() == {0, 1} diff --git a/uv.lock b/uv.lock index bfdb3b236..984581b6b 100644 --- a/uv.lock +++ b/uv.lock @@ -29,7 +29,7 @@ wheels = [ [[package]] name = "aiohttp" -version = "3.14.0" +version = "3.14.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohappyeyeballs", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, @@ -41,42 +41,42 @@ dependencies = [ { name = "typing-extensions", marker = "(python_full_version < '3.13' and platform_machine == 'arm64' and sys_platform == 'darwin') or (python_full_version < '3.13' and platform_machine == 'x86_64' and sys_platform == 'darwin') or (python_full_version < '3.13' and platform_machine == 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.13' and platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "yarl", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'x86_64' and sys_platform == 'darwin') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ee/ab/93ce242f899b68c51b0578c027aafa791ab3614cb9345fa5d37b5f5c8e3e/aiohttp-3.14.0.tar.gz", hash = "sha256:2882de819734c715fd1b9c11c97e09fa020d14438203d1d354d8ed1702791c9b", size = 7940674, upload-time = "2026-06-01T19:41:02.763Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/89/97/2b6889bfb6b6847520d50d95eb8c4307a45e28aaca39faf4a9454b3d1b2f/aiohttp-3.14.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:b29518c9c2ec7e373e68259206a137c7f4f5439c58baaec4b5ab3ab799850a4e", size = 750194, upload-time = "2026-06-01T19:37:48.164Z" }, - { url = "https://files.pythonhosted.org/packages/21/e2/62634b7fff918ed98c3c6b2f0e70d520f7f28846cb412d451b04354c6459/aiohttp-3.14.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:dbec68ce61b64cb73cab4d33df9433427b1713c8bcccb181dce695c1b6f8e87c", size = 506966, upload-time = "2026-06-01T19:37:50.014Z" }, - { url = "https://files.pythonhosted.org/packages/dd/fb/5ce075150828c797a5106f1c2fb26034e709d4289b9d2bf8b07f1e59fac6/aiohttp-3.14.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3cdf534aa455593e589302990c5097aa5c92c06c4262a20da22934f9186a5fff", size = 507527, upload-time = "2026-06-01T19:37:51.96Z" }, - { url = "https://files.pythonhosted.org/packages/01/d5/405a0ae4e6b081754a3609c1c97c63a950e000a2def16046f1e736933a0e/aiohttp-3.14.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:cb6c657104393b5fbff01a5f59b2023db74058a8077d94475d6c25d03882a108", size = 1762420, upload-time = "2026-06-01T19:37:53.839Z" }, - { url = "https://files.pythonhosted.org/packages/19/d8/51de5c6b971c27bb1ef620293b8d1ca611ec78736b34b3f6ccf68e4c8785/aiohttp-3.14.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:78d6f9286a629ce52728430afe18f8ed2b6c39a1fddb3802d7244b9983910ad2", size = 1783112, upload-time = "2026-06-01T19:38:02.641Z" }, - { url = "https://files.pythonhosted.org/packages/bc/05/750a3265ca4dc54a460bd0cb1121a8f2ce9171fce4a135fb47ea7fd594d2/aiohttp-3.14.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:4d6a998191f5ebe3b8c28463ff72bc030250008b3193c402464efadd08b5ca02", size = 1723119, upload-time = "2026-06-01T19:38:06.713Z" }, - { url = "https://files.pythonhosted.org/packages/a8/fb/05d9214c975f23225a8cd5c439325e338c7c377b315480ef3871db51f54e/aiohttp-3.14.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:5ba10966d4f03dd96a14365be4b8e37c327c76f11c3ca867116966cdd9f98066", size = 1760193, upload-time = "2026-06-01T19:38:17.624Z" }, - { url = "https://files.pythonhosted.org/packages/11/41/cc2d2cfbfbdc3126ba258f3cd27d1ac8a33492ae3c35a4583ee21f0ba7f1/aiohttp-3.14.0-cp313-cp313-ios_13_0_arm64_iphoneos.whl", hash = "sha256:3366751d68d237c621264233a32f3078bbc21b7904ab90a77e03d21390c742c6", size = 481670, upload-time = "2026-06-01T19:38:29.836Z" }, - { url = "https://files.pythonhosted.org/packages/3c/07/381f4023c3b08cb616e520f566d8c58957abad54e56441d41fe67cfb0195/aiohttp-3.14.0-cp313-cp313-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:57ea07d28695a7a40304d42251892a8df765e5588c10ee32afeddcd5df33c0a2", size = 487591, upload-time = "2026-06-01T19:38:31.704Z" }, - { url = "https://files.pythonhosted.org/packages/fb/4d/4506fdb7a022bdf70011a3bbb4ca00c5c570026ef6a3c5bd7bc70c39089c/aiohttp-3.14.0-cp313-cp313-ios_13_0_x86_64_iphonesimulator.whl", hash = "sha256:076cb014191ae2e65d949e1ad01f1dcfe33e32789b5172510f3e79c79fc04d50", size = 496503, upload-time = "2026-06-01T19:38:33.6Z" }, - { url = "https://files.pythonhosted.org/packages/ef/7d/c814111e04894a45d9e2defc94443879a6f118d9633d5fedfe6e2e8af5f0/aiohttp-3.14.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:2f3fc37054564dee64a855b5b092d87ec35dcddfaabf7dacb1c8a2b1f83dc0a9", size = 745870, upload-time = "2026-06-01T19:38:36.013Z" }, - { url = "https://files.pythonhosted.org/packages/c6/ee/80eee0efddfe187e7cd05027086b7ce1c0e492e82a4eda58f5c5543a44a0/aiohttp-3.14.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:8fcaef74d2ab0f607d7ff85a0d15e21bb5a258c4a58df1908396eb50d7f4ed3c", size = 505588, upload-time = "2026-06-01T19:38:38.282Z" }, - { url = "https://files.pythonhosted.org/packages/d6/f8/0f28f04eef75d52fc9c715dde7ce9c0abb810fd20cfeb0fea7afd2ab1e98/aiohttp-3.14.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:e4c01b0bfc6209590960e68eac083cd22d5d87c21f974dd6208cafa5d3542bc8", size = 504492, upload-time = "2026-06-01T19:38:40.611Z" }, - { url = "https://files.pythonhosted.org/packages/ff/db/44c755232085545065c94378dfce38641b1aee647f4939fcd32f5b32e719/aiohttp-3.14.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f12eb7896e81caf403a2b18c9406426f1207361e7239c057ab29c076d4257e83", size = 1752111, upload-time = "2026-06-01T19:38:42.682Z" }, - { url = "https://files.pythonhosted.org/packages/c5/a3/3800dbd095cb2bb165a7ea5d94d790914677e27f45638c7d80e3f34c8945/aiohttp-3.14.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:26d9224c6dd7f5c749aba4f61315a894601448b28d94d12f4dea0903e26d2096", size = 1777241, upload-time = "2026-06-01T19:38:52.04Z" }, - { url = "https://files.pythonhosted.org/packages/b4/3d/dc94df99ed1511fdf28314f722643ed334112643cab00223577085e788c4/aiohttp-3.14.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:23e8314e7aed8576fbe33314d218bd81447a3adbc91dc36f1163bf583cd3084c", size = 1714864, upload-time = "2026-06-01T19:38:56.788Z" }, - { url = "https://files.pythonhosted.org/packages/fa/10/ab28818262f4d26bdb47ed5f1fc7999b69e2fc6e0370b02d0f49011f45ea/aiohttp-3.14.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:666c7c5036df57b693026398b69b41874a1931ac5b3485fd910e57bfac253869", size = 1754516, upload-time = "2026-06-01T19:39:08.788Z" }, - { url = "https://files.pythonhosted.org/packages/1a/fe/6edbf5d39bf29322b6816365b17ed8ede4dace164a3aea1abcd30110eb78/aiohttp-3.14.0-cp314-cp314-ios_13_0_arm64_iphoneos.whl", hash = "sha256:70ea956f6cc4a37620966b56c2e205d88ca3e6d85ec063277e414b1035cddad3", size = 483329, upload-time = "2026-06-01T19:39:22.607Z" }, - { url = "https://files.pythonhosted.org/packages/1b/5a/fae531bdbc6456fb6241f46b7b81e4d8a0dd3fc09118a0055dc7141ac1ec/aiohttp-3.14.0-cp314-cp314-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:ea3b9806c89f61da22fddf1f12dd524fb368e5e28f1261fbdafe5c3cd8ce893b", size = 489502, upload-time = "2026-06-01T19:39:24.881Z" }, - { url = "https://files.pythonhosted.org/packages/36/f4/48a7b0414db7fed77a03d5dde34508c026afd83510ab6bca08c313855776/aiohttp-3.14.0-cp314-cp314-ios_13_0_x86_64_iphonesimulator.whl", hash = "sha256:a071be341c2bd9b0188e62d173509f024e0a35b1c342c53c50f8daaeda8c3bd8", size = 497357, upload-time = "2026-06-01T19:39:27.197Z" }, - { url = "https://files.pythonhosted.org/packages/75/75/e85a13a370acc007fca5feb1fd1b88ac2d8426e6dadd625479b7cadd55a3/aiohttp-3.14.0-cp314-cp314-macosx_10_15_universal2.whl", hash = "sha256:198cfe61bf253b19da1fb3e0fa122249dc4f14c12709493fed8054aa0411cc76", size = 750898, upload-time = "2026-06-01T19:39:29.563Z" }, - { url = "https://files.pythonhosted.org/packages/9e/e4/3d637f800c724eff0e2bed64df72557444482366fd0a35b0cec0e6968f6c/aiohttp-3.14.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:9dc203d6ce6b9106d54e2a93f41dfdfebfbca2d99962ba503bfd3e5921a6549e", size = 506986, upload-time = "2026-06-01T19:39:31.872Z" }, - { url = "https://files.pythonhosted.org/packages/1d/df/35161f3598bf7501d2b2a805b41ab4f45a2e34150c421bcb4ef8c0d281a7/aiohttp-3.14.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:9e19d17ab02bf16832a2c8c0d55a486792c5b1645665652ee9531aebcc30cb72", size = 508033, upload-time = "2026-06-01T19:39:34.137Z" }, - { url = "https://files.pythonhosted.org/packages/e5/39/b36e5d3d31e850fb4691dd3e941684ac490a2559249f6fa634b6b0fdf020/aiohttp-3.14.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d925fba0c14d5b498a8028b0107beebdfd16c5d48d702ff54f879cb017aaaca3", size = 1746213, upload-time = "2026-06-01T19:39:36.654Z" }, - { url = "https://files.pythonhosted.org/packages/3a/05/27df32c844b2156e1675a8d8ec22d963e3c8ba469ed7ceb1863320c7b521/aiohttp-3.14.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ff82be7f1ef73634cb77890a770743239bc3d487b848669be1c599889336dc0a", size = 1751659, upload-time = "2026-06-01T19:39:46.398Z" }, - { url = "https://files.pythonhosted.org/packages/66/e3/53c67097e8a5ce98625e91e3fa7f43c9c6940de680345d03b3509a72a078/aiohttp-3.14.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:edc01ea4e1ec5a1649a28866262bf24195889ff7b27bdd947029a6086741de9b", size = 1710090, upload-time = "2026-06-01T19:39:51.392Z" }, - { url = "https://files.pythonhosted.org/packages/b8/69/155c4ef3aec96417d47024800472b33b16c5d8a665371dcd044c2afdf25d/aiohttp-3.14.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:26b6d79aa54cb4ed50cc7d41ed14e99e0f1fc8e7c2d42f2e05b37aea897b2b52", size = 1733716, upload-time = "2026-06-01T19:40:03.631Z" }, - { url = "https://files.pythonhosted.org/packages/12/34/6180103ce9aabc8ebff3f7bb55a1228ffe60f61042823031d9692cb7b101/aiohttp-3.14.0-cp314-cp314t-macosx_10_15_universal2.whl", hash = "sha256:6aa1a40f9cbb3da9f80714c5966b8946c21e6a2530d809b9498b33161e3c8733", size = 787878, upload-time = "2026-06-01T19:40:13.401Z" }, - { url = "https://files.pythonhosted.org/packages/92/e9/08954a40e8b7baa3d8beadd2b074b186e9b1e9c8ddabc288678a6265de50/aiohttp-3.14.0-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:b62af5a8cc96a194eaa01a9ed7b34a3ffa58d3d8daaa1a0d7a749353ad12d228", size = 524400, upload-time = "2026-06-01T19:40:15.972Z" }, - { url = "https://files.pythonhosted.org/packages/08/6a/b5965a634ac4d5ba99a463314cf4ab214ca073fcdc38a15e0294273701fc/aiohttp-3.14.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:6eb63b1417efaf7d1002a6ad034a40d44376afcc16508a57f8e74b49ad26a095", size = 527904, upload-time = "2026-06-01T19:40:18.28Z" }, - { url = "https://files.pythonhosted.org/packages/06/b4/932bcdd850c354d9bcca30f360e475d7852e30413fbbd44b182782ed5432/aiohttp-3.14.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c20b9ad156a79eb97be5cf9e069eec01d2f0dc8472ffbd75299a8b2d4c2cbbde", size = 1912162, upload-time = "2026-06-01T19:40:20.825Z" }, - { url = "https://files.pythonhosted.org/packages/d0/1c/a57de71a4508c93a830b77c28af3d08cd97f606dedfc6b94275347744508/aiohttp-3.14.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:145262119b07d7f95abc1839add35ba2bfc84551d4b4660ca11542c0b215455b", size = 1868606, upload-time = "2026-06-01T19:40:31.843Z" }, - { url = "https://files.pythonhosted.org/packages/35/1e/c237923232c7da7f0392ea25d89fc5e60c0e93f685f4ebca8e7bcdd5271c/aiohttp-3.14.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:2cc736a9c9fc2bc4dd71fd404815741b6573df27c3f985948ec4076989ac57de", size = 1834090, upload-time = "2026-06-01T19:40:37.733Z" }, - { url = "https://files.pythonhosted.org/packages/cc/bc/2aaab2f85cadb26ea59c091fa2b8e370d625154b5c14b478f1b489d07551/aiohttp-3.14.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:6199707cc40e0e9cd39c36fbc97bec416c704e1d0ddce03412bb3b3e6a90ccd0", size = 1832281, upload-time = "2026-06-01T19:40:52.303Z" }, +sdist = { url = "https://files.pythonhosted.org/packages/82/78/8ea7308cac6934de8c74a14f3d5f65d1c89287426688be79538d0e5c013d/aiohttp-3.14.1.tar.gz", hash = "sha256:307f2cff90a764d329e77040603fa032db89c5c24fdad50c4c15334cba744035", size = 7955794, upload-time = "2026-06-07T21:09:35.529Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1d/21/151624b51cd92553d95424daf4bf19f19ce9be9002d19253e7e7ce67197b/aiohttp-3.14.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:d35143e27778b4bb0fb189562d7f275bff79c62ab8e98459717c0ea617ff2480", size = 757402, upload-time = "2026-06-07T21:06:40.311Z" }, + { url = "https://files.pythonhosted.org/packages/c2/82/280619e0bd7bf2454987e19282616e84762255dd9c8468f62382e8c191f1/aiohttp-3.14.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:bcfb80a2cc36fba2534e5e5b5264dc7ae6fcd9bf15256da3e53d2f499e6fa29d", size = 512310, upload-time = "2026-06-07T21:06:42.207Z" }, + { url = "https://files.pythonhosted.org/packages/55/b2/2aac325583aaa1353045f96dffa586d8a34e8322e14a7ba49cffeb103ab4/aiohttp-3.14.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:27fd7c91e51729b4f7e1577865fa6d34c9adccbc39aabe9000285b48af9f0ec2", size = 512448, upload-time = "2026-06-07T21:06:43.813Z" }, + { url = "https://files.pythonhosted.org/packages/8a/72/a60607cb849faa8af8a356c9329ea2eb6f395d49e82cc82ccba1fd8deb8f/aiohttp-3.14.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:64c567bf9eaf664280116a8688f63016e6b32db2505908e2bdaca1b6438142f2", size = 1766854, upload-time = "2026-06-07T21:06:45.391Z" }, + { url = "https://files.pythonhosted.org/packages/20/9c/d445818389df371f56d141d881153ba23183c4735a03f7356ffb43f7757d/aiohttp-3.14.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3e6fc1a85fa7194a1a7d19f44e8609180f4a8eb5fa4c7ed8b4355f080fad235c", size = 1790278, upload-time = "2026-06-07T21:06:54.049Z" }, + { url = "https://files.pythonhosted.org/packages/dc/b4/4dac0038960427ba832f6609dfb4ea5437d7fd80c72001b9e48f834f428b/aiohttp-3.14.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:c6fa4dc7ad6f8109c70bb1499e589f76b0b792baf39f9b017eb92c8a81d0a199", size = 1728397, upload-time = "2026-06-07T21:06:57.777Z" }, + { url = "https://files.pythonhosted.org/packages/70/0a/e0075ce9ca0279ee1d4f0c0b85f54fea02ebc83c3007651a72bece658fec/aiohttp-3.14.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:6f71173be42d3241d428f760122febb748de0623f44308a6f120d0dd9ec572e3", size = 1767580, upload-time = "2026-06-07T21:07:07.873Z" }, + { url = "https://files.pythonhosted.org/packages/fe/22/a73ccbf9dbd6e26dda0b24d5fd5db7da92ee3383a79f47677ffb834c5c5b/aiohttp-3.14.1-cp313-cp313-ios_13_0_arm64_iphoneos.whl", hash = "sha256:915fbb7b41b115192259f8c9ae58f3ddc444d2b5579917270211858e606a4afd", size = 485841, upload-time = "2026-06-07T21:07:19.555Z" }, + { url = "https://files.pythonhosted.org/packages/3b/b9/57ed8eaf596321c2ad747bd480fb1700dbd7177c60dfc9e4c187f629662e/aiohttp-3.14.1-cp313-cp313-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:7fb4bdf95b0561a79f259f9d28fbc109728c5ee7f27aff6391f0ca703a329abe", size = 492088, upload-time = "2026-06-07T21:07:21.581Z" }, + { url = "https://files.pythonhosted.org/packages/78/c0/5ebe5270a7c140d7c6f79dcb018640225f14d406c149e4eec04a7d82fe71/aiohttp-3.14.1-cp313-cp313-ios_13_0_x86_64_iphonesimulator.whl", hash = "sha256:1b9748363260121d2927704f5d4fc498150669ca3ae93625986ee89c8f80dcd4", size = 501564, upload-time = "2026-06-07T21:07:23.388Z" }, + { url = "https://files.pythonhosted.org/packages/75/7f/8cdaa24fc7983865e0915153b96a9ac5bcdd3548d64c5a27d17cecccad2d/aiohttp-3.14.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:86a6dab78b0e43e2897a3bbe15745aa60dc5423ca437b7b0b164c069bf91b876", size = 751998, upload-time = "2026-06-07T21:07:25.046Z" }, + { url = "https://files.pythonhosted.org/packages/b2/f4/c4227aacfacc5cb0cc2d119b65301d177912a6842cd64e120c47af76064f/aiohttp-3.14.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:4dfd6e47d3c44c2279907607f73a4240b88c69eb8b90da7e2441a8045dfd21da", size = 510918, upload-time = "2026-06-07T21:07:27.28Z" }, + { url = "https://files.pythonhosted.org/packages/ab/01/a2d5f96cd4e74424864d30bc0a7e44d0a12dacdcfa91b5b2d1bd3dca6bf3/aiohttp-3.14.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:317acd9f8602858dc7d59679812c376c7f0b97bcbbf16e0d6237f54141d8a8a6", size = 508657, upload-time = "2026-06-07T21:07:29.252Z" }, + { url = "https://files.pythonhosted.org/packages/e8/ed/3c0fb5c500fdd8e7ebc10d1889c04384fffa1a9163eac1356088ca9da1b1/aiohttp-3.14.1-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:bd869c427324e5cb15195793de951295710db28be7d818247f3097b4ab5d4b96", size = 1757907, upload-time = "2026-06-07T21:07:31.03Z" }, + { url = "https://files.pythonhosted.org/packages/9d/6e/dbf1d0625dc711fb2851f4f3c3055c39ed58bae92082d8c627dbe6013736/aiohttp-3.14.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:faccab372e66bc76d5731525e7f1143c922271725b9d38c9f97edcc66266b451", size = 1783881, upload-time = "2026-06-07T21:07:39.063Z" }, + { url = "https://files.pythonhosted.org/packages/2a/bd/cf9cee17e140f942a3de73e658a543aa8fbf35a5fc67a9d2538d52d77f0b/aiohttp-3.14.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:97e704dcd26271f5bda3fa07c3ce0fb76d6d3f8659f4baa1a24442cc9ba177ca", size = 1722137, upload-time = "2026-06-07T21:07:43.014Z" }, + { url = "https://files.pythonhosted.org/packages/ba/45/4de841f005cfe1fd63e2a2fe011262c515e2a62aa6994b15947e7d717ac9/aiohttp-3.14.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:cb21957bb8aca671c1765e32f58164cf0c50e6bf41c0bbbd16da20732ecaf588", size = 1761094, upload-time = "2026-06-07T21:07:54.113Z" }, + { url = "https://files.pythonhosted.org/packages/85/a5/9594ad6289eebbc97d167c44213d557807f90e59115caad24de21ad2c3b1/aiohttp-3.14.1-cp314-cp314-ios_13_0_arm64_iphoneos.whl", hash = "sha256:62a759436b29e677181a9e76bab8b8f689a29cb9c535f45f7c48c9c830d3f8c3", size = 487918, upload-time = "2026-06-07T21:08:06.377Z" }, + { url = "https://files.pythonhosted.org/packages/b4/61/16a32c36c3c49edec122a3dc811f2057df2f94d3b14aa107c8017d981618/aiohttp-3.14.1-cp314-cp314-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:2964cbf553df4d7a57348da44d961d871895fc1ee4e8c322b2a95612c7b17fba", size = 494014, upload-time = "2026-06-07T21:08:08.263Z" }, + { url = "https://files.pythonhosted.org/packages/9b/89/3ebcf96ed99c05bec9c434aaac6963fd3cbab4a786ae739908a144d9ce44/aiohttp-3.14.1-cp314-cp314-ios_13_0_x86_64_iphonesimulator.whl", hash = "sha256:237651caadc3a59badd39319c54642b5299e9cc98a3a194310e55d5bb9f5e397", size = 502398, upload-time = "2026-06-07T21:08:10.244Z" }, + { url = "https://files.pythonhosted.org/packages/fd/3d/b74870a0c2d40c355928cd5b96c7a11fa821b8a40fc41365e64479b151fb/aiohttp-3.14.1-cp314-cp314-macosx_10_15_universal2.whl", hash = "sha256:896e12dfdbbab9d8f7e16d2b28c6769a60126fa92095d1ebf9473d02593a2448", size = 758018, upload-time = "2026-06-07T21:08:12.447Z" }, + { url = "https://files.pythonhosted.org/packages/d3/66/f42f5c984d99e49c6cff5f26f590750f2e2f7ef1fcfb99966ab5be1b632e/aiohttp-3.14.1-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:d03f281ed22579314ba00821ce20115a7c0ac430660b4cc05704a3f818b3e004", size = 512462, upload-time = "2026-06-07T21:08:14.624Z" }, + { url = "https://files.pythonhosted.org/packages/e9/a7/248e1aebe0c7810b0271e021a0f2a5eb6e78a051885b3c9df49f42a5802d/aiohttp-3.14.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:07eabb979d236335fed927e137a928c9adfb7df3b9ec7aa31726f133a62be983", size = 512824, upload-time = "2026-06-07T21:08:16.572Z" }, + { url = "https://files.pythonhosted.org/packages/26/97/2aa0e5ba0727dc3bd5aaebb7ccbc510f7dfb7fb961ec87497cd496635ab1/aiohttp-3.14.1-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4fe1f1087cbadb280b5e1bb054a4f00d1423c74d6626c5e48400d871d34ecefe", size = 1749898, upload-time = "2026-06-07T21:08:18.635Z" }, + { url = "https://files.pythonhosted.org/packages/a0/18/938441025db6769a3464596b2410af3afde0b21eb2f204c6f766f68af4bd/aiohttp-3.14.1-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:634e385930fb6d2d479cf3aa66515955863b77a5e3c2b5894ca259a25b308602", size = 1760329, upload-time = "2026-06-07T21:08:27.363Z" }, + { url = "https://files.pythonhosted.org/packages/49/a2/2136674d52123b1354bd05dd5753c318db47dc0c927cc70b27bab3755456/aiohttp-3.14.1-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:335c0cc3e3545ce98dcb9cfcb836f40c3411f43fa03dab757597d80c89af8a35", size = 1714756, upload-time = "2026-06-07T21:08:32.094Z" }, + { url = "https://files.pythonhosted.org/packages/c1/af/14bb5843eccbe234f4dfb78ab73e549d99727247e62ae5d62cbd22eaf5b0/aiohttp-3.14.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:6ffbb2f4ec1ceaff7e07d43922954da26b223d188bf30658e561b98e23089444", size = 1742574, upload-time = "2026-06-07T21:08:43.795Z" }, + { url = "https://files.pythonhosted.org/packages/34/e3/19dbe1a1f4cc6230eb9e314de7fe68053b0992f9302b27d12141a0b5db53/aiohttp-3.14.1-cp314-cp314t-macosx_10_15_universal2.whl", hash = "sha256:819c054312f1af92947e6a55883d1b66feefab11531a7fc45e0fb9b63880b5c2", size = 793320, upload-time = "2026-06-07T21:08:52.775Z" }, + { url = "https://files.pythonhosted.org/packages/7f/20/1b7182219ba1b108430d6e4dc53d25ae02dcfcf5a045b33af4e8c5167527/aiohttp-3.14.1-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:10ee9c1753a8f706345b22496c79fbddb5be0599e0823f3738b1534058e25340", size = 529077, upload-time = "2026-06-07T21:08:55Z" }, + { url = "https://files.pythonhosted.org/packages/b9/c8/14ce60ec31a2e5f5274bb17d383a6f7a3aabca31ac04eee05585bbadab16/aiohttp-3.14.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:1601cc37baf5750ccacae618ec2daf020769581695550e3b654a911f859c563d", size = 532476, upload-time = "2026-06-07T21:08:57.176Z" }, + { url = "https://files.pythonhosted.org/packages/7e/02/9ac85e081e53da2e061b02fa7758fe0a12d17b8ce2d1f5e6c7cb76730328/aiohttp-3.14.1-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4d6e0ac9da31c9c04c84e1c0182ad8d6df35965a85cae29cd71d089621b3ae94", size = 1922347, upload-time = "2026-06-07T21:08:59.563Z" }, + { url = "https://files.pythonhosted.org/packages/66/4e/560c7472d3d198a23aa5c8b19a5115bf6a9b77b7d3e4bb363da320430ad2/aiohttp-3.14.1-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fc0cacab7ba4e56f0f81c82a98c09bed2f39c940107b03a34b168bdf7597edd3", size = 1877095, upload-time = "2026-06-07T21:09:09.011Z" }, + { url = "https://files.pythonhosted.org/packages/6a/c9/48255813cca749a229ef0ab476004ec623728ad79a9c0840616f6c076325/aiohttp-3.14.1-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:38e1e7daaea81df51c952e18483f323d878499a1e2bfe564790e0f9701d6f203", size = 1842922, upload-time = "2026-06-07T21:09:14.118Z" }, + { url = "https://files.pythonhosted.org/packages/44/be/0474c5a8b5640e1e4aa1923430a91f4151be82e511373fe764189b89aef5/aiohttp-3.14.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:99abd37084b82f5830c635fddd0b4993b9742a66eb746dacf433c8590e8f9e3c", size = 1841409, upload-time = "2026-06-07T21:09:26.207Z" }, ] [[package]] @@ -858,7 +858,7 @@ test = [ [package.metadata] requires-dist = [ - { name = "aiohttp", marker = "extra == 'test'", specifier = "==3.14.0" }, + { name = "aiohttp", marker = "extra == 'test'", specifier = "==3.14.1" }, { name = "colorama", specifier = "==0.4.6" }, { name = "coverage", marker = "extra == 'test'", specifier = "==7.13.4" }, { name = "cyclopts", specifier = "==4.10.0" },