Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ uv run inference-endpoint probe --endpoints http://localhost:8765 --model test-m
uv run inference-endpoint benchmark offline --endpoints URL --model NAME --dataset PATH
uv run inference-endpoint benchmark online --endpoints URL --model NAME --dataset PATH --load-pattern poisson --target-qps 100
uv run inference-endpoint benchmark from-config --config config.yaml

# Shrink a large accuracy results.json (keep a few full responses + a hash of every response)
uv run inference-endpoint truncate-results results.json --keep-n 5
```

### Backward-compatible setup (pip + venv)
Expand Down Expand Up @@ -60,6 +63,9 @@ inference-endpoint probe --endpoints http://localhost:8765 --model test-model
inference-endpoint benchmark offline --endpoints URL --model NAME --dataset PATH
inference-endpoint benchmark online --endpoints URL --model NAME --dataset PATH --load-pattern poisson --target-qps 100
inference-endpoint benchmark from-config --config config.yaml

# Shrink a large accuracy results.json (keep a few full responses + a hash of every response)
inference-endpoint truncate-results results.json --keep-n 5
```

## Architecture
Expand Down Expand Up @@ -168,6 +174,7 @@ src/inference_endpoint/
│ ├── probe.py # ProbeConfig + execute_probe()
│ ├── info.py # execute_info()
│ ├── validate.py # execute_validate()
│ ├── truncate_results.py # TruncateConfig + execute_truncate() — shrink results.json (keep N full + hash rest)
│ └── init.py # execute_init()
├── core/
│ ├── types.py # APIType, Query, QueryResult, StreamChunk, QueryStatus (msgspec Structs)
Expand Down
126 changes: 126 additions & 0 deletions src/inference_endpoint/commands/truncate_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Truncate a benchmark ``results.json``.

Perf+accuracy runs store every query's full response text under
``responses``, which can reach gigabytes. ``truncate-results`` keeps the
first ``keep_n`` responses verbatim and replaces the rest with a per-sample
content hash, so the file stays small while still proving which outputs were
produced (proof of work).
"""

from __future__ import annotations

import hashlib
import json
import logging
from pathlib import Path
from typing import Annotated, Any

import cyclopts
from pydantic import BaseModel, ConfigDict, Field

from inference_endpoint.exceptions import InputValidationError

logger = logging.getLogger(__name__)

_HASH_ALGORITHM = "sha256"


def truncate_results_dict(results: dict[str, Any], keep_n: int = 5) -> dict[str, Any]:
"""Return a truncated copy of a ``results.json`` dict.

Keeps ``config``/``results``/``accuracy_scores``/``errors`` verbatim,
keeps the first ``keep_n`` ``responses`` full, and adds a ``truncation``
block holding a ``sha256`` hash of every response plus counts. A dict
without a non-empty ``responses`` section (e.g. a perf-only run) is
returned unchanged.
"""
responses = results.get("responses")
if not responses:
return dict(results)

uuids = list(responses.keys())
kept = uuids[:keep_n]
Comment on lines +52 to +57

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To make truncate_results_dict more robust and prevent potential runtime exceptions:

  1. Ensure responses is actually a dictionary before calling .keys() or .items(). If it's of an unexpected type (e.g., a list or string due to malformed input), calling .keys() would raise an AttributeError.
  2. Guard against negative values of keep_n. If keep_n is negative, Python's slice notation uuids[:keep_n] will slice from the end of the list (e.g., keep_n = -1 would keep all but the last response), which is likely unintended.
Suggested change
responses = results.get("responses")
if not responses:
return dict(results)
uuids = list(responses.keys())
kept = uuids[:keep_n]
responses = results.get("responses")
if not isinstance(responses, dict) or not responses:
return dict(results)
uuids = list(responses.keys())
kept = uuids[:max(0, keep_n)]


out = dict(results)
out["responses"] = {uuid: responses[uuid] for uuid in kept}
out["truncation"] = {
"responses_truncated": True,
"hash_algorithm": _HASH_ALGORITHM,
"n_responses_total": len(uuids),
"n_responses_kept": len(kept),
"response_hashes": {
uuid: hashlib.sha256(str(text).encode("utf-8")).hexdigest()
for uuid, text in responses.items()
},
}
return out


@cyclopts.Parameter(name="*")
class TruncateConfig(BaseModel):
"""truncate-results command config."""

model_config = ConfigDict(extra="forbid", frozen=True, str_strip_whitespace=True)

results: Path
keep_n: Annotated[
int,
cyclopts.Parameter(
alias="--keep-n", help="Number of full responses to keep verbatim"
),
] = Field(5, ge=0)
output: Annotated[
Path | None,
cyclopts.Parameter(
alias="--output", help="Output path (default: *.truncated.json)"
),
] = None
in_place: Annotated[
bool,
cyclopts.Parameter(alias="--in-place", help="Overwrite the input file"),
] = False


def execute_truncate(config: TruncateConfig) -> None:
"""Read ``config.results``, truncate it, and write the result."""
if not config.results.exists():
raise InputValidationError(f"Results file not found: {config.results}")

data = json.loads(config.results.read_text())
truncated = truncate_results_dict(data, keep_n=config.keep_n)

if config.in_place:
out_path = config.results
elif config.output is not None:
out_path = config.output
else:
out_path = config.results.with_name(config.results.stem + ".truncated.json")

out_path.write_text(json.dumps(truncated, indent=2))
Comment on lines +104 to +114

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Since this command is specifically designed to handle potentially gigabyte-sized results.json files, reading the entire file into memory as a string with read_text() and then parsing/serializing with json.loads()/json.dumps() can lead to high memory usage or Out-Of-Memory (OOM) errors.

Using json.load() and json.dump() with file objects avoids loading the entire file as a single Python string, significantly reducing the memory footprint.

Suggested change
data = json.loads(config.results.read_text())
truncated = truncate_results_dict(data, keep_n=config.keep_n)
if config.in_place:
out_path = config.results
elif config.output is not None:
out_path = config.output
else:
out_path = config.results.with_name(config.results.stem + ".truncated.json")
out_path.write_text(json.dumps(truncated, indent=2))
with open(config.results, "r", encoding="utf-8") as f:
data = json.load(f)
truncated = truncate_results_dict(data, keep_n=config.keep_n)
if config.in_place:
out_path = config.results
elif config.output is not None:
out_path = config.output
else:
out_path = config.results.with_name(config.results.stem + ".truncated.json")
with open(out_path, "w", encoding="utf-8") as f:
json.dump(truncated, f, indent=2)


meta = truncated.get("truncation")
if meta is None:
logger.info("No responses to truncate; wrote passthrough copy to %s", out_path)
else:
logger.info(
"Truncated %d responses to %d full + %d hashes; wrote %s",
meta["n_responses_total"],
meta["n_responses_kept"],
meta["n_responses_total"],
out_path,
)
10 changes: 10 additions & 0 deletions src/inference_endpoint/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@
from inference_endpoint.commands.info import execute_info
from inference_endpoint.commands.init import execute_init
from inference_endpoint.commands.probe import ProbeConfig, execute_probe
from inference_endpoint.commands.truncate_results import (
TruncateConfig,
execute_truncate,
)
from inference_endpoint.commands.validate import execute_validate
from inference_endpoint.config.utils import cli_error_formatter
from inference_endpoint.exceptions import (
Expand Down Expand Up @@ -86,6 +90,12 @@ def probe(*, config: ProbeConfig):
execute_probe(config)


@app.command(name="truncate-results")
def truncate_results(*, config: TruncateConfig):
"""Shrink a results.json: keep a few full responses + hash the rest."""
execute_truncate(config)


@app.command
def info():
"""Show system information."""
Expand Down
113 changes: 113 additions & 0 deletions tests/unit/commands/test_truncate_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests for the truncate-results command."""

