diff --git a/scripts/agentic_inference_isl_precompute.py b/scripts/agentic_inference_isl_precompute.py new file mode 100644 index 000000000..09cdc8e91 --- /dev/null +++ b/scripts/agentic_inference_isl_precompute.py @@ -0,0 +1,174 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Offline ISL (Input Sequence Length) computation for multi-turn datasets. + +Run from the repo root to print the ISL distribution for a dataset:: + + python scripts/agentic_inference_isl_precompute.py \\ + --dataset path/to/dataset.jsonl \\ + --tokenizer +""" + +from __future__ import annotations + +import argparse +import logging +import os +import threading +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pandas as pd +from inference_endpoint.async_utils.services.metrics_aggregator.token_metrics import ( + _normalize_tool_calls_for_template, +) +from inference_endpoint.dataset_manager.multi_turn_dataset import MultiTurnDataset +from tqdm import tqdm +from transformers import AutoTokenizer + +logger = logging.getLogger(__name__) + + +def _precompute_isl(dataloader: MultiTurnDataset, tokenizer_name: str) -> None: + samples_with_messages = [s for s in (dataloader.data or []) if s.get("messages")] + if not samples_with_messages: + return + + try: + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + except Exception: + logger.exception("Failed to load tokenizer %s", tokenizer_name) + return + + first_failure_logged = False + first_failure_lock = threading.Lock() + + def _tokenize_sample(sample: dict) -> list[int] | None: + try: + normalized_messages = [] + for msg in sample["messages"]: + if msg.get("tool_calls"): + msg = { + **msg, + "tool_calls": _normalize_tool_calls_for_template( + msg["tool_calls"] + ), + } + normalized_messages.append(msg) + tools = sample.get("tools") + raw = tokenizer.apply_chat_template( + normalized_messages, + tools=tools if tools else None, + tokenize=True, + add_generation_prompt=True, + ) + # Some tokenizers (e.g. Qwen3 fast tokenizer) return BatchEncoding + # instead of a plain list; extract .input_ids in that case. + token_ids: list[int] = raw.input_ids if hasattr(raw, "input_ids") else raw + return token_ids + except Exception: + nonlocal first_failure_logged + with first_failure_lock: + if not first_failure_logged: + logger.exception("apply_chat_template failed (first failure shown)") + first_failure_logged = True + return None + + n_workers = min(os.cpu_count() or 32, 32) + skipped = 0 + with ThreadPoolExecutor( + max_workers=n_workers, thread_name_prefix="ISLPrecompute" + ) as pool: + futures = { + pool.submit(_tokenize_sample, sample): sample + for sample in samples_with_messages + } + for future in tqdm( + as_completed(futures), + total=len(futures), + desc="Pre-computing ISL", + unit="turn", + ): + sample = futures[future] + token_ids = future.result() + if token_ids is not None: + sample["input_tokens"] = token_ids + else: + skipped += 1 + + if skipped: + logger.warning("%d turn(s) skipped (apply_chat_template failed)", skipped) + if skipped == len(samples_with_messages): + logger.warning( + "All %d turn(s) failed apply_chat_template. " + "Check tokenizer/template compatibility.", + len(samples_with_messages), + ) + + +def _isl_distribution(dataloader: MultiTurnDataset) -> dict[str, float]: + values = sorted( + len(s["input_tokens"]) + for s in (dataloader.data or []) + if s.get("input_tokens") is not None + ) + if not values: + raise ValueError( + "No input_tokens found — tokenization may have failed entirely." + ) + n = len(values) + + def percentile(p: float) -> float: + idx = (p / 100) * (n - 1) + lo, frac = int(idx), idx % 1 + return values[lo] + frac * (values[lo + 1] - values[lo] if lo + 1 < n else 0) + + return { + "min": values[0], + "max": values[-1], + "mean": sum(values) / n, + "p50": percentile(50), + "p99": percentile(99), + } + + +def main() -> None: + logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s") + + parser = argparse.ArgumentParser( + description="Compute ISL distribution for a multi-turn dataset." + ) + parser.add_argument("--dataset", required=True, help="Path to JSONL dataset file.") + parser.add_argument( + "--tokenizer", required=True, help="HuggingFace repo ID or local path." + ) + args = parser.parse_args() + + ds = MultiTurnDataset(pd.read_json(args.dataset, lines=True)) + ds.load() + _precompute_isl(ds, args.tokenizer) + + stats = _isl_distribution(ds) + n = sum(1 for s in (ds.data or []) if s.get("input_tokens") is not None) + print(f"ISL distribution ({n} turns)") + print(f" min : {stats['min']:.0f}") + print(f" mean : {stats['mean']:.1f}") + print(f" p50 : {stats['p50']:.0f}") + print(f" p99 : {stats['p99']:.0f}") + print(f" max : {stats['max']:.0f}") + + +if __name__ == "__main__": + main() diff --git a/src/inference_endpoint/commands/benchmark/execute.py b/src/inference_endpoint/commands/benchmark/execute.py index e3c5505b9..7019ef444 100644 --- a/src/inference_endpoint/commands/benchmark/execute.py +++ b/src/inference_endpoint/commands/benchmark/execute.py @@ -43,7 +43,6 @@ import msgspec.json from huggingface_hub import model_info from tqdm import tqdm -from transformers import AutoTokenizer from transformers.utils import logging as transformers_logging from inference_endpoint.async_utils.event_publisher import EventPublisherService @@ -58,9 +57,6 @@ from inference_endpoint.async_utils.services.metrics_aggregator.subscriber import ( MetricsSnapshotSubscriber, ) -from inference_endpoint.async_utils.services.metrics_aggregator.token_metrics import ( - _normalize_tool_calls_for_template, -) from inference_endpoint.async_utils.transport.zmq.context import ManagedZMQContext from inference_endpoint.config.runtime_settings import RuntimeSettings from inference_endpoint.config.schema import ( @@ -314,75 +310,6 @@ def _load_datasets( return dataloader, accuracy_datasets, eval_configs -def _precompute_isl_for_multi_turn( - dataloader: MultiTurnDataset, tokenizer_name: str -) -> None: - """Tokenize pre-built message lists and store token counts in each sample. - - Runs apply_chat_template once per client turn so the hot-path IslTrigger - sync path (len(token_ids)) is used instead of on-the-fly text tokenization. - Only affects dataset-history turns; live-history turns override 'messages' - at runtime so the stored input_tokens are stale (acceptable approximation). - """ - try: - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) - except Exception: - logger.exception( - "ISL pre-computation: failed to load tokenizer %s; " - "falling back to text-tokenization at runtime", - tokenizer_name, - ) - return - skipped = 0 - first_failure_logged = False - for sample in dataloader.data or []: - messages = sample.get("messages") - if not messages: - continue - try: - normalized_messages = [] - for msg in messages: - if msg.get("tool_calls"): - msg = { - **msg, - "tool_calls": _normalize_tool_calls_for_template( - msg["tool_calls"] - ), - } - normalized_messages.append(msg) - tools = sample.get("tools") - raw = tokenizer.apply_chat_template( - normalized_messages, - tools=tools if tools else None, - tokenize=True, - add_generation_prompt=True, - ) - # Some tokenizers (e.g. Qwen3 fast tokenizer) return BatchEncoding - # instead of a plain list; extract .input_ids in that case. - token_ids: list[int] = raw.input_ids if hasattr(raw, "input_ids") else raw - sample["input_tokens"] = token_ids - except Exception: - if not first_failure_logged: - logger.exception( - "ISL pre-computation: apply_chat_template failed (first failure shown)" - ) - first_failure_logged = True - skipped += 1 - if skipped: - logger.warning( - "ISL pre-computation: %d turn(s) skipped (apply_chat_template failed)", - skipped, - ) - total_with_messages = len([s for s in (dataloader.data or []) if s.get("messages")]) - if total_with_messages > 0 and skipped == total_with_messages: - logger.warning( - "ISL precomputation: all %d turn(s) failed apply_chat_template; " - "ISL metrics will use text-tokenization fallback. " - "Check tokenizer/template compatibility.", - total_with_messages, - ) - - def setup_benchmark(config: BenchmarkConfig, test_mode: TestMode) -> BenchmarkContext: """Load tokenizer, dataset, create scheduler, setup report dir.""" # CPU affinity @@ -423,10 +350,6 @@ def setup_benchmark(config: BenchmarkConfig, test_mode: TestMode) -> BenchmarkCo # Datasets dataloader, accuracy_datasets, eval_configs = _load_datasets(config, report_dir) - if isinstance(dataloader, MultiTurnDataset) and tokenizer_name is not None: - logger.info("Pre-computing ISL token counts for multi-turn dataset…") - _precompute_isl_for_multi_turn(dataloader, tokenizer_name) - # Setup runtime settings using factory method rt_settings = RuntimeSettings.from_config(config, dataloader.num_samples()) diff --git a/src/inference_endpoint/dataset_manager/multi_turn_dataset.py b/src/inference_endpoint/dataset_manager/multi_turn_dataset.py index 4e200d06a..bb20153a5 100644 --- a/src/inference_endpoint/dataset_manager/multi_turn_dataset.py +++ b/src/inference_endpoint/dataset_manager/multi_turn_dataset.py @@ -70,7 +70,7 @@ class ConversationMetadata: delay_seconds_by_key: dict[tuple[str, int], float] = field(default_factory=dict) -def _expand_tool_results(row: dict) -> list[dict]: +def _expand_tool_results(row: dict | pd.Series) -> list[dict]: """Expand a tool row into one OpenAI tool message per result. All ``role: "tool"`` rows carry a ``tool_results`` array. Each entry expands to @@ -113,6 +113,127 @@ def _expand_tool_results(row: dict) -> list[dict]: return messages +def _build_conversation_metadata( + conv_id: Any, + group: Any, + enable_salt: bool, +) -> tuple[ + str, + dict[tuple[str, int], list[dict]], + dict[tuple[str, int], list[dict]], + str | None, + dict[tuple[str, int], float], + list[ConversationSampleEntry], + int, +]: + """Build message history for all client turns in a single conversation. + + Returns a tuple of (str_conv_id, pre_built_messages, current_turn_messages, + system_prompt, delay_seconds, samples, client_turns_count). + """ + str_conv_id = str(conv_id) + sorted_group = group.sort_values("turn") + client_rows = sorted_group[sorted_group["role"].isin(["user", "tool"])] + + system_content: str | None = None + for _, srow in sorted_group.iterrows(): + val = srow.get("system") + if val and isinstance(val, str): + system_content = val + break + if enable_salt and system_content: + salt_hex = hashlib.blake2b( + str_conv_id.encode("utf-8"), digest_size=8 + ).hexdigest() + system_content = f"{system_content}\n\n[cache_salt: {salt_hex}]" + elif enable_salt: + logger.warning( + "multi_turn.enable_salt requested but conversation %s has no " + "system prompt; cache salt not applied", + conv_id, + ) + + pre_built_messages_by_key: dict[tuple[str, int], list[dict]] = {} + current_turn_messages_by_key: dict[tuple[str, int], list[dict]] = {} + delay_seconds_by_key: dict[tuple[str, int], float] = {} + samples: list[ConversationSampleEntry] = [] + + # Single pass over all rows in turn order, carrying a running history list. + # Each row is formatted once; client turns snapshot (history + current_msgs) + # before extending history. + history: list[dict] = [] + if system_content: + history.append({"role": "system", "content": system_content}) + + for _, row in sorted_group.iterrows(): + role = row.get("role") + + # Format this row into message(s) using the same field extraction as before. + expanded = _expand_tool_results(row) + if expanded: + row_msgs: list[dict] = expanded + else: + msg: dict[str, Any] = {} + for key in ( + "role", + "content", + "name", + "tool_calls", + "tool_results", + "reasoning_content", + ): + val = row.get(key) + if val is not None and not (isinstance(val, float) and pd.isna(val)): + msg[key] = val + if ( + msg.get("role") == "assistant" + and "tool_calls" in msg + and "content" not in msg + ): + msg["content"] = None + row_msgs = [msg] if msg.get("role") else [] + + if role in ("user", "tool"): + t_n = int(row["turn"]) + current_turn_msgs: list[dict] = row_msgs + # Snapshot: history holds everything before this turn; create new lists + # so stored snapshots are not mutated by later history extensions. + pre_built_messages_by_key[(str_conv_id, t_n)] = ( + list(history) + current_turn_msgs + ) + current_turn_messages_by_key[(str_conv_id, t_n)] = current_turn_msgs + history.extend(current_turn_msgs) + + delay_val = row.get("delay_seconds") + if delay_val is not None and not ( + isinstance(delay_val, float) and pd.isna(delay_val) + ): + try: + delay_f = float(delay_val) + except (TypeError, ValueError): + delay_f = 0.0 + if delay_f > 0.0: + delay_seconds_by_key[(str_conv_id, t_n)] = delay_f + + samples.append( + ConversationSampleEntry(conversation_id=str_conv_id, turn=t_n) + ) + else: + # Non-client row (assistant, etc.): extend history for future client turns. + history.extend(row_msgs) + + client_turns_count = int(client_rows.shape[0]) + return ( + str_conv_id, + pre_built_messages_by_key, + current_turn_messages_by_key, + system_content, + delay_seconds_by_key, + samples, + client_turns_count, + ) + + class MultiTurnDataset(Dataset, dataset_id="multi_turn_conversations"): """Dataset for multi-turn conversations. @@ -384,129 +505,31 @@ def _build_metadata(self) -> ConversationMetadata: Returns: ConversationMetadata with samples, counts, and pre-built message maps. """ - samples: list[ConversationSampleEntry] = [] - - # Count client turns (user + tool) per conversation for completion tracking - client_turns_per_conv = { - str(conv_id): int(group["role"].isin(["user", "tool"]).sum()) - for conv_id, group in self._conv_groups.items() - } + assert self.dataframe is not None, "Dataframe must be initialized" - # Map (conversation_id, turn) → complete message list ready to send to endpoint. - # Each entry is: [system (optional)] + all prior rows formatted as messages - # + the current client turn message. - # This includes assistant rows (tool dispatches or terminal responses) - # so no runtime injection is required. - pre_built_messages_by_key: dict[tuple, list[dict]] = {} - current_turn_messages_by_key: dict[tuple, list[dict]] = {} + samples: list[ConversationSampleEntry] = [] + client_turns_per_conv: dict[str, int] = {} + pre_built_messages_by_key: dict[tuple[str, int], list[dict]] = {} + current_turn_messages_by_key: dict[tuple[str, int], list[dict]] = {} system_prompts_by_conv: dict[str, str | None] = {} - delay_seconds_by_key: dict[tuple, float] = {} + delay_seconds_by_key: dict[tuple[str, int], float] = {} - assert self.dataframe is not None, "Dataframe must be initialized" for conv_id, group in self._conv_groups.items(): - sorted_group = group.sort_values("turn") - client_rows = sorted_group[sorted_group["role"].isin(["user", "tool"])] - - # Extract system prompt from the first row that has it (typically turn 1) - system_content: str | None = None - for _, srow in sorted_group.iterrows(): - val = srow.get("system") - if val and isinstance(val, str): - system_content = val - break - if self._enable_salt and system_content: - salt_hex = hashlib.blake2b( - str(conv_id).encode("utf-8"), digest_size=8 - ).hexdigest() - system_content = f"{system_content}\n\n[cache_salt: {salt_hex}]" - elif self._enable_salt: - logger.warning( - "multi_turn.enable_salt requested but conversation %s has no " - "system prompt; cache salt not applied", - conv_id, - ) - system_prompts_by_conv[str(conv_id)] = system_content - - for _, row in client_rows.iterrows(): - t_n = int(row["turn"]) - - messages: list[dict] = [] - if system_content: - messages.append({"role": "system", "content": system_content}) - - # All dataset rows strictly before this client turn (includes - # assistant rows and prior tool results). - prior_rows = sorted_group[sorted_group["turn"] < t_n] - for _, prior_row in prior_rows.iterrows(): - msg: dict[str, Any] = {} - for key in ( - "role", - "content", - "name", - "tool_calls", - "tool_results", - "reasoning_content", - ): - val = prior_row.get(key) - if val is not None and not ( - isinstance(val, float) and pd.isna(val) - ): - msg[key] = val - if ( - msg.get("role") == "assistant" - and "tool_calls" in msg - and "content" not in msg - ): - msg["content"] = None - if msg.get("role"): - # Expand merged parallel tool results: a single row with - # tool_results: [{tool_call_id, content}, ...] expands into - # one OpenAI tool message per result entry. - expanded = _expand_tool_results(msg) - if expanded: - messages.extend(expanded) - else: - messages.append(msg) - - # Append the current client turn message. - # A merged parallel-tool row carries tool_results instead of a - # single tool_call_id/content pair; expand to one message per result. - current_turn_msgs: list[dict] = [] - expanded = _expand_tool_results(row) - if expanded: - current_turn_msgs = expanded - else: - cur: dict[str, Any] = {} - for key in ("role", "content", "name"): - val = row.get(key) - if val is not None and not ( - isinstance(val, float) and pd.isna(val) - ): - cur[key] = val - current_turn_msgs = [cur] - messages.extend(current_turn_msgs) - - str_conv_id = str(conv_id) - pre_built_messages_by_key[(str_conv_id, t_n)] = messages - current_turn_messages_by_key[(str_conv_id, t_n)] = current_turn_msgs - - delay_val = row.get("delay_seconds") - if delay_val is not None and not ( - isinstance(delay_val, float) and pd.isna(delay_val) - ): - try: - delay_f = float(delay_val) - except (TypeError, ValueError): - delay_f = 0.0 - if delay_f > 0.0: - delay_seconds_by_key[(str_conv_id, t_n)] = delay_f - - samples.append( - ConversationSampleEntry( - conversation_id=str_conv_id, - turn=t_n, - ) - ) + ( + str_conv_id, + partial_pre_built, + partial_current_turn, + system_prompt, + partial_delay, + conv_samples, + client_turns_count, + ) = _build_conversation_metadata(conv_id, group, self._enable_salt) + pre_built_messages_by_key.update(partial_pre_built) + current_turn_messages_by_key.update(partial_current_turn) + system_prompts_by_conv[str_conv_id] = system_prompt + delay_seconds_by_key.update(partial_delay) + samples.extend(conv_samples) + client_turns_per_conv[str_conv_id] = client_turns_count return ConversationMetadata( samples=samples, diff --git a/tests/integration/test_multi_turn_metrics.py b/tests/integration/test_multi_turn_metrics.py deleted file mode 100644 index 22e3a59a4..000000000 --- a/tests/integration/test_multi_turn_metrics.py +++ /dev/null @@ -1,272 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Integration tests: ISL precompute and ISL/OSL/TPOT aggregator metrics for multi-turn runs.""" - -from __future__ import annotations - -import asyncio -from pathlib import Path -from unittest.mock import AsyncMock, MagicMock - -import msgspec -import pandas as pd -import pytest -from inference_endpoint.async_utils.services.metrics_aggregator.aggregator import ( - MetricsAggregatorService, -) -from inference_endpoint.async_utils.services.metrics_aggregator.metrics_table import ( - MetricSeriesKey, -) -from inference_endpoint.async_utils.services.metrics_aggregator.registry import ( - MetricsRegistry, -) -from inference_endpoint.async_utils.services.metrics_aggregator.snapshot import ( - SeriesStat, - SessionState, -) -from inference_endpoint.async_utils.transport.zmq.context import ManagedZMQContext -from inference_endpoint.commands.benchmark.execute import _precompute_isl_for_multi_turn -from inference_endpoint.core.record import ( - EventRecord, - SampleEventType, - SessionEventType, -) -from inference_endpoint.core.types import PromptData, TextModelOutput -from inference_endpoint.dataset_manager.multi_turn_dataset import MultiTurnDataset - - -class _MockTokenizePool: - async def token_count_async( - self, text: str, _loop: asyncio.AbstractEventLoop - ) -> int: - return len(text.split()) - - async def token_count_message_async( - self, - content: str, - reasoning: str | None, - tool_calls, - _loop: asyncio.AbstractEventLoop, - ) -> int: - tool_calls_str = ( - msgspec.json.encode(list(tool_calls)).decode() if tool_calls else "" - ) - combined = (content or "") + " " + (reasoning or "") + " " + tool_calls_str - return len(combined.split()) - - def close(self) -> None: - pass - - -def _make_aggregator_with_mock_publisher( - zmq_ctx: ManagedZMQContext, - loop: asyncio.AbstractEventLoop, - socket_name: str, - shutdown_event: asyncio.Event, - tokenize_pool=None, -) -> tuple[MetricsAggregatorService, MetricsRegistry]: - """Build an aggregator with a mocked publisher (no ZMQ / disk I/O).""" - registry = MetricsRegistry() - publisher = MagicMock() - publisher.publish_final = AsyncMock() - publisher.aclose = AsyncMock() - agg = MetricsAggregatorService( - socket_name, - zmq_ctx, - loop, - registry=registry, - publisher=publisher, - publish_interval_s=1.0, - sig_figs=3, - n_histogram_buckets=10, - tokenize_pool=tokenize_pool, - streaming=True, - shutdown_event=shutdown_event, - ) - return agg, registry - - -def _session_event(ev_type: SessionEventType, ts: int = 0) -> EventRecord: - return EventRecord(event_type=ev_type, timestamp_ns=ts) - - -def _sample_event( - ev_type: SampleEventType, - uuid: str, - ts: int = 0, - data=None, - conversation_id: str = "", - turn: int | None = None, -) -> EventRecord: - return EventRecord( - event_type=ev_type, - timestamp_ns=ts, - sample_uuid=uuid, - data=data, - conversation_id=conversation_id, - turn=turn, - ) - - -def _snapshot_series_count(registry: MetricsRegistry, name: str) -> int: - snap = registry.build_snapshot(state=SessionState.LIVE, n_pending_tasks=0) - for m in snap.metrics: - if isinstance(m, SeriesStat) and m.name == name: - return m.count - return 0 - - -@pytest.mark.integration -def test_multi_turn_isl_uses_precomputed_token_count(): - """_precompute_isl_for_multi_turn stores input_tokens on each sample dict.""" - transformers = pytest.importorskip("transformers") - - rows = [ - {"conversation_id": "c1", "turn": 1, "role": "user", "content": "Hello there"}, - {"conversation_id": "c1", "turn": 2, "role": "assistant", "content": "Hi"}, - {"conversation_id": "c1", "turn": 3, "role": "user", "content": "How are you?"}, - ] - ds = MultiTurnDataset(pd.DataFrame(rows)) - ds.load() - - tokenizer_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" - try: - transformers.AutoTokenizer.from_pretrained(tokenizer_name) - except Exception: - pytest.skip(f"Tokenizer '{tokenizer_name}' not available in this environment") - - _precompute_isl_for_multi_turn(ds, tokenizer_name) - - for i, sample in enumerate(ds.data or []): - assert ( - "input_tokens" in sample - ), f"Sample {i} missing input_tokens after precompute" - assert isinstance( - sample["input_tokens"], list - ), f"Sample {i} input_tokens not a list" - assert len(sample["input_tokens"]) > 0, f"Sample {i} input_tokens is empty" - - -_TOOL_CALLS = ( - { - "id": "call_1", - "type": "function", - "function": {"name": "search", "arguments": '{"q":"hello"}'}, - }, -) - - -@pytest.mark.integration -@pytest.mark.asyncio -@pytest.mark.parametrize( - ("first_chunk", "later_chunk", "complete_data"), - [ - ( - TextModelOutput(output=("chunk1",)), - TextModelOutput(output=("chunk2",)), - TextModelOutput(output=("chunk1", "chunk2", "chunk3")), - ), - ( - TextModelOutput(output=(), tool_calls=_TOOL_CALLS), - TextModelOutput(output=(), tool_calls=_TOOL_CALLS), - TextModelOutput(output=(), tool_calls=_TOOL_CALLS), - ), - ], - ids=["text", "tool_calls_only"], -) -async def test_multi_turn_aggregator_records_metrics_streaming( - tmp_path: Path, - first_chunk: TextModelOutput, - later_chunk: TextModelOutput, - complete_data: TextModelOutput, -): - """TTFT/ISL/OSL/TPOT fire for streamed multi-turn turns (text and tool-call payloads).""" - loop = asyncio.get_event_loop() - shutdown_event = asyncio.Event() - with ManagedZMQContext.scoped(socket_dir=str(tmp_path)) as ctx: - agg, registry = _make_aggregator_with_mock_publisher( - ctx, - loop, - "test_mt_streaming_metrics", - shutdown_event, - tokenize_pool=_MockTokenizePool(), - ) - try: - t = 0 - - def ts() -> int: - nonlocal t - t += 1_000_000 - return t - - uuid = "mt-turn-1" - events = [ - _session_event(SessionEventType.STARTED, ts=ts()), - _session_event(SessionEventType.START_PERFORMANCE_TRACKING, ts=ts()), - _sample_event( - SampleEventType.ISSUED, - uuid, - ts=ts(), - data=PromptData(token_ids=(1, 2, 3, 4, 5)), - conversation_id="c1", - turn=1, - ), - _sample_event( - SampleEventType.RECV_FIRST, - uuid, - ts=ts(), - data=first_chunk, - conversation_id="c1", - turn=1, - ), - _sample_event( - SampleEventType.RECV_NON_FIRST, - uuid, - ts=ts(), - data=later_chunk, - conversation_id="c1", - turn=1, - ), - _sample_event( - SampleEventType.COMPLETE, - uuid, - ts=ts(), - data=complete_data, - conversation_id="c1", - turn=1, - ), - _session_event(SessionEventType.STOP_PERFORMANCE_TRACKING, ts=ts()), - _session_event(SessionEventType.ENDED, ts=ts()), - ] - await agg.process(events) - - # OSL/TPOT fire via async tokenization tasks; poll until they land. - for _ in range(30): - if _snapshot_series_count(registry, MetricSeriesKey.OSL.value) > 0: - break - await asyncio.sleep(0.05) - - for key in ( - MetricSeriesKey.ISL, - MetricSeriesKey.TTFT_NS, - MetricSeriesKey.OSL, - MetricSeriesKey.TPOT_NS, - ): - assert ( - _snapshot_series_count(registry, key.value) > 0 - ), f"{key.value} must be recorded" - finally: - agg.close() diff --git a/tests/unit/commands/test_precompute_isl.py b/tests/unit/commands/test_precompute_isl.py deleted file mode 100644 index d4ccf8521..000000000 --- a/tests/unit/commands/test_precompute_isl.py +++ /dev/null @@ -1,173 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Unit tests for _precompute_isl_for_multi_turn.""" - -from unittest.mock import MagicMock, patch - -import pytest -from inference_endpoint.commands.benchmark.execute import _precompute_isl_for_multi_turn - - -def _make_dataloader(samples: list[dict]) -> MagicMock: - dl = MagicMock() - dl.data = samples - return dl - - -class TestPrecomputeIslForMultiTurn: - @pytest.mark.unit - def test_sets_input_tokens_for_samples_with_messages(self): - samples = [ - {"messages": [{"role": "user", "content": "hello"}]}, - {"messages": [{"role": "user", "content": "world"}]}, - ] - dataloader = _make_dataloader(samples) - mock_tokenizer = MagicMock() - mock_tokenizer.apply_chat_template.side_effect = lambda msgs, **_: list( - range(len(msgs) * 3) - ) - - with patch( - "inference_endpoint.commands.benchmark.execute.AutoTokenizer" - ) as mock_cls: - mock_cls.from_pretrained.return_value = mock_tokenizer - _precompute_isl_for_multi_turn(dataloader, "test-model") - - for sample in samples: - assert "input_tokens" in sample - assert isinstance(sample["input_tokens"], list) - - @pytest.mark.unit - def test_leaves_samples_without_messages_untouched(self): - samples = [ - {"prompt": "no messages here"}, - {"input_tokens": [1, 2, 3]}, - ] - dataloader = _make_dataloader(samples) - mock_tokenizer = MagicMock() - - with patch( - "inference_endpoint.commands.benchmark.execute.AutoTokenizer" - ) as mock_cls: - mock_cls.from_pretrained.return_value = mock_tokenizer - _precompute_isl_for_multi_turn(dataloader, "test-model") - - mock_tokenizer.apply_chat_template.assert_not_called() - assert "input_tokens" not in samples[0] - assert samples[1]["input_tokens"] == [1, 2, 3] - - @pytest.mark.unit - def test_skips_failed_template_calls_with_warning(self, caplog): - samples = [ - {"messages": [{"role": "user", "content": "good"}]}, - {"messages": [{"role": "user", "content": "bad"}]}, - ] - dataloader = _make_dataloader(samples) - - def side_effect(msgs, **_): - if msgs[0]["content"] == "bad": - raise ValueError("template error") - return [10, 20, 30] - - mock_tokenizer = MagicMock() - mock_tokenizer.apply_chat_template.side_effect = side_effect - - with patch( - "inference_endpoint.commands.benchmark.execute.AutoTokenizer" - ) as mock_cls: - mock_cls.from_pretrained.return_value = mock_tokenizer - with caplog.at_level("WARNING"): - _precompute_isl_for_multi_turn(dataloader, "test-model") - - assert "input_tokens" in samples[0] - assert "input_tokens" not in samples[1] - assert "1 turn(s) skipped" in caplog.text - - @pytest.mark.unit - def test_batch_encoding_return_value_is_unwrapped(self): - """Tokenizers like Qwen3 return BatchEncoding instead of list[int].""" - samples = [{"messages": [{"role": "user", "content": "hi"}]}] - dataloader = _make_dataloader(samples) - - batch_encoding = MagicMock() - batch_encoding.input_ids = [1, 2, 3] - - mock_tokenizer = MagicMock() - mock_tokenizer.apply_chat_template.return_value = batch_encoding - - with patch( - "inference_endpoint.commands.benchmark.execute.AutoTokenizer" - ) as mock_cls: - mock_cls.from_pretrained.return_value = mock_tokenizer - _precompute_isl_for_multi_turn(dataloader, "test-model") - - assert samples[0]["input_tokens"] == [1, 2, 3] - - @pytest.mark.unit - def test_add_generation_prompt_true(self): - samples = [{"messages": [{"role": "user", "content": "hi"}]}] - dataloader = _make_dataloader(samples) - mock_tokenizer = MagicMock() - mock_tokenizer.apply_chat_template.return_value = [1, 2, 3] - - with patch( - "inference_endpoint.commands.benchmark.execute.AutoTokenizer" - ) as mock_cls: - mock_cls.from_pretrained.return_value = mock_tokenizer - _precompute_isl_for_multi_turn(dataloader, "test-model") - - _, kwargs = mock_tokenizer.apply_chat_template.call_args - assert kwargs.get("add_generation_prompt") is True - assert kwargs.get("tokenize") is True - - @pytest.mark.unit - def test_normalizes_tool_call_arguments_before_apply_chat_template(self): - """_normalize_tool_calls_for_template converts arguments strings to dicts.""" - samples = [ - { - "messages": [ - {"role": "user", "content": "use a tool"}, - { - "role": "assistant", - "content": None, - "tool_calls": [ - { - "id": "c1", - "type": "function", - "function": { - "name": "bash", - "arguments": '{"cmd": "ls"}', - }, - } - ], - }, - ] - }, - ] - dataloader = _make_dataloader(samples) - mock_tokenizer = MagicMock() - mock_tokenizer.apply_chat_template.return_value = [1, 2, 3] - - with patch( - "inference_endpoint.commands.benchmark.execute.AutoTokenizer" - ) as mock_cls: - mock_cls.from_pretrained.return_value = mock_tokenizer - _precompute_isl_for_multi_turn(dataloader, "test-model") - - # Production code builds new dicts, so call_args captures the normalized value. - passed_msgs = mock_tokenizer.apply_chat_template.call_args[0][0] - asst_msg = next(m for m in passed_msgs if m.get("role") == "assistant") - assert asst_msg["tool_calls"][0]["function"]["arguments"] == {"cmd": "ls"}