Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,12 @@ known-first-party = ["submission_checker"]
"SIM117", # deeply-nested with-patch blocks are intentional for clarity
"D", # docstrings not required in test files
]

[tool.mutmut]
also_copy = ["test_submissions/"]

[dependency-groups]
dev = [
"hypothesis>=6.155.3",
"mutmut>=3.6.0",
]
111 changes: 111 additions & 0 deletions src/endpoints_submission_cli/_version_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 MLCommons
# SPDX-License-Identifier: Apache-2.0
"""Background PyPI version check with a 24-hour file cache.

Usage::

from endpoints_submission_cli._version_check import register_upgrade_notice
register_upgrade_notice() # call once at CLI entry

The check never blocks the CLI: if the cache is stale a daemon thread fetches
the latest version from PyPI while the command runs, then an atexit handler
prints a one-line notice (to stderr) if a newer release exists.
"""

from __future__ import annotations

import atexit
import json
import threading
import time
from importlib.metadata import PackageNotFoundError
from importlib.metadata import version as _pkg_version
from pathlib import Path
from urllib.request import urlopen

__all__ = ["register_upgrade_notice"]

_PACKAGE = "endpoints-submission-cli"
_PYPI_URL = f"https://pypi.org/pypi/{_PACKAGE}/json"
_CACHE_PATH = Path.home() / ".cache" / _PACKAGE / "version_check.json"
_CACHE_TTL = 24 * 3600 # seconds
_TIMEOUT = 3.0 # PyPI request timeout


def _current_version() -> str | None:
try:
return _pkg_version(_PACKAGE)
except PackageNotFoundError:
return None


def _cached_latest() -> str | None:
try:
data = json.loads(_CACHE_PATH.read_text())
if time.time() - data["ts"] < _CACHE_TTL:
return data["latest"]
except Exception:
pass
return None


def _fetch_latest() -> str | None:
try:
with urlopen(_PYPI_URL, timeout=_TIMEOUT) as resp:
latest: str = json.loads(resp.read())["info"]["version"]
try:
_CACHE_PATH.parent.mkdir(parents=True, exist_ok=True)
_CACHE_PATH.write_text(json.dumps({"latest": latest, "ts": time.time()}))
except Exception:
pass
return latest
except Exception:
return None


def _is_newer(latest: str, current: str) -> bool:
def _parts(v: str) -> tuple[int, ...]:
try:
return tuple(int(x) for x in v.split(".")[:3])
except ValueError:
return (0,)

return _parts(latest) > _parts(current)


def register_upgrade_notice() -> None:
"""Register a background version check and an atexit upgrade notice.

Safe to call multiple times — subsequent calls are no-ops.
"""
current = _current_version()
if not current:
return

latest = _cached_latest()
result: list[str | None] = [latest]
thread: threading.Thread | None = None

if latest is None:

def _fetch() -> None:
result[0] = _fetch_latest()

thread = threading.Thread(target=_fetch, daemon=True)
thread.start()

def _on_exit() -> None:
if thread is not None:
thread.join(timeout=_TIMEOUT + 0.5)
fetched = result[0]
if fetched and _is_newer(fetched, current):
# Import here to avoid a top-level Rich import cost on every invocation.
from rich.console import Console

Console(stderr=True).print(
f"\n[dim]A new version of [bold]{_PACKAGE}[/bold] is available: "
f"[yellow]{current}[/yellow] → [green]{fetched}[/green] "
f"([bold]pip install --upgrade {_PACKAGE}[/bold])[/dim]"
)

atexit.register(_on_exit)
2 changes: 2 additions & 0 deletions src/endpoints_submission_cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import click

from ._version_check import register_upgrade_notice
from .commands.runs import runs
from .commands.submissions import submissions

Expand All @@ -22,4 +23,5 @@ def app() -> None:

