From df85833e34145075ab62c21f8aa2da9f176cc9f3 Mon Sep 17 00:00:00 2001 From: kschlt Date: Mon, 23 Mar 2026 00:00:23 +0100 Subject: [PATCH 1/2] feat(enforce): add staged enforcement with hooks, classifier, and validator MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements the STG task: ADR policy checks now run automatically at the right workflow stage rather than only at CI time. A classification model maps policy types to stages (imports/python/patterns → commit, architecture → push, required_structure/config → ci), StagedValidator fetches the appropriate file set via git and runs offline grep-based checks, and HookGenerator writes idempotent sentinel-marked sections into .git/hooks/pre-commit and .git/hooks/pre-push. Approval workflow auto-calls generate() so hooks stay in sync as rules are added. Three new CLI commands (enforce, setup-enforcement, enforce-status) plus --with-enforcement on init. Architecture and config checks are classified but not yet executed — reserved for the ENF task. --- adr_kit/cli.py | 177 +++++++++ adr_kit/enforce/hooks.py | 173 +++++++++ adr_kit/enforce/stages.py | 188 +++++++++ adr_kit/enforce/validator.py | 308 +++++++++++++++ adr_kit/workflows/approval.py | 37 +- tests/unit/test_hook_generator.py | 167 ++++++++ tests/unit/test_staged_enforcement.py | 526 ++++++++++++++++++++++++++ 7 files changed, 1575 insertions(+), 1 deletion(-) create mode 100644 adr_kit/enforce/hooks.py create mode 100644 adr_kit/enforce/stages.py create mode 100644 adr_kit/enforce/validator.py create mode 100644 tests/unit/test_hook_generator.py create mode 100644 tests/unit/test_staged_enforcement.py diff --git a/adr_kit/cli.py b/adr_kit/cli.py index 3a35a77..f72dd00 100644 --- a/adr_kit/cli.py +++ b/adr_kit/cli.py @@ -124,6 +124,9 @@ def init( skip_setup: bool = typer.Option( False, "--skip-setup", help="Skip interactive AI agent setup" ), + with_enforcement: bool = typer.Option( + False, "--with-enforcement", help="Set up git hooks for staged enforcement" + ), ) -> None: """Initialize ADR structure in repository.""" try: @@ -145,6 +148,10 @@ def init( except Exception as e: console.print(f"⚠️ Could not generate initial JSON index: {e}") + # Optional: set up git hooks for enforcement + if with_enforcement: + _setup_enforcement_hooks() + # Interactive setup prompt (skip if --skip-setup flag is provided) if not skip_setup: console.print("\n🤖 [bold]Setup AI Agent Integration?[/bold]") @@ -641,6 +648,33 @@ def info() -> None: # Keep only essential manual commands + + +def _setup_enforcement_hooks() -> None: + """Set up git hooks for staged ADR enforcement (called from init --with-enforcement).""" + from .enforce.hooks import HookGenerator + + gen = HookGenerator() + results = gen.generate() + + actions = { + "pre-commit": results.get("pre-commit", "skipped"), + "pre-push": results.get("pre-push", "skipped"), + } + + for hook_name, action in actions.items(): + if "skipped" in action: + console.print(f" ⚠️ {hook_name}: skipped ({action})") + elif action == "unchanged": + console.print(f" ✅ {hook_name}: already configured") + else: + console.print(f" ✅ {hook_name}: {action}") + + console.print( + " 💡 Run 'adr-kit enforce commit' to test pre-commit checks manually" + ) + + def _setup_cursor_impl() -> None: """Implementation for Cursor setup that can be called from commands or init.""" import json @@ -698,6 +732,71 @@ def _setup_cursor_impl() -> None: ) +@app.command() +def setup_enforcement( + project_root: Path = typer.Option( + Path("."), "--root", help="Project root (git repository)" + ), +) -> None: + """Set up git hooks for staged ADR enforcement. + + Writes ADR-Kit managed sections into .git/hooks/pre-commit and + .git/hooks/pre-push. Safe on existing hooks — appends only. + Re-running is idempotent. + """ + from .enforce.hooks import HookGenerator + + try: + gen = HookGenerator() + results = gen.generate(project_root=project_root) + + console.print("🔧 Setting up enforcement hooks...") + for hook_name, action in results.items(): + if "skipped" in action: + console.print(f" ⚠️ {hook_name}: {action}") + elif action == "unchanged": + console.print(f" ✅ {hook_name}: already configured") + else: + console.print(f" ✅ {hook_name}: {action}") + + console.print( + "\n💡 Use 'adr-kit enforce commit' to run pre-commit checks manually" + ) + console.print("💡 Use 'adr-kit enforce push' to run pre-push checks manually") + except Exception as e: + console.print(f"❌ Failed to set up enforcement hooks: {e}") + raise typer.Exit(code=1) from e + + +@app.command() +def enforce_status( + project_root: Path = typer.Option( + Path("."), "--root", help="Project root (git repository)" + ), +) -> None: + """Show status of ADR enforcement hooks.""" + from .enforce.hooks import HookGenerator + + try: + gen = HookGenerator() + status = gen.status(project_root=project_root) + + console.print("🔍 ADR Enforcement Hook Status") + for hook_name, active in status.items(): + icon = "✅" if active else "❌" + console.print( + f" {icon} {hook_name}: {'active' if active else 'not configured'}" + ) + + if not any(status.values()): + console.print( + "\n💡 Run 'adr-kit setup-enforcement' to enable automatic enforcement" + ) + except Exception as e: + console.print(f"❌ Failed to get enforcement status: {e}") + raise typer.Exit(code=1) from e + + @app.command() def setup_cursor() -> None: """Set up ADR Kit MCP server for Cursor IDE.""" @@ -1144,6 +1243,84 @@ def legacy() -> None: console.print() +@app.command() +def enforce( + level: str = typer.Argument( + ..., + help="Enforcement level: commit (staged files), push (changed files), ci (all files)", + ), + adr_dir: Path = typer.Option(Path("docs/adr"), "--adr-dir", help="ADR directory"), + project_root: Path = typer.Option( + Path("."), "--root", help="Project root directory" + ), +) -> None: + """Run ADR policy enforcement checks at the given workflow stage. + + Reads accepted ADRs, classifies their policies by stage, and checks the + appropriate files for violations. + + \\b + Levels: + commit Check staged files only (<5s). Run as pre-commit hook. + push Check changed files (<15s). Run as pre-push hook. + ci Check entire codebase (<2min). Run in CI pipelines. + + Exit codes: 0 = pass, 1 = violations found, 2 = warnings only, 3 = error + """ + from .enforce.stages import EnforcementLevel + from .enforce.validator import StagedValidator + + try: + try: + enforcement_level = EnforcementLevel(level.lower()) + except ValueError: + stderr_console.print( + f"❌ Unknown level '{level}'. Valid levels: commit, push, ci" + ) + raise typer.Exit(code=3) from None + + validator = StagedValidator(adr_dir=adr_dir) + result = validator.validate(enforcement_level, project_root=project_root) + + level_labels = { + EnforcementLevel.COMMIT: "pre-commit (staged files)", + EnforcementLevel.PUSH: "pre-push (changed files)", + EnforcementLevel.CI: "ci (full codebase)", + } + console.print(f"🔍 ADR enforcement — {level_labels[enforcement_level]}") + console.print(f" {result.checks_run} checks · {result.files_checked} files") + + if not result.violations: + console.print("✅ All checks passed") + raise typer.Exit(code=0) + + # Print violations grouped by ADR + for violation in result.violations: + icon = "❌" if violation.severity == "error" else "⚠️ " + location = ( + f"{violation.file}:{violation.line}" + if violation.line + else violation.file + ) + console.print(f"{icon} {location}") + console.print(f" {violation.message}") + + console.print( + f"\n{'❌' if result.error_count else '⚠️ '} " + f"{result.error_count} error(s), {result.warning_count} warning(s)" + ) + + if result.passed: + raise typer.Exit(code=2) # warnings only + raise typer.Exit(code=1) # errors found + + except typer.Exit: + raise + except Exception as e: + stderr_console.print(f"❌ Enforcement check failed: {e}") + raise typer.Exit(code=3) from e + + if __name__ == "__main__": import sys diff --git a/adr_kit/enforce/hooks.py b/adr_kit/enforce/hooks.py new file mode 100644 index 0000000..adf1bce --- /dev/null +++ b/adr_kit/enforce/hooks.py @@ -0,0 +1,173 @@ +"""Git hook generator for staged ADR enforcement. + +Writes a managed section into .git/hooks/pre-commit and .git/hooks/pre-push +so that ADR policy checks run automatically at the right workflow stage. + +Design: +- Non-interfering: appends a managed section to existing hooks, never overwrites. +- Idempotent: re-running updates the managed section in-place. +- Clearly marked: ADR-KIT markers make ownership obvious. +- First-run bootstraps: creates hook file if it doesn't exist. +""" + +import stat +from pathlib import Path + +# Sentinel markers — must be unique and stable across versions +MANAGED_START = "# >>> ADR-KIT MANAGED - DO NOT EDIT >>>" +MANAGED_END = "# <<< ADR-KIT MANAGED <<<" + +_HOOK_HEADER = "#!/bin/sh" + +# Per-hook managed content +_COMMIT_SECTION = f"""\ +{MANAGED_START} +adr-kit enforce commit +{MANAGED_END}""" + +_PUSH_SECTION = f"""\ +{MANAGED_START} +adr-kit enforce push +{MANAGED_END}""" + + +def _make_executable(path: Path) -> None: + """Ensure the hook file has executable permission.""" + current = path.stat().st_mode + path.chmod(current | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH) + + +def _apply_managed_section(hook_path: Path, managed_content: str) -> str: + """Insert or replace the ADR-Kit managed section in a hook file. + + If the hook doesn't exist, creates it with a shebang + managed section. + Returns a string describing what changed: "created" | "updated" | "unchanged". + """ + if not hook_path.exists(): + hook_path.write_text(f"{_HOOK_HEADER}\n\n{managed_content}\n") + _make_executable(hook_path) + return "created" + + existing = hook_path.read_text() + + # Replace existing managed section + if MANAGED_START in existing and MANAGED_END in existing: + start_idx = existing.index(MANAGED_START) + end_idx = existing.index(MANAGED_END) + len(MANAGED_END) + new_section = ( + existing[:start_idx].rstrip("\n") + + "\n\n" + + managed_content + + "\n" + + existing[end_idx:].lstrip("\n") + ) + if new_section == existing: + return "unchanged" + hook_path.write_text(new_section) + _make_executable(hook_path) + return "updated" + + # No managed section yet — append + separator = "\n\n" if existing.rstrip() else "" + hook_path.write_text(existing.rstrip() + separator + managed_content + "\n") + _make_executable(hook_path) + return "appended" + + +class HookGenerator: + """Generates and updates git hooks for staged ADR enforcement. + + Writes ADR-Kit managed sections into .git/hooks/pre-commit and + .git/hooks/pre-push. Safe to call repeatedly — idempotent. + """ + + def generate(self, project_root: Path | None = None) -> dict[str, str]: + """Write managed sections into pre-commit and pre-push hooks. + + Args: + project_root: Root of the git repository. Defaults to cwd. + + Returns: + Dict mapping hook name → action taken ("created"|"updated"|"appended"|"unchanged"|"skipped"). + """ + project_root = project_root or Path.cwd() + hooks_dir = project_root / ".git" / "hooks" + + if not hooks_dir.exists(): + # Not a git repo or hooks dir missing — skip silently + return { + "pre-commit": "skipped (no .git/hooks directory)", + "pre-push": "skipped (no .git/hooks directory)", + } + + results: dict[str, str] = {} + + results["pre-commit"] = _apply_managed_section( + hooks_dir / "pre-commit", _COMMIT_SECTION + ) + results["pre-push"] = _apply_managed_section( + hooks_dir / "pre-push", _PUSH_SECTION + ) + + return results + + def remove(self, project_root: Path | None = None) -> dict[str, str]: + """Remove ADR-Kit managed sections from git hooks. + + Useful when uninstalling or disabling enforcement. + + Returns: + Dict mapping hook name → action taken ("removed"|"not_found"|"skipped"). + """ + project_root = project_root or Path.cwd() + hooks_dir = project_root / ".git" / "hooks" + + if not hooks_dir.exists(): + return { + "pre-commit": "skipped (no .git/hooks directory)", + "pre-push": "skipped (no .git/hooks directory)", + } + + results: dict[str, str] = {} + for hook_name in ("pre-commit", "pre-push"): + hook_path = hooks_dir / hook_name + if not hook_path.exists(): + results[hook_name] = "not_found" + continue + + content = hook_path.read_text() + if MANAGED_START not in content: + results[hook_name] = "not_found" + continue + + start_idx = content.index(MANAGED_START) + end_idx = content.index(MANAGED_END) + len(MANAGED_END) + # Strip surrounding blank lines added when appending + cleaned = content[:start_idx].rstrip("\n") + content[end_idx:].lstrip("\n") + if not cleaned.strip(): + # Hook only contained our section — remove the file + hook_path.unlink() + else: + hook_path.write_text(cleaned) + results[hook_name] = "removed" + + return results + + def status(self, project_root: Path | None = None) -> dict[str, bool]: + """Check whether ADR-Kit managed sections are present in hooks. + + Returns: + Dict mapping hook name → True if managed section is present. + """ + project_root = project_root or Path.cwd() + hooks_dir = project_root / ".git" / "hooks" + + result: dict[str, bool] = {} + for hook_name in ("pre-commit", "pre-push"): + hook_path = hooks_dir / hook_name + if not hook_path.exists(): + result[hook_name] = False + continue + result[hook_name] = MANAGED_START in hook_path.read_text() + + return result diff --git a/adr_kit/enforce/stages.py b/adr_kit/enforce/stages.py new file mode 100644 index 0000000..ebdf797 --- /dev/null +++ b/adr_kit/enforce/stages.py @@ -0,0 +1,188 @@ +"""Enforcement stage classification model. + +Maps ADR policy types to workflow stages (commit/push/ci) based on: +- Speed: how fast the check runs +- Scope: what files and context it needs + +Stage semantics: +- commit (<5s): staged files only, fast grep — first checkpoint +- push (<15s): changed files, broader context +- ci (<2min): full codebase, all checks — safety net + +A check assigned to level X also runs at all higher levels +(commit checks run at push and ci too). +""" + +from dataclasses import dataclass, field +from enum import Enum + + +class EnforcementLevel(str, Enum): + """Workflow stage at which enforcement checks run.""" + + COMMIT = "commit" + PUSH = "push" + CI = "ci" + + +# Ordered levels for inclusion logic (lower index = earlier stage) +_LEVEL_ORDER: dict[EnforcementLevel, int] = { + EnforcementLevel.COMMIT: 0, + EnforcementLevel.PUSH: 1, + EnforcementLevel.CI: 2, +} + +# Policy type → minimum enforcement level +# A policy type at level X also runs at all higher levels. +POLICY_LEVEL_MAP: dict[str, EnforcementLevel] = { + "imports": EnforcementLevel.COMMIT, # fast grep — always first + "python": EnforcementLevel.COMMIT, # fast grep — always first + "patterns": EnforcementLevel.COMMIT, # fast regex — always first + "architecture": EnforcementLevel.PUSH, # needs broader file context + "required_structure": EnforcementLevel.CI, # full codebase check + "config_enforcement": EnforcementLevel.CI, # config deep check +} + + +@dataclass +class StagedCheck: + """A single enforceable check classified to an enforcement level.""" + + adr_id: str + adr_title: str + check_type: str # "import" | "python_import" | "pattern" | "architecture" | "required_structure" | "config" + level: EnforcementLevel + pattern: str # what to grep/check for + message: str # human-readable violation message + file_glob: str | None = None # file extension filter + severity: str = "error" + metadata: dict = field(default_factory=dict) # extra context for complex checks + + +def classify_adr_checks(adrs: list) -> list[StagedCheck]: + """Extract and classify all enforceable checks from a list of accepted ADRs. + + Returns one StagedCheck per enforceable rule across all policy types. + Architecture and config checks are classified but not yet executed + (reserved for ENF task — reported here for transparency). + """ + checks: list[StagedCheck] = [] + + for adr in adrs: + if not adr.policy: + continue + + policy = adr.policy + adr_id = adr.id + adr_title = adr.title + + # imports: disallowed JS/TS imports — COMMIT level + if policy.imports and policy.imports.disallow: + for lib in policy.imports.disallow: + checks.append( + StagedCheck( + adr_id=adr_id, + adr_title=adr_title, + check_type="import", + level=EnforcementLevel.COMMIT, + pattern=lib, + message=f"Import of '{lib}' is disallowed — see {adr_id}: {adr_title}", + ) + ) + + # python: disallowed Python imports — COMMIT level + if policy.python and policy.python.disallow_imports: + for lib in policy.python.disallow_imports: + checks.append( + StagedCheck( + adr_id=adr_id, + adr_title=adr_title, + check_type="python_import", + level=EnforcementLevel.COMMIT, + pattern=lib, + message=f"Python import of '{lib}' is disallowed — see {adr_id}: {adr_title}", + file_glob="*.py", + ) + ) + + # patterns: regex code pattern rules — COMMIT level (fast grep) + if policy.patterns and policy.patterns.patterns: + for name, rule in policy.patterns.patterns.items(): + if isinstance(rule.rule, str): # only handle regex patterns + checks.append( + StagedCheck( + adr_id=adr_id, + adr_title=adr_title, + check_type="pattern", + level=EnforcementLevel.COMMIT, + pattern=rule.rule, + message=f"Pattern '{name}': {rule.description} — see {adr_id}", + file_glob=f"*.{rule.language}" if rule.language else None, + severity=rule.severity, + ) + ) + + # architecture: layer boundaries — PUSH level + if policy.architecture and policy.architecture.layer_boundaries: + for boundary in policy.architecture.layer_boundaries: + checks.append( + StagedCheck( + adr_id=adr_id, + adr_title=adr_title, + check_type="architecture", + level=EnforcementLevel.PUSH, + pattern=boundary.rule, + message=boundary.message + or f"Architecture violation: {boundary.rule} — see {adr_id}", + severity="error" if boundary.action == "block" else "warning", + metadata={"rule": boundary.rule, "check": boundary.check}, + ) + ) + + # required_structure: file/dir existence — CI level + if policy.architecture and policy.architecture.required_structure: + for required in policy.architecture.required_structure: + checks.append( + StagedCheck( + adr_id=adr_id, + adr_title=adr_title, + check_type="required_structure", + level=EnforcementLevel.CI, + pattern=required.path, + message=required.description + or f"Required path missing: {required.path} — see {adr_id}", + ) + ) + + # config_enforcement — CI level + if policy.config_enforcement: + checks.append( + StagedCheck( + adr_id=adr_id, + adr_title=adr_title, + check_type="config", + level=EnforcementLevel.CI, + pattern="config_check", + message=f"Configuration requirements from {adr_id}: {adr_title}", + metadata={ + "policy": policy.config_enforcement.model_dump( + exclude_none=True + ) + }, + ) + ) + + return checks + + +def checks_for_level( + checks: list[StagedCheck], level: EnforcementLevel +) -> list[StagedCheck]: + """Return checks that should run at the given level (inclusive of lower levels). + + commit → runs commit checks only + push → runs commit + push checks + ci → runs all checks + """ + target_order = _LEVEL_ORDER[level] + return [c for c in checks if _LEVEL_ORDER[c.level] <= target_order] diff --git a/adr_kit/enforce/validator.py b/adr_kit/enforce/validator.py new file mode 100644 index 0000000..f8f6f6c --- /dev/null +++ b/adr_kit/enforce/validator.py @@ -0,0 +1,308 @@ +"""Staged validation runner. + +Executes ADR policy checks against files based on enforcement level: +- commit: staged files only (git diff --cached) — fast grep, <5s +- push: changed files (git diff @{upstream}..HEAD) — broader, <15s +- ci: all project files — comprehensive safety net, <2min + +Architecture and config checks are classified but not yet executed +(reserved for ENF task). They appear in the check count but produce +no violations today — this is intentional and documented. +""" + +import re +import subprocess +from dataclasses import dataclass, field +from pathlib import Path + +from ..core.model import ADRStatus +from ..core.parse import ParseError, find_adr_files, parse_adr_file +from .stages import EnforcementLevel, StagedCheck, checks_for_level, classify_adr_checks + +# Source file extensions scanned during CI full-codebase pass +_SOURCE_EXTENSIONS = {".py", ".js", ".ts", ".jsx", ".tsx", ".java", ".go", ".rs", ".kt"} + +# Directories never scanned — generated/installed content +_EXCLUDE_DIRS = { + ".git", + ".venv", + "venv", + "node_modules", + "__pycache__", + ".pytest_cache", + ".mypy_cache", + ".ruff_cache", + "dist", + "build", + ".adr-kit", + ".project-index", +} + + +@dataclass +class Violation: + """A single policy violation found during validation.""" + + file: str + adr_id: str + message: str + level: EnforcementLevel + severity: str = "error" + line: int | None = None + + +@dataclass +class ValidationResult: + """Result of a staged validation run.""" + + level: EnforcementLevel + files_checked: int + checks_run: int + violations: list[Violation] = field(default_factory=list) + errors: list[str] = field(default_factory=list) + + @property + def passed(self) -> bool: + """True when no error-severity violations exist.""" + return not any(v.severity == "error" for v in self.violations) + + @property + def has_warnings(self) -> bool: + return any(v.severity == "warning" for v in self.violations) + + @property + def error_count(self) -> int: + return sum(1 for v in self.violations if v.severity == "error") + + @property + def warning_count(self) -> int: + return sum(1 for v in self.violations if v.severity == "warning") + + +class StagedValidator: + """Runs ADR policy checks classified by enforcement level.""" + + def __init__(self, adr_dir: str | Path = "docs/adr"): + self.adr_dir = Path(adr_dir) + + def validate( + self, + level: EnforcementLevel, + project_root: Path | None = None, + ) -> ValidationResult: + """Run all checks active at the given level. + + Args: + level: Enforcement level to run (commit/push/ci). + project_root: Root directory for file resolution. Defaults to cwd. + + Returns: + ValidationResult with all violations and metadata. + """ + project_root = project_root or Path.cwd() + + adrs = self._load_accepted_adrs() + all_checks = classify_adr_checks(adrs) + active_checks = checks_for_level(all_checks, level) + files = self._get_files(level, project_root) + + result = ValidationResult( + level=level, + files_checked=len(files), + checks_run=len(active_checks), + ) + + for check in active_checks: + violations = self._run_check(check, files, project_root) + result.violations.extend(violations) + + return result + + # --- ADR loading --- + + def _load_accepted_adrs(self) -> list: + adrs = [] + if not self.adr_dir.exists(): + return adrs + for file_path in find_adr_files(self.adr_dir): + try: + adr = parse_adr_file(file_path, strict=False) + if adr and adr.front_matter.status == ADRStatus.ACCEPTED: + adrs.append(adr) + except ParseError: + continue + return adrs + + # --- File collection --- + + def _get_files(self, level: EnforcementLevel, project_root: Path) -> list[Path]: + if level == EnforcementLevel.COMMIT: + return self._get_staged_files(project_root) + elif level == EnforcementLevel.PUSH: + files = self._get_changed_files(project_root) + # Fall back to staged if no upstream info available + return files or self._get_staged_files(project_root) + else: # CI + return self._get_all_files(project_root) + + def _get_staged_files(self, project_root: Path) -> list[Path]: + try: + result = subprocess.run( + ["git", "diff", "--cached", "--name-only", "--diff-filter=ACM"], + capture_output=True, + text=True, + cwd=project_root, + ) + if result.returncode != 0: + return [] + files = [project_root / f for f in result.stdout.strip().splitlines() if f] + return [f for f in files if f.exists()] + except Exception: + return [] + + def _get_changed_files(self, project_root: Path) -> list[Path]: + """Files changed since last push. Falls back gracefully if no upstream.""" + for cmd in [ + ["git", "diff", "--name-only", "@{upstream}..HEAD"], + ["git", "diff", "--name-only", "HEAD~1..HEAD"], + ]: + try: + result = subprocess.run( + cmd, capture_output=True, text=True, cwd=project_root + ) + if result.returncode == 0 and result.stdout.strip(): + files = [ + project_root / f + for f in result.stdout.strip().splitlines() + if f + ] + return [f for f in files if f.exists()] + except Exception: + continue + return [] + + def _get_all_files(self, project_root: Path) -> list[Path]: + files = [] + for f in project_root.rglob("*"): + if not f.is_file(): + continue + if f.suffix not in _SOURCE_EXTENSIONS: + continue + # Skip excluded directories + if any(part in _EXCLUDE_DIRS for part in f.parts): + continue + files.append(f) + return files + + # --- Check dispatch --- + + def _run_check( + self, check: StagedCheck, files: list[Path], project_root: Path + ) -> list[Violation]: + if check.check_type in ("import", "python_import"): + return self._run_import_check(check, files, project_root) + elif check.check_type == "pattern": + return self._run_pattern_check(check, files, project_root) + elif check.check_type == "required_structure": + return self._run_structure_check(check, project_root) + # architecture and config: classified but not yet executed (ENF task) + return [] + + def _filter_files_for_check( + self, files: list[Path], check: StagedCheck + ) -> list[Path]: + """Filter file list to those relevant for the check type.""" + if check.check_type == "python_import": + return [f for f in files if f.suffix == ".py"] + if check.check_type == "import": + return [f for f in files if f.suffix in {".js", ".ts", ".jsx", ".tsx"}] + if check.file_glob and check.file_glob.startswith("*."): + ext = check.file_glob[1:] # "*.py" → ".py" + return [f for f in files if f.name.endswith(ext)] + return files + + def _run_import_check( + self, check: StagedCheck, files: list[Path], project_root: Path + ) -> list[Violation]: + target_files = self._filter_files_for_check(files, check) + violations = [] + escaped = re.escape(check.pattern) + + # Matches: import 'lib', from 'lib', require('lib') — with or without path prefix + import_patterns = [ + re.compile(rf"""(import|from)\s+['"]([^'"]*?/)?{escaped}['"]"""), + re.compile(rf"""require\s*\(\s*['"]([^'"]*?/)?{escaped}['"]\s*\)"""), + re.compile(rf"""(import|from)\s+{escaped}(\s|$|;)"""), # Python style + ] + + for file_path in target_files: + try: + content = file_path.read_text(encoding="utf-8", errors="ignore") + for line_num, line in enumerate(content.splitlines(), 1): + for pattern in import_patterns: + if pattern.search(line): + violations.append( + Violation( + file=str(file_path.relative_to(project_root)), + adr_id=check.adr_id, + message=check.message, + level=check.level, + severity=check.severity, + line=line_num, + ) + ) + break # one violation per line + except Exception: + continue + + return violations + + def _run_pattern_check( + self, check: StagedCheck, files: list[Path], project_root: Path + ) -> list[Violation]: + target_files = self._filter_files_for_check(files, check) + violations = [] + + try: + compiled = re.compile(check.pattern) + except re.error: + return [] # invalid regex in ADR policy — skip silently + + for file_path in target_files: + try: + content = file_path.read_text(encoding="utf-8", errors="ignore") + for line_num, line in enumerate(content.splitlines(), 1): + if compiled.search(line): + violations.append( + Violation( + file=str(file_path.relative_to(project_root)), + adr_id=check.adr_id, + message=check.message, + level=check.level, + severity=check.severity, + line=line_num, + ) + ) + except Exception: + continue + + return violations + + def _run_structure_check( + self, check: StagedCheck, project_root: Path + ) -> list[Violation]: + """Check that a required path (glob pattern) exists in the project.""" + import glob + + matches = list(glob.glob(check.pattern, root_dir=str(project_root))) + if not matches: + return [ + Violation( + file=check.pattern, + adr_id=check.adr_id, + message=check.message, + level=check.level, + severity=check.severity, + ) + ] + return [] diff --git a/adr_kit/workflows/approval.py b/adr_kit/workflows/approval.py index 02d9e78..9d6f5b0 100644 --- a/adr_kit/workflows/approval.py +++ b/adr_kit/workflows/approval.py @@ -341,7 +341,7 @@ def _apply_guardrails(self, adr: ADR) -> dict[str, Any]: } def _generate_enforcement_rules(self, adr: ADR) -> dict[str, Any]: - """Generate enforcement rules (ESLint, Ruff, etc.) from ADR policies.""" + """Generate enforcement rules (ESLint, Ruff, git hooks) from ADR policies.""" results = {} try: @@ -355,6 +355,10 @@ def _generate_enforcement_rules(self, adr: ADR) -> dict[str, Any]: ruff_result = self._generate_ruff_rules(adr) results["ruff"] = ruff_result + # Always update git hooks so staged enforcement reflects new rules + hooks_result = self._update_git_hooks() + results["hooks"] = hooks_result + return { "success": True, "rule_generators": list(results.keys()), @@ -369,6 +373,37 @@ def _generate_enforcement_rules(self, adr: ADR) -> dict[str, Any]: "message": "Failed to generate enforcement rules", } + def _update_git_hooks(self) -> dict[str, Any]: + """Update git hooks to run staged enforcement checks.""" + try: + from ..enforce.hooks import HookGenerator + + generator = HookGenerator() + hook_results = generator.generate(project_root=Path.cwd()) + + updated = [ + name + for name, action in hook_results.items() + if action not in ("unchanged", "skipped") + ] + skipped = [ + name for name, action in hook_results.items() if "skipped" in action + ] + + return { + "success": True, + "hooks_updated": updated, + "hooks_skipped": skipped, + "details": hook_results, + "message": f"Git hooks updated: {', '.join(updated) if updated else 'all unchanged'}", + } + except Exception as e: + return { + "success": False, + "error": str(e), + "message": "Failed to update git hooks (non-blocking)", + } + def _has_javascript_policies(self, adr: ADR) -> bool: """Check if ADR has JavaScript/TypeScript related policies.""" if not adr.policy: diff --git a/tests/unit/test_hook_generator.py b/tests/unit/test_hook_generator.py new file mode 100644 index 0000000..95d8158 --- /dev/null +++ b/tests/unit/test_hook_generator.py @@ -0,0 +1,167 @@ +"""Unit tests for the git hook generator.""" + +import stat +import tempfile +from pathlib import Path + +import pytest + +from adr_kit.enforce.hooks import ( + MANAGED_END, + MANAGED_START, + HookGenerator, + _apply_managed_section, +) + +# --------------------------------------------------------------------------- +# _apply_managed_section +# --------------------------------------------------------------------------- + + +class TestApplyManagedSection: + CONTENT = f"{MANAGED_START}\nadr-kit enforce commit\n{MANAGED_END}" + + def test_creates_hook_when_file_missing(self, tmp_path): + hook = tmp_path / "pre-commit" + action = _apply_managed_section(hook, self.CONTENT) + assert action == "created" + assert hook.exists() + assert MANAGED_START in hook.read_text() + assert "adr-kit enforce commit" in hook.read_text() + + def test_created_hook_has_shebang(self, tmp_path): + hook = tmp_path / "pre-commit" + _apply_managed_section(hook, self.CONTENT) + assert hook.read_text().startswith("#!/bin/sh") + + def test_created_hook_is_executable(self, tmp_path): + hook = tmp_path / "pre-commit" + _apply_managed_section(hook, self.CONTENT) + mode = hook.stat().st_mode + assert mode & stat.S_IXUSR + + def test_appends_to_existing_hook(self, tmp_path): + hook = tmp_path / "pre-commit" + hook.write_text("#!/bin/sh\nnpm test\n") + action = _apply_managed_section(hook, self.CONTENT) + assert action == "appended" + text = hook.read_text() + assert "npm test" in text + assert MANAGED_START in text + + def test_updates_existing_managed_section(self, tmp_path): + hook = tmp_path / "pre-commit" + old_content = f"#!/bin/sh\n{MANAGED_START}\nold command\n{MANAGED_END}\n" + hook.write_text(old_content) + new_section = f"{MANAGED_START}\nadr-kit enforce commit\n{MANAGED_END}" + action = _apply_managed_section(hook, new_section) + assert action == "updated" + text = hook.read_text() + assert "old command" not in text + assert "adr-kit enforce commit" in text + + def test_unchanged_when_content_identical(self, tmp_path): + hook = tmp_path / "pre-commit" + content = f"#!/bin/sh\n\n{self.CONTENT}\n" + hook.write_text(content) + action = _apply_managed_section(hook, self.CONTENT) + assert action == "unchanged" + + def test_user_content_preserved_on_update(self, tmp_path): + hook = tmp_path / "pre-commit" + hook.write_text( + f"#!/bin/sh\nnpm test\n\n{MANAGED_START}\nold\n{MANAGED_END}\n\necho done\n" + ) + _apply_managed_section(hook, self.CONTENT) + text = hook.read_text() + assert "npm test" in text + assert "echo done" in text + assert "adr-kit enforce commit" in text + + def test_only_one_managed_section_after_multiple_calls(self, tmp_path): + hook = tmp_path / "pre-commit" + _apply_managed_section(hook, self.CONTENT) + _apply_managed_section(hook, self.CONTENT) + text = hook.read_text() + assert text.count(MANAGED_START) == 1 + + +# --------------------------------------------------------------------------- +# HookGenerator +# --------------------------------------------------------------------------- + + +class TestHookGenerator: + def _make_git_repo(self, tmp_path: Path) -> Path: + """Create a minimal git repo structure.""" + hooks_dir = tmp_path / ".git" / "hooks" + hooks_dir.mkdir(parents=True) + return tmp_path + + def test_generate_creates_both_hooks(self, tmp_path): + root = self._make_git_repo(tmp_path) + gen = HookGenerator() + results = gen.generate(project_root=root) + assert "pre-commit" in results + assert "pre-push" in results + assert (root / ".git" / "hooks" / "pre-commit").exists() + assert (root / ".git" / "hooks" / "pre-push").exists() + + def test_pre_commit_calls_enforce_commit(self, tmp_path): + root = self._make_git_repo(tmp_path) + HookGenerator().generate(project_root=root) + content = (root / ".git" / "hooks" / "pre-commit").read_text() + assert "adr-kit enforce commit" in content + + def test_pre_push_calls_enforce_push(self, tmp_path): + root = self._make_git_repo(tmp_path) + HookGenerator().generate(project_root=root) + content = (root / ".git" / "hooks" / "pre-push").read_text() + assert "adr-kit enforce push" in content + + def test_generate_is_idempotent(self, tmp_path): + root = self._make_git_repo(tmp_path) + gen = HookGenerator() + gen.generate(project_root=root) + gen.generate(project_root=root) + content = (root / ".git" / "hooks" / "pre-commit").read_text() + assert content.count(MANAGED_START) == 1 + + def test_generate_skips_when_no_git_dir(self, tmp_path): + gen = HookGenerator() + results = gen.generate(project_root=tmp_path) + assert all("skipped" in v for v in results.values()) + + def test_status_false_before_generate(self, tmp_path): + root = self._make_git_repo(tmp_path) + status = HookGenerator().status(project_root=root) + assert status["pre-commit"] is False + assert status["pre-push"] is False + + def test_status_true_after_generate(self, tmp_path): + root = self._make_git_repo(tmp_path) + HookGenerator().generate(project_root=root) + status = HookGenerator().status(project_root=root) + assert status["pre-commit"] is True + assert status["pre-push"] is True + + def test_remove_cleans_managed_section(self, tmp_path): + root = self._make_git_repo(tmp_path) + HookGenerator().generate(project_root=root) + HookGenerator().remove(project_root=root) + hook = root / ".git" / "hooks" / "pre-commit" + assert not hook.exists() or MANAGED_START not in hook.read_text() + + def test_remove_preserves_user_content(self, tmp_path): + root = self._make_git_repo(tmp_path) + hook = root / ".git" / "hooks" / "pre-commit" + hook.write_text("#!/bin/sh\nnpm test\n") + HookGenerator().generate(project_root=root) + HookGenerator().remove(project_root=root) + assert "npm test" in hook.read_text() + + def test_remove_returns_not_found_when_no_section(self, tmp_path): + root = self._make_git_repo(tmp_path) + results = HookGenerator().remove(project_root=root) + assert results["pre-commit"] == "not_found" + assert results["pre-push"] == "not_found" diff --git a/tests/unit/test_staged_enforcement.py b/tests/unit/test_staged_enforcement.py new file mode 100644 index 0000000..bb975d8 --- /dev/null +++ b/tests/unit/test_staged_enforcement.py @@ -0,0 +1,526 @@ +"""Unit tests for staged enforcement: stage classification model and validator.""" + +import re +import tempfile +from pathlib import Path + +import pytest + +from adr_kit.enforce.stages import ( + EnforcementLevel, + StagedCheck, + checks_for_level, + classify_adr_checks, +) +from adr_kit.enforce.validator import StagedValidator, ValidationResult + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_adr( + adr_id: str = "ADR-0001", + title: str = "Test ADR", + imports_disallow: list[str] | None = None, + python_disallow: list[str] | None = None, + patterns: dict | None = None, + architecture_boundaries: list[dict] | None = None, + required_structure: list[dict] | None = None, + config_enforcement: dict | None = None, +): + """Create a minimal ADR-like object with given policies (no file I/O).""" + from unittest.mock import MagicMock + + from adr_kit.core.model import ( + ArchitecturePolicy, + ConfigEnforcementPolicy, + ImportPolicy, + LayerBoundaryRule, + PatternPolicy, + PatternRule, + PolicyModel, + PythonPolicy, + RequiredStructure, + ) + + policy_kwargs: dict = {} + + if imports_disallow is not None: + policy_kwargs["imports"] = ImportPolicy(disallow=imports_disallow) + + if python_disallow is not None: + policy_kwargs["python"] = PythonPolicy(disallow_imports=python_disallow) + + if patterns is not None: + rules = { + name: PatternRule( + description=data["description"], + rule=data["rule"], + language=data.get("language"), + severity=data.get("severity", "error"), + ) + for name, data in patterns.items() + } + policy_kwargs["patterns"] = PatternPolicy(patterns=rules) + + if architecture_boundaries is not None: + boundary_rules = [ + LayerBoundaryRule( + rule=b["rule"], + check=b.get("check"), + action=b.get("action", "block"), + message=b.get("message"), + ) + for b in architecture_boundaries + ] + arch = policy_kwargs.get("architecture") or ArchitecturePolicy() + policy_kwargs["architecture"] = ArchitecturePolicy( + layer_boundaries=boundary_rules, + required_structure=arch.required_structure, + ) + + if required_structure is not None: + structs = [ + RequiredStructure(path=s["path"], description=s.get("description")) + for s in required_structure + ] + arch = policy_kwargs.get("architecture") or ArchitecturePolicy() + policy_kwargs["architecture"] = ArchitecturePolicy( + layer_boundaries=arch.layer_boundaries, + required_structure=structs, + ) + + if config_enforcement is not None: + policy_kwargs["config_enforcement"] = ConfigEnforcementPolicy( + **config_enforcement + ) + + policy = PolicyModel(**policy_kwargs) if policy_kwargs else None + + adr = MagicMock() + adr.id = adr_id + adr.title = title + adr.policy = policy + return adr + + +# --------------------------------------------------------------------------- +# classify_adr_checks +# --------------------------------------------------------------------------- + + +class TestClassifyAdrChecks: + def test_no_policy_produces_no_checks(self): + adr = _make_adr() + assert classify_adr_checks([adr]) == [] + + def test_imports_disallow_produces_commit_checks(self): + adr = _make_adr(imports_disallow=["flask", "django"]) + checks = classify_adr_checks([adr]) + assert len(checks) == 2 + for c in checks: + assert c.check_type == "import" + assert c.level == EnforcementLevel.COMMIT + assert c.adr_id == "ADR-0001" + + def test_python_disallow_produces_commit_checks(self): + adr = _make_adr(python_disallow=["requests"]) + checks = classify_adr_checks([adr]) + assert len(checks) == 1 + assert checks[0].check_type == "python_import" + assert checks[0].level == EnforcementLevel.COMMIT + assert checks[0].file_glob == "*.py" + + def test_pattern_rule_produces_commit_check(self): + adr = _make_adr( + patterns={ + "no_console": { + "description": "No console.log", + "rule": r"console\.log", + "language": "typescript", + } + } + ) + checks = classify_adr_checks([adr]) + assert len(checks) == 1 + assert checks[0].check_type == "pattern" + assert checks[0].level == EnforcementLevel.COMMIT + assert checks[0].file_glob == "*.typescript" + + def test_architecture_boundaries_produce_push_checks(self): + adr = _make_adr( + architecture_boundaries=[{"rule": "ui -> database", "action": "block"}] + ) + checks = classify_adr_checks([adr]) + arch_checks = [c for c in checks if c.check_type == "architecture"] + assert len(arch_checks) == 1 + assert arch_checks[0].level == EnforcementLevel.PUSH + + def test_required_structure_produces_ci_checks(self): + adr = _make_adr( + required_structure=[{"path": "docs/adr", "description": "ADR dir"}] + ) + checks = classify_adr_checks([adr]) + struct_checks = [c for c in checks if c.check_type == "required_structure"] + assert len(struct_checks) == 1 + assert struct_checks[0].level == EnforcementLevel.CI + + def test_config_enforcement_produces_ci_check(self): + adr = _make_adr(config_enforcement={}) + checks = classify_adr_checks([adr]) + config_checks = [c for c in checks if c.check_type == "config"] + assert len(config_checks) == 1 + assert config_checks[0].level == EnforcementLevel.CI + + def test_violation_message_includes_adr_id(self): + adr = _make_adr(imports_disallow=["flask"]) + checks = classify_adr_checks([adr]) + assert "ADR-0001" in checks[0].message + + def test_multiple_adrs_classified_independently(self): + adr1 = _make_adr("ADR-0001", imports_disallow=["flask"]) + adr2 = _make_adr("ADR-0002", python_disallow=["requests"]) + checks = classify_adr_checks([adr1, adr2]) + assert len(checks) == 2 + assert {c.adr_id for c in checks} == {"ADR-0001", "ADR-0002"} + + def test_architecture_block_action_maps_to_error_severity(self): + adr = _make_adr( + architecture_boundaries=[{"rule": "ui -> db", "action": "block"}] + ) + checks = classify_adr_checks([adr]) + assert checks[0].severity == "error" + + def test_architecture_warn_action_maps_to_warning_severity(self): + adr = _make_adr( + architecture_boundaries=[{"rule": "ui -> db", "action": "warn"}] + ) + checks = classify_adr_checks([adr]) + assert checks[0].severity == "warning" + + +# --------------------------------------------------------------------------- +# checks_for_level +# --------------------------------------------------------------------------- + + +class TestChecksForLevel: + def _make_check(self, level: EnforcementLevel) -> StagedCheck: + return StagedCheck( + adr_id="ADR-0001", + adr_title="T", + check_type="import", + level=level, + pattern="flask", + message="msg", + ) + + def test_commit_level_includes_only_commit_checks(self): + checks = [ + self._make_check(EnforcementLevel.COMMIT), + self._make_check(EnforcementLevel.PUSH), + self._make_check(EnforcementLevel.CI), + ] + result = checks_for_level(checks, EnforcementLevel.COMMIT) + assert len(result) == 1 + assert result[0].level == EnforcementLevel.COMMIT + + def test_push_level_includes_commit_and_push(self): + checks = [ + self._make_check(EnforcementLevel.COMMIT), + self._make_check(EnforcementLevel.PUSH), + self._make_check(EnforcementLevel.CI), + ] + result = checks_for_level(checks, EnforcementLevel.PUSH) + assert len(result) == 2 + levels = {c.level for c in result} + assert EnforcementLevel.CI not in levels + + def test_ci_level_includes_all_checks(self): + checks = [ + self._make_check(EnforcementLevel.COMMIT), + self._make_check(EnforcementLevel.PUSH), + self._make_check(EnforcementLevel.CI), + ] + result = checks_for_level(checks, EnforcementLevel.CI) + assert len(result) == 3 + + def test_empty_checks_returns_empty(self): + assert checks_for_level([], EnforcementLevel.COMMIT) == [] + + +# --------------------------------------------------------------------------- +# StagedValidator — file filtering and import detection +# --------------------------------------------------------------------------- + + +class TestStagedValidatorImportCheck: + def _run_ci_validate( + self, files: dict[str, str], imports_disallow: list[str] + ) -> ValidationResult: + """Helper: write files to temp dir, run CI-level validation.""" + with tempfile.TemporaryDirectory() as tmp: + root = Path(tmp) + adr_dir = root / "docs" / "adr" + adr_dir.mkdir(parents=True) + + for name, content in files.items(): + target = root / name + target.parent.mkdir(parents=True, exist_ok=True) + target.write_text(content) + + # Write a minimal accepted ADR with given policy + adr_content = f"""--- +id: ADR-0001 +title: Test +status: accepted +date: 2026-01-01 +policy: + imports: + disallow: {imports_disallow} +--- + +## Context +test +""" + (adr_dir / "ADR-0001-test.md").write_text(adr_content) + + validator = StagedValidator(adr_dir=adr_dir) + return validator.validate(EnforcementLevel.CI, project_root=root) + + def test_detects_js_import_violation(self): + result = self._run_ci_validate( + {"src/index.ts": "import { app } from 'flask'\nconsole.log(app)"}, + ["flask"], + ) + assert not result.passed + assert result.error_count == 1 + assert "flask" in result.violations[0].message + + def test_detects_require_violation(self): + result = self._run_ci_validate( + {"src/index.js": "const flask = require('flask')"}, + ["flask"], + ) + assert not result.passed + + def test_no_violation_when_import_absent(self): + result = self._run_ci_validate( + { + "src/index.ts": "import { something } from 'fastapi'\nconsole.log(something)" + }, + ["flask"], + ) + assert result.passed + assert result.error_count == 0 + + def test_python_import_only_checks_py_files(self): + with tempfile.TemporaryDirectory() as tmp: + root = Path(tmp) + adr_dir = root / "docs" / "adr" + adr_dir.mkdir(parents=True) + + # TS file with "requests" — should NOT trigger python_import check + (root / "index.ts").write_text("import requests from 'requests'") + # PY file with "requests" — SHOULD trigger + (root / "app.py").write_text( + "import requests\nrequests.get('http://example.com')" + ) + + adr_content = """--- +id: ADR-0001 +title: No requests +status: accepted +date: 2026-01-01 +policy: + python: + disallow_imports: [requests] +--- + +## Context +Use httpx instead. +""" + (adr_dir / "ADR-0001-no-requests.md").write_text(adr_content) + + validator = StagedValidator(adr_dir=adr_dir) + result = validator.validate(EnforcementLevel.CI, project_root=root) + + # Only app.py should be flagged + assert result.error_count == 1 + assert "app.py" in result.violations[0].file + + def test_no_adr_dir_returns_clean_result(self): + with tempfile.TemporaryDirectory() as tmp: + root = Path(tmp) + validator = StagedValidator(adr_dir=root / "nonexistent") + result = validator.validate(EnforcementLevel.CI, project_root=root) + assert result.passed + assert result.checks_run == 0 + + def test_result_metadata(self): + with tempfile.TemporaryDirectory() as tmp: + root = Path(tmp) + adr_dir = root / "docs" / "adr" + adr_dir.mkdir(parents=True) + (root / "app.py").write_text("import flask") + + adr_content = """--- +id: ADR-0001 +title: No flask +status: accepted +date: 2026-01-01 +policy: + python: + disallow_imports: [flask] +--- + +## Context +Use FastAPI. +""" + (adr_dir / "ADR-0001.md").write_text(adr_content) + validator = StagedValidator(adr_dir=adr_dir) + result = validator.validate(EnforcementLevel.CI, project_root=root) + + assert result.level == EnforcementLevel.CI + assert result.checks_run >= 1 + assert result.files_checked >= 1 + + def test_violation_includes_line_number(self): + with tempfile.TemporaryDirectory() as tmp: + root = Path(tmp) + adr_dir = root / "docs" / "adr" + adr_dir.mkdir(parents=True) + (root / "app.py").write_text("# line 1\nimport flask\n# line 3") + + adr_content = """--- +id: ADR-0001 +title: No flask +status: accepted +date: 2026-01-01 +policy: + python: + disallow_imports: [flask] +--- +""" + (adr_dir / "ADR-0001.md").write_text(adr_content) + validator = StagedValidator(adr_dir=adr_dir) + result = validator.validate(EnforcementLevel.CI, project_root=root) + + assert result.violations[0].line == 2 + + +class TestStagedValidatorPatternCheck: + def test_detects_pattern_violation(self): + with tempfile.TemporaryDirectory() as tmp: + root = Path(tmp) + adr_dir = root / "docs" / "adr" + adr_dir.mkdir(parents=True) + (root / "app.py").write_text("x = eval('dangerous')") + + adr_content = """--- +id: ADR-0001 +title: No eval +status: accepted +date: 2026-01-01 +policy: + patterns: + patterns: + no_eval: + description: No eval usage + rule: "\\\\beval\\\\(" + severity: error +--- +""" + (adr_dir / "ADR-0001.md").write_text(adr_content) + validator = StagedValidator(adr_dir=adr_dir) + result = validator.validate(EnforcementLevel.CI, project_root=root) + + assert result.error_count >= 1 + + def test_invalid_regex_pattern_skipped_gracefully(self): + """An invalid regex in an ADR policy should not crash the validator.""" + from adr_kit.enforce.stages import StagedCheck + from adr_kit.enforce.validator import StagedValidator + + check = StagedCheck( + adr_id="ADR-0001", + adr_title="T", + check_type="pattern", + level=EnforcementLevel.CI, + pattern="[invalid regex", + message="bad", + ) + validator = StagedValidator() + result = validator._run_pattern_check(check, [], Path(".")) + assert result == [] + + +class TestValidationResult: + def test_passed_true_when_no_violations(self): + result = ValidationResult( + level=EnforcementLevel.COMMIT, files_checked=5, checks_run=3 + ) + assert result.passed + + def test_passed_false_when_error_violation_present(self): + from adr_kit.enforce.validator import Violation + + result = ValidationResult( + level=EnforcementLevel.COMMIT, files_checked=5, checks_run=3 + ) + result.violations.append( + Violation( + file="x.py", + adr_id="ADR-0001", + message="m", + level=EnforcementLevel.COMMIT, + severity="error", + ) + ) + assert not result.passed + + def test_passed_true_with_only_warnings(self): + from adr_kit.enforce.validator import Violation + + result = ValidationResult( + level=EnforcementLevel.COMMIT, files_checked=5, checks_run=3 + ) + result.violations.append( + Violation( + file="x.py", + adr_id="ADR-0001", + message="m", + level=EnforcementLevel.COMMIT, + severity="warning", + ) + ) + assert result.passed + assert result.has_warnings + + def test_error_and_warning_counts(self): + from adr_kit.enforce.validator import Violation + + result = ValidationResult( + level=EnforcementLevel.CI, files_checked=10, checks_run=5 + ) + result.violations.append( + Violation( + file="a.py", + adr_id="ADR-0001", + message="e", + level=EnforcementLevel.CI, + severity="error", + ) + ) + result.violations.append( + Violation( + file="b.py", + adr_id="ADR-0001", + message="w", + level=EnforcementLevel.CI, + severity="warning", + ) + ) + assert result.error_count == 1 + assert result.warning_count == 1 From 7d38dba8627787e36436a07548266d31fa19dea1 Mon Sep 17 00:00:00 2001 From: kschlt Date: Mon, 23 Mar 2026 00:08:41 +0100 Subject: [PATCH 2/2] fix(types): resolve mypy errors in cli and validator Add threading import at module level for Thread return type annotation, and annotate adrs list with ADR type in _load_accepted_adrs. --- adr_kit/cli.py | 4 ++-- adr_kit/enforce/validator.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/adr_kit/cli.py b/adr_kit/cli.py index f72dd00..6deec11 100644 --- a/adr_kit/cli.py +++ b/adr_kit/cli.py @@ -8,6 +8,7 @@ """ import sys +import threading from pathlib import Path from typing import Annotated @@ -28,12 +29,11 @@ stderr_console = Console(stderr=True) -def check_for_updates_async() -> object: +def check_for_updates_async() -> threading.Thread: """Check for updates in the background and show notification if available. Returns the background thread so callers can join it if needed. """ - import threading def _check() -> None: try: diff --git a/adr_kit/enforce/validator.py b/adr_kit/enforce/validator.py index f8f6f6c..5917884 100644 --- a/adr_kit/enforce/validator.py +++ b/adr_kit/enforce/validator.py @@ -15,7 +15,7 @@ from dataclasses import dataclass, field from pathlib import Path -from ..core.model import ADRStatus +from ..core.model import ADR, ADRStatus from ..core.parse import ParseError, find_adr_files, parse_adr_file from .stages import EnforcementLevel, StagedCheck, checks_for_level, classify_adr_checks @@ -120,8 +120,8 @@ def validate( # --- ADR loading --- - def _load_accepted_adrs(self) -> list: - adrs = [] + def _load_accepted_adrs(self) -> list[ADR]: + adrs: list[ADR] = [] if not self.adr_dir.exists(): return adrs for file_path in find_adr_files(self.adr_dir):