-
Notifications
You must be signed in to change notification settings - Fork 22
feat: add truncate-results command #354
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: release/v0.5
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,126 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # SPDX-License-Identifier: Apache-2.0 | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # you may not use this file except in compliance with the License. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # You may obtain a copy of the License at | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # Unless required by applicable law or agreed to in writing, software | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # See the License for the specific language governing permissions and | ||||||||||||||||||||||||||||||||||||||||||||||||||
| # limitations under the License. | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| """Truncate a benchmark ``results.json``. | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| Perf+accuracy runs store every query's full response text under | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ``responses``, which can reach gigabytes. ``truncate-results`` keeps the | ||||||||||||||||||||||||||||||||||||||||||||||||||
| first ``keep_n`` responses verbatim and replaces the rest with a per-sample | ||||||||||||||||||||||||||||||||||||||||||||||||||
| content hash, so the file stays small while still proving which outputs were | ||||||||||||||||||||||||||||||||||||||||||||||||||
| produced (proof of work). | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| from __future__ import annotations | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| import hashlib | ||||||||||||||||||||||||||||||||||||||||||||||||||
| import json | ||||||||||||||||||||||||||||||||||||||||||||||||||
| import logging | ||||||||||||||||||||||||||||||||||||||||||||||||||
| from pathlib import Path | ||||||||||||||||||||||||||||||||||||||||||||||||||
| from typing import Annotated, Any | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| import cyclopts | ||||||||||||||||||||||||||||||||||||||||||||||||||
| from pydantic import BaseModel, ConfigDict, Field | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| from inference_endpoint.exceptions import InputValidationError | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| logger = logging.getLogger(__name__) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| _HASH_ALGORITHM = "sha256" | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| def truncate_results_dict(results: dict[str, Any], keep_n: int = 5) -> dict[str, Any]: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """Return a truncated copy of a ``results.json`` dict. | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| Keeps ``config``/``results``/``accuracy_scores``/``errors`` verbatim, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| keeps the first ``keep_n`` ``responses`` full, and adds a ``truncation`` | ||||||||||||||||||||||||||||||||||||||||||||||||||
| block holding a ``sha256`` hash of every response plus counts. A dict | ||||||||||||||||||||||||||||||||||||||||||||||||||
| without a non-empty ``responses`` section (e.g. a perf-only run) is | ||||||||||||||||||||||||||||||||||||||||||||||||||
| returned unchanged. | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||
| responses = results.get("responses") | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if not responses: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| return dict(results) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| uuids = list(responses.keys()) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| kept = uuids[:keep_n] | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| out = dict(results) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| out["responses"] = {uuid: responses[uuid] for uuid in kept} | ||||||||||||||||||||||||||||||||||||||||||||||||||
| out["truncation"] = { | ||||||||||||||||||||||||||||||||||||||||||||||||||
| "responses_truncated": True, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| "hash_algorithm": _HASH_ALGORITHM, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| "n_responses_total": len(uuids), | ||||||||||||||||||||||||||||||||||||||||||||||||||
| "n_responses_kept": len(kept), | ||||||||||||||||||||||||||||||||||||||||||||||||||
| "response_hashes": { | ||||||||||||||||||||||||||||||||||||||||||||||||||
| uuid: hashlib.sha256(str(text).encode("utf-8")).hexdigest() | ||||||||||||||||||||||||||||||||||||||||||||||||||
| for uuid, text in responses.items() | ||||||||||||||||||||||||||||||||||||||||||||||||||
| }, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||
| return out | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| @cyclopts.Parameter(name="*") | ||||||||||||||||||||||||||||||||||||||||||||||||||
| class TruncateConfig(BaseModel): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """truncate-results command config.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| model_config = ConfigDict(extra="forbid", frozen=True, str_strip_whitespace=True) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| results: Path | ||||||||||||||||||||||||||||||||||||||||||||||||||
| keep_n: Annotated[ | ||||||||||||||||||||||||||||||||||||||||||||||||||
| int, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| cyclopts.Parameter( | ||||||||||||||||||||||||||||||||||||||||||||||||||
| alias="--keep-n", help="Number of full responses to keep verbatim" | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ] = Field(5, ge=0) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| output: Annotated[ | ||||||||||||||||||||||||||||||||||||||||||||||||||
| Path | None, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| cyclopts.Parameter( | ||||||||||||||||||||||||||||||||||||||||||||||||||
| alias="--output", help="Output path (default: *.truncated.json)" | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ), | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ] = None | ||||||||||||||||||||||||||||||||||||||||||||||||||
| in_place: Annotated[ | ||||||||||||||||||||||||||||||||||||||||||||||||||
| bool, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| cyclopts.Parameter(alias="--in-place", help="Overwrite the input file"), | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ] = False | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| def execute_truncate(config: TruncateConfig) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| """Read ``config.results``, truncate it, and write the result.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if not config.results.exists(): | ||||||||||||||||||||||||||||||||||||||||||||||||||
| raise InputValidationError(f"Results file not found: {config.results}") | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| data = json.loads(config.results.read_text()) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| truncated = truncate_results_dict(data, keep_n=config.keep_n) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| if config.in_place: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| out_path = config.results | ||||||||||||||||||||||||||||||||||||||||||||||||||
| elif config.output is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| out_path = config.output | ||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| out_path = config.results.with_name(config.results.stem + ".truncated.json") | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| out_path.write_text(json.dumps(truncated, indent=2)) | ||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+104
to
+114
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since this command is specifically designed to handle potentially gigabyte-sized Using
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||
| meta = truncated.get("truncation") | ||||||||||||||||||||||||||||||||||||||||||||||||||
| if meta is None: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| logger.info("No responses to truncate; wrote passthrough copy to %s", out_path) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||
| logger.info( | ||||||||||||||||||||||||||||||||||||||||||||||||||
| "Truncated %d responses to %d full + %d hashes; wrote %s", | ||||||||||||||||||||||||||||||||||||||||||||||||||
| meta["n_responses_total"], | ||||||||||||||||||||||||||||||||||||||||||||||||||
| meta["n_responses_kept"], | ||||||||||||||||||||||||||||||||||||||||||||||||||
| meta["n_responses_total"], | ||||||||||||||||||||||||||||||||||||||||||||||||||
| out_path, | ||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,113 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Unit tests for the truncate-results command.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import hashlib | ||
| import json | ||
|
|
||
| import pytest | ||
| from inference_endpoint.commands.truncate_results import ( | ||
| TruncateConfig, | ||
| execute_truncate, | ||
| truncate_results_dict, | ||
| ) | ||
|
|
||
|
|
||
| def _results(n: int) -> dict: | ||
| return { | ||
| "config": {"mode": "both"}, | ||
| "results": {"total": n, "successful": n, "qps": float(n)}, | ||
| "accuracy_scores": {"ds": {"score": 0.9}}, | ||
| "responses": {f"u{i}": f"response {i}" for i in range(n)}, | ||
| "errors": ["Sample u-err: boom"], | ||
| } | ||
|
|
||
|
|
||
| @pytest.mark.unit | ||
| def test_keeps_first_n_full_and_hashes_every_response(): | ||
| src = _results(5) | ||
| out = truncate_results_dict(src, keep_n=2) | ||
|
|
||
| # First N kept verbatim, the rest dropped from `responses`. | ||
| assert out["responses"] == {"u0": "response 0", "u1": "response 1"} | ||
| # Every original response is provably accounted for via its sha256. | ||
| assert out["truncation"]["response_hashes"] == { | ||
| uid: hashlib.sha256(text.encode()).hexdigest() | ||
| for uid, text in src["responses"].items() | ||
| } | ||
| assert out["truncation"] == { | ||
| "responses_truncated": True, | ||
| "hash_algorithm": "sha256", | ||
| "n_responses_total": 5, | ||
| "n_responses_kept": 2, | ||
| "response_hashes": out["truncation"]["response_hashes"], | ||
| } | ||
|
|
||
|
|
||
| @pytest.mark.unit | ||
| def test_preserves_non_response_sections(): | ||
| src = _results(5) | ||
| out = truncate_results_dict(src, keep_n=2) | ||
| for key in ("config", "results", "accuracy_scores", "errors"): | ||
| assert out[key] == src[key] | ||
|
|
||
|
|
||
| @pytest.mark.unit | ||
| def test_does_not_mutate_input(): | ||
| src = _results(5) | ||
| truncate_results_dict(src, keep_n=2) | ||
| assert len(src["responses"]) == 5 | ||
|
|
||
|
|
||
| @pytest.mark.unit | ||
| def test_keep_n_exceeding_total_keeps_all(): | ||
| out = truncate_results_dict(_results(3), keep_n=10) | ||
| assert len(out["responses"]) == 3 | ||
| assert out["truncation"]["n_responses_kept"] == 3 | ||
|
|
||
|
|
||
| @pytest.mark.unit | ||
| def test_passthrough_when_no_responses(): | ||
| perf_only = {"config": {"mode": "offline"}, "results": {"qps": 50.0}} | ||
| out = truncate_results_dict(perf_only, keep_n=5) | ||
| assert out == perf_only | ||
| assert "truncation" not in out | ||
|
|
||
|
|
||
| @pytest.mark.unit | ||
| def test_execute_writes_truncated_copy_leaving_original(tmp_path): | ||
| src = tmp_path / "results.json" | ||
| src.write_text(json.dumps(_results(4))) | ||
|
|
||
| execute_truncate(TruncateConfig(results=src, keep_n=1)) | ||
|
|
||
| out = json.loads((tmp_path / "results.truncated.json").read_text()) | ||
| assert len(out["responses"]) == 1 | ||
| assert out["truncation"]["n_responses_total"] == 4 | ||
| assert len(json.loads(src.read_text())["responses"]) == 4 # original intact | ||
|
|
||
|
|
||
| @pytest.mark.unit | ||
| def test_execute_in_place(tmp_path): | ||
| src = tmp_path / "results.json" | ||
| src.write_text(json.dumps(_results(4))) | ||
|
|
||
| execute_truncate(TruncateConfig(results=src, keep_n=1, in_place=True)) | ||
|
|
||
| assert len(json.loads(src.read_text())["responses"]) == 1 | ||
| assert not (tmp_path / "results.truncated.json").exists() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To make
truncate_results_dictmore robust and prevent potential runtime exceptions:responsesis actually a dictionary before calling.keys()or.items(). If it's of an unexpected type (e.g., a list or string due to malformed input), calling.keys()would raise anAttributeError.keep_n. Ifkeep_nis negative, Python's slice notationuuids[:keep_n]will slice from the end of the list (e.g.,keep_n = -1would keep all but the last response), which is likely unintended.