diff --git a/src/inference_endpoint/commands/benchmark/execute.py b/src/inference_endpoint/commands/benchmark/execute.py index e3c5505b9..8b9ee0499 100644 --- a/src/inference_endpoint/commands/benchmark/execute.py +++ b/src/inference_endpoint/commands/benchmark/execute.py @@ -42,6 +42,7 @@ import msgspec import msgspec.json from huggingface_hub import model_info +from pydantic import ValidationError from tqdm import tqdm from transformers import AutoTokenizer from transformers.utils import logging as transformers_logging @@ -288,9 +289,13 @@ def _load_datasets( acc_cfg.accuracy_config.extras or {}, ) ) - ds.load( - api_type=config.endpoint_config.api_type, model_params=config.model_params - ) + try: + ds_model_params = acc_cfg.effective_generation_config(config.model_params) + except (ValidationError, ValueError) as e: + raise InputValidationError( + f"Dataset '{acc_cfg.name}': invalid generation_config_override: {e}" + ) from e + ds.load(api_type=config.endpoint_config.api_type, model_params=ds_model_params) logger.info(f"Loaded {ds} - {ds.num_samples()} samples") if not accuracy_cfgs: @@ -298,16 +303,21 @@ def _load_datasets( if len(performance_cfgs) > 1: raise InputValidationError("Multiple performance datasets not supported") + perf_cfg = performance_cfgs[0] + try: + perf_model_params = perf_cfg.effective_generation_config(config.model_params) + except (ValidationError, ValueError) as e: + raise InputValidationError( + f"Dataset '{perf_cfg.name}': invalid generation_config_override: {e}" + ) from e try: - dataloader = DataLoaderFactory.create_loader(performance_cfgs[0]) + dataloader = DataLoaderFactory.create_loader(perf_cfg) dataloader.load( - api_type=config.endpoint_config.api_type, model_params=config.model_params + api_type=config.endpoint_config.api_type, model_params=perf_model_params ) logger.info(f"Loaded {dataloader.num_samples()} samples") except FileNotFoundError as e: - raise InputValidationError( - f"Dataset file not found: {performance_cfgs[0].path}" - ) from e + raise InputValidationError(f"Dataset file not found: {perf_cfg.path}") from e except Exception as e: raise SetupError(f"Failed to load dataset: {e}") from e diff --git a/src/inference_endpoint/config/schema.py b/src/inference_endpoint/config/schema.py index 2cfa35d73..6fc6d3afc 100644 --- a/src/inference_endpoint/config/schema.py +++ b/src/inference_endpoint/config/schema.py @@ -54,6 +54,23 @@ class SystemDefaults(BaseModel): DEFAULT_METRIC: ClassVar[metrics.Metric] = metrics.Throughput(0.0) +def _deep_merge(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]: + """Recursively merge ``override`` into ``base`` and return the result. + + For overlapping keys whose values are both dicts, recurse; otherwise the + override value wins. Mutates a *copy* — callers can safely pass model_dump() + output. Used by ``Dataset.effective_generation_config`` so a sparse nested + override (e.g. ``{osl_distribution: {max: 512}}``) preserves siblings. + """ + out = dict(base) + for k, v in override.items(): + if isinstance(v, dict) and isinstance(out.get(k), dict): + out[k] = _deep_merge(out[k], v) + else: + out[k] = v + return out + + class LoadPatternType(str, Enum): """Load pattern types.""" @@ -313,6 +330,38 @@ class Dataset(BaseModel): multi_turn: MultiTurnConfig | None = Field( None, description="Multi-turn conversation configuration" ) + # TODO(post-mortem): generation config is per-phase (perf vs. accuracy), + # not per-dataset — phases are derived from datasets and the override is + # keyed to dataset identity. This lives on Dataset as a short-term WAR + # so MLPerf-style accuracy + perf can share one fleet. The proper fix is + # a first-class GenerationConfig carried on PhaseConfig, decoupled from + # the dataset entry. Field/method names use "generation_config" to keep + # the eventual migration mechanical. + # + # Caveats on per-dataset overrides today: + # - `name` flows into the request `model` field but the tokenizer and + # aggregator are launched from the global `model_params.name`, so a + # per-dataset rename mismatches ISL/OSL accounting. + # - `streaming` flows into the request but the single MetricsAggregator + # is launched with the global `model_params.streaming` flag, so a + # per-dataset streaming flip will not produce TTFT/TPOT for that + # phase. Keep streaming on `model_params` (per-run) for now. + # - Nested dicts (`osl_distribution`, `chat_template_kwargs`) are + # deep-merged so sparse overrides preserve sibling defaults. + generation_config_override: dict[str, Any] | None = Field( + None, + description=( + "Per-dataset overrides for the top-level model_params (sparse — " + "only the fields you want to override). Merged on top of " + "BenchmarkConfig.model_params at dataset-load time. Useful for " + "MLPerf-style runs where accuracy and performance use different " + "output budgets in the same fleet, e.g. " + "generation_config_override: {max_new_tokens: 32768, " + "temperature: 0.0}. NOTE: per-dataset `streaming` and `name` are " + "accepted (kwargs-style) but not honored by the single-aggregator " + "metrics path — set those on top-level model_params." + ), + ) @model_validator(mode="after") def _auto_derive_name(self) -> Self: @@ -321,6 +370,40 @@ def _auto_derive_name(self) -> Self: object.__setattr__(self, "name", Path(self.path).stem) return self + @model_validator(mode="after") + def _validate_generation_config_override(self) -> Self: + """Fail fast on unknown keys; values are validated at merge time + (see ``effective_generation_config``) because cross-field validation + needs the base ``ModelParams`` from ``BenchmarkConfig``. + """ + if self.generation_config_override: + valid = set(ModelParams.model_fields) + bad = sorted(set(self.generation_config_override) - valid) + if bad: + raise ValueError( + f"Dataset '{self.name}': unknown keys in " + f"generation_config_override: {bad}. " + f"Valid keys: {sorted(valid)}" + ) + return self + + def effective_generation_config(self, base: ModelParams) -> ModelParams: + """Return base merged with this dataset's generation-config overrides. + + Nested dicts are deep-merged so a sparse nested override preserves + sibling defaults (e.g. ``{osl_distribution: {max: 512}}`` keeps the + base ``type/mean/std/min``). The merged dict is re-validated through + ``ModelParams.model_validate`` so type-invalid scalar overrides (e.g. + ``temperature: 'hot'``) are rejected. Note that this only catches + scalar invalidity — a sparse nested override whose merged result + passes default-validation will not raise (callers that need stricter + nested validation should set ``base`` to an explicit instance). + """ + if not self.generation_config_override: + return base + merged = _deep_merge(base.model_dump(), self.generation_config_override) + return ModelParams.model_validate(merged) + class AccuracyConfig(BaseModel): """Accuracy configuration. diff --git a/src/inference_endpoint/config/templates/concurrency_template_full.yaml b/src/inference_endpoint/config/templates/concurrency_template_full.yaml index 4fef4afcb..4308a860f 100644 --- a/src/inference_endpoint/config/templates/concurrency_template_full.yaml +++ b/src/inference_endpoint/config/templates/concurrency_template_full.yaml @@ -27,6 +27,7 @@ datasets: # Dataset configs prompt: text_input accuracy_config: null # Accuracy evaluation settings multi_turn: null # Multi-turn conversation configuration + generation_config_override: null # Per-dataset overrides for the top-level model_params (sparse — only the fields you want to override). Merged on top of BenchmarkConfig.model_params at dataset-load time. Useful for MLPerf-style runs where accuracy and performance use different output budgets in the same fleet, e.g. generation_config_override: {max_new_tokens: 32768, temperature: 0.0}. NOTE: per-dataset `streaming` and `name` are accepted (kwargs-style) but not honored by the single-aggregator metrics path — set those on top-level model_params. - name: accuracy type: accuracy # Dataset purpose: performance or accuracy | options: performance, accuracy path: '' # Dataset file path @@ -42,6 +43,7 @@ datasets: # Dataset configs extractor: boxed_math_extractor # Answer extractor (abcd_extractor, boxed_math_extractor, identity_extractor, python_code_extractor) num_repeats: 1 # Repeat dataset N times for evaluation multi_turn: null # Multi-turn conversation configuration + generation_config_override: null # Per-dataset overrides for the top-level model_params (sparse — only the fields you want to override). Merged on top of BenchmarkConfig.model_params at dataset-load time. Useful for MLPerf-style runs where accuracy and performance use different output budgets in the same fleet, e.g. generation_config_override: {max_new_tokens: 32768, temperature: 0.0}. NOTE: per-dataset `streaming` and `name` are accepted (kwargs-style) but not honored by the single-aggregator metrics path — set those on top-level model_params. settings: runtime: min_duration_ms: 600000 # Min duration (ms, or with suffix: 600s, 10m) diff --git a/src/inference_endpoint/config/templates/offline_template_full.yaml b/src/inference_endpoint/config/templates/offline_template_full.yaml index 1f61837fe..ad307f7db 100644 --- a/src/inference_endpoint/config/templates/offline_template_full.yaml +++ b/src/inference_endpoint/config/templates/offline_template_full.yaml @@ -27,6 +27,7 @@ datasets: # Dataset configs prompt: text_input accuracy_config: null # Accuracy evaluation settings multi_turn: null # Multi-turn conversation configuration + generation_config_override: null # Per-dataset overrides for the top-level model_params (sparse — only the fields you want to override). Merged on top of BenchmarkConfig.model_params at dataset-load time. Useful for MLPerf-style runs where accuracy and performance use different output budgets in the same fleet, e.g. generation_config_override: {max_new_tokens: 32768, temperature: 0.0}. NOTE: per-dataset `streaming` and `name` are accepted (kwargs-style) but not honored by the single-aggregator metrics path — set those on top-level model_params. - name: accuracy type: accuracy # Dataset purpose: performance or accuracy | options: performance, accuracy path: '' # Dataset file path @@ -42,6 +43,7 @@ datasets: # Dataset configs extractor: boxed_math_extractor # Answer extractor (abcd_extractor, boxed_math_extractor, identity_extractor, python_code_extractor) num_repeats: 1 # Repeat dataset N times for evaluation multi_turn: null # Multi-turn conversation configuration + generation_config_override: null # Per-dataset overrides for the top-level model_params (sparse — only the fields you want to override). Merged on top of BenchmarkConfig.model_params at dataset-load time. Useful for MLPerf-style runs where accuracy and performance use different output budgets in the same fleet, e.g. generation_config_override: {max_new_tokens: 32768, temperature: 0.0}. NOTE: per-dataset `streaming` and `name` are accepted (kwargs-style) but not honored by the single-aggregator metrics path — set those on top-level model_params. settings: runtime: min_duration_ms: 600000 # Min duration (ms, or with suffix: 600s, 10m) diff --git a/src/inference_endpoint/config/templates/online_template_full.yaml b/src/inference_endpoint/config/templates/online_template_full.yaml index a212fa95b..2230eaa7a 100644 --- a/src/inference_endpoint/config/templates/online_template_full.yaml +++ b/src/inference_endpoint/config/templates/online_template_full.yaml @@ -27,6 +27,7 @@ datasets: # Dataset configs prompt: text_input accuracy_config: null # Accuracy evaluation settings multi_turn: null # Multi-turn conversation configuration + generation_config_override: null # Per-dataset overrides for the top-level model_params (sparse — only the fields you want to override). Merged on top of BenchmarkConfig.model_params at dataset-load time. Useful for MLPerf-style runs where accuracy and performance use different output budgets in the same fleet, e.g. generation_config_override: {max_new_tokens: 32768, temperature: 0.0}. NOTE: per-dataset `streaming` and `name` are accepted (kwargs-style) but not honored by the single-aggregator metrics path — set those on top-level model_params. - name: accuracy type: accuracy # Dataset purpose: performance or accuracy | options: performance, accuracy path: '' # Dataset file path @@ -42,6 +43,7 @@ datasets: # Dataset configs extractor: boxed_math_extractor # Answer extractor (abcd_extractor, boxed_math_extractor, identity_extractor, python_code_extractor) num_repeats: 1 # Repeat dataset N times for evaluation multi_turn: null # Multi-turn conversation configuration + generation_config_override: null # Per-dataset overrides for the top-level model_params (sparse — only the fields you want to override). Merged on top of BenchmarkConfig.model_params at dataset-load time. Useful for MLPerf-style runs where accuracy and performance use different output budgets in the same fleet, e.g. generation_config_override: {max_new_tokens: 32768, temperature: 0.0}. NOTE: per-dataset `streaming` and `name` are accepted (kwargs-style) but not honored by the single-aggregator metrics path — set those on top-level model_params. settings: runtime: min_duration_ms: 600000 # Min duration (ms, or with suffix: 600s, 10m) diff --git a/tests/unit/commands/test_benchmark.py b/tests/unit/commands/test_benchmark.py index 1c90554fb..83c4fd9d5 100644 --- a/tests/unit/commands/test_benchmark.py +++ b/tests/unit/commands/test_benchmark.py @@ -16,6 +16,7 @@ """Tests for benchmark CLI models, config building, and command handlers.""" import asyncio +import json import random import tempfile from pathlib import Path @@ -34,6 +35,7 @@ BenchmarkContext, ResponseCollector, _build_phases, + _load_datasets, _run_benchmark_async, setup_benchmark, ) @@ -1277,3 +1279,149 @@ def test_no_override_yields_none_when_model_has_no_tokenizer( ctx = setup_benchmark(config, TestMode.PERF) assert ctx.tokenizer_name is None + + +class _OverrideTestBase: + """Shared helpers for the two end-to-end ``_load_datasets`` override classes + below (parametrized over the chat vs text-completions adapter).""" + + # Subclasses set these: + api_type: str = "" + max_tokens_key: str = "" # static column name AddStaticColumns adds + + def _write_jsonl(self, path: Path, rows: list[dict]) -> None: + path.write_text("\n".join(json.dumps(r) for r in rows) + "\n") + + def _prompt_rows(self, prompt: str, ground_truth: str | None = None) -> list[dict]: + """Adapter-shaped row. Chat adapter wants a 'prompt' column; the + completions adapter wants pre-tokenized 'input_tokens' (so the + Harmonize transform early-exits and we avoid the HF tokenizer + dependency in unit tests).""" + row: dict = {"prompt": prompt} + if self.api_type == "openai_completions": + row = {"input_tokens": [1, 2, 3, 4]} + if ground_truth is not None: + row["ground_truth"] = ground_truth + return [row] + + def _build_config( + self, + perf_path: Path, + acc_path: Path, + acc_override: dict | None, + perf_override: dict | None = None, + ) -> BenchmarkConfig: + return BenchmarkConfig( + type=TestType.OFFLINE, + model_params={"name": "test-model", "max_new_tokens": 1024}, + endpoint_config={ + "endpoints": ["http://localhost:8000"], + "api_type": self.api_type, + }, + datasets=[ + { + "name": "perf", + "type": "performance", + "path": str(perf_path), + **( + {"generation_config_override": perf_override} + if perf_override + else {} + ), + }, + { + "name": "acc", + "type": "accuracy", + "path": str(acc_path), + "accuracy_config": { + "eval_method": "pass_at_1", + "ground_truth": "ground_truth", + "extractor": "boxed_math_extractor", + }, + **( + {"generation_config_override": acc_override} + if acc_override + else {} + ), + }, + ], + ) + + def _write_fixture(self, tmp_path: Path) -> tuple[Path, Path]: + perf_path = tmp_path / "perf.jsonl" + acc_path = tmp_path / "acc.jsonl" + self._write_jsonl(perf_path, self._prompt_rows("perf-prompt")) + self._write_jsonl(acc_path, self._prompt_rows("acc-prompt", ground_truth="42")) + return perf_path, acc_path + + @pytest.mark.unit + def test_override_propagates_to_loaded_rows(self, tmp_path): + """Override on accuracy dataset → its rows get the overridden value; + unmodified perf dataset keeps the global 1024.""" + perf_path, acc_path = self._write_fixture(tmp_path) + config = self._build_config( + perf_path, acc_path, acc_override={"max_new_tokens": 32768} + ) + perf_ds, acc_datasets, _ = _load_datasets(config, tmp_path) + assert perf_ds.load_sample(0)[self.max_tokens_key] == 1024 + assert acc_datasets[0].load_sample(0)[self.max_tokens_key] == 32768 + + @pytest.mark.unit + def test_no_override_inherits_global(self, tmp_path): + """Without overrides, both datasets use the global model_params.""" + perf_path, acc_path = self._write_fixture(tmp_path) + config = self._build_config(perf_path, acc_path, acc_override=None) + perf_ds, acc_datasets, _ = _load_datasets(config, tmp_path) + assert perf_ds.load_sample(0)[self.max_tokens_key] == 1024 + assert acc_datasets[0].load_sample(0)[self.max_tokens_key] == 1024 + + @pytest.mark.unit + def test_perf_dataset_override_also_honored(self, tmp_path): + """Symmetric check: overrides on the performance entry also flow + through (relevant for MLPerf-style perf with shorter max_new_tokens).""" + perf_path, acc_path = self._write_fixture(tmp_path) + config = self._build_config( + perf_path, + acc_path, + acc_override={"max_new_tokens": 32768}, + perf_override={"max_new_tokens": 10240}, + ) + perf_ds, acc_datasets, _ = _load_datasets(config, tmp_path) + assert perf_ds.load_sample(0)[self.max_tokens_key] == 10240 + assert acc_datasets[0].load_sample(0)[self.max_tokens_key] == 32768 + + @pytest.mark.unit + def test_invalid_override_value_raises_input_validation_error(self, tmp_path): + """A value-level invalidity (e.g. bad streaming enum) is caught at + load time and surfaces as InputValidationError, not a generic + SetupError, so the user sees a clear actionable message.""" + perf_path, acc_path = self._write_fixture(tmp_path) + config = self._build_config( + perf_path, acc_path, acc_override={"streaming": "garbage"} + ) + with pytest.raises( + InputValidationError, match="invalid generation_config_override" + ): + _load_datasets(config, tmp_path) + + +class TestLoadDatasetsGenerationConfigOverrideChat(_OverrideTestBase): + """End-to-end ``_load_datasets`` check against the OpenAI **chat** + completions adapter, which emits ``max_completion_tokens``.""" + + api_type = "openai" + max_tokens_key = "max_completion_tokens" + + +class TestLoadDatasetsGenerationConfigOverrideCompletions(_OverrideTestBase): + """End-to-end ``_load_datasets`` check against the OpenAI **text** + completions adapter (``/v1/completions``), which emits ``max_tokens``. + + This is the headline target of PR #344 — MLPerf-style runs use + ``api_type: openai_completions`` for pre-tokenized inputs — so an + integration test on this code path is essential. Rows carry pre-baked + ``input_tokens`` so the adapter's ``Harmonize()`` transform early-exits + and the test stays free of HF tokenizer downloads.""" + + api_type = "openai_completions" + max_tokens_key = "max_tokens" diff --git a/tests/unit/config/test_schema.py b/tests/unit/config/test_schema.py index a60770121..0c99565c1 100644 --- a/tests/unit/config/test_schema.py +++ b/tests/unit/config/test_schema.py @@ -123,6 +123,113 @@ def test_auto_derive_name(self): ds = Dataset(path="datasets/my_data.jsonl") assert ds.name == "my_data" + @pytest.mark.unit + def test_generation_config_override_accepts_known_keys(self): + ds = Dataset( + name="acc", + type=DatasetType.ACCURACY, + path="acc.jsonl", + generation_config_override={"max_new_tokens": 32768, "temperature": 0.0}, + ) + assert ds.generation_config_override == { + "max_new_tokens": 32768, + "temperature": 0.0, + } + + @pytest.mark.unit + def test_generation_config_override_rejects_unknown_key(self): + with pytest.raises( + ValueError, match=r"unknown keys in generation_config_override.*bogus" + ): + Dataset( + name="acc", + path="a.jsonl", + generation_config_override={"bogus": 1}, + ) + + @pytest.mark.unit + def test_generation_config_override_none_is_noop(self): + base = ModelParams(name="m", max_new_tokens=1024, streaming=StreamingMode.ON) + ds = Dataset(name="x", path="x.jsonl") + assert ds.effective_generation_config(base) is base + + @pytest.mark.unit + def test_effective_generation_config_merges_sparse_dict(self): + base = ModelParams(name="m", temperature=0.5, top_p=0.9, max_new_tokens=1024) + ds = Dataset( + name="x", + path="x.jsonl", + generation_config_override={"max_new_tokens": 32768}, + ) + merged = ds.effective_generation_config(base) + # overridden field changes... + assert merged.max_new_tokens == 32768 + # ...everything else is preserved from base + assert merged.name == "m" + assert merged.temperature == 0.5 + assert merged.top_p == 0.9 + + @pytest.mark.unit + def test_effective_generation_config_validates_value(self): + """ModelParams.model_validate is invoked on the merged dict, so a + type-invalid override is rejected (e.g. wrong type for streaming).""" + base = ModelParams(name="m") + ds = Dataset( + name="x", + path="x.jsonl", + generation_config_override={"streaming": "garbage"}, + ) + with pytest.raises(ValueError): + ds.effective_generation_config(base) + + @pytest.mark.unit + def test_effective_generation_config_deep_merges_nested_dict(self): + """Sparse overrides of nested fields (osl_distribution, + chat_template_kwargs) preserve sibling defaults from the base rather + than wholesale-replacing the nested object. Pins the deep-merge + behavior added in response to PR review feedback. + """ + base = ModelParams( + name="m", + osl_distribution=OSLDistribution( + type=OSLDistributionType.NORMAL, mean=1000, std=200, min=512, max=2048 + ), + ) + ds = Dataset( + name="x", + path="x.jsonl", + generation_config_override={"osl_distribution": {"max": 512}}, + ) + merged = ds.effective_generation_config(base) + # the explicitly overridden nested field changes... + assert merged.osl_distribution.max == 512 + # ...and the unspecified siblings are preserved from base + assert merged.osl_distribution.type == OSLDistributionType.NORMAL + assert merged.osl_distribution.mean == 1000 + assert merged.osl_distribution.std == 200 + assert merged.osl_distribution.min == 512 + + @pytest.mark.unit + def test_effective_generation_config_deep_merges_chat_template_kwargs(self): + """Deep-merge also applies to free-form nested dicts like + chat_template_kwargs; sparse overrides preserve sibling entries. + """ + base = ModelParams( + name="m", chat_template_kwargs={"enable_thinking": True, "tools": []} + ) + ds = Dataset( + name="x", + path="x.jsonl", + generation_config_override={ + "chat_template_kwargs": {"enable_thinking": False} + }, + ) + merged = ds.effective_generation_config(base) + assert merged.chat_template_kwargs == { + "enable_thinking": False, + "tools": [], + } + class TestBenchmarkConfig: @pytest.mark.unit