def main() -> None:
"""Entry point called by the ``endpoints-submission-cli`` script."""
register_upgrade_notice()
app()
13 changes: 10 additions & 3 deletions src/submission_checker/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
load_accuracy_result,
load_point_config,
load_result_summary,
load_run_config,
load_system_description,
)

Expand Down Expand Up @@ -289,6 +290,9 @@ def _check_model(
"#8.1",
)
)
else:
_, run_config_results = load_run_config(config_yaml_path)
results.extend(run_config_results)

summary, load_results = load_result_summary(summary_path)
results.extend(load_results)
Expand All @@ -315,9 +319,12 @@ def _check_model(
json_p = pd / "results.json"
if not json_p.exists():
results.append(
_err("accuracy-file",
f"Missing results.json in point_{config.concurrency}/accuracy/",
json_p, "#15")
_err(
"accuracy-file",
f"Missing results.json in point_{config.concurrency}/accuracy/",
json_p,
"#15",
)
)
elif accuracy_result is None:
accuracy_result, acc_results = load_accuracy_result(json_p)
Expand Down
9 changes: 9 additions & 0 deletions src/submission_checker/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@
from .checker import SubmissionChecker
from .models import Severity, compute_regions

try:
from endpoints_submission_cli._version_check import (
register_upgrade_notice as _register_upgrade_notice,
)
except ImportError:
_register_upgrade_notice = None

console = Console()

_SEVERITY_STYLE: dict[Severity, str] = {
Expand All @@ -22,6 +29,8 @@
@click.version_option(package_name="endpoints-submission-cli")
def main() -> None:
"""MLPerf Endpoints submission checker — validate a submission directory."""
if _register_upgrade_notice is not None:
_register_upgrade_notice()


@main.command()
Expand Down
2 changes: 2 additions & 0 deletions src/submission_checker/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
PercentileStats,
PointConfig,
PointSummary,
RunConfig,
RuntimeSettings,
SystemAvailabilityStatus,
SystemDescription,
Expand All @@ -30,6 +31,7 @@
"PointResult",
"PointSummary",
"RegionBounds",
"RunConfig",
"Regions",
"Report",
"RuntimeSettings",
Expand Down
2 changes: 2 additions & 0 deletions src/submission_checker/models/file/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .accuracy import AccuracyResult
from .point_config import PointConfig, RuntimeSettings
from .point_summary import PercentileStats, PointSummary
from .run_config import RunConfig
from .system import Division, NodeType, SystemAvailabilityStatus, SystemDescription

__all__ = [
Expand All @@ -13,6 +14,7 @@
"PercentileStats",
"PointConfig",
"PointSummary",
"RunConfig",
"RuntimeSettings",
"SystemDescription",
]
59 changes: 59 additions & 0 deletions src/submission_checker/models/file/run_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""Run configuration model — ``results/point_<N>/config.yaml`` (§8.4).

Only the fields needed for compliance checks are parsed; all other fields are
accepted via ``extra="allow"`` so that new endpoint tool versions don't break
older checker versions.
"""

from __future__ import annotations

from pathlib import Path

__all__ = ["RunConfig"]

from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, ValidationInfo, model_validator

from ..results import CheckResult, err, ok


class _WarmupConfig(BaseModel):
model_config = ConfigDict(extra="allow")

enabled: bool = False
salt: bool = False


class _RunSettings(BaseModel):
model_config = ConfigDict(extra="allow")

warmup: _WarmupConfig = Field(default_factory=_WarmupConfig)


class RunConfig(BaseModel):
"""Parsed ``results/point_<N>/config.yaml`` — checked for warmup salt compliance."""

model_config = ConfigDict(extra="allow")
_check_results: list[CheckResult] = PrivateAttr(default_factory=list)

settings: _RunSettings = Field(default_factory=_RunSettings)

@model_validator(mode="after")
def _check_warmup_salted(self, info: ValidationInfo) -> RunConfig:
"""§6.3: warmup prompts must be salted to prevent KV-cache priming of the perf run."""
path: Path | None = (info.context or {}).get("config_path")
warmup = self.settings.warmup
if not warmup.enabled:
return self
if not warmup.salt:
self._check_results.append(
err(
"warmup-salt",
"Warmup is enabled but salt=false; unsalted prompts may prime the KV cache"
" before the performance phase",
path,
"#6.3",
)
)
else:
self._check_results.append(ok("warmup-salt", "Warmup salt enabled", path, "#6.3"))
return self
57 changes: 48 additions & 9 deletions src/submission_checker/models/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"load_accuracy_result",
"load_point_config",
"load_result_summary",
"load_run_config",
"load_system_description",
]

Expand All @@ -17,7 +18,7 @@
import yaml
from pydantic import ValidationError

from .file import AccuracyResult, PointConfig, PointSummary, SystemDescription
from .file import AccuracyResult, PointConfig, PointSummary, RunConfig, SystemDescription
from .results import CheckResult, Severity


Expand Down Expand Up @@ -70,8 +71,14 @@ def load_system_description(
"""
data, load_err = _load_json(path)
if load_err:
return None, [CheckResult(rule="system-description-valid", message=load_err,
severity=Severity.ERROR, path=path)]
return None, [
CheckResult(
rule="system-description-valid",
message=load_err,
severity=Severity.ERROR,
path=path,
)
]
try:
return SystemDescription.model_validate(data), []
except ValidationError as exc:
Expand All @@ -91,8 +98,11 @@ def load_point_config(
"""
data, load_err = _load_yaml(path)
if load_err:
return None, [CheckResult(rule="point-config-valid", message=load_err,
severity=Severity.ERROR, path=path)]
return None, [
CheckResult(
rule="point-config-valid", message=load_err, severity=Severity.ERROR, path=path
)
]
try:
instance = PointConfig.model_validate(data, context=context or {})
return instance, list(instance._check_results)
Expand All @@ -110,14 +120,42 @@ def load_result_summary(path: Path) -> tuple[PointSummary | None, list[CheckResu
"""
data, load_err = _load_json(path)
if load_err:
return None, [CheckResult(rule="result-file-valid", message=load_err,
severity=Severity.ERROR, path=path)]
return None, [
CheckResult(
rule="result-file-valid", message=load_err, severity=Severity.ERROR, path=path
)
]
try:
return PointSummary.model_validate(data), []
except ValidationError as exc:
return None, _validation_errors(exc, "result-file-valid", path)


def load_run_config(
path: Path,
) -> tuple[RunConfig | None, list[CheckResult]]:
"""Load and validate ``results/point_<N>/config.yaml``.

Returns:
A ``(model, check_results)`` pair. On success the model is not None and
check_results contains the validator-produced CheckResult entries.
On failure the model is None and check_results contains one entry per
validation error.
"""
data, load_err = _load_yaml(path)
if load_err:
return None, [
CheckResult(
rule="run-config-valid", message=load_err, severity=Severity.ERROR, path=path
)
]
try:
instance = RunConfig.model_validate(data, context={"config_path": path})
return instance, list(instance._check_results)
except ValidationError as exc:
return None, _validation_errors(exc, "run-config-valid", path)


def load_accuracy_result(
path: Path,
) -> tuple[AccuracyResult | None, list[CheckResult]]:
Expand All @@ -131,8 +169,9 @@ def load_accuracy_result(
"""
data, load_err = _load_json(path)
if load_err:
return None, [CheckResult(rule="accuracy-valid", message=load_err,
severity=Severity.ERROR, path=path)]
return None, [
CheckResult(rule="accuracy-valid", message=load_err, severity=Severity.ERROR, path=path)
]
try:
instance = AccuracyResult.model_validate(data, context={"json_path": path})
return instance, list(instance._check_results)
Expand Down
Loading
Loading