Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions src/inference_endpoint/commands/benchmark/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import msgspec
import msgspec.json
from huggingface_hub import model_info
from pydantic import ValidationError
from tqdm import tqdm
from transformers import AutoTokenizer
from transformers.utils import logging as transformers_logging
Expand Down Expand Up @@ -288,26 +289,35 @@ def _load_datasets(
acc_cfg.accuracy_config.extras or {},
)
)
ds.load(
api_type=config.endpoint_config.api_type, model_params=config.model_params
)
try:
ds_model_params = acc_cfg.effective_generation_config(config.model_params)
except (ValidationError, ValueError) as e:
raise InputValidationError(
f"Dataset '{acc_cfg.name}': invalid generation_config_override: {e}"
) from e
ds.load(api_type=config.endpoint_config.api_type, model_params=ds_model_params)
logger.info(f"Loaded {ds} - {ds.num_samples()} samples")

if not accuracy_cfgs:
logger.info("No accuracy datasets provided")
if len(performance_cfgs) > 1:
raise InputValidationError("Multiple performance datasets not supported")

perf_cfg = performance_cfgs[0]
try:
perf_model_params = perf_cfg.effective_generation_config(config.model_params)
except (ValidationError, ValueError) as e:
raise InputValidationError(
f"Dataset '{perf_cfg.name}': invalid generation_config_override: {e}"
) from e
try:
dataloader = DataLoaderFactory.create_loader(performance_cfgs[0])
dataloader = DataLoaderFactory.create_loader(perf_cfg)
dataloader.load(
api_type=config.endpoint_config.api_type, model_params=config.model_params
api_type=config.endpoint_config.api_type, model_params=perf_model_params
)
logger.info(f"Loaded {dataloader.num_samples()} samples")
except FileNotFoundError as e:
raise InputValidationError(
f"Dataset file not found: {performance_cfgs[0].path}"
) from e
raise InputValidationError(f"Dataset file not found: {perf_cfg.path}") from e
except Exception as e:
raise SetupError(f"Failed to load dataset: {e}") from e

Expand Down
83 changes: 83 additions & 0 deletions src/inference_endpoint/config/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,23 @@ class SystemDefaults(BaseModel):
DEFAULT_METRIC: ClassVar[metrics.Metric] = metrics.Throughput(0.0)


def _deep_merge(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]:
"""Recursively merge ``override`` into ``base`` and return the result.

For overlapping keys whose values are both dicts, recurse; otherwise the
override value wins. Mutates a *copy* — callers can safely pass model_dump()
output. Used by ``Dataset.effective_generation_config`` so a sparse nested
override (e.g. ``{osl_distribution: {max: 512}}``) preserves siblings.
"""
out = dict(base)
for k, v in override.items():
if isinstance(v, dict) and isinstance(out.get(k), dict):
out[k] = _deep_merge(out[k], v)
else:
out[k] = v
return out


class LoadPatternType(str, Enum):
"""Load pattern types."""

Expand Down Expand Up @@ -313,6 +330,38 @@ class Dataset(BaseModel):
multi_turn: MultiTurnConfig | None = Field(
None, description="Multi-turn conversation configuration"
)
# TODO(post-mortem): generation config is per-phase (perf vs. accuracy),

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

# not per-dataset — phases are derived from datasets and the override is
# keyed to dataset identity. This lives on Dataset as a short-term WAR
# so MLPerf-style accuracy + perf can share one fleet. The proper fix is
# a first-class GenerationConfig carried on PhaseConfig, decoupled from
# the dataset entry. Field/method names use "generation_config" to keep
# the eventual migration mechanical.
#
# Caveats on per-dataset overrides today:
# - `name` flows into the request `model` field but the tokenizer and
# aggregator are launched from the global `model_params.name`, so a
# per-dataset rename mismatches ISL/OSL accounting.
# - `streaming` flows into the request but the single MetricsAggregator
# is launched with the global `model_params.streaming` flag, so a
# per-dataset streaming flip will not produce TTFT/TPOT for that
# phase. Keep streaming on `model_params` (per-run) for now.
# - Nested dicts (`osl_distribution`, `chat_template_kwargs`) are
# deep-merged so sparse overrides preserve sibling defaults.
generation_config_override: dict[str, Any] | None = Field(
None,
description=(
"Per-dataset overrides for the top-level model_params (sparse — "
"only the fields you want to override). Merged on top of "
"BenchmarkConfig.model_params at dataset-load time. Useful for "
"MLPerf-style runs where accuracy and performance use different "
"output budgets in the same fleet, e.g. "
"generation_config_override: {max_new_tokens: 32768, "
"temperature: 0.0}. NOTE: per-dataset `streaming` and `name` are "
"accepted (kwargs-style) but not honored by the single-aggregator "
"metrics path — set those on top-level model_params."
),
)

