Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions scripts/verify_route_contract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from __future__ import annotations

import argparse
from pathlib import Path

import yaml


def _resolved_target_model(route_id: str, route_spec: dict) -> str:
target_model = route_spec.get("target_model")
if target_model is None:
return route_id
return str(target_model)


def verify_route_contract(routes_path: Path, eval_cases_path: Path | None = None) -> list[str]:
raw = yaml.safe_load(routes_path.read_text(encoding="utf-8")) or {}
routes = raw.get("routes") or {}
route_ids = set(routes)
route_model = raw.get("route_model") or raw.get("entry_model") or "semantic-router"

errors: list[str] = []

if route_model in route_ids:
errors.append(f"route_model/entry_model '{route_model}' must not be a route_id")

fallback_route_id = raw.get("fallback_route_id") or raw.get("default_route")
if fallback_route_id not in route_ids:
errors.append(
f"fallback_route_id/default_route '{fallback_route_id}' must exist in routes"
)

hard_rules = raw.get("hard_rules") or []
for index, hard_rule in enumerate(hard_rules, start=1):
hard_rule_route_id = hard_rule.get("route_id")
if hard_rule_route_id not in route_ids:
errors.append(
f"hard_rules[{index}] route_id '{hard_rule_route_id}' must exist in routes"
)

target_models = set()
for route_id, route_spec in routes.items():
resolved_target_model = _resolved_target_model(route_id, route_spec)
if not resolved_target_model:
errors.append(f"routes.{route_id} must resolve a non-empty target_model")
continue
if resolved_target_model == route_model:
errors.append(
f"routes.{route_id} resolved target_model '{resolved_target_model}' must not equal entry model"
)
target_models.add(resolved_target_model)

if eval_cases_path is not None:
eval_cases_raw = yaml.safe_load(eval_cases_path.read_text(encoding="utf-8")) or {}
for index, case in enumerate(eval_cases_raw.get("cases") or [], start=1):
expect = case.get("expect")
if expect in target_models and expect not in route_ids:
errors.append(
f"eval case {index} expect '{expect}' matches a target_model; use route_id instead"
)
elif expect not in route_ids:
errors.append(f"eval case {index} expect '{expect}' must be a configured route_id")

return errors


def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--routes", default="config/routes.yaml")
parser.add_argument("--eval-cases", default="config/eval_cases.yaml")
args = parser.parse_args()

errors = verify_route_contract(Path(args.routes), Path(args.eval_cases))
if errors:
print("Route contract verification failed:")
for error in errors:
print(f"- {error}")
raise SystemExit(1)

print(
"Route contract verification passed "
f"for {args.routes} and {args.eval_cases}."
)


if __name__ == "__main__":
main()
84 changes: 84 additions & 0 deletions tests/test_verify_route_contract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from __future__ import annotations

from pathlib import Path

from scripts.verify_route_contract import verify_route_contract


def test_verify_route_contract_passes_for_valid_config(tmp_path: Path):
routes_path = tmp_path / "routes.yaml"
eval_cases_path = tmp_path / "eval_cases.yaml"

routes_path.write_text(
"""
entry_model: semantic-router
fallback_route_id: fast
routes:
fast:
target_model: cheap-router
description: low risk
utterances: [hi]
strong:
target_model: pro-router
description: high risk
utterances: [debug]
hard_rules:
- route_id: strong
keywords: [PR]
""",
encoding="utf-8",
)
eval_cases_path.write_text(
"""
cases:
- text: hello
expect: fast
- text: review this patch
expect: strong
""",
encoding="utf-8",
)

assert verify_route_contract(routes_path, eval_cases_path) == []


def test_verify_route_contract_reports_drift_and_missing_references(tmp_path: Path):
routes_path = tmp_path / "routes.yaml"
eval_cases_path = tmp_path / "eval_cases.yaml"

routes_path.write_text(
"""
route_model: semantic-router
fallback_route_id: missing
routes:
semantic-router:
target_model: semantic-router
description: loop
utterances: [loop]
fast:
target_model: cheap-router
description: low risk
utterances: [hi]
hard_rules:
- route_id: strong
keywords: [PR]
""",
encoding="utf-8",
)
eval_cases_path.write_text(
"""
cases:
- text: hello
expect: cheap-router
- text: unknown
expect: ghost
""",
encoding="utf-8",
)

errors = verify_route_contract(routes_path, eval_cases_path)

assert any("must not be a route_id" in error for error in errors)
assert any("must exist in routes" in error for error in errors)
assert any("matches a target_model" in error for error in errors)
assert any("must be a configured route_id" in error for error in errors)
Loading