from __future__ import annotations

import hashlib
import json

import pytest
from inference_endpoint.commands.truncate_results import (
TruncateConfig,
execute_truncate,
truncate_results_dict,
)


def _results(n: int) -> dict:
return {
"config": {"mode": "both"},
"results": {"total": n, "successful": n, "qps": float(n)},
"accuracy_scores": {"ds": {"score": 0.9}},
"responses": {f"u{i}": f"response {i}" for i in range(n)},
"errors": ["Sample u-err: boom"],
}


@pytest.mark.unit
def test_keeps_first_n_full_and_hashes_every_response():
src = _results(5)
out = truncate_results_dict(src, keep_n=2)

# First N kept verbatim, the rest dropped from `responses`.
assert out["responses"] == {"u0": "response 0", "u1": "response 1"}
# Every original response is provably accounted for via its sha256.
assert out["truncation"]["response_hashes"] == {
uid: hashlib.sha256(text.encode()).hexdigest()
for uid, text in src["responses"].items()
}
assert out["truncation"] == {
"responses_truncated": True,
"hash_algorithm": "sha256",
"n_responses_total": 5,
"n_responses_kept": 2,
"response_hashes": out["truncation"]["response_hashes"],
}


@pytest.mark.unit
def test_preserves_non_response_sections():
src = _results(5)
out = truncate_results_dict(src, keep_n=2)
for key in ("config", "results", "accuracy_scores", "errors"):
assert out[key] == src[key]


@pytest.mark.unit
def test_does_not_mutate_input():
src = _results(5)
truncate_results_dict(src, keep_n=2)
assert len(src["responses"]) == 5


@pytest.mark.unit
def test_keep_n_exceeding_total_keeps_all():
out = truncate_results_dict(_results(3), keep_n=10)
assert len(out["responses"]) == 3
assert out["truncation"]["n_responses_kept"] == 3


@pytest.mark.unit
def test_passthrough_when_no_responses():
perf_only = {"config": {"mode": "offline"}, "results": {"qps": 50.0}}
out = truncate_results_dict(perf_only, keep_n=5)
assert out == perf_only
assert "truncation" not in out


@pytest.mark.unit
def test_execute_writes_truncated_copy_leaving_original(tmp_path):
src = tmp_path / "results.json"
src.write_text(json.dumps(_results(4)))

execute_truncate(TruncateConfig(results=src, keep_n=1))

out = json.loads((tmp_path / "results.truncated.json").read_text())
assert len(out["responses"]) == 1
assert out["truncation"]["n_responses_total"] == 4
assert len(json.loads(src.read_text())["responses"]) == 4 # original intact


@pytest.mark.unit
def test_execute_in_place(tmp_path):
src = tmp_path / "results.json"
src.write_text(json.dumps(_results(4)))

execute_truncate(TruncateConfig(results=src, keep_n=1, in_place=True))

assert len(json.loads(src.read_text())["responses"]) == 1
assert not (tmp_path / "results.truncated.json").exists()
Loading