@model_validator(mode="after")
def _auto_derive_name(self) -> Self:
Expand All @@ -321,6 +370,40 @@ def _auto_derive_name(self) -> Self:
object.__setattr__(self, "name", Path(self.path).stem)
return self

@model_validator(mode="after")
def _validate_generation_config_override(self) -> Self:
"""Fail fast on unknown keys; values are validated at merge time
(see ``effective_generation_config``) because cross-field validation
needs the base ``ModelParams`` from ``BenchmarkConfig``.
"""
if self.generation_config_override:
valid = set(ModelParams.model_fields)
bad = sorted(set(self.generation_config_override) - valid)
if bad:
raise ValueError(
f"Dataset '{self.name}': unknown keys in "
f"generation_config_override: {bad}. "
f"Valid keys: {sorted(valid)}"
)
return self

def effective_generation_config(self, base: ModelParams) -> ModelParams:
"""Return base merged with this dataset's generation-config overrides.

Nested dicts are deep-merged so a sparse nested override preserves
sibling defaults (e.g. ``{osl_distribution: {max: 512}}`` keeps the
base ``type/mean/std/min``). The merged dict is re-validated through
``ModelParams.model_validate`` so type-invalid scalar overrides (e.g.
``temperature: 'hot'``) are rejected. Note that this only catches
scalar invalidity — a sparse nested override whose merged result
passes default-validation will not raise (callers that need stricter
nested validation should set ``base`` to an explicit instance).
"""
if not self.generation_config_override:
return base
merged = _deep_merge(base.model_dump(), self.generation_config_override)
return ModelParams.model_validate(merged)


class AccuracyConfig(BaseModel):
"""Accuracy configuration.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ datasets: # Dataset configs
prompt: text_input
accuracy_config: null # Accuracy evaluation settings
multi_turn: null # Multi-turn conversation configuration
generation_config_override: null # Per-dataset overrides for the top-level model_params (sparse — only the fields you want to override). Merged on top of BenchmarkConfig.model_params at dataset-load time. Useful for MLPerf-style runs where accuracy and performance use different output budgets in the same fleet, e.g. generation_config_override: {max_new_tokens: 32768, temperature: 0.0}. NOTE: per-dataset `streaming` and `name` are accepted (kwargs-style) but not honored by the single-aggregator metrics path — set those on top-level model_params.
- name: accuracy
type: accuracy # Dataset purpose: performance or accuracy | options: performance, accuracy
path: '<DATASET_PATH eg: tests/assets/datasets/ds_samples.jsonl>' # Dataset file path
Expand All @@ -42,6 +43,7 @@ datasets: # Dataset configs
extractor: boxed_math_extractor # Answer extractor (abcd_extractor, boxed_math_extractor, identity_extractor, python_code_extractor)
num_repeats: 1 # Repeat dataset N times for evaluation
multi_turn: null # Multi-turn conversation configuration
generation_config_override: null # Per-dataset overrides for the top-level model_params (sparse — only the fields you want to override). Merged on top of BenchmarkConfig.model_params at dataset-load time. Useful for MLPerf-style runs where accuracy and performance use different output budgets in the same fleet, e.g. generation_config_override: {max_new_tokens: 32768, temperature: 0.0}. NOTE: per-dataset `streaming` and `name` are accepted (kwargs-style) but not honored by the single-aggregator metrics path — set those on top-level model_params.
settings:
runtime:
min_duration_ms: 600000 # Min duration (ms, or with suffix: 600s, 10m)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ datasets: # Dataset configs
prompt: text_input
accuracy_config: null # Accuracy evaluation settings
multi_turn: null # Multi-turn conversation configuration
generation_config_override: null # Per-dataset overrides for the top-level model_params (sparse — only the fields you want to override). Merged on top of BenchmarkConfig.model_params at dataset-load time. Useful for MLPerf-style runs where accuracy and performance use different output budgets in the same fleet, e.g. generation_config_override: {max_new_tokens: 32768, temperature: 0.0}. NOTE: per-dataset `streaming` and `name` are accepted (kwargs-style) but not honored by the single-aggregator metrics path — set those on top-level model_params.
- name: accuracy
type: accuracy # Dataset purpose: performance or accuracy | options: performance, accuracy
path: '<DATASET_PATH eg: tests/assets/datasets/ds_samples.jsonl>' # Dataset file path
Expand All @@ -42,6 +43,7 @@ datasets: # Dataset configs
extractor: boxed_math_extractor # Answer extractor (abcd_extractor, boxed_math_extractor, identity_extractor, python_code_extractor)
num_repeats: 1 # Repeat dataset N times for evaluation
multi_turn: null # Multi-turn conversation configuration
generation_config_override: null # Per-dataset overrides for the top-level model_params (sparse — only the fields you want to override). Merged on top of BenchmarkConfig.model_params at dataset-load time. Useful for MLPerf-style runs where accuracy and performance use different output budgets in the same fleet, e.g. generation_config_override: {max_new_tokens: 32768, temperature: 0.0}. NOTE: per-dataset `streaming` and `name` are accepted (kwargs-style) but not honored by the single-aggregator metrics path — set those on top-level model_params.
settings:
runtime:
min_duration_ms: 600000 # Min duration (ms, or with suffix: 600s, 10m)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ datasets: # Dataset configs
prompt: text_input
accuracy_config: null # Accuracy evaluation settings
multi_turn: null # Multi-turn conversation configuration
generation_config_override: null # Per-dataset overrides for the top-level model_params (sparse — only the fields you want to override). Merged on top of BenchmarkConfig.model_params at dataset-load time. Useful for MLPerf-style runs where accuracy and performance use different output budgets in the same fleet, e.g. generation_config_override: {max_new_tokens: 32768, temperature: 0.0}. NOTE: per-dataset `streaming` and `name` are accepted (kwargs-style) but not honored by the single-aggregator metrics path — set those on top-level model_params.
- name: accuracy
type: accuracy # Dataset purpose: performance or accuracy | options: performance, accuracy
path: '<DATASET_PATH eg: tests/assets/datasets/ds_samples.jsonl>' # Dataset file path
Expand All @@ -42,6 +43,7 @@ datasets: # Dataset configs
extractor: boxed_math_extractor # Answer extractor (abcd_extractor, boxed_math_extractor, identity_extractor, python_code_extractor)
num_repeats: 1 # Repeat dataset N times for evaluation
multi_turn: null # Multi-turn conversation configuration
generation_config_override: null # Per-dataset overrides for the top-level model_params (sparse — only the fields you want to override). Merged on top of BenchmarkConfig.model_params at dataset-load time. Useful for MLPerf-style runs where accuracy and performance use different output budgets in the same fleet, e.g. generation_config_override: {max_new_tokens: 32768, temperature: 0.0}. NOTE: per-dataset `streaming` and `name` are accepted (kwargs-style) but not honored by the single-aggregator metrics path — set those on top-level model_params.
settings:
runtime:
min_duration_ms: 600000 # Min duration (ms, or with suffix: 600s, 10m)
Expand Down
148 changes: 148 additions & 0 deletions tests/unit/commands/test_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""Tests for benchmark CLI models, config building, and command handlers."""

import asyncio
import json
import random
import tempfile
from pathlib import Path
Expand All @@ -34,6 +35,7 @@
BenchmarkContext,
ResponseCollector,
_build_phases,
_load_datasets,
_run_benchmark_async,
setup_benchmark,
)
Expand Down Expand Up @@ -1277,3 +1279,149 @@ def test_no_override_yields_none_when_model_has_no_tokenizer(
ctx = setup_benchmark(config, TestMode.PERF)

assert ctx.tokenizer_name is None


class _OverrideTestBase:
"""Shared helpers for the two end-to-end ``_load_datasets`` override classes
below (parametrized over the chat vs text-completions adapter)."""

# Subclasses set these:
api_type: str = ""
max_tokens_key: str = "" # static column name AddStaticColumns adds

def _write_jsonl(self, path: Path, rows: list[dict]) -> None:
path.write_text("\n".join(json.dumps(r) for r in rows) + "\n")

def _prompt_rows(self, prompt: str, ground_truth: str | None = None) -> list[dict]:
"""Adapter-shaped row. Chat adapter wants a 'prompt' column; the
completions adapter wants pre-tokenized 'input_tokens' (so the
Harmonize transform early-exits and we avoid the HF tokenizer
dependency in unit tests)."""
row: dict = {"prompt": prompt}
if self.api_type == "openai_completions":
row = {"input_tokens": [1, 2, 3, 4]}
if ground_truth is not None:
row["ground_truth"] = ground_truth
return [row]

def _build_config(
self,
perf_path: Path,
acc_path: Path,
acc_override: dict | None,
perf_override: dict | None = None,
) -> BenchmarkConfig:
return BenchmarkConfig(
type=TestType.OFFLINE,
model_params={"name": "test-model", "max_new_tokens": 1024},
endpoint_config={
"endpoints": ["http://localhost:8000"],
"api_type": self.api_type,
},
datasets=[
{
"name": "perf",
"type": "performance",
"path": str(perf_path),
**(
{"generation_config_override": perf_override}
if perf_override
else {}
),
},
{
"name": "acc",
"type": "accuracy",
"path": str(acc_path),
"accuracy_config": {
"eval_method": "pass_at_1",
"ground_truth": "ground_truth",
"extractor": "boxed_math_extractor",
},
**(
{"generation_config_override": acc_override}
if acc_override
else {}
),
},
],
)

def _write_fixture(self, tmp_path: Path) -> tuple[Path, Path]:
perf_path = tmp_path / "perf.jsonl"
acc_path = tmp_path / "acc.jsonl"
self._write_jsonl(perf_path, self._prompt_rows("perf-prompt"))
self._write_jsonl(acc_path, self._prompt_rows("acc-prompt", ground_truth="42"))
return perf_path, acc_path

@pytest.mark.unit
def test_override_propagates_to_loaded_rows(self, tmp_path):
"""Override on accuracy dataset → its rows get the overridden value;
unmodified perf dataset keeps the global 1024."""
perf_path, acc_path = self._write_fixture(tmp_path)
config = self._build_config(
perf_path, acc_path, acc_override={"max_new_tokens": 32768}
)
perf_ds, acc_datasets, _ = _load_datasets(config, tmp_path)
assert perf_ds.load_sample(0)[self.max_tokens_key] == 1024
assert acc_datasets[0].load_sample(0)[self.max_tokens_key] == 32768

@pytest.mark.unit
def test_no_override_inherits_global(self, tmp_path):
"""Without overrides, both datasets use the global model_params."""
perf_path, acc_path = self._write_fixture(tmp_path)
config = self._build_config(perf_path, acc_path, acc_override=None)
perf_ds, acc_datasets, _ = _load_datasets(config, tmp_path)
assert perf_ds.load_sample(0)[self.max_tokens_key] == 1024
assert acc_datasets[0].load_sample(0)[self.max_tokens_key] == 1024

@pytest.mark.unit
def test_perf_dataset_override_also_honored(self, tmp_path):
"""Symmetric check: overrides on the performance entry also flow
through (relevant for MLPerf-style perf with shorter max_new_tokens)."""
perf_path, acc_path = self._write_fixture(tmp_path)
config = self._build_config(
perf_path,
acc_path,
acc_override={"max_new_tokens": 32768},
perf_override={"max_new_tokens": 10240},
)
perf_ds, acc_datasets, _ = _load_datasets(config, tmp_path)
assert perf_ds.load_sample(0)[self.max_tokens_key] == 10240
assert acc_datasets[0].load_sample(0)[self.max_tokens_key] == 32768

@pytest.mark.unit
def test_invalid_override_value_raises_input_validation_error(self, tmp_path):
"""A value-level invalidity (e.g. bad streaming enum) is caught at
load time and surfaces as InputValidationError, not a generic
SetupError, so the user sees a clear actionable message."""
perf_path, acc_path = self._write_fixture(tmp_path)
config = self._build_config(
perf_path, acc_path, acc_override={"streaming": "garbage"}
)
with pytest.raises(
InputValidationError, match="invalid generation_config_override"
):
_load_datasets(config, tmp_path)


class TestLoadDatasetsGenerationConfigOverrideChat(_OverrideTestBase):
"""End-to-end ``_load_datasets`` check against the OpenAI **chat**
completions adapter, which emits ``max_completion_tokens``."""

api_type = "openai"
max_tokens_key = "max_completion_tokens"


class TestLoadDatasetsGenerationConfigOverrideCompletions(_OverrideTestBase):
"""End-to-end ``_load_datasets`` check against the OpenAI **text**
completions adapter (``/v1/completions``), which emits ``max_tokens``.

This is the headline target of PR #344 — MLPerf-style runs use
``api_type: openai_completions`` for pre-tokenized inputs — so an
integration test on this code path is essential. Rows carry pre-baked
``input_tokens`` so the adapter's ``Harmonize()`` transform early-exits
and the test stays free of HF tokenizer downloads."""

api_type = "openai_completions"
max_tokens_key = "max_tokens"
Loading
